From 1aa49b729ac5140a526dab250715c39c4f29c0e5 Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Mon, 17 May 2021 13:10:41 -0400 Subject: [PATCH 01/28] ARROW-11930: [C++][Dataset][Compute] Use an ExecPlan for dataset scans --- cpp/src/arrow/compute/exec.cc | 18 +- cpp/src/arrow/compute/exec.h | 2 + cpp/src/arrow/compute/exec/exec_plan.cc | 316 ++++++++++++++++-- cpp/src/arrow/compute/exec/exec_plan.h | 52 +-- cpp/src/arrow/compute/exec/expression.cc | 190 ++++++----- cpp/src/arrow/compute/exec/expression.h | 50 ++- cpp/src/arrow/compute/exec/expression_test.cc | 147 +++++--- cpp/src/arrow/compute/exec/plan_test.cc | 207 +++++------- cpp/src/arrow/compute/exec/test_util.cc | 169 +++------- cpp/src/arrow/compute/exec/test_util.h | 9 +- .../arrow/compute/kernels/codegen_internal.cc | 2 +- cpp/src/arrow/dataset/dataset.cc | 7 +- cpp/src/arrow/dataset/dataset.h | 1 - cpp/src/arrow/dataset/file_ipc_test.cc | 9 +- cpp/src/arrow/dataset/file_test.cc | 8 +- cpp/src/arrow/dataset/scanner.cc | 43 +-- cpp/src/arrow/dataset/scanner.h | 16 +- cpp/src/arrow/dataset/scanner_internal.h | 52 +-- cpp/src/arrow/dataset/test_util.h | 15 +- cpp/src/arrow/result.h | 6 +- cpp/src/arrow/util/async_generator.h | 173 +++++----- cpp/src/arrow/util/async_generator_test.cc | 2 +- cpp/src/arrow/util/future.h | 18 + cpp/src/arrow/util/iterator.h | 6 + 24 files changed, 892 insertions(+), 626 deletions(-) diff --git a/cpp/src/arrow/compute/exec.cc b/cpp/src/arrow/compute/exec.cc index 73cb82ef026..add6188ab48 100644 --- a/cpp/src/arrow/compute/exec.cc +++ b/cpp/src/arrow/compute/exec.cc @@ -69,6 +69,16 @@ ExecBatch::ExecBatch(const RecordBatch& batch) std::move(columns.begin(), columns.end(), values.begin()); } +ExecBatch ExecBatch::Slice(int64_t offset, int64_t length) const { + ExecBatch out = *this; + for (auto& value : out.values) { + if (value.is_scalar()) continue; + value = value.array()->Slice(offset, length); + } + out.length = length; + return out; +} + Result ExecBatch::Make(std::vector values) { if (values.empty()) { return Status::Invalid("Cannot infer ExecBatch length without at least one value"); @@ -77,9 +87,6 @@ Result ExecBatch::Make(std::vector values) { int64_t length = -1; for (const auto& value : values) { if (value.is_scalar()) { - if (length == -1) { - length = 1; - } continue; } @@ -94,8 +101,13 @@ Result ExecBatch::Make(std::vector values) { } } + if (length == -1) { + length = 1; + } + return ExecBatch(std::move(values), length); } + namespace { Result> AllocateDataBuffer(KernelContext* ctx, int64_t length, diff --git a/cpp/src/arrow/compute/exec.h b/cpp/src/arrow/compute/exec.h index cd95db2fd8c..d188d2d246d 100644 --- a/cpp/src/arrow/compute/exec.h +++ b/cpp/src/arrow/compute/exec.h @@ -206,6 +206,8 @@ struct ARROW_EXPORT ExecBatch { /// \brief A convenience for the number of values / arguments. int num_values() const { return static_cast(values.size()); } + ExecBatch Slice(int64_t offset, int64_t length) const; + /// \brief A convenience for returning the ValueDescr objects (types and /// shapes) from the batch. std::vector GetDescriptors() const { diff --git a/cpp/src/arrow/compute/exec/exec_plan.cc b/cpp/src/arrow/compute/exec/exec_plan.cc index f765ceccf0c..cd14bcc2d82 100644 --- a/cpp/src/arrow/compute/exec/exec_plan.cc +++ b/cpp/src/arrow/compute/exec/exec_plan.cc @@ -17,8 +17,12 @@ #include "arrow/compute/exec/exec_plan.h" +#include #include +#include "arrow/compute/api_vector.h" +#include "arrow/compute/exec.h" +#include "arrow/compute/exec/expression.h" #include "arrow/datum.h" #include "arrow/result.h" #include "arrow/util/checked_cast.h" @@ -81,7 +85,6 @@ struct ExecPlanImpl : public ExecPlan { struct TopoSort { const std::vector>& nodes; std::unordered_set visited; - std::unordered_set visiting; NodeVector sorted; explicit TopoSort(const std::vector>& nodes) @@ -96,7 +99,6 @@ struct ExecPlanImpl : public ExecPlan { } DCHECK_EQ(sorted.size(), nodes.size()); DCHECK_EQ(visited.size(), nodes.size()); - DCHECK_EQ(visiting.size(), 0); return Status::OK(); } @@ -105,18 +107,11 @@ struct ExecPlanImpl : public ExecPlan { return Status::OK(); } - auto it_success = visiting.insert(node); - if (!it_success.second) { - // Insertion failed => node is already being visited - return Status::Invalid("Cycle detected in execution plan"); - } - for (auto input : node->inputs()) { // Ensure that producers are inserted before this consumer RETURN_NOT_OK(Visit(input)); } - visiting.erase(it_success.first); visited.insert(node); sorted.push_back(node); return Status::OK(); @@ -170,21 +165,24 @@ Status ExecPlan::Validate() { return ToDerived(this)->Validate(); } Status ExecPlan::StartProducing() { return ToDerived(this)->StartProducing(); } -ExecNode::ExecNode(ExecPlan* plan, std::string label, - std::vector input_descrs, +ExecNode::ExecNode(ExecPlan* plan, std::string label, NodeVector inputs, std::vector input_labels, BatchDescr output_descr, int num_outputs) : plan_(plan), label_(std::move(label)), - input_descrs_(std::move(input_descrs)), + inputs_(std::move(inputs)), input_labels_(std::move(input_labels)), output_descr_(std::move(output_descr)), - num_outputs_(num_outputs) {} + num_outputs_(num_outputs) { + for (auto input : inputs_) { + input->outputs_.push_back(this); + } +} Status ExecNode::Validate() const { - if (inputs_.size() != input_descrs_.size()) { + if (inputs_.size() != input_labels_.size()) { return Status::Invalid("Invalid number of inputs for '", label(), "' (expected ", - num_inputs(), ", actual ", inputs_.size(), ")"); + num_inputs(), ", actual ", input_labels_.size(), ")"); } if (static_cast(outputs_.size()) != num_outputs_) { @@ -192,27 +190,295 @@ Status ExecNode::Validate() const { num_outputs(), ", actual ", outputs_.size(), ")"); } - DCHECK_EQ(input_descrs_.size(), input_labels_.size()); - for (auto out : outputs_) { auto input_index = GetNodeIndex(out->inputs(), this); if (!input_index) { return Status::Invalid("Node '", label(), "' outputs to node '", out->label(), "' but is not listed as an input."); } + } + + return Status::OK(); +} + +struct GeneratorNode : ExecNode { + GeneratorNode(ExecPlan* plan, std::string label, ExecNode::BatchDescr output_descr, + AsyncGenerator> generator) + : ExecNode(plan, std::move(label), {}, {}, std::move(output_descr), + /*num_outputs=*/1), + generator_(std::move(generator)) {} + + const char* kind_name() override { return "GeneratorNode"; } + + void InputReceived(ExecNode*, int, compute::ExecBatch) override { DCHECK(false); } + void ErrorReceived(ExecNode*, Status) override { DCHECK(false); } + void InputFinished(ExecNode*, int) override { DCHECK(false); } - const auto& in_descr = out->input_descrs_[*input_index]; - if (in_descr != output_descr_) { - return Status::Invalid( - "Node '", label(), "' (bound to input ", input_labels_[*input_index], - ") produces batches with type '", ValueDescr::ToString(output_descr_), - "' inconsistent with consumer '", out->label(), "' which accepts '", - ValueDescr::ToString(in_descr), "'"); + Status StartProducing() override { + if (!generator_) { + return Status::Invalid("Restarted GeneratorNode '", label(), "'"); } + GenerateOne(std::unique_lock{mutex_}); + return Status::OK(); } - return Status::OK(); + void PauseProducing(ExecNode* output) override {} + + void ResumeProducing(ExecNode* output) override {} + + void StopProducing(ExecNode* output) override { + DCHECK_EQ(output, outputs_[0]); + std::unique_lock lock(mutex_); + generator_ = nullptr; // null function + } + + void StopProducing() override { StopProducing(outputs_[0]); } + + private: + void GenerateOne(std::unique_lock&& lock) { + if (!generator_) { + // Stopped + return; + } + + auto fut = generator_(); + const auto batch_index = next_batch_index_++; + + lock.unlock(); + fut.AddCallback([batch_index, this](const Result>& res) { + std::unique_lock lock(mutex_); + if (!res.ok()) { + for (auto out : outputs_) { + out->ErrorReceived(this, res.status()); + } + return; + } + + const auto& batch = *res; + if (IsIterationEnd(batch)) { + lock.unlock(); + for (auto out : outputs_) { + out->InputFinished(this, batch_index); + } + return; + } + + lock.unlock(); + for (auto out : outputs_) { + out->InputReceived(this, batch_index, compute::ExecBatch(*batch)); + } + lock.lock(); + + GenerateOne(std::move(lock)); + }); + } + + std::mutex mutex_; + AsyncGenerator> generator_; + int next_batch_index_ = 0; +}; + +ExecNode* MakeSourceNode(ExecPlan* plan, std::string label, + ExecNode::BatchDescr output_descr, + AsyncGenerator> generator) { + return plan->EmplaceNode(plan, std::move(label), std::move(output_descr), + std::move(generator)); } +struct FilterNode : ExecNode { + FilterNode(ExecNode* input, std::string label, Expression filter) + : ExecNode(input->plan(), std::move(label), {input}, {"target"}, + /*output_descr=*/{input->output_descr()}, + /*num_outputs=*/1), + filter_(std::move(filter)) {} + + const char* kind_name() override { return "FilterNode"; } + + Result DoFilter(const ExecBatch& target) { + // XXX get a non-default exec context + ARROW_ASSIGN_OR_RAISE(Datum mask, ExecuteScalarExpression(filter_, target)); + + if (mask.is_scalar()) { + const auto& mask_scalar = mask.scalar_as(); + if (mask_scalar.is_valid && mask_scalar.value) { + return target; + } + + return target.Slice(0, 0); + } + + auto values = target.values; + for (auto& value : values) { + if (value.is_scalar()) continue; + ARROW_ASSIGN_OR_RAISE(value, Filter(value, mask, FilterOptions::Defaults())); + } + return ExecBatch::Make(std::move(values)); + } + + void InputReceived(ExecNode* input, int seq, ExecBatch batch) override { + DCHECK_EQ(input, inputs_[0]); + + auto maybe_filtered = DoFilter(std::move(batch)); + if (!maybe_filtered.ok()) { + outputs_[0]->ErrorReceived(this, maybe_filtered.status()); + inputs_[0]->StopProducing(this); + return; + } + + outputs_[0]->InputReceived(this, seq, maybe_filtered.MoveValueUnsafe()); + } + + void ErrorReceived(ExecNode* input, Status error) override { + DCHECK_EQ(input, inputs_[0]); + outputs_[0]->ErrorReceived(this, std::move(error)); + inputs_[0]->StopProducing(this); + } + + void InputFinished(ExecNode* input, int seq) override { + DCHECK_EQ(input, inputs_[0]); + outputs_[0]->InputFinished(this, seq); + inputs_[0]->StopProducing(this); + } + + Status StartProducing() override { return Status::OK(); } + + void PauseProducing(ExecNode* output) override {} + + void ResumeProducing(ExecNode* output) override {} + + void StopProducing(ExecNode* output) override { + DCHECK_EQ(output, outputs_[0]); + inputs_[0]->StopProducing(this); + } + + void StopProducing() override { StopProducing(outputs_[0]); } + + private: + Expression filter_; +}; + +ExecNode* MakeFilterNode(ExecNode* input, std::string label, Expression filter) { + return input->plan()->EmplaceNode(input, std::move(label), + std::move(filter)); +} + +struct SinkNode : ExecNode { + SinkNode(ExecNode* input, std::string label, + AsyncGenerator>* generator) + : ExecNode(input->plan(), std::move(label), {input}, {"collected"}, {}, + /*num_outputs=*/0), + producer_(MakeProducer(generator)) {} + + static PushGenerator>::Producer MakeProducer( + AsyncGenerator>* out_gen) { + PushGenerator> gen; + auto out = gen.producer(); + *out_gen = std::move(gen); + return out; + } + + const char* kind_name() override { return "SinkNode"; } + + Status StartProducing() override { return Status::OK(); } + + // sink nodes have no outputs from which to feel backpressure + static void NoOutputs() { DCHECK(false) << "no outputs; this should never be called"; } + void ResumeProducing(ExecNode* output) override { NoOutputs(); } + void PauseProducing(ExecNode* output) override { NoOutputs(); } + void StopProducing(ExecNode* output) override { NoOutputs(); } + + void StopProducing() override { + std::unique_lock lock(mutex_); + StopProducingUnlocked(); + } + + void InputReceived(ExecNode* input, int seq_num, ExecBatch exec_batch) override { + std::unique_lock lock(mutex_); + if (stopped_) return; + + // TODO would be nice to factor this out in a ReorderQueue + if (seq_num <= static_cast(received_batches_.size())) { + received_batches_.resize(seq_num + 1); + emitted_.resize(seq_num + 1, false); + } + received_batches_[seq_num] = std::move(exec_batch); + ++num_received_; + + if (seq_num != num_emitted_) { + // Cannot emit yet as there is a hole at `num_emitted_` + DCHECK_GT(seq_num, num_emitted_); + return; + } + + if (num_received_ == emit_stop_) { + StopProducingUnlocked(); + } + + // Emit batches in order as far as possible + // First collect these batches, then unlock before producing. + const auto seq_start = seq_num; + while (seq_num < static_cast(emitted_.size()) && !emitted_[seq_num]) { + emitted_[seq_num] = true; + ++seq_num; + } + DCHECK_GT(seq_num, seq_start); + // By moving the values now, we make sure another thread won't emit the same values + // below + std::vector to_emit( + std::make_move_iterator(received_batches_.begin() + seq_start), + std::make_move_iterator(received_batches_.begin() + seq_num)); + + lock.unlock(); + for (auto&& batch : to_emit) { + producer_.Push(std::move(batch)); + } + lock.lock(); + + DCHECK_EQ(seq_start, num_emitted_); // num_emitted_ wasn't bumped in the meantime + num_emitted_ = seq_num; + } + + void ErrorReceived(ExecNode* input, Status error) override { + // XXX do we care about properly sequencing the error? + producer_.Push(std::move(error)); + std::unique_lock lock(mutex_); + StopProducingUnlocked(); + } + + void InputFinished(ExecNode* input, int seq_stop) override { + std::unique_lock lock(mutex_); + DCHECK_GE(seq_stop, static_cast(received_batches_.size())); + received_batches_.reserve(seq_stop); + emit_stop_ = seq_stop; + if (emit_stop_ == num_received_) { + DCHECK_EQ(emit_stop_, num_emitted_); + StopProducingUnlocked(); + } + } + + private: + void StopProducingUnlocked() { + if (!stopped_) { + stopped_ = true; + producer_.Close(); + inputs_[0]->StopProducing(this); + } + } + + std::mutex mutex_; + std::vector received_batches_; + std::vector emitted_; + + int num_received_ = 0; + int num_emitted_ = 0; + int emit_stop_ = -1; + bool stopped_ = false; + + PushGenerator>::Producer producer_; +}; + +AsyncGenerator> MakeSinkNode(ExecNode* input, + std::string label); + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/exec/exec_plan.h b/cpp/src/arrow/compute/exec/exec_plan.h index 0d2faea0ddc..eabc34d6d04 100644 --- a/cpp/src/arrow/compute/exec/exec_plan.h +++ b/cpp/src/arrow/compute/exec/exec_plan.h @@ -23,6 +23,7 @@ #include "arrow/compute/type_fwd.h" #include "arrow/type_fwd.h" +#include "arrow/util/async_generator.h" #include "arrow/util/macros.h" #include "arrow/util/visibility.h" @@ -48,8 +49,11 @@ class ARROW_EXPORT ExecPlan : public std::enable_shared_from_this { ExecNode* AddNode(std::unique_ptr node); template - ExecNode* EmplaceNode(Args&&... args) { - return AddNode(std::unique_ptr(new Node{std::forward(args)...})); + Node* EmplaceNode(Args&&... args) { + auto node = std::unique_ptr(new Node{std::forward(args)...}); + auto out = node.get(); + AddNode(std::move(node)); + return out; } /// The initial inputs @@ -58,15 +62,6 @@ class ARROW_EXPORT ExecPlan : public std::enable_shared_from_this { /// The final outputs const NodeVector& sinks() const; - // XXX API question: - // There are clearly two phases in the ExecPlan lifecycle: - // - one construction phase where AddNode() and ExecNode::AddInput() is called - // (with optional validation at the end) - // - one execution phase where the nodes are topo-sorted and then started - // - // => Should we separate out those APIs? e.g. have a ExecPlanBuilder - // for the first phase. - Status Validate(); /// Start producing on all nodes @@ -91,15 +86,12 @@ class ARROW_EXPORT ExecNode { virtual const char* kind_name() = 0; // The number of inputs/outputs expected by this node - int num_inputs() const { return static_cast(input_descrs_.size()); } + int num_inputs() const { return static_cast(inputs_.size()); } int num_outputs() const { return num_outputs_; } /// This node's predecessors in the exec plan const NodeVector& inputs() const { return inputs_; } - /// The datatypes accepted by this node for each input - const std::vector& input_descrs() const { return input_descrs_; } - /// \brief Labels identifying the function of each input. /// /// For example, FilterNode accepts "target" and "filter" inputs. @@ -119,11 +111,6 @@ class ARROW_EXPORT ExecNode { /// There is no guarantee that this value is non-empty or unique. const std::string& label() const { return label_; } - void AddInput(ExecNode* input) { - inputs_.push_back(input); - input->outputs_.push_back(this); - } - Status Validate() const; /// Upstream API: @@ -139,7 +126,7 @@ class ARROW_EXPORT ExecNode { /// and StopProducing() /// Transfer input batch to ExecNode - virtual void InputReceived(ExecNode* input, int seq_num, compute::ExecBatch batch) = 0; + virtual void InputReceived(ExecNode* input, int seq_num, ExecBatch batch) = 0; /// Signal error to ExecNode virtual void ErrorReceived(ExecNode* input, Status error) = 0; @@ -225,22 +212,37 @@ class ARROW_EXPORT ExecNode { virtual void StopProducing() = 0; protected: - ExecNode(ExecPlan* plan, std::string label, std::vector input_descrs, + ExecNode(ExecPlan* plan, std::string label, NodeVector inputs, std::vector input_labels, BatchDescr output_descr, int num_outputs); ExecPlan* plan_; - std::string label_; - std::vector input_descrs_; - std::vector input_labels_; NodeVector inputs_; + std::vector input_labels_; BatchDescr output_descr_; int num_outputs_; NodeVector outputs_; }; +/// \brief Adapt an AsyncGenerator as a source node +ARROW_EXPORT +ExecNode* MakeSourceNode(ExecPlan* plan, std::string label, + ExecNode::BatchDescr output_descr, + AsyncGenerator>); + +/// \brief Add a sink node which forwards to an AsyncGenerator +ARROW_EXPORT +AsyncGenerator> MakeSinkNode(ExecNode* input, + std::string label); + +/// \brief Make a node which excludes some rows from batches passed through it +/// +/// filter Expression must be bound; no field references will be looked up by name +ARROW_EXPORT +ExecNode* MakeFilterNode(ExecNode* input, std::string label, Expression filter); + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/exec/expression.cc b/cpp/src/arrow/compute/exec/expression.cc index aeabbf7bc5b..19a311a5d39 100644 --- a/cpp/src/arrow/compute/exec/expression.cc +++ b/cpp/src/arrow/compute/exec/expression.cc @@ -53,7 +53,7 @@ Expression::Expression(Parameter parameter) Expression literal(Datum lit) { return Expression(std::move(lit)); } Expression field_ref(FieldRef ref) { - return Expression(Expression::Parameter{std::move(ref), {}}); + return Expression(Expression::Parameter{std::move(ref), {}, {}}); } Expression call(std::string function, std::vector arguments, @@ -62,13 +62,22 @@ Expression call(std::string function, std::vector arguments, call.function_name = std::move(function); call.arguments = std::move(arguments); call.options = std::move(options); + + call.hash = std::hash{}(call.function_name); + for (const auto& arg : call.arguments) { + call.hash ^= arg.hash(); + } return Expression(std::move(call)); } const Datum* Expression::literal() const { return util::get_if(impl_.get()); } +const Expression::Parameter* Expression::parameter() const { + return util::get_if(impl_.get()); +} + const FieldRef* Expression::field_ref() const { - if (auto parameter = util::get_if(impl_.get())) { + if (auto parameter = this->parameter()) { return ¶meter->ref; } return nullptr; @@ -85,7 +94,7 @@ ValueDescr Expression::descr() const { return lit->descr(); } - if (auto parameter = util::get_if(impl_.get())) { + if (auto parameter = this->parameter()) { return parameter->descr; } @@ -235,21 +244,7 @@ size_t Expression::hash() const { return ref->hash(); } - auto call = CallNotNull(*this); - if (call->hash != nullptr) { - return call->hash->load(); - } - - size_t out = std::hash{}(call->function_name); - for (const auto& arg : call->arguments) { - out ^= arg.hash(); - } - - std::shared_ptr> expected = nullptr; - ::arrow::internal::atomic_compare_exchange_strong( - &const_cast(call)->hash, &expected, - std::make_shared>(out)); - return out; + return CallNotNull(*this)->hash; } bool Expression::IsBound() const { @@ -383,45 +378,6 @@ Result BindNonRecursive(Expression::Call call, bool insert_implicit_ return Expression(std::move(call)); } -struct FieldPathGetDatumImpl { - template ()))> - Result operator()(const std::shared_ptr& ptr) { - return path_.Get(*ptr).template As(); - } - - template - Result operator()(const T&) { - return Status::NotImplemented("FieldPath::Get() into Datum ", datum_.ToString()); - } - - const Datum& datum_; - const FieldPath& path_; -}; - -inline Result GetDatumField(const FieldRef& ref, const Datum& input) { - Datum field; - - FieldPath match; - if (auto type = input.type()) { - ARROW_ASSIGN_OR_RAISE(match, ref.FindOneOrNone(*type)); - } else if (auto schema = input.schema()) { - ARROW_ASSIGN_OR_RAISE(match, ref.FindOneOrNone(*schema)); - } else { - return Status::NotImplemented("retrieving fields from datum ", input.ToString()); - } - - if (!match.empty()) { - ARROW_ASSIGN_OR_RAISE(field, - util::visit(FieldPathGetDatumImpl{input, match}, input.value)); - } - - if (field == Datum{}) { - return Datum(std::make_shared()); - } - - return field; -} - } // namespace Result Expression::Bind(ValueDescr in, @@ -434,9 +390,18 @@ Result Expression::Bind(ValueDescr in, if (literal()) return *this; if (auto ref = field_ref()) { - ARROW_ASSIGN_OR_RAISE(auto field, ref->GetOneOrNone(*in.type)); - auto descr = field ? ValueDescr{field->type(), in.shape} : ValueDescr::Scalar(null()); - return Expression{Parameter{*ref, std::move(descr)}}; + if (ref->IsNested()) { + return Status::NotImplemented("nested field references"); + } + + ARROW_ASSIGN_OR_RAISE(auto path, ref->FindOne(*in.type)); + + auto bound = *parameter(); + bound.index = path[0]; + ARROW_ASSIGN_OR_RAISE(auto field, path.Get(*in.type)); + bound.descr.type = field->type(); + bound.descr.shape = in.shape; + return Expression{std::move(bound)}; } auto call = *CallNotNull(*this); @@ -452,7 +417,67 @@ Result Expression::Bind(const Schema& in_schema, return Bind(ValueDescr::Array(struct_(in_schema.fields())), exec_context); } -Result ExecuteScalarExpression(const Expression& expr, const Datum& input, +Result MakeExecBatch(const Schema& full_schema, const Datum& partial) { + ExecBatch out; + + if (partial.kind() == Datum::RECORD_BATCH) { + const auto& partial_batch = *partial.record_batch(); + out.length = partial_batch.num_rows(); + + for (const auto& field : full_schema.fields()) { + ARROW_ASSIGN_OR_RAISE(auto column, + FieldRef(field->name()).GetOneOrNone(partial_batch)); + + if (column) { + if (!column->type()->Equals(field->type())) { + // Referenced field was present but didn't have the expected type. + // This *should* be handled by readers, and will just be an error in the future. + ARROW_ASSIGN_OR_RAISE( + auto converted, + compute::Cast(column, field->type(), compute::CastOptions::Safe())); + column = converted.make_array(); + } + out.values.emplace_back(std::move(column)); + } else { + out.values.emplace_back(MakeNullScalar(field->type())); + } + } + return out; + } + + // wasteful but useful for testing: + if (partial.type()->id() == Type::STRUCT) { + if (partial.is_array()) { + ARROW_ASSIGN_OR_RAISE(auto partial_batch, + RecordBatch::FromStructArray(partial.make_array())); + + return MakeExecBatch(full_schema, partial_batch); + } + + if (partial.is_scalar()) { + ARROW_ASSIGN_OR_RAISE(auto partial_array, + MakeArrayFromScalar(*partial.scalar(), 1)); + ARROW_ASSIGN_OR_RAISE(auto out, MakeExecBatch(full_schema, partial_array)); + + for (Datum& value : out.values) { + if (value.is_scalar()) continue; + ARROW_ASSIGN_OR_RAISE(value, value.make_array()->GetScalar(0)); + } + return out; + } + } + + return Status::NotImplemented("MakeExecBatch from ", PrintDatum(partial)); +} + +Result ExecuteScalarExpression(const Expression& expr, const Schema& full_schema, + const Datum& partial_input, + compute::ExecContext* exec_context) { + ARROW_ASSIGN_OR_RAISE(auto input, MakeExecBatch(full_schema, partial_input)); + return ExecuteScalarExpression(expr, input, exec_context); +} + +Result ExecuteScalarExpression(const Expression& expr, const ExecBatch& input, compute::ExecContext* exec_context) { if (exec_context == nullptr) { compute::ExecContext exec_context; @@ -470,15 +495,16 @@ Result ExecuteScalarExpression(const Expression& expr, const Datum& input if (auto lit = expr.literal()) return *lit; - if (auto ref = expr.field_ref()) { - ARROW_ASSIGN_OR_RAISE(Datum field, GetDatumField(*ref, input)); + if (auto param = expr.parameter()) { + if (param->descr.type->id() == Type::NA) { + return MakeNullScalar(null()); + } - if (field.descr() != expr.descr()) { - // Refernced field was present but didn't have the expected type. - // Should we just error here? For now, pay dispatch cost and just cast. - ARROW_ASSIGN_OR_RAISE( - field, - compute::Cast(field, expr.type(), compute::CastOptions::Safe(), exec_context)); + const Datum& field = input[param->index]; + if (!field.type()->Equals(param->descr.type)) { + return Status::Invalid("Referenced field ", expr.ToString(), " was ", + field.type()->ToString(), " but should have been ", + param->descr.type->ToString()); } return field; @@ -555,6 +581,22 @@ std::vector FieldsInExpression(const Expression& expr) { return fields; } +std::vector ParametersInExpression(const Expression& expr) { + if (expr.literal()) return {}; + + if (auto parameter = expr.parameter()) { + return {parameter->index}; + } + + std::vector indices; + for (const Expression& arg : CallNotNull(expr)->arguments) { + auto argument_indices = ParametersInExpression(arg); + std::move(argument_indices.begin(), argument_indices.end(), + std::back_inserter(indices)); + } + return indices; +} + bool ExpressionHasFieldRefs(const Expression& expr) { if (expr.literal()) return false; @@ -574,7 +616,7 @@ Result FoldConstants(Expression expr) { if (std::all_of(call->arguments.begin(), call->arguments.end(), [](const Expression& argument) { return argument.literal(); })) { // all arguments are literal; we can evaluate this subexpression *now* - static const Datum ignored_input = Datum{}; + static const ExecBatch ignored_input = ExecBatch{}; ARROW_ASSIGN_OR_RAISE(Datum constant, ExecuteScalarExpression(expr, ignored_input)); @@ -1144,13 +1186,5 @@ Expression or_(const std::vector& operands) { Expression not_(Expression operand) { return call("invert", {std::move(operand)}); } -Expression operator&&(Expression lhs, Expression rhs) { - return and_(std::move(lhs), std::move(rhs)); -} - -Expression operator||(Expression lhs, Expression rhs) { - return or_(std::move(lhs), std::move(rhs)); -} - } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/exec/expression.h b/cpp/src/arrow/compute/exec/expression.h index f5ca2c2118d..eb1dcfd3091 100644 --- a/cpp/src/arrow/compute/exec/expression.h +++ b/cpp/src/arrow/compute/exec/expression.h @@ -44,13 +44,13 @@ class ARROW_EXPORT Expression { struct Call { std::string function_name; std::vector arguments; - std::shared_ptr options; - std::shared_ptr> hash; + std::shared_ptr options; + size_t hash; // post-Bind properties: - std::shared_ptr function; - const compute::Kernel* kernel = NULLPTR; - std::shared_ptr kernel_state; + std::shared_ptr function; + const Kernel* kernel = NULLPTR; + std::shared_ptr kernel_state; ValueDescr descr; }; @@ -64,8 +64,11 @@ class ARROW_EXPORT Expression { /// Bind this expression to the given input type, looking up Kernels and field types. /// Some expression simplification may be performed and implicit casts will be inserted. /// Any state necessary for execution will be initialized and returned. - Result Bind(ValueDescr in, compute::ExecContext* = NULLPTR) const; - Result Bind(const Schema& in_schema, compute::ExecContext* = NULLPTR) const; + Result Bind(ValueDescr in, ExecContext* = NULLPTR) const; + Result Bind(const Schema& in_schema, ExecContext* = NULLPTR) const; + + Result BindFlattened(ValueDescr in, ExecContext* = NULLPTR) const; + Result BindFlattened(const Schema& in_schema, ExecContext* = NULLPTR) const; // XXX someday // Clone all KernelState in this bound expression. If any function referenced by this @@ -108,8 +111,12 @@ class ARROW_EXPORT Expression { struct Parameter { FieldRef ref; + + // post-bind properties ValueDescr descr; + int index; }; + const Parameter* parameter() const; Expression() = default; explicit Expression(Call call); @@ -143,10 +150,10 @@ Expression field_ref(FieldRef ref); ARROW_EXPORT Expression call(std::string function, std::vector arguments, - std::shared_ptr options = NULLPTR); + std::shared_ptr options = NULLPTR); -template ::value>::type> +template ::value>::type> Expression call(std::string function, std::vector arguments, Options options) { return call(std::move(function), std::move(arguments), @@ -157,6 +164,10 @@ Expression call(std::string function, std::vector arguments, ARROW_EXPORT std::vector FieldsInExpression(const Expression&); +/// Assemble parameter indices referenced by an Expression at any depth. +ARROW_EXPORT +std::vector ParametersInExpression(const Expression&); + /// Check if the expression references any fields. ARROW_EXPORT bool ExpressionHasFieldRefs(const Expression&); @@ -182,7 +193,7 @@ Result> ExtractKnownFieldVal /// equivalent Expressions may result in different canonicalized expressions. /// TODO this could be a strong canonicalization ARROW_EXPORT -Result Canonicalize(Expression, compute::ExecContext* = NULLPTR); +Result Canonicalize(Expression, ExecContext* = NULLPTR); /// Simplify Expressions based on literal arguments (for example, add(null, x) will always /// be null so replace the call with a null literal). Includes early evaluation of all @@ -207,11 +218,22 @@ Result SimplifyWithGuarantee(Expression, // Execution -/// Execute a scalar expression against the provided state and input Datum. This +/// Ensure that a RecordBatch (which may have missing or incorrectly ordered columns) +/// precisely matches the schema. This is necessary when executing Expressions +/// since we look up fields by index. Missing fields will be replaced with null scalars. +ARROW_EXPORT Result MakeExecBatch(const Schema& full_schema, + const Datum& partial); + +/// Execute a scalar expression against the provided state and input ExecBatch. This /// expression must be bound. ARROW_EXPORT -Result ExecuteScalarExpression(const Expression&, const Datum& input, - compute::ExecContext* = NULLPTR); +Result ExecuteScalarExpression(const Expression&, const ExecBatch& input, + ExecContext* = NULLPTR); + +/// Convenience function for invoking against a RecordBatch +ARROW_EXPORT +Result ExecuteScalarExpression(const Expression&, const Schema& full_schema, + const Datum& partial_input, ExecContext* = NULLPTR); // Serialization diff --git a/cpp/src/arrow/compute/exec/expression_test.cc b/cpp/src/arrow/compute/exec/expression_test.cc index 908e8962e43..c4c2dd1c951 100644 --- a/cpp/src/arrow/compute/exec/expression_test.cc +++ b/cpp/src/arrow/compute/exec/expression_test.cc @@ -166,6 +166,56 @@ TEST(ExpressionUtils, StripOrderPreservingCasts) { Expect(cast(field_ref("i32"), uint64()), no_change); } +TEST(ExpressionUtils, MakeExecBatch) { + auto Expect = [](std::shared_ptr partial_batch) { + SCOPED_TRACE(partial_batch->ToString()); + ASSERT_OK_AND_ASSIGN(auto batch, MakeExecBatch(*kBoringSchema, partial_batch)); + + ASSERT_EQ(batch.num_values(), kBoringSchema->num_fields()); + for (int i = 0; i < kBoringSchema->num_fields(); ++i) { + const auto& field = *kBoringSchema->field(i); + + SCOPED_TRACE("Field#" + std::to_string(i) + " " + field.ToString()); + + EXPECT_TRUE(batch[i].type()->Equals(field.type())) + << "Incorrect type " << batch[i].type()->ToString(); + + ASSERT_OK_AND_ASSIGN(auto col, FieldRef(field.name()).GetOneOrNone(*partial_batch)); + + if (batch[i].is_scalar()) { + EXPECT_FALSE(batch[i].scalar()->is_valid) + << "Non-null placeholder scalar was injected"; + + EXPECT_EQ(col, nullptr) + << "Placeholder scalar overwrote column " << col->ToString(); + } else { + AssertDatumsEqual(col, batch[i]); + } + } + }; + + auto GetField = [](std::string name) { return kBoringSchema->GetFieldByName(name); }; + + constexpr int64_t kNumRows = 3; + auto i32 = ArrayFromJSON(int32(), "[1, 2, 3]"); + auto f32 = ArrayFromJSON(float32(), "[1.5, 2.25, 3.125]"); + + // empty + Expect(RecordBatchFromJSON(kBoringSchema, "[]")); + + // subset + Expect(RecordBatch::Make(schema({GetField("i32"), GetField("f32")}), kNumRows, + {i32, f32})); + + // flipped subset + Expect(RecordBatch::Make(schema({GetField("f32"), GetField("i32")}), kNumRows, + {f32, i32})); + + auto duplicated_names = + RecordBatch::Make(schema({GetField("i32"), GetField("i32")}), kNumRows, {i32, i32}); + ASSERT_RAISES(Invalid, MakeExecBatch(*kBoringSchema, duplicated_names)); +} + TEST(Expression, ToString) { EXPECT_EQ(field_ref("alpha").ToString(), "alpha"); @@ -445,21 +495,18 @@ TEST(Expression, BindFieldRef) { ExpectBindsTo(field_ref("i32"), no_change, &expr); EXPECT_EQ(expr.descr(), ValueDescr::Array(int32())); - // if the field is not found, a null scalar will be emitted - ExpectBindsTo(field_ref("no such field"), no_change, &expr); - EXPECT_EQ(expr.descr(), ValueDescr::Scalar(null())); + // if the field is not found, an error will be raised + ASSERT_RAISES(Invalid, field_ref("no such field").Bind(*kBoringSchema)); // referencing a field by name is not supported if that name is not unique // in the input schema ASSERT_RAISES(Invalid, field_ref("alpha").Bind(Schema( {field("alpha", int32()), field("alpha", float32())}))); - // referencing nested fields is supported - ASSERT_OK_AND_ASSIGN(expr, - field_ref(FieldRef("a", "b")) - .Bind(Schema({field("a", struct_({field("b", int32())}))}))); - EXPECT_TRUE(expr.IsBound()); - EXPECT_EQ(expr.descr(), ValueDescr::Array(int32())); + // referencing nested fields is not supported + ASSERT_RAISES(NotImplemented, + field_ref(FieldRef("a", "b")) + .Bind(Schema({field("a", struct_({field("b", int32())}))}))); } TEST(Expression, BindCall) { @@ -525,7 +572,8 @@ TEST(Expression, ExecuteFieldRef) { auto expr = field_ref(ref); ASSERT_OK_AND_ASSIGN(expr, expr.Bind(in.descr())); - ASSERT_OK_AND_ASSIGN(Datum actual, ExecuteScalarExpression(expr, in)); + ASSERT_OK_AND_ASSIGN(Datum actual, + ExecuteScalarExpression(expr, Schema(in.type()->fields()), in)); AssertDatumsEqual(actual, expected, /*verbose=*/true); }; @@ -537,39 +585,45 @@ TEST(Expression, ExecuteFieldRef) { ])"), ArrayFromJSON(float64(), R"([6.125, 0.0, -1])")); - // more nested: - ExpectRefIs(FieldRef{"a", "a"}, - ArrayFromJSON(struct_({field("a", struct_({field("a", float64())}))}), R"([ - {"a": {"a": 6.125}}, - {"a": {"a": 0.0}}, - {"a": {"a": -1}} + ExpectRefIs("a", + ArrayFromJSON(struct_({ + field("a", float64()), + field("b", float64()), + }), + R"([ + {"a": 6.125, "b": 7.5}, + {"a": 0.0, "b": 2.125}, + {"a": -1, "b": 4.0} ])"), ArrayFromJSON(float64(), R"([6.125, 0.0, -1])")); - // absent fields are resolved as a null scalar: - ExpectRefIs(FieldRef{"b"}, ArrayFromJSON(struct_({field("a", float64())}), R"([ - {"a": 6.125}, - {"a": 0.0}, - {"a": -1} + ExpectRefIs("b", + ArrayFromJSON(struct_({ + field("a", float64()), + field("b", float64()), + }), + R"([ + {"a": 6.125, "b": 7.5}, + {"a": 0.0, "b": 2.125}, + {"a": -1, "b": 4.0} ])"), - MakeNullScalar(null())); - - // XXX this *should* fail in Bind but for now it will just error in - // ExecuteScalarExpression - ASSERT_OK_AND_ASSIGN(auto list_item, field_ref("item").Bind(list(int32()))); - EXPECT_RAISES_WITH_MESSAGE_THAT( - NotImplemented, HasSubstr("non-struct array"), - ExecuteScalarExpression(list_item, - ArrayFromJSON(list(int32()), "[[1,2], [], null, [5]]"))); + ArrayFromJSON(float64(), R"([7.5, 2.125, 4.0])")); } Result NaiveExecuteScalarExpression(const Expression& expr, const Datum& input) { - auto call = expr.call(); - if (call == nullptr) { - // already tested execution of field_ref, execution of literal is trivial - return ExecuteScalarExpression(expr, input); + if (auto lit = expr.literal()) { + return *lit; } + if (auto ref = expr.field_ref()) { + if (input.type()) { + return ref->GetOneOrNone(*input.make_array()); + } + return ref->GetOneOrNone(*input.record_batch()); + } + + auto call = CallNotNull(expr); + std::vector arguments(call->arguments.size()); for (size_t i = 0; i < arguments.size(); ++i) { ARROW_ASSIGN_OR_RAISE(arguments[i], @@ -587,13 +641,16 @@ Result NaiveExecuteScalarExpression(const Expression& expr, const Datum& } void ExpectExecute(Expression expr, Datum in, Datum* actual_out = NULLPTR) { + std::shared_ptr schm; if (in.is_value()) { ASSERT_OK_AND_ASSIGN(expr, expr.Bind(in.descr())); + schm = schema(in.type()->fields()); } else { ASSERT_OK_AND_ASSIGN(expr, expr.Bind(*in.schema())); + schm = in.schema(); } - ASSERT_OK_AND_ASSIGN(Datum actual, ExecuteScalarExpression(expr, in)); + ASSERT_OK_AND_ASSIGN(Datum actual, ExecuteScalarExpression(expr, *schm, in)); ASSERT_OK_AND_ASSIGN(Datum expected, NaiveExecuteScalarExpression(expr, in)); @@ -653,9 +710,9 @@ TEST(Expression, ExecuteDictionaryTransparent) { ASSERT_OK_AND_ASSIGN( expr, SimplifyWithGuarantee(expr, equal(field_ref("dict_str"), literal("eh")))); - ASSERT_OK_AND_ASSIGN( - auto res, - ExecuteScalarExpression(expr, ArrayFromJSON(struct_({field("i32", int32())}), R"([ + ASSERT_OK_AND_ASSIGN(auto res, ExecuteScalarExpression( + expr, *kBoringSchema, + ArrayFromJSON(struct_({field("i32", int32())}), R"([ {"i32": 0}, {"i32": 1}, {"i32": 2} @@ -841,7 +898,7 @@ TEST(Expression, ReplaceFieldsWithKnownValues) { // NB: known_values will be cast ExpectReplacesTo(field_ref("i32"), {{"i32", Datum("3")}}, literal(3)); - ExpectReplacesTo(field_ref("b"), i32_is_3, field_ref("b")); + ExpectReplacesTo(field_ref("f32"), i32_is_3, field_ref("f32")); ExpectReplacesTo(equal(field_ref("i32"), literal(1)), i32_is_3, equal(literal(3), literal(1))); @@ -1082,7 +1139,8 @@ TEST(Expression, SingleComparisonGuarantees) { {"i32"})); ASSERT_OK_AND_ASSIGN(filter, filter.Bind(*kBoringSchema)); - ASSERT_OK_AND_ASSIGN(Datum evaluated, ExecuteScalarExpression(filter, input)); + ASSERT_OK_AND_ASSIGN(Datum evaluated, + ExecuteScalarExpression(filter, *kBoringSchema, input)); // ensure that the simplified filter is as simplified as it could be // (this is always possible for single comparisons) @@ -1193,7 +1251,8 @@ TEST(Expression, Filter) { auto expected_mask = batch->column(0); ASSERT_OK_AND_ASSIGN(filter, filter.Bind(*kBoringSchema)); - ASSERT_OK_AND_ASSIGN(Datum mask, ExecuteScalarExpression(filter, batch)); + ASSERT_OK_AND_ASSIGN(Datum mask, + ExecuteScalarExpression(filter, *kBoringSchema, batch)); AssertDatumsEqual(expected_mask, mask); }; @@ -1286,7 +1345,8 @@ TEST(Projection, AugmentWithNull) { auto ExpectProject = [&](Expression proj, Datum expected) { ASSERT_OK_AND_ASSIGN(proj, proj.Bind(*kBoringSchema)); - ASSERT_OK_AND_ASSIGN(auto actual, ExecuteScalarExpression(proj, input)); + ASSERT_OK_AND_ASSIGN(auto actual, + ExecuteScalarExpression(proj, *kBoringSchema, input)); AssertDatumsEqual(Datum(expected), actual); }; @@ -1316,7 +1376,8 @@ TEST(Projection, AugmentWithKnownValues) { Expression guarantee) { ASSERT_OK_AND_ASSIGN(proj, proj.Bind(*kBoringSchema)); ASSERT_OK_AND_ASSIGN(proj, SimplifyWithGuarantee(proj, guarantee)); - ASSERT_OK_AND_ASSIGN(auto actual, ExecuteScalarExpression(proj, input)); + ASSERT_OK_AND_ASSIGN(auto actual, + ExecuteScalarExpression(proj, *kBoringSchema, input)); AssertDatumsEqual(Datum(expected), actual); }; diff --git a/cpp/src/arrow/compute/exec/plan_test.cc b/cpp/src/arrow/compute/exec/plan_test.cc index 86f1879cbe9..5f0d5441ed5 100644 --- a/cpp/src/arrow/compute/exec/plan_test.cc +++ b/cpp/src/arrow/compute/exec/plan_test.cc @@ -21,6 +21,7 @@ #include #include "arrow/compute/exec/exec_plan.h" +#include "arrow/compute/exec/expression.h" #include "arrow/compute/exec/test_util.h" #include "arrow/record_batch.h" #include "arrow/testing/future_util.h" @@ -51,30 +52,22 @@ TEST(ExecPlanConstruction, Empty) { TEST(ExecPlanConstruction, SingleNode) { ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); - auto node = MakeDummyNode(plan.get(), "dummy", /*num_inputs=*/0, /*num_outputs=*/0); + auto node = MakeDummyNode(plan.get(), "dummy", /*inputs=*/{}, /*num_outputs=*/0); ASSERT_OK(plan->Validate()); ASSERT_THAT(plan->sources(), ::testing::ElementsAre(node)); ASSERT_THAT(plan->sinks(), ::testing::ElementsAre(node)); ASSERT_OK_AND_ASSIGN(plan, ExecPlan::Make()); - node = MakeDummyNode(plan.get(), "dummy", /*num_inputs=*/1, /*num_outputs=*/0); - // Input not bound - ASSERT_RAISES(Invalid, plan->Validate()); - - ASSERT_OK_AND_ASSIGN(plan, ExecPlan::Make()); - node = MakeDummyNode(plan.get(), "dummy", /*num_inputs=*/0, /*num_outputs=*/1); + node = MakeDummyNode(plan.get(), "dummy", /*inputs=*/{}, /*num_outputs=*/1); // Output not bound ASSERT_RAISES(Invalid, plan->Validate()); } TEST(ExecPlanConstruction, SourceSink) { ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); - auto source = MakeDummyNode(plan.get(), "source", /*num_inputs=*/0, /*num_outputs=*/1); - auto sink = MakeDummyNode(plan.get(), "sink", /*num_inputs=*/1, /*num_outputs=*/0); - // Input / output not bound - ASSERT_RAISES(Invalid, plan->Validate()); + auto source = MakeDummyNode(plan.get(), "source", /*inputs=*/{}, /*num_outputs=*/1); + auto sink = MakeDummyNode(plan.get(), "sink", /*inputs=*/{source}, /*num_outputs=*/0); - sink->AddInput(source); ASSERT_OK(plan->Validate()); EXPECT_THAT(plan->sources(), ::testing::ElementsAre(source)); EXPECT_THAT(plan->sinks(), ::testing::ElementsAre(sink)); @@ -83,33 +76,21 @@ TEST(ExecPlanConstruction, SourceSink) { TEST(ExecPlanConstruction, MultipleNode) { ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); - auto source1 = - MakeDummyNode(plan.get(), "source1", /*num_inputs=*/0, /*num_outputs=*/2); + auto source1 = MakeDummyNode(plan.get(), "source1", /*inputs=*/{}, /*num_outputs=*/2); - auto source2 = - MakeDummyNode(plan.get(), "source2", /*num_inputs=*/0, /*num_outputs=*/1); + auto source2 = MakeDummyNode(plan.get(), "source2", /*inputs=*/{}, /*num_outputs=*/1); auto process1 = - MakeDummyNode(plan.get(), "process1", /*num_inputs=*/1, /*num_outputs=*/2); + MakeDummyNode(plan.get(), "process1", /*inputs=*/{source1}, /*num_outputs=*/2); - auto process2 = - MakeDummyNode(plan.get(), "process1", /*num_inputs=*/2, /*num_outputs=*/1); + auto process2 = MakeDummyNode(plan.get(), "process1", /*inputs=*/{source1, source2}, + /*num_outputs=*/1); auto process3 = - MakeDummyNode(plan.get(), "process3", /*num_inputs=*/3, /*num_outputs=*/1); - - auto sink = MakeDummyNode(plan.get(), "sink", /*num_inputs=*/1, /*num_outputs=*/0); + MakeDummyNode(plan.get(), "process3", /*inputs=*/{process1, process2, process1}, + /*num_outputs=*/1); - sink->AddInput(process3); - - process3->AddInput(process1); - process3->AddInput(process2); - process3->AddInput(process1); - - process2->AddInput(source1); - process2->AddInput(source2); - - process1->AddInput(source1); + auto sink = MakeDummyNode(plan.get(), "sink", /*inputs=*/{process3}, /*num_outputs=*/0); ASSERT_OK(plan->Validate()); ASSERT_THAT(plan->sources(), ::testing::ElementsAre(source1, source2)); @@ -135,30 +116,27 @@ TEST(ExecPlan, DummyStartProducing) { StartStopTracker t; ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); - auto source1 = MakeDummyNode(plan.get(), "source1", /*num_inputs=*/0, /*num_outputs=*/2, + + auto source1 = MakeDummyNode(plan.get(), "source1", /*inputs=*/{}, /*num_outputs=*/2, t.start_producing_func(), t.stop_producing_func()); - auto source2 = MakeDummyNode(plan.get(), "source2", /*num_inputs=*/0, /*num_outputs=*/1, + + auto source2 = MakeDummyNode(plan.get(), "source2", /*inputs=*/{}, /*num_outputs=*/1, t.start_producing_func(), t.stop_producing_func()); + auto process1 = - MakeDummyNode(plan.get(), "process1", /*num_inputs=*/1, /*num_outputs=*/2, + MakeDummyNode(plan.get(), "process1", /*inputs=*/{source1}, /*num_outputs=*/2, t.start_producing_func(), t.stop_producing_func()); + auto process2 = - MakeDummyNode(plan.get(), "process2", /*num_inputs=*/2, /*num_outputs=*/1, - t.start_producing_func(), t.stop_producing_func()); - auto process3 = - MakeDummyNode(plan.get(), "process3", /*num_inputs=*/3, /*num_outputs=*/1, - t.start_producing_func(), t.stop_producing_func()); + MakeDummyNode(plan.get(), "process2", /*inputs=*/{process1, source2}, + /*num_outputs=*/1, t.start_producing_func(), t.stop_producing_func()); - auto sink = MakeDummyNode(plan.get(), "sink", /*num_inputs=*/1, /*num_outputs=*/0, - t.start_producing_func(), t.stop_producing_func()); + auto process3 = + MakeDummyNode(plan.get(), "process3", /*inputs=*/{process1, source1, process2}, + /*num_outputs=*/1, t.start_producing_func(), t.stop_producing_func()); - process1->AddInput(source1); - process2->AddInput(process1); - process2->AddInput(source2); - process3->AddInput(process1); - process3->AddInput(source1); - process3->AddInput(process2); - sink->AddInput(process3); + MakeDummyNode(plan.get(), "sink", /*inputs=*/{process3}, /*num_outputs=*/0, + t.start_producing_func(), t.stop_producing_func()); ASSERT_OK(plan->Validate()); ASSERT_EQ(t.started.size(), 0); @@ -171,63 +149,32 @@ TEST(ExecPlan, DummyStartProducing) { ASSERT_EQ(t.stopped.size(), 0); } -TEST(ExecPlan, DummyStartProducingCycle) { - // A trivial cycle - ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); - auto node = MakeDummyNode(plan.get(), "dummy", /*num_inputs=*/1, /*num_outputs=*/1); - node->AddInput(node); - ASSERT_OK(plan->Validate()); - ASSERT_RAISES(Invalid, plan->StartProducing()); - - // A less trivial one - ASSERT_OK_AND_ASSIGN(plan, ExecPlan::Make()); - auto source = MakeDummyNode(plan.get(), "source", /*num_inputs=*/0, /*num_outputs=*/1); - auto process1 = - MakeDummyNode(plan.get(), "process1", /*num_inputs=*/2, /*num_outputs=*/2); - auto process2 = - MakeDummyNode(plan.get(), "process2", /*num_inputs=*/1, /*num_outputs=*/1); - auto process3 = - MakeDummyNode(plan.get(), "process3", /*num_inputs=*/2, /*num_outputs=*/2); - auto sink = MakeDummyNode(plan.get(), "sink", /*num_inputs=*/1, /*num_outputs=*/0); - - process1->AddInput(source); - process2->AddInput(process1); - process3->AddInput(process2); - process3->AddInput(process1); - process1->AddInput(process3); - sink->AddInput(process3); - - ASSERT_OK(plan->Validate()); - ASSERT_RAISES(Invalid, plan->StartProducing()); -} - TEST(ExecPlan, DummyStartProducingError) { StartStopTracker t; ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); - auto source1 = MakeDummyNode(plan.get(), "source1", /*num_inputs=*/0, /*num_outputs=*/2, - t.start_producing_func(Status::NotImplemented("zzz")), - t.stop_producing_func()); - auto source2 = MakeDummyNode(plan.get(), "source2", /*num_inputs=*/0, /*num_outputs=*/1, - t.start_producing_func(), t.stop_producing_func()); + auto source1 = MakeDummyNode( + plan.get(), "source1", /*num_inputs=*/{}, /*num_outputs=*/2, + t.start_producing_func(Status::NotImplemented("zzz")), t.stop_producing_func()); + + auto source2 = + MakeDummyNode(plan.get(), "source2", /*num_inputs=*/{}, /*num_outputs=*/1, + t.start_producing_func(), t.stop_producing_func()); + auto process1 = MakeDummyNode( - plan.get(), "process1", /*num_inputs=*/1, /*num_outputs=*/2, + plan.get(), "process1", /*num_inputs=*/{source1}, /*num_outputs=*/2, t.start_producing_func(Status::IOError("xxx")), t.stop_producing_func()); + auto process2 = - MakeDummyNode(plan.get(), "process2", /*num_inputs=*/2, /*num_outputs=*/1, - t.start_producing_func(), t.stop_producing_func()); - process1->AddInput(source1); - process2->AddInput(process1); - process2->AddInput(source2); + MakeDummyNode(plan.get(), "process2", /*num_inputs=*/{process1, source2}, + /*num_outputs=*/1, t.start_producing_func(), t.stop_producing_func()); + auto process3 = - MakeDummyNode(plan.get(), "process3", /*num_inputs=*/3, /*num_outputs=*/1, - t.start_producing_func(), t.stop_producing_func()); - process3->AddInput(process1); - process3->AddInput(source1); - process3->AddInput(process2); - auto sink = MakeDummyNode(plan.get(), "sink", /*num_inputs=*/1, /*num_outputs=*/0, - t.start_producing_func(), t.stop_producing_func()); - sink->AddInput(process3); + MakeDummyNode(plan.get(), "process3", /*num_inputs=*/{process1, source1, process2}, + /*num_outputs=*/1, t.start_producing_func(), t.stop_producing_func()); + + MakeDummyNode(plan.get(), "sink", /*num_inputs=*/{process3}, /*num_outputs=*/0, + t.start_producing_func(), t.stop_producing_func()); ASSERT_OK(plan->Validate()); ASSERT_EQ(t.started.size(), 0); @@ -268,15 +215,12 @@ class SlowRecordBatchReader : public RecordBatchReader { static Result MakeSlowRecordBatchGenerator( RecordBatchVector batches, std::shared_ptr schema) { - auto gen = MakeVectorGenerator(batches); // TODO move this into testing/async_generator_util.h? - auto delayed_gen = MakeMappedGenerator>( - std::move(gen), [](const std::shared_ptr& batch) { - auto fut = Future>::Make(); - SleepABitAsync().AddCallback( - [fut, batch](const Status& status) mutable { fut.MarkFinished(batch); }); - return fut; - }); + auto delayed_gen = + MakeMappedGenerator(MakeVectorGenerator(std::move(batches)), + [](const std::shared_ptr& batch) { + return SleepABitAsync().Then([=] { return batch; }); + }); // Adding readahead implicitly adds parallelism by pulling reentrantly from // the delayed generator return MakeReadaheadGenerator(std::move(delayed_gen), /*max_readahead=*/64); @@ -305,36 +249,33 @@ class TestExecPlanExecution : public ::testing::Test { }; Result MakeSourceSink(std::shared_ptr reader, - const std::shared_ptr& schema) { + std::shared_ptr schema) { ARROW_ASSIGN_OR_RAISE(auto plan, ExecPlan::Make()); auto source = MakeRecordBatchReaderNode(plan.get(), "source", reader, io_executor_.get()); - auto sink = MakeRecordBatchCollectNode(plan.get(), "sink", schema); - sink->AddInput(source); + auto sink = MakeRecordBatchCollectNode(plan.get(), "sink", source, std::move(schema)); return CollectorPlan{plan, sink}; } Result MakeSourceSink(RecordBatchGenerator generator, - const std::shared_ptr& schema) { + std::shared_ptr schema) { ARROW_ASSIGN_OR_RAISE(auto plan, ExecPlan::Make()); auto source = MakeRecordBatchReaderNode(plan.get(), "source", schema, generator, io_executor_.get()); - auto sink = MakeRecordBatchCollectNode(plan.get(), "sink", schema); - sink->AddInput(source); + auto sink = MakeRecordBatchCollectNode(plan.get(), "sink", source, std::move(schema)); return CollectorPlan{plan, sink}; } Result MakeSourceSink(const RecordBatchVector& batches, - const std::shared_ptr& schema) { + std::shared_ptr schema) { ARROW_ASSIGN_OR_RAISE(auto reader, RecordBatchReader::Make(batches, schema)); - return MakeSourceSink(std::move(reader), schema); + return MakeSourceSink(std::move(reader), std::move(schema)); } Result StartAndCollect(ExecPlan* plan, RecordBatchCollectNode* sink) { RETURN_NOT_OK(plan->StartProducing()); - auto fut = CollectAsyncGenerator(sink->generator()); - return fut.result(); + return CollectAsyncGenerator(sink->generator()).result(); } template @@ -373,6 +314,8 @@ class TestExecPlanExecution : public ::testing::Test { std::shared_ptr io_executor_; }; +// FIXME Test "collecting" an error + TEST_F(TestExecPlanExecution, SourceSink) { TestSourceSink(RecordBatchReader::Make); } TEST_F(TestExecPlanExecution, SlowSourceSink) { @@ -396,5 +339,37 @@ TEST_F(TestExecPlanExecution, StressSlowSourceSinkParallel) { TestStressSourceSink(/*num_batches=*/300, MakeSlowRecordBatchGenerator); } +TEST_F(TestExecPlanExecution, SourceFilterSink) { + ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); + + const auto schema = ::arrow::schema({field("a", int32()), field("b", boolean())}); + RecordBatchVector batches{ + RecordBatchFromJSON(schema, R"([{"a": null, "b": true}, + {"a": 4, "b": false}])"), + RecordBatchFromJSON(schema, R"([{"a": 5, "b": null}, + {"a": 6, "b": false}, + {"a": 7, "b": false}])"), + }; + + ASSERT_OK_AND_ASSIGN(auto reader, RecordBatchReader::Make(std::move(batches), schema)); + + auto source = + MakeRecordBatchReaderNode(plan.get(), "source", reader, io_executor_.get()); + + ASSERT_OK_AND_ASSIGN(auto predicate, equal(field_ref("a"), literal(6)).Bind(*schema)); + + auto filter = MakeFilterNode(source, "filter", predicate); + + auto sink = MakeRecordBatchCollectNode(plan.get(), "sink", filter, schema); + + ASSERT_OK_AND_ASSIGN(auto got_batches, StartAndCollect(plan.get(), sink)); + AssertBatchesEqual( + { + RecordBatchFromJSON(schema, R"([])"), + RecordBatchFromJSON(schema, R"([{"a": 6, "b": false}])"), + }, + got_batches); +} + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/exec/test_util.cc b/cpp/src/arrow/compute/exec/test_util.cc index ae2c9446aa9..a64ab11c512 100644 --- a/cpp/src/arrow/compute/exec/test_util.cc +++ b/cpp/src/arrow/compute/exec/test_util.cc @@ -57,20 +57,20 @@ std::vector DescrFromSchemaColumns(const Schema& schema) { } struct DummyNode : ExecNode { - DummyNode(ExecPlan* plan, std::string label, int num_inputs, int num_outputs, + DummyNode(ExecPlan* plan, std::string label, NodeVector inputs, int num_outputs, StartProducingFunc start_producing, StopProducingFunc stop_producing) - : ExecNode(plan, std::move(label), std::vector(num_inputs, descr()), {}, - descr(), num_outputs), + : ExecNode(plan, std::move(label), std::move(inputs), {}, descr(), num_outputs), start_producing_(std::move(start_producing)), stop_producing_(std::move(stop_producing)) { - for (int i = 0; i < num_inputs; ++i) { - input_labels_.push_back(std::to_string(i)); + input_labels_.resize(inputs_.size()); + for (size_t i = 0; i < input_labels_.size(); ++i) { + input_labels_[i] = std::to_string(i); } } const char* kind_name() override { return "Dummy"; } - void InputReceived(ExecNode* input, int seq_num, compute::ExecBatch batch) override {} + void InputReceived(ExecNode* input, int seq_num, ExecBatch batch) override {} void ErrorReceived(ExecNode* input, Status error) override {} @@ -124,107 +124,10 @@ struct DummyNode : ExecNode { bool started_ = false; }; -struct RecordBatchReaderNode : ExecNode { - RecordBatchReaderNode(ExecPlan* plan, std::string label, - std::shared_ptr reader, Executor* io_executor) - : ExecNode(plan, std::move(label), {}, {}, - DescrFromSchemaColumns(*reader->schema()), /*num_outputs=*/1), - schema_(reader->schema()), - reader_(std::move(reader)), - io_executor_(io_executor) {} - - RecordBatchReaderNode(ExecPlan* plan, std::string label, std::shared_ptr schema, - RecordBatchGenerator generator, Executor* io_executor) - : ExecNode(plan, std::move(label), {}, {}, DescrFromSchemaColumns(*schema), - /*num_outputs=*/1), - schema_(std::move(schema)), - generator_(std::move(generator)), - io_executor_(io_executor) {} - - const char* kind_name() override { return "RecordBatchReader"; } - - void InputReceived(ExecNode* input, int seq_num, compute::ExecBatch batch) override {} - - void ErrorReceived(ExecNode* input, Status error) override {} - - void InputFinished(ExecNode* input, int seq_stop) override {} - - Status StartProducing() override { - next_batch_index_ = 0; - if (!generator_) { - auto it = MakeIteratorFromReader(reader_); - ARROW_ASSIGN_OR_RAISE(generator_, - MakeBackgroundGenerator(std::move(it), io_executor_)); - } - GenerateOne(std::unique_lock{mutex_}); - return Status::OK(); - } - - void PauseProducing(ExecNode* output) override {} - - void ResumeProducing(ExecNode* output) override {} - - void StopProducing(ExecNode* output) override { - ASSERT_EQ(output, outputs_[0]); - std::unique_lock lock(mutex_); - generator_ = nullptr; // null function - } - - void StopProducing() override { StopProducing(outputs_[0]); } - - private: - void GenerateOne(std::unique_lock&& lock) { - if (!generator_) { - // Stopped - return; - } - auto plan = this->plan()->shared_from_this(); - auto fut = generator_(); - const auto batch_index = next_batch_index_++; - - lock.unlock(); - // TODO we want to transfer always here - io_executor_->Transfer(std::move(fut)) - .AddCallback( - [plan, batch_index, this](const Result>& res) { - std::unique_lock lock(mutex_); - if (!res.ok()) { - for (auto out : outputs_) { - out->ErrorReceived(this, res.status()); - } - return; - } - const auto& batch = *res; - if (IsIterationEnd(batch)) { - lock.unlock(); - for (auto out : outputs_) { - out->InputFinished(this, batch_index); - } - } else { - lock.unlock(); - for (auto out : outputs_) { - out->InputReceived(this, batch_index, compute::ExecBatch(*batch)); - } - lock.lock(); - GenerateOne(std::move(lock)); - } - }); - } - - std::mutex mutex_; - const std::shared_ptr schema_; - const std::shared_ptr reader_; - RecordBatchGenerator generator_; - int next_batch_index_; - - Executor* const io_executor_; -}; - struct RecordBatchCollectNodeImpl : public RecordBatchCollectNode { - RecordBatchCollectNodeImpl(ExecPlan* plan, std::string label, + RecordBatchCollectNodeImpl(ExecPlan* plan, std::string label, ExecNode* input, std::shared_ptr schema) - : RecordBatchCollectNode(plan, std::move(label), {DescrFromSchemaColumns(*schema)}, - {"batches_to_collect"}, {}, 0), + : RecordBatchCollectNode(plan, std::move(label), {input}, {"collected"}, {}, 0), schema_(std::move(schema)) {} RecordBatchGenerator generator() override { return generator_; } @@ -256,12 +159,10 @@ struct RecordBatchCollectNodeImpl : public RecordBatchCollectNode { StopProducingUnlocked(); } - void InputReceived(ExecNode* input, int seq_num, - compute::ExecBatch exec_batch) override { + void InputReceived(ExecNode* input, int seq_num, ExecBatch exec_batch) override { std::unique_lock lock(mutex_); - if (stopped_) { - return; - } + if (stopped_) return; + auto maybe_batch = MakeBatch(std::move(exec_batch)); if (!maybe_batch.ok()) { lock.unlock(); @@ -339,8 +240,7 @@ struct RecordBatchCollectNodeImpl : public RecordBatchCollectNode { } } - // TODO factor this out as ExecBatch::ToRecordBatch()? - Result> MakeBatch(compute::ExecBatch&& exec_batch) { + Result> MakeBatch(ExecBatch&& exec_batch) { ArrayDataVector columns; columns.reserve(exec_batch.values.size()); for (auto&& value : exec_batch.values) { @@ -365,35 +265,48 @@ struct RecordBatchCollectNodeImpl : public RecordBatchCollectNode { util::optional>::Producer> producer_; }; +AsyncGenerator> Wrap(RecordBatchGenerator gen, + ::arrow::internal::Executor* io_executor) { + return MakeMappedGenerator( + MakeTransferredGenerator(std::move(gen), io_executor), + [](const std::shared_ptr& batch) -> util::optional { + return ExecBatch(*batch); + }); +} + } // namespace ExecNode* MakeRecordBatchReaderNode(ExecPlan* plan, std::string label, - std::shared_ptr reader, - Executor* io_executor) { - return plan->EmplaceNode(plan, std::move(label), - std::move(reader), io_executor); + const std::shared_ptr& schema, + RecordBatchGenerator generator, + ::arrow::internal::Executor* io_executor) { + return MakeSourceNode(plan, std::move(label), DescrFromSchemaColumns(*schema), + Wrap(std::move(generator), io_executor)); } ExecNode* MakeRecordBatchReaderNode(ExecPlan* plan, std::string label, - std::shared_ptr schema, - RecordBatchGenerator generator, - ::arrow::internal::Executor* io_executor) { - return plan->EmplaceNode( - plan, std::move(label), std::move(schema), std::move(generator), io_executor); + const std::shared_ptr& reader, + Executor* io_executor) { + auto gen = + MakeBackgroundGenerator(MakeIteratorFromReader(reader), io_executor).ValueOrDie(); + + return MakeRecordBatchReaderNode(plan, std::move(label), reader->schema(), + std::move(gen), io_executor); } -ExecNode* MakeDummyNode(ExecPlan* plan, std::string label, int num_inputs, +ExecNode* MakeDummyNode(ExecPlan* plan, std::string label, std::vector inputs, int num_outputs, StartProducingFunc start_producing, StopProducingFunc stop_producing) { - return plan->EmplaceNode(plan, std::move(label), num_inputs, num_outputs, - std::move(start_producing), + return plan->EmplaceNode(plan, std::move(label), std::move(inputs), + num_outputs, std::move(start_producing), std::move(stop_producing)); } -RecordBatchCollectNode* MakeRecordBatchCollectNode( - ExecPlan* plan, std::string label, const std::shared_ptr& schema) { - return arrow::internal::checked_cast( - plan->EmplaceNode(plan, std::move(label), schema)); +RecordBatchCollectNode* MakeRecordBatchCollectNode(ExecPlan* plan, std::string label, + ExecNode* input, + std::shared_ptr schema) { + return plan->EmplaceNode(plan, std::move(label), input, + std::move(schema)); } } // namespace compute diff --git a/cpp/src/arrow/compute/exec/test_util.h b/cpp/src/arrow/compute/exec/test_util.h index c2dc785a501..40f2b572be0 100644 --- a/cpp/src/arrow/compute/exec/test_util.h +++ b/cpp/src/arrow/compute/exec/test_util.h @@ -36,7 +36,7 @@ using StopProducingFunc = std::function; // Make a dummy node that has no execution behaviour ARROW_TESTING_EXPORT -ExecNode* MakeDummyNode(ExecPlan* plan, std::string label, int num_inputs, +ExecNode* MakeDummyNode(ExecPlan* plan, std::string label, std::vector inputs, int num_outputs, StartProducingFunc = {}, StopProducingFunc = {}); using RecordBatchGenerator = AsyncGenerator>; @@ -45,12 +45,12 @@ using RecordBatchGenerator = AsyncGenerator>; // background from a RecordBatchReader. ARROW_TESTING_EXPORT ExecNode* MakeRecordBatchReaderNode(ExecPlan* plan, std::string label, - std::shared_ptr reader, + const std::shared_ptr& reader, ::arrow::internal::Executor* io_executor); ARROW_TESTING_EXPORT ExecNode* MakeRecordBatchReaderNode(ExecPlan* plan, std::string label, - std::shared_ptr schema, + const std::shared_ptr& schema, RecordBatchGenerator generator, ::arrow::internal::Executor* io_executor); @@ -64,7 +64,8 @@ class RecordBatchCollectNode : public ExecNode { ARROW_TESTING_EXPORT RecordBatchCollectNode* MakeRecordBatchCollectNode(ExecPlan* plan, std::string label, - const std::shared_ptr& schema); + ExecNode* input, + std::shared_ptr schema); } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.cc b/cpp/src/arrow/compute/kernels/codegen_internal.cc index d6a1d4ccbc4..e723bd7838e 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal.cc +++ b/cpp/src/arrow/compute/kernels/codegen_internal.cc @@ -244,7 +244,7 @@ std::shared_ptr CommonNumeric(const std::vector& descrs) { for (const auto& descr : descrs) { auto id = descr.type->id(); - auto max_width = is_signed_integer(id) ? &max_width_signed : &max_width_unsigned; + auto max_width = &(is_signed_integer(id) ? max_width_signed : max_width_unsigned); *max_width = std::max(bit_width(id), *max_width); } diff --git a/cpp/src/arrow/dataset/dataset.cc b/cpp/src/arrow/dataset/dataset.cc index 841b792ee34..fc6b38b37a9 100644 --- a/cpp/src/arrow/dataset/dataset.cc +++ b/cpp/src/arrow/dataset/dataset.cc @@ -165,13 +165,8 @@ Dataset::Dataset(std::shared_ptr schema, compute::Expression partition_e : schema_(std::move(schema)), partition_expression_(std::move(partition_expression)) {} -Result> Dataset::NewScan( - std::shared_ptr options) { - return std::make_shared(this->shared_from_this(), options); -} - Result> Dataset::NewScan() { - return NewScan(std::make_shared()); + return std::make_shared(this->shared_from_this()); } Result Dataset::GetFragments() { diff --git a/cpp/src/arrow/dataset/dataset.h b/cpp/src/arrow/dataset/dataset.h index d2cba730252..11210fdc27b 100644 --- a/cpp/src/arrow/dataset/dataset.h +++ b/cpp/src/arrow/dataset/dataset.h @@ -155,7 +155,6 @@ class ARROW_DS_EXPORT InMemoryFragment : public Fragment { class ARROW_DS_EXPORT Dataset : public std::enable_shared_from_this { public: /// \brief Begin to build a new Scan operation against this Dataset - Result> NewScan(std::shared_ptr options); Result> NewScan(); /// \brief GetFragments returns an iterator of Fragments given a predicate. diff --git a/cpp/src/arrow/dataset/file_ipc_test.cc b/cpp/src/arrow/dataset/file_ipc_test.cc index f0409abe85b..e6192523f53 100644 --- a/cpp/src/arrow/dataset/file_ipc_test.cc +++ b/cpp/src/arrow/dataset/file_ipc_test.cc @@ -100,13 +100,6 @@ class TestIpcFileSystemDataset : public testing::Test, format_ = ipc_format; SetWriteOptions(ipc_format->DefaultWriteOptions()); } - - std::shared_ptr MakeScanner(const std::shared_ptr& dataset, - const std::shared_ptr& scan_options) { - ScannerBuilder builder(dataset, scan_options); - EXPECT_OK_AND_ASSIGN(auto scanner, builder.Finish()); - return scanner; - } }; TEST_F(TestIpcFileSystemDataset, WriteWithIdenticalPartitioningSchema) { @@ -132,7 +125,7 @@ TEST_F(TestIpcFileSystemDataset, WriteExceedsMaxPartitions) { // require that no batch be grouped into more than 2 written batches: write_options_.max_partitions = 2; - auto scanner = MakeScanner(dataset_, scan_options_); + EXPECT_OK_AND_ASSIGN(auto scanner, ScannerBuilder(dataset_, scan_options_).Finish()); EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, testing::HasSubstr("This exceeds the maximum"), FileSystemDataset::Write(write_options_, scanner)); } diff --git a/cpp/src/arrow/dataset/file_test.cc b/cpp/src/arrow/dataset/file_test.cc index b80d1bb57f0..5bf89330429 100644 --- a/cpp/src/arrow/dataset/file_test.cc +++ b/cpp/src/arrow/dataset/file_test.cc @@ -169,7 +169,8 @@ TEST_F(TestFileSystemDataset, ReplaceSchema) { TEST_F(TestFileSystemDataset, RootPartitionPruning) { auto root_partition = equal(field_ref("i32"), literal(5)); - MakeDataset({fs::File("a"), fs::File("b")}, root_partition); + MakeDataset({fs::File("a"), fs::File("b")}, root_partition, {}, + schema({field("i32", int32()), field("f32", float32())})); auto GetFragments = [&](compute::Expression filter) { return *dataset_->GetFragments(*filter.Bind(*dataset_->schema())); @@ -191,8 +192,9 @@ TEST_F(TestFileSystemDataset, RootPartitionPruning) { AssertFragmentsAreFromPath(GetFragments(equal(field_ref("f32"), literal(3.F))), {"a", "b"}); - // No partition should match - MakeDataset({fs::File("a"), fs::File("b")}); + // No root partition: don't prune any fragments + MakeDataset({fs::File("a"), fs::File("b")}, literal(true), {}, + schema({field("i32", int32()), field("f32", float32())})); AssertFragmentsAreFromPath(GetFragments(equal(field_ref("f32"), literal(3.F))), {"a", "b"}); } diff --git a/cpp/src/arrow/dataset/scanner.cc b/cpp/src/arrow/dataset/scanner.cc index 09e05cdbf75..eda409ea4ea 100644 --- a/cpp/src/arrow/dataset/scanner.cc +++ b/cpp/src/arrow/dataset/scanner.cc @@ -411,18 +411,6 @@ Result SyncScanner::ScanInternal() { return GetScanTaskIterator(std::move(fragment_it), scan_options_); } -Result ScanTaskIteratorFromRecordBatch( - std::vector> batches, - std::shared_ptr options) { - if (batches.empty()) { - return MakeVectorIterator(ScanTaskVector()); - } - auto schema = batches[0]->schema(); - auto fragment = - std::make_shared(std::move(schema), std::move(batches)); - return fragment->Scan(std::move(options)); -} - class ARROW_DS_EXPORT AsyncScanner : public Scanner, public std::enable_shared_from_this { public: @@ -459,10 +447,12 @@ inline Result DoFilterAndProjectRecordBatchAsync( SimplifyWithGuarantee(scanner->options()->filter, in.fragment.value->partition_expression())); + const auto& schema = *scanner->options()->dataset_schema; + compute::ExecContext exec_context{scanner->options()->pool}; - ARROW_ASSIGN_OR_RAISE( - Datum mask, ExecuteScalarExpression(simplified_filter, Datum(in.record_batch.value), - &exec_context)); + ARROW_ASSIGN_OR_RAISE(Datum mask, + ExecuteScalarExpression(simplified_filter, schema, + in.record_batch.value, &exec_context)); Datum filtered; if (mask.is_scalar()) { @@ -483,9 +473,10 @@ inline Result DoFilterAndProjectRecordBatchAsync( ARROW_ASSIGN_OR_RAISE(compute::Expression simplified_projection, SimplifyWithGuarantee(scanner->options()->projection, in.fragment.value->partition_expression())); + ARROW_ASSIGN_OR_RAISE( Datum projected, - ExecuteScalarExpression(simplified_projection, filtered, &exec_context)); + ExecuteScalarExpression(simplified_projection, schema, filtered, &exec_context)); DCHECK_EQ(projected.type()->id(), Type::STRUCT); if (projected.shape() == ValueDescr::SCALAR) { @@ -510,7 +501,7 @@ inline EnumeratedRecordBatchGenerator FilterAndProjectRecordBatchAsync( auto mapper = [scanner](const EnumeratedRecordBatch& in) { return DoFilterAndProjectRecordBatchAsync(scanner, in); }; - return MakeMappedGenerator(std::move(rbs), mapper); + return MakeMappedGenerator(std::move(rbs), mapper); } Result FragmentToBatches( @@ -525,8 +516,7 @@ Result FragmentToBatches( return EnumeratedRecordBatch{record_batch, fragment}; }; - auto combined_gen = MakeMappedGenerator(enumerated_batch_gen, - std::move(combine_fn)); + auto combined_gen = MakeMappedGenerator(enumerated_batch_gen, std::move(combine_fn)); return FilterAndProjectRecordBatchAsync(scanner, std::move(combined_gen)); } @@ -534,7 +524,7 @@ Result FragmentToBatches( Result> FragmentsToBatches( std::shared_ptr scanner, FragmentGenerator fragment_gen) { auto enumerated_fragment_gen = MakeEnumeratedGenerator(std::move(fragment_gen)); - return MakeMappedGenerator( + return MakeMappedGenerator( std::move(enumerated_fragment_gen), [scanner](const Enumerated>& fragment) { return FragmentToBatches(scanner, fragment, scanner->options()); @@ -566,12 +556,11 @@ Result>>> FragmentsToRowCo [](const EnumeratedRecordBatch& enumerated) -> util::optional { return enumerated.record_batch.value->num_rows(); }; - return MakeMappedGenerator>(batch_gen, - std::move(count_fn)); + return MakeMappedGenerator(batch_gen, std::move(count_fn)); })); }; - return MakeMappedGenerator>>( - std::move(enumerated_fragment_gen), std::move(count_fragment_fn)); + return MakeMappedGenerator(std::move(enumerated_fragment_gen), + std::move(count_fragment_fn)); } } // namespace @@ -664,7 +653,7 @@ Result AsyncScanner::ScanBatchesAsync( return TaggedRecordBatch{enumerated_batch.record_batch.value, enumerated_batch.fragment.value}; }; - return MakeMappedGenerator(std::move(sequenced), unenumerate_fn); + return MakeMappedGenerator(std::move(sequenced), unenumerate_fn); } struct AsyncTableAssemblyState { @@ -725,8 +714,8 @@ Future> AsyncScanner::ToTableAsync( return batch; }; - auto table_building_gen = MakeMappedGenerator( - positioned_batch_gen, table_building_task); + auto table_building_gen = + MakeMappedGenerator(positioned_batch_gen, table_building_task); return DiscardAllFromAsyncGenerator(table_building_gen).Then([state, scan_options]() { return Table::FromRecordBatches(scan_options->projected_schema, state->Finish()); diff --git a/cpp/src/arrow/dataset/scanner.h b/cpp/src/arrow/dataset/scanner.h index 29fd5aad994..3ee18c5c049 100644 --- a/cpp/src/arrow/dataset/scanner.h +++ b/cpp/src/arrow/dataset/scanner.h @@ -194,20 +194,22 @@ using EnumeratedRecordBatchIterator = Iterator; template <> struct IterationTraits { static dataset::TaggedRecordBatch End() { - return dataset::TaggedRecordBatch{NULL, NULL}; + return dataset::TaggedRecordBatch{NULLPTR, NULLPTR}; } static bool IsEnd(const dataset::TaggedRecordBatch& val) { - return val.record_batch == NULL; + return val.record_batch == NULLPTR; } }; template <> struct IterationTraits { static dataset::EnumeratedRecordBatch End() { - return dataset::EnumeratedRecordBatch{{NULL, -1, false}, {NULL, -1, false}}; + return dataset::EnumeratedRecordBatch{ + IterationEnd>>(), + IterationEnd>>()}; } static bool IsEnd(const dataset::EnumeratedRecordBatch& val) { - return val.fragment.value == NULL; + return val.fragment.value == NULLPTR; } }; @@ -402,7 +404,7 @@ class ARROW_DS_EXPORT ScannerBuilder { private: std::shared_ptr dataset_; std::shared_ptr fragment_; - std::shared_ptr scan_options_; + std::shared_ptr scan_options_ = std::make_shared(); }; /// @} @@ -422,9 +424,5 @@ class ARROW_DS_EXPORT InMemoryScanTask : public ScanTask { std::vector> record_batches_; }; -ARROW_DS_EXPORT Result ScanTaskIteratorFromRecordBatch( - std::vector> batches, - std::shared_ptr options); - } // namespace dataset } // namespace arrow diff --git a/cpp/src/arrow/dataset/scanner_internal.h b/cpp/src/arrow/dataset/scanner_internal.h index 30fb4e07cef..27b32aa6f19 100644 --- a/cpp/src/arrow/dataset/scanner_internal.h +++ b/cpp/src/arrow/dataset/scanner_internal.h @@ -40,10 +40,11 @@ namespace dataset { inline Result> FilterSingleBatch( const std::shared_ptr& in, const compute::Expression& filter, - MemoryPool* pool) { - compute::ExecContext exec_context{pool}; - ARROW_ASSIGN_OR_RAISE(Datum mask, - ExecuteScalarExpression(filter, Datum(in), &exec_context)); + const std::shared_ptr& options) { + compute::ExecContext exec_context{options->pool}; + ARROW_ASSIGN_OR_RAISE( + Datum mask, + ExecuteScalarExpression(filter, *options->dataset_schema, in, &exec_context)); if (mask.is_scalar()) { const auto& mask_scalar = mask.scalar_as(); @@ -59,28 +60,29 @@ inline Result> FilterSingleBatch( return filtered.record_batch(); } -inline RecordBatchIterator FilterRecordBatch(RecordBatchIterator it, - compute::Expression filter, - MemoryPool* pool) { +inline RecordBatchIterator FilterRecordBatch( + RecordBatchIterator it, compute::Expression filter, + const std::shared_ptr& options) { return MakeMaybeMapIterator( [=](std::shared_ptr in) -> Result> { - return FilterSingleBatch(in, filter, pool); + return FilterSingleBatch(in, filter, options); }, std::move(it)); } inline Result> ProjectSingleBatch( const std::shared_ptr& in, const compute::Expression& projection, - MemoryPool* pool) { - compute::ExecContext exec_context{pool}; - ARROW_ASSIGN_OR_RAISE(Datum projected, - ExecuteScalarExpression(projection, Datum(in), &exec_context)); + const std::shared_ptr& options) { + compute::ExecContext exec_context{options->pool}; + ARROW_ASSIGN_OR_RAISE( + Datum projected, + ExecuteScalarExpression(projection, *options->dataset_schema, in, &exec_context)); DCHECK_EQ(projected.type()->id(), Type::STRUCT); if (projected.shape() == ValueDescr::SCALAR) { // Only virtual columns are projected. Broadcast to an array - ARROW_ASSIGN_OR_RAISE(projected, - MakeArrayFromScalar(*projected.scalar(), in->num_rows(), pool)); + ARROW_ASSIGN_OR_RAISE(projected, MakeArrayFromScalar(*projected.scalar(), + in->num_rows(), options->pool)); } ARROW_ASSIGN_OR_RAISE(auto out, @@ -89,12 +91,12 @@ inline Result> ProjectSingleBatch( return out->ReplaceSchemaMetadata(in->schema()->metadata()); } -inline RecordBatchIterator ProjectRecordBatch(RecordBatchIterator it, - compute::Expression projection, - MemoryPool* pool) { +inline RecordBatchIterator ProjectRecordBatch( + RecordBatchIterator it, compute::Expression projection, + const std::shared_ptr& options) { return MakeMaybeMapIterator( [=](std::shared_ptr in) -> Result> { - return ProjectSingleBatch(in, projection, pool); + return ProjectSingleBatch(in, projection, options); }, std::move(it)); } @@ -117,10 +119,9 @@ class FilterAndProjectScanTask : public ScanTask { SimplifyWithGuarantee(options()->projection, partition_)); RecordBatchIterator filter_it = - FilterRecordBatch(std::move(it), simplified_filter, options_->pool); + FilterRecordBatch(std::move(it), simplified_filter, options_); - return ProjectRecordBatch(std::move(filter_it), simplified_projection, - options_->pool); + return ProjectRecordBatch(std::move(filter_it), simplified_projection, options_); } Result ToFilteredAndProjectedIterator( @@ -133,10 +134,9 @@ class FilterAndProjectScanTask : public ScanTask { SimplifyWithGuarantee(options()->projection, partition_)); RecordBatchIterator filter_it = - FilterRecordBatch(std::move(it), simplified_filter, options_->pool); + FilterRecordBatch(std::move(it), simplified_filter, options_); - return ProjectRecordBatch(std::move(filter_it), simplified_projection, - options_->pool); + return ProjectRecordBatch(std::move(filter_it), simplified_projection, options_); } Result> FilterAndProjectBatch( @@ -147,8 +147,8 @@ class FilterAndProjectScanTask : public ScanTask { ARROW_ASSIGN_OR_RAISE(compute::Expression simplified_projection, SimplifyWithGuarantee(options()->projection, partition_)); ARROW_ASSIGN_OR_RAISE(auto filtered, - FilterSingleBatch(batch, simplified_filter, options_->pool)); - return ProjectSingleBatch(filtered, simplified_projection, options_->pool); + FilterSingleBatch(batch, simplified_filter, options_)); + return ProjectSingleBatch(filtered, simplified_projection, options_); } inline Future SafeExecute(internal::Executor* executor) override { diff --git a/cpp/src/arrow/dataset/test_util.h b/cpp/src/arrow/dataset/test_util.h index 42704fea9b5..e1d85d23342 100644 --- a/cpp/src/arrow/dataset/test_util.h +++ b/cpp/src/arrow/dataset/test_util.h @@ -547,8 +547,8 @@ class FileFormatScanMixin : public FileFormatFixtureMixin, // Scan the fragment through the scanner. RecordBatchIterator Batches(std::shared_ptr fragment) { - EXPECT_OK_AND_ASSIGN(auto schema, fragment->ReadPhysicalSchema()); - auto dataset = std::make_shared(schema, FragmentVector{fragment}); + auto dataset = std::make_shared(opts_->dataset_schema, + FragmentVector{fragment}); ScannerBuilder builder(dataset, opts_); ARROW_EXPECT_OK(builder.UseAsync(GetParam().use_async)); ARROW_EXPECT_OK(builder.UseThreads(GetParam().use_threads)); @@ -761,12 +761,11 @@ class JSONRecordBatchFileFormat : public FileFormat { ARROW_ASSIGN_OR_RAISE(auto file, fragment->source().Open()); ARROW_ASSIGN_OR_RAISE(int64_t size, file->GetSize()); ARROW_ASSIGN_OR_RAISE(auto buffer, file->Read(size)); - - util::string_view view{*buffer}; - ARROW_ASSIGN_OR_RAISE(auto schema, Inspect(fragment->source())); - std::shared_ptr batch = RecordBatchFromJSON(schema, view); - return ScanTaskIteratorFromRecordBatch({batch}, std::move(options)); + + RecordBatchVector batches{RecordBatchFromJSON(schema, util::string_view{*buffer})}; + return std::make_shared(std::move(schema), std::move(batches)) + ->Scan(std::move(options)); } Result> MakeWriter( @@ -1052,7 +1051,7 @@ class WriteFileSystemDatasetMixin : public MakeFileSystemDatasetMixin { ASSERT_OK_AND_ASSIGN(dataset_, factory->Finish()); scan_options_ = std::make_shared(); - scan_options_->dataset_schema = source_schema_; + scan_options_->dataset_schema = dataset_->schema(); ASSERT_OK(SetProjection(scan_options_.get(), source_schema_->field_names())); } diff --git a/cpp/src/arrow/result.h b/cpp/src/arrow/result.h index 0172a852434..6b68d70ef5c 100644 --- a/cpp/src/arrow/result.h +++ b/cpp/src/arrow/result.h @@ -496,9 +496,9 @@ inline Status GenericToStatus(Result&& res) { } // namespace internal -template -Result ToResult(T t) { - return Result(std::move(t)); +template ::type> +R ToResult(T t) { + return R(std::move(t)); } template diff --git a/cpp/src/arrow/util/async_generator.h b/cpp/src/arrow/util/async_generator.h index 084720f9908..29635ec5bd2 100644 --- a/cpp/src/arrow/util/async_generator.h +++ b/cpp/src/arrow/util/async_generator.h @@ -259,43 +259,27 @@ class MappingGenerator { /// Note: Errors returned from the `map` function will be propagated /// /// If the source generator is async-reentrant then this generator will be also -template -AsyncGenerator MakeMappedGenerator(AsyncGenerator source_generator, - std::function(const T&)> map) { - std::function(const T&)> future_map = [map](const T& val) -> Future { - return Future::MakeFinished(map(val)); - }; - return MappingGenerator(std::move(source_generator), std::move(future_map)); -} -template -AsyncGenerator MakeMappedGenerator(AsyncGenerator source_generator, - std::function map) { - std::function(const T&)> maybe_future_map = [map](const T& val) -> Future { - return Future::MakeFinished(map(val)); - }; - return MappingGenerator(std::move(source_generator), std::move(maybe_future_map)); -} -template -AsyncGenerator MakeMappedGenerator(AsyncGenerator source_generator, - std::function(const T&)> map) { - return MappingGenerator(std::move(source_generator), std::move(map)); -} - -template -AsyncGenerator MakeMappedGenerator(AsyncGenerator source_generator, MapFunc map) { +template , + typename V = typename EnsureFuture::type::ValueType> +AsyncGenerator MakeMappedGenerator(AsyncGenerator source_generator, MapFn map) { struct MapCallback { - MapFunc map; + MapFn map_; - Future operator()(const T& val) { return EnsureFuture(map(val)); } + Future operator()(const T& val) { return EnsureFuture(map_(val)); } + + Future EnsureFuture(V mapped) { + return Future::MakeFinished(std::move(mapped)); + } - Future EnsureFuture(Result val) { - return Future::MakeFinished(std::move(val)); + Future EnsureFuture(Result mapped) { + return Future::MakeFinished(std::move(mapped)); } - Future EnsureFuture(V val) { return Future::MakeFinished(std::move(val)); } - Future EnsureFuture(Future val) { return val; } + + Future EnsureFuture(Future mapped) { return mapped; } }; - std::function(const T&)> map_fn = MapCallback{map}; - return MappingGenerator(std::move(source_generator), map_fn); + + return MappingGenerator(std::move(source_generator), MapCallback{std::move(map)}); } /// \see MakeSequencingGenerator @@ -308,30 +292,28 @@ class SequencingGenerator { std::move(is_next), std::move(initial_value))) {} Future operator()() { - { - auto guard = state_->mutex.Lock(); - // We can send a result immediately if the top of the queue is either an - // error or the next item - if (!state_->queue.empty() && - (!state_->queue.top().ok() || - state_->is_next(state_->previous_value, *state_->queue.top()))) { - auto result = std::move(state_->queue.top()); - if (result.ok()) { - state_->previous_value = *result; - } - state_->queue.pop(); - return Future::MakeFinished(result); - } - if (state_->finished) { - return AsyncGeneratorEnd(); + auto guard = state_->mutex.Lock(); + // We can send a result immediately if the top of the queue is either an + // error or the next item + if (!state_->queue.empty() && + (!state_->queue.top().ok() || + state_->is_next(state_->previous_value, *state_->queue.top()))) { + auto result = std::move(state_->queue.top()); + if (result.ok()) { + state_->previous_value = *result; } - // The next item is not in the queue so we will need to wait - auto new_waiting_fut = Future::Make(); - state_->waiting_future = new_waiting_fut; - guard.Unlock(); - state_->source().AddCallback(Callback{state_}); - return new_waiting_fut; + state_->queue.pop(); + return Future::MakeFinished(result); } + if (state_->finished) { + return AsyncGeneratorEnd(); + } + // The next item is not in the queue so we will need to wait + auto new_waiting_fut = Future::Make(); + state_->waiting_future = new_waiting_fut; + guard.Unlock(); + state_->source().AddCallback(Callback{state_}); + return new_waiting_fut; } private: @@ -365,11 +347,8 @@ class SequencingGenerator { util::Mutex mutex; }; - class Callback { - public: - explicit Callback(std::shared_ptr state) : state_(std::move(state)) {} - - void operator()(const Result result) { + struct Callback { + void operator()(const Result& result) { Future to_deliver; bool finished; { @@ -412,7 +391,6 @@ class SequencingGenerator { } } - private: const std::shared_ptr state_; }; @@ -1150,17 +1128,17 @@ class EnumeratingGenerator { Future> operator()() { if (state_->finished) { return AsyncGeneratorEnd>(); - } else { - auto state = state_; - return state->source().Then([state](const T& next) { - auto finished = IsIterationEnd(next); - auto prev = Enumerated{state->prev_value, state->prev_index, finished}; - state->prev_value = next; - state->prev_index++; - state->finished = finished; - return prev; - }); } + + auto state = state_; + return state->source().Then([state](const T& next) { + auto finished = IsIterationEnd(next); + auto prev = Enumerated{state->prev_value, state->prev_index, finished}; + state->prev_value = next; + state->prev_index++; + state->finished = finished; + return prev; + }); } private: @@ -1247,27 +1225,27 @@ class BackgroundGenerator { Future operator()() { auto guard = state_->mutex.Lock(); - Future waiting_future; - if (state_->queue.empty()) { - if (state_->finished) { - return AsyncGeneratorEnd(); - } else { - waiting_future = Future::Make(); - state_->waiting_future = waiting_future; - } - } else { + if (!state_->queue.empty()) { auto next = Future::MakeFinished(std::move(state_->queue.front())); state_->queue.pop(); + if (state_->NeedsRestart()) { return state_->RestartTask(state_, std::move(guard), std::move(next)); } return next; } + + if (state_->finished) { + return AsyncGeneratorEnd(); + } + + state_->waiting_future = Future::Make(); + // This should only trigger the very first time this method is called if (state_->NeedsRestart()) { - return state_->RestartTask(state_, std::move(guard), std::move(waiting_future)); + return state_->RestartTask(state_, std::move(guard), *state_->waiting_future); } - return waiting_future; + return *state_->waiting_future; } protected: @@ -1298,22 +1276,23 @@ class BackgroundGenerator { // task_finished future for it state->task_finished = Future<>::Make(); state->reading = true; + auto spawn_status = io_executor->Spawn( [state]() { BackgroundGenerator::WorkerTask(std::move(state)); }); - if (!spawn_status.ok()) { - // If we can't spawn a new task then send an error to the consumer (either via a - // waiting future or the queue) and mark ourselves finished - state->finished = true; - state->task_finished = Future<>(); - if (waiting_future.has_value()) { - auto to_deliver = std::move(waiting_future.value()); - waiting_future.reset(); - guard.Unlock(); - to_deliver.MarkFinished(spawn_status); - } else { - ClearQueue(); - queue.push(spawn_status); - } + if (spawn_status.ok()) return; + + // If we can't spawn a new task then send an error to the consumer (either via a + // waiting future or the queue) and mark ourselves finished + state->finished = true; + state->task_finished = Future<>(); + if (waiting_future.has_value()) { + auto to_deliver = std::move(waiting_future.value()); + waiting_future.reset(); + guard.Unlock(); + to_deliver.MarkFinished(spawn_status); + } else { + ClearQueue(); + queue.push(spawn_status); } } @@ -1404,7 +1383,7 @@ class BackgroundGenerator { break; } - if (!next.ok() || IsIterationEnd(*next)) { + if (IsIterationEnd(next)) { // Terminal item. Mark finished to true, send this last item, and quit state->finished = true; if (!next.ok()) { @@ -1431,7 +1410,7 @@ class BackgroundGenerator { // callbacks off of this thread so we can continue looping. Still, best not to // rely on that if (waiting_future.is_valid()) { - waiting_future.MarkFinished(next); + waiting_future.MarkFinished(std::move(next)); } } // Once we've sent our last item we can notify any waiters that we are done and so diff --git a/cpp/src/arrow/util/async_generator_test.cc b/cpp/src/arrow/util/async_generator_test.cc index 14b528ade5e..61efb0043be 100644 --- a/cpp/src/arrow/util/async_generator_test.cc +++ b/cpp/src/arrow/util/async_generator_test.cc @@ -51,7 +51,7 @@ AsyncGenerator FailsAt(AsyncGenerator src, int failing_index) { template AsyncGenerator SlowdownABit(AsyncGenerator source) { - return MakeMappedGenerator(std::move(source), [](const T& res) -> Future { + return MakeMappedGenerator(std::move(source), [](const T& res) { return SleepABitAsync().Then([res]() { return res; }); }); } diff --git a/cpp/src/arrow/util/future.h b/cpp/src/arrow/util/future.h index d08c598a32b..741550b3e64 100644 --- a/cpp/src/arrow/util/future.h +++ b/cpp/src/arrow/util/future.h @@ -36,6 +36,9 @@ namespace arrow { +template +struct EnsureFuture; + namespace detail { template @@ -976,4 +979,19 @@ Future Loop(Iterate iterate) { return break_fut; } +template +struct EnsureFuture { + using type = Future; +}; + +template +struct EnsureFuture> { + using type = Future; +}; + +template +struct EnsureFuture> { + using type = Future; +}; + } // namespace arrow diff --git a/cpp/src/arrow/util/iterator.h b/cpp/src/arrow/util/iterator.h index b82021e4b21..c84cdb21e03 100644 --- a/cpp/src/arrow/util/iterator.h +++ b/cpp/src/arrow/util/iterator.h @@ -66,6 +66,12 @@ bool IsIterationEnd(const T& val) { return IterationTraits::IsEnd(val); } +template +bool IsIterationEnd(const Result& maybe_val) { + if (!maybe_val.ok()) return true; + return IsIterationEnd(*maybe_val); +} + template struct IterationTraits> { /// \brief by default when iterating through a sequence of optional, From a367f6031c7b9e2d5e3f7f4b12b05ca7b4bb8200 Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Wed, 26 May 2021 10:47:36 -0400 Subject: [PATCH 02/28] use Loop in GeneratorNode --- cpp/src/arrow/compute/exec/exec_plan.cc | 91 +++++++++++++------------ 1 file changed, 48 insertions(+), 43 deletions(-) diff --git a/cpp/src/arrow/compute/exec/exec_plan.cc b/cpp/src/arrow/compute/exec/exec_plan.cc index cd14bcc2d82..49947305456 100644 --- a/cpp/src/arrow/compute/exec/exec_plan.cc +++ b/cpp/src/arrow/compute/exec/exec_plan.cc @@ -215,10 +215,49 @@ struct GeneratorNode : ExecNode { void InputFinished(ExecNode*, int) override { DCHECK(false); } Status StartProducing() override { - if (!generator_) { + if (finished_) { return Status::Invalid("Restarted GeneratorNode '", label(), "'"); } - GenerateOne(std::unique_lock{mutex_}); + + auto gen = std::move(generator_); + + /// XXX should we wait on this future anywhere? In StopProducing() maybe? + auto done_fut = + Loop([gen, this] { + std::unique_lock lock(mutex_); + int seq = next_batch_index_++; + lock.unlock(); + + return gen().Then( + [=](const util::optional& batch) -> ControlFlow { + std::unique_lock lock(mutex_); + if (!batch || finished_) { + finished_ = true; + return Break(seq); + } + lock.unlock(); + + outputs_[0]->InputReceived(this, seq, *batch); + return Continue(); + }, + [=](const Status& error) -> ControlFlow { + std::unique_lock lock(mutex_); + if (!finished_) { + finished_ = true; + lock.unlock(); + // unless we were already finished, push the error to our output + // XXX is this correct? Is it reasonable for a consumer to ignore errors + // from a finished producer? + outputs_[0]->ErrorReceived(this, error); + } + return Break(seq); + }); + }).Then([&](int seq) { + /// XXX this is probably redundant: do we always call InputFinished after + /// ErrorReceived or will ErrorRecieved be sufficient? + outputs_[0]->InputFinished(this, seq); + }); + return Status::OK(); } @@ -229,53 +268,16 @@ struct GeneratorNode : ExecNode { void StopProducing(ExecNode* output) override { DCHECK_EQ(output, outputs_[0]); std::unique_lock lock(mutex_); - generator_ = nullptr; // null function + finished_ = true; } void StopProducing() override { StopProducing(outputs_[0]); } private: - void GenerateOne(std::unique_lock&& lock) { - if (!generator_) { - // Stopped - return; - } - - auto fut = generator_(); - const auto batch_index = next_batch_index_++; - - lock.unlock(); - fut.AddCallback([batch_index, this](const Result>& res) { - std::unique_lock lock(mutex_); - if (!res.ok()) { - for (auto out : outputs_) { - out->ErrorReceived(this, res.status()); - } - return; - } - - const auto& batch = *res; - if (IsIterationEnd(batch)) { - lock.unlock(); - for (auto out : outputs_) { - out->InputFinished(this, batch_index); - } - return; - } - - lock.unlock(); - for (auto out : outputs_) { - out->InputReceived(this, batch_index, compute::ExecBatch(*batch)); - } - lock.lock(); - - GenerateOne(std::move(lock)); - }); - } - std::mutex mutex_; + bool finished_{false}; + int next_batch_index_{0}; AsyncGenerator> generator_; - int next_batch_index_ = 0; }; ExecNode* MakeSourceNode(ExecPlan* plan, std::string label, @@ -340,7 +342,10 @@ struct FilterNode : ExecNode { inputs_[0]->StopProducing(this); } - Status StartProducing() override { return Status::OK(); } + Status StartProducing() override { + // XXX validate inputs_[0]->output_descr() against filter_ + return Status::OK(); + } void PauseProducing(ExecNode* output) override {} From a6615d9fb723c73e5dd2399fdb2c3430f4208ef4 Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Wed, 26 May 2021 12:02:08 -0400 Subject: [PATCH 03/28] Make CollectNode public (as SinkNode) --- cpp/src/arrow/compute/exec/exec_plan.cc | 27 +++-- cpp/src/arrow/compute/exec/plan_test.cc | 88 +++++++------- cpp/src/arrow/compute/exec/test_util.cc | 154 +----------------------- cpp/src/arrow/compute/exec/test_util.h | 15 +-- 4 files changed, 72 insertions(+), 212 deletions(-) diff --git a/cpp/src/arrow/compute/exec/exec_plan.cc b/cpp/src/arrow/compute/exec/exec_plan.cc index 49947305456..e8107ffc78c 100644 --- a/cpp/src/arrow/compute/exec/exec_plan.cc +++ b/cpp/src/arrow/compute/exec/exec_plan.cc @@ -201,22 +201,23 @@ Status ExecNode::Validate() const { return Status::OK(); } -struct GeneratorNode : ExecNode { - GeneratorNode(ExecPlan* plan, std::string label, ExecNode::BatchDescr output_descr, - AsyncGenerator> generator) +struct SourceNode : ExecNode { + SourceNode(ExecPlan* plan, std::string label, ExecNode::BatchDescr output_descr, + AsyncGenerator> generator) : ExecNode(plan, std::move(label), {}, {}, std::move(output_descr), /*num_outputs=*/1), generator_(std::move(generator)) {} - const char* kind_name() override { return "GeneratorNode"; } + const char* kind_name() override { return "SourceNode"; } - void InputReceived(ExecNode*, int, compute::ExecBatch) override { DCHECK(false); } - void ErrorReceived(ExecNode*, Status) override { DCHECK(false); } - void InputFinished(ExecNode*, int) override { DCHECK(false); } + static void NoInputs() { DCHECK(false) << "no inputs; this should never be called"; } + void InputReceived(ExecNode*, int, ExecBatch) override { NoInputs(); } + void ErrorReceived(ExecNode*, Status) override { NoInputs(); } + void InputFinished(ExecNode*, int) override { NoInputs(); } Status StartProducing() override { if (finished_) { - return Status::Invalid("Restarted GeneratorNode '", label(), "'"); + return Status::Invalid("Restarted SourceNode '", label(), "'"); } auto gen = std::move(generator_); @@ -283,8 +284,8 @@ struct GeneratorNode : ExecNode { ExecNode* MakeSourceNode(ExecPlan* plan, std::string label, ExecNode::BatchDescr output_descr, AsyncGenerator> generator) { - return plan->EmplaceNode(plan, std::move(label), std::move(output_descr), - std::move(generator)); + return plan->EmplaceNode(plan, std::move(label), std::move(output_descr), + std::move(generator)); } struct FilterNode : ExecNode { @@ -483,7 +484,11 @@ struct SinkNode : ExecNode { }; AsyncGenerator> MakeSinkNode(ExecNode* input, - std::string label); + std::string label) { + AsyncGenerator> out; + (void)input->plan()->EmplaceNode(input, std::move(label), &out); + return out; +} } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/exec/plan_test.cc b/cpp/src/arrow/compute/exec/plan_test.cc index 5f0d5441ed5..4421ed72969 100644 --- a/cpp/src/arrow/compute/exec/plan_test.cc +++ b/cpp/src/arrow/compute/exec/plan_test.cc @@ -20,6 +20,7 @@ #include #include +#include "arrow/compute/exec.h" #include "arrow/compute/exec/exec_plan.h" #include "arrow/compute/exec/expression.h" #include "arrow/compute/exec/test_util.h" @@ -29,6 +30,7 @@ #include "arrow/testing/random.h" #include "arrow/util/logging.h" #include "arrow/util/thread_pool.h" +#include "arrow/util/vector.h" namespace arrow { @@ -44,6 +46,22 @@ void AssertBatchesEqual(const RecordBatchVector& expected, } } +void AssertBatchesEqual(const std::vector>& expected, + const std::vector>& actual) { + ASSERT_EQ(expected.size(), actual.size()); + for (size_t i = 0; i < expected.size(); ++i) { + AssertBatchesEqual(*expected[i], *actual[i]); + } +} + +void AssertBatchesEqual(const RecordBatchVector& expected, + const std::vector>& actual) { + ASSERT_EQ(expected.size(), actual.size()); + for (size_t i = 0; i < expected.size(); ++i) { + AssertBatchesEqual(ExecBatch(*expected[i]), *actual[i]); + } +} + TEST(ExecPlanConstruction, Empty) { ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); @@ -243,43 +261,26 @@ class TestExecPlanExecution : public ::testing::Test { return batches; } - struct CollectorPlan { - std::shared_ptr plan; - RecordBatchCollectNode* sink; - }; - - Result MakeSourceSink(std::shared_ptr reader, - std::shared_ptr schema) { - ARROW_ASSIGN_OR_RAISE(auto plan, ExecPlan::Make()); - auto source = - MakeRecordBatchReaderNode(plan.get(), "source", reader, io_executor_.get()); - auto sink = MakeRecordBatchCollectNode(plan.get(), "sink", source, std::move(schema)); - return CollectorPlan{plan, sink}; - } - - Result MakeSourceSink(RecordBatchGenerator generator, - std::shared_ptr schema) { - ARROW_ASSIGN_OR_RAISE(auto plan, ExecPlan::Make()); - auto source = MakeRecordBatchReaderNode(plan.get(), "source", schema, generator, - io_executor_.get()); - auto sink = MakeRecordBatchCollectNode(plan.get(), "sink", source, std::move(schema)); - return CollectorPlan{plan, sink}; + Result>> StartAndCollect( + ExecPlan* plan, AsyncGenerator> gen) { + RETURN_NOT_OK(plan->Validate()); + RETURN_NOT_OK(plan->StartProducing()); + return CollectAsyncGenerator(gen).result(); } - Result MakeSourceSink(const RecordBatchVector& batches, - std::shared_ptr schema) { - ARROW_ASSIGN_OR_RAISE(auto reader, RecordBatchReader::Make(batches, schema)); - return MakeSourceSink(std::move(reader), std::move(schema)); + ExecNode* MakeSource(ExecPlan* plan, std::shared_ptr reader, + std::shared_ptr schema) { + return MakeRecordBatchReaderNode(plan, "source", reader, io_executor_.get()); } - Result StartAndCollect(ExecPlan* plan, - RecordBatchCollectNode* sink) { - RETURN_NOT_OK(plan->StartProducing()); - return CollectAsyncGenerator(sink->generator()).result(); + ExecNode* MakeSource(ExecPlan* plan, RecordBatchGenerator generator, + std::shared_ptr schema) { + return MakeRecordBatchReaderNode(plan, "source", schema, generator, + io_executor_.get()); } template - void TestSourceSink(RecordBatchReaderFactory reader_factory) { + void TestSourceSink(RecordBatchReaderFactory factory) { auto schema = ::arrow::schema({field("a", int32()), field("b", boolean())}); RecordBatchVector batches{ RecordBatchFromJSON(schema, R"([{"a": null, "b": true}, @@ -289,24 +290,27 @@ class TestExecPlanExecution : public ::testing::Test { {"a": 7, "b": false}])"), }; - ASSERT_OK_AND_ASSIGN(auto reader, reader_factory(batches, schema)); - ASSERT_OK_AND_ASSIGN(auto cp, MakeSourceSink(reader, schema)); - ASSERT_OK(cp.plan->Validate()); + ASSERT_OK_AND_ASSIGN(auto reader_or_gen, factory(batches, schema)); - ASSERT_OK_AND_ASSIGN(auto got_batches, StartAndCollect(cp.plan.get(), cp.sink)); + ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); + auto source = MakeSource(plan.get(), reader_or_gen, schema); + auto sink_gen = MakeSinkNode(source, "sink"); + + ASSERT_OK_AND_ASSIGN(auto got_batches, StartAndCollect(plan.get(), sink_gen)); AssertBatchesEqual(batches, got_batches); } template - void TestStressSourceSink(int num_batches, RecordBatchReaderFactory batch_factory) { + void TestStressSourceSink(int num_batches, RecordBatchReaderFactory factory) { auto schema = ::arrow::schema({field("a", int32()), field("b", boolean())}); auto batches = MakeRandomBatches(schema, num_batches); - ASSERT_OK_AND_ASSIGN(auto reader, batch_factory(batches, schema)); - ASSERT_OK_AND_ASSIGN(auto cp, MakeSourceSink(reader, schema)); - ASSERT_OK(cp.plan->Validate()); + ASSERT_OK_AND_ASSIGN(auto reader_or_gen, factory(batches, schema)); + ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); + auto source = MakeSource(plan.get(), reader_or_gen, schema); + auto sink_gen = MakeSinkNode(source, "sink"); - ASSERT_OK_AND_ASSIGN(auto got_batches, StartAndCollect(cp.plan.get(), cp.sink)); + ASSERT_OK_AND_ASSIGN(auto got_batches, StartAndCollect(plan.get(), sink_gen)); AssertBatchesEqual(batches, got_batches); } @@ -360,9 +364,11 @@ TEST_F(TestExecPlanExecution, SourceFilterSink) { auto filter = MakeFilterNode(source, "filter", predicate); - auto sink = MakeRecordBatchCollectNode(plan.get(), "sink", filter, schema); + auto sink_gen = MakeSinkNode(filter, "sink"); + + ASSERT_OK_AND_ASSIGN(auto got_batches, StartAndCollect(plan.get(), sink_gen)); - ASSERT_OK_AND_ASSIGN(auto got_batches, StartAndCollect(plan.get(), sink)); + ASSERT_EQ(got_batches.size(), 2); AssertBatchesEqual( { RecordBatchFromJSON(schema, R"([])"), diff --git a/cpp/src/arrow/compute/exec/test_util.cc b/cpp/src/arrow/compute/exec/test_util.cc index a64ab11c512..fda39a66d28 100644 --- a/cpp/src/arrow/compute/exec/test_util.cc +++ b/cpp/src/arrow/compute/exec/test_util.cc @@ -33,6 +33,7 @@ #include "arrow/compute/exec/exec_plan.h" #include "arrow/datum.h" #include "arrow/record_batch.h" +#include "arrow/testing/gtest_util.h" #include "arrow/type.h" #include "arrow/util/async_generator.h" #include "arrow/util/iterator.h" @@ -44,6 +45,11 @@ namespace arrow { using internal::Executor; namespace compute { + +void AssertBatchesEqual(const ExecBatch& expected, const ExecBatch& actual) { + ASSERT_THAT(actual.values, testing::ElementsAreArray(expected.values)); +} + namespace { // TODO expose this as `static ValueDescr::FromSchemaColumns`? @@ -124,147 +130,6 @@ struct DummyNode : ExecNode { bool started_ = false; }; -struct RecordBatchCollectNodeImpl : public RecordBatchCollectNode { - RecordBatchCollectNodeImpl(ExecPlan* plan, std::string label, ExecNode* input, - std::shared_ptr schema) - : RecordBatchCollectNode(plan, std::move(label), {input}, {"collected"}, {}, 0), - schema_(std::move(schema)) {} - - RecordBatchGenerator generator() override { return generator_; } - - const char* kind_name() override { return "RecordBatchReader"; } - - Status StartProducing() override { - num_received_ = 0; - num_emitted_ = 0; - emit_stop_ = -1; - stopped_ = false; - producer_.emplace(generator_.producer()); - return Status::OK(); - } - - // sink nodes have no outputs from which to feel backpressure - void ResumeProducing(ExecNode* output) override { - FAIL() << "no outputs; this should never be called"; - } - void PauseProducing(ExecNode* output) override { - FAIL() << "no outputs; this should never be called"; - } - void StopProducing(ExecNode* output) override { - FAIL() << "no outputs; this should never be called"; - } - - void StopProducing() override { - std::unique_lock lock(mutex_); - StopProducingUnlocked(); - } - - void InputReceived(ExecNode* input, int seq_num, ExecBatch exec_batch) override { - std::unique_lock lock(mutex_); - if (stopped_) return; - - auto maybe_batch = MakeBatch(std::move(exec_batch)); - if (!maybe_batch.ok()) { - lock.unlock(); - producer_->Push(std::move(maybe_batch)); - return; - } - - // TODO would be nice to factor this out in a ReorderQueue - auto batch = *std::move(maybe_batch); - if (seq_num <= static_cast(received_batches_.size())) { - received_batches_.resize(seq_num + 1, nullptr); - } - DCHECK_EQ(received_batches_[seq_num], nullptr); - received_batches_[seq_num] = std::move(batch); - ++num_received_; - - if (seq_num != num_emitted_) { - // Cannot emit yet as there is a hole at `num_emitted_` - DCHECK_GT(seq_num, num_emitted_); - DCHECK_EQ(received_batches_[num_emitted_], nullptr); - return; - } - if (num_received_ == emit_stop_) { - StopProducingUnlocked(); - } - - // Emit batches in order as far as possible - // First collect these batches, then unlock before producing. - const auto seq_start = seq_num; - while (seq_num < static_cast(received_batches_.size()) && - received_batches_[seq_num] != nullptr) { - ++seq_num; - } - DCHECK_GT(seq_num, seq_start); - // By moving the values now, we make sure another thread won't emit the same values - // below - RecordBatchVector to_emit( - std::make_move_iterator(received_batches_.begin() + seq_start), - std::make_move_iterator(received_batches_.begin() + seq_num)); - - lock.unlock(); - for (auto&& batch : to_emit) { - producer_->Push(std::move(batch)); - } - lock.lock(); - - DCHECK_EQ(seq_start, num_emitted_); // num_emitted_ wasn't bumped in the meantime - num_emitted_ = seq_num; - } - - void ErrorReceived(ExecNode* input, Status error) override { - // XXX do we care about properly sequencing the error? - producer_->Push(std::move(error)); - std::unique_lock lock(mutex_); - StopProducingUnlocked(); - } - - void InputFinished(ExecNode* input, int seq_stop) override { - std::unique_lock lock(mutex_); - DCHECK_GE(seq_stop, static_cast(received_batches_.size())); - received_batches_.reserve(seq_stop); - emit_stop_ = seq_stop; - if (emit_stop_ == num_received_) { - DCHECK_EQ(emit_stop_, num_emitted_); - StopProducingUnlocked(); - } - } - - private: - void StopProducingUnlocked() { - if (!stopped_) { - stopped_ = true; - producer_->Close(); - inputs_[0]->StopProducing(this); - } - } - - Result> MakeBatch(ExecBatch&& exec_batch) { - ArrayDataVector columns; - columns.reserve(exec_batch.values.size()); - for (auto&& value : exec_batch.values) { - if (!value.is_array()) { - return Status::TypeError("Expected array input"); - } - columns.push_back(std::move(value).array()); - } - return RecordBatch::Make(schema_, exec_batch.length, std::move(columns)); - } - - const std::shared_ptr schema_; - - std::mutex mutex_; - RecordBatchVector received_batches_; - int num_received_; - int num_emitted_; - int emit_stop_; - bool stopped_; - - PushGenerator> generator_; - util::optional>::Producer> producer_; -}; - AsyncGenerator> Wrap(RecordBatchGenerator gen, ::arrow::internal::Executor* io_executor) { return MakeMappedGenerator( @@ -302,12 +167,5 @@ ExecNode* MakeDummyNode(ExecPlan* plan, std::string label, std::vector schema) { - return plan->EmplaceNode(plan, std::move(label), input, - std::move(schema)); -} - } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/exec/test_util.h b/cpp/src/arrow/compute/exec/test_util.h index 40f2b572be0..543df257353 100644 --- a/cpp/src/arrow/compute/exec/test_util.h +++ b/cpp/src/arrow/compute/exec/test_util.h @@ -26,6 +26,7 @@ #include "arrow/record_batch.h" #include "arrow/testing/visibility.h" #include "arrow/util/async_generator.h" +#include "arrow/util/string_view.h" #include "arrow/util/type_fwd.h" namespace arrow { @@ -54,18 +55,8 @@ ExecNode* MakeRecordBatchReaderNode(ExecPlan* plan, std::string label, RecordBatchGenerator generator, ::arrow::internal::Executor* io_executor); -class RecordBatchCollectNode : public ExecNode { - public: - virtual RecordBatchGenerator generator() = 0; - - protected: - using ExecNode::ExecNode; -}; - -ARROW_TESTING_EXPORT -RecordBatchCollectNode* MakeRecordBatchCollectNode(ExecPlan* plan, std::string label, - ExecNode* input, - std::shared_ptr schema); +ARROW_TESTING_EXPORT void AssertBatchesEqual(const ExecBatch& expected, + const ExecBatch& actual); } // namespace compute } // namespace arrow From d996e050f98ee389a510cdb342a9d9ba863ccd71 Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Wed, 26 May 2021 12:32:31 -0400 Subject: [PATCH 04/28] add ProjectNode --- cpp/src/arrow/compute/exec/exec_plan.cc | 69 +++++++++++++++++++++++++ cpp/src/arrow/compute/exec/exec_plan.h | 7 +++ cpp/src/arrow/compute/exec/plan_test.cc | 44 ++++++++++++++++ 3 files changed, 120 insertions(+) diff --git a/cpp/src/arrow/compute/exec/exec_plan.cc b/cpp/src/arrow/compute/exec/exec_plan.cc index e8107ffc78c..317884d5490 100644 --- a/cpp/src/arrow/compute/exec/exec_plan.cc +++ b/cpp/src/arrow/compute/exec/exec_plan.cc @@ -368,6 +368,75 @@ ExecNode* MakeFilterNode(ExecNode* input, std::string label, Expression filter) std::move(filter)); } +struct ProjectNode : ExecNode { + ProjectNode(ExecNode* input, std::string label, std::vector exprs) + : ExecNode(input->plan(), std::move(label), {input}, {"target"}, + /*output_descr=*/{input->output_descr()}, + /*num_outputs=*/1), + exprs_(std::move(exprs)) {} + + const char* kind_name() override { return "ProjectNode"; } + + Result DoProject(const ExecBatch& target) { + // XXX get a non-default exec context + std::vector values{exprs_.size()}; + for (size_t i = 0; i < exprs_.size(); ++i) { + ARROW_ASSIGN_OR_RAISE(values[i], ExecuteScalarExpression(exprs_[i], target)); + } + return ExecBatch::Make(std::move(values)); + } + + void InputReceived(ExecNode* input, int seq, ExecBatch batch) override { + DCHECK_EQ(input, inputs_[0]); + + auto maybe_filtered = DoProject(std::move(batch)); + if (!maybe_filtered.ok()) { + outputs_[0]->ErrorReceived(this, maybe_filtered.status()); + inputs_[0]->StopProducing(this); + return; + } + + outputs_[0]->InputReceived(this, seq, maybe_filtered.MoveValueUnsafe()); + } + + void ErrorReceived(ExecNode* input, Status error) override { + DCHECK_EQ(input, inputs_[0]); + outputs_[0]->ErrorReceived(this, std::move(error)); + inputs_[0]->StopProducing(this); + } + + void InputFinished(ExecNode* input, int seq) override { + DCHECK_EQ(input, inputs_[0]); + outputs_[0]->InputFinished(this, seq); + inputs_[0]->StopProducing(this); + } + + Status StartProducing() override { + // XXX validate inputs_[0]->output_descr() against filter_ + return Status::OK(); + } + + void PauseProducing(ExecNode* output) override {} + + void ResumeProducing(ExecNode* output) override {} + + void StopProducing(ExecNode* output) override { + DCHECK_EQ(output, outputs_[0]); + inputs_[0]->StopProducing(this); + } + + void StopProducing() override { StopProducing(outputs_[0]); } + + private: + std::vector exprs_; +}; + +ExecNode* MakeProjectNode(ExecNode* input, std::string label, + std::vector exprs) { + return input->plan()->EmplaceNode(input, std::move(label), + std::move(exprs)); +} + struct SinkNode : ExecNode { SinkNode(ExecNode* input, std::string label, AsyncGenerator>* generator) diff --git a/cpp/src/arrow/compute/exec/exec_plan.h b/cpp/src/arrow/compute/exec/exec_plan.h index eabc34d6d04..c45c7b424b8 100644 --- a/cpp/src/arrow/compute/exec/exec_plan.h +++ b/cpp/src/arrow/compute/exec/exec_plan.h @@ -244,5 +244,12 @@ AsyncGenerator> MakeSinkNode(ExecNode* input, ARROW_EXPORT ExecNode* MakeFilterNode(ExecNode* input, std::string label, Expression filter); +/// \brief Make a node which executes expressions on input batches, producing new batches. +/// +/// Expressions must be bound; no field references will be looked up by name +ARROW_EXPORT +ExecNode* MakeProjectNode(ExecNode* input, std::string label, + std::vector exprs); + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/exec/plan_test.cc b/cpp/src/arrow/compute/exec/plan_test.cc index 4421ed72969..8668e281200 100644 --- a/cpp/src/arrow/compute/exec/plan_test.cc +++ b/cpp/src/arrow/compute/exec/plan_test.cc @@ -377,5 +377,49 @@ TEST_F(TestExecPlanExecution, SourceFilterSink) { got_batches); } +TEST_F(TestExecPlanExecution, SourceProjectSink) { + ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); + + const auto schema = ::arrow::schema({field("a", int32()), field("b", boolean())}); + RecordBatchVector batches{ + RecordBatchFromJSON(schema, R"([{"a": null, "b": true}, + {"a": 4, "b": false}])"), + RecordBatchFromJSON(schema, R"([{"a": 5, "b": null}, + {"a": 6, "b": false}, + {"a": 7, "b": false}])"), + }; + + ASSERT_OK_AND_ASSIGN(auto reader, RecordBatchReader::Make(std::move(batches), schema)); + + auto source = + MakeRecordBatchReaderNode(plan.get(), "source", reader, io_executor_.get()); + + std::vector exprs{ + not_(field_ref("b")), + call("add", {field_ref("a"), literal(1)}), + }; + for (auto& expr : exprs) { + ASSERT_OK_AND_ASSIGN(expr, expr.Bind(*schema)); + } + + auto projection = MakeProjectNode(source, "project", exprs); + + auto sink_gen = MakeSinkNode(projection, "sink"); + + ASSERT_OK_AND_ASSIGN(auto got_batches, StartAndCollect(plan.get(), sink_gen)); + + auto out_schema = ::arrow::schema({field("!b", boolean()), field("a + 1", int32())}); + ASSERT_EQ(got_batches.size(), 2); + AssertBatchesEqual( + { + RecordBatchFromJSON(out_schema, R"([{"!b": false, "a + 1": null}, + {"!b": true, "a + 1": 5}])"), + RecordBatchFromJSON(out_schema, R"([{"!b": null, "a + 1": 6}, + {"!b": true, "a + 1": 7}, + {"!b": true, "a + 1": 8}])"), + }, + got_batches); +} + } // namespace compute } // namespace arrow From f2d4626808a69ee671739b12dd562e00a4d78ef2 Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Wed, 26 May 2021 13:31:35 -0400 Subject: [PATCH 05/28] add sketch of ScanNode --- cpp/src/arrow/compute/exec/exec_plan.h | 5 ++--- cpp/src/arrow/dataset/scanner.cc | 26 ++++++++++++++++++++++++++ cpp/src/arrow/dataset/scanner.h | 6 ++++++ 3 files changed, 34 insertions(+), 3 deletions(-) diff --git a/cpp/src/arrow/compute/exec/exec_plan.h b/cpp/src/arrow/compute/exec/exec_plan.h index c45c7b424b8..ff55c631647 100644 --- a/cpp/src/arrow/compute/exec/exec_plan.h +++ b/cpp/src/arrow/compute/exec/exec_plan.h @@ -212,7 +212,7 @@ class ARROW_EXPORT ExecNode { virtual void StopProducing() = 0; protected: - ExecNode(ExecPlan* plan, std::string label, NodeVector inputs, + ExecNode(ExecPlan*, std::string label, NodeVector inputs, std::vector input_labels, BatchDescr output_descr, int num_outputs); @@ -229,8 +229,7 @@ class ARROW_EXPORT ExecNode { /// \brief Adapt an AsyncGenerator as a source node ARROW_EXPORT -ExecNode* MakeSourceNode(ExecPlan* plan, std::string label, - ExecNode::BatchDescr output_descr, +ExecNode* MakeSourceNode(ExecPlan*, std::string label, ExecNode::BatchDescr output_descr, AsyncGenerator>); /// \brief Add a sink node which forwards to an AsyncGenerator diff --git a/cpp/src/arrow/dataset/scanner.cc b/cpp/src/arrow/dataset/scanner.cc index eda409ea4ea..d86128ac1df 100644 --- a/cpp/src/arrow/dataset/scanner.cc +++ b/cpp/src/arrow/dataset/scanner.cc @@ -761,6 +761,32 @@ ScannerBuilder::ScannerBuilder(std::shared_ptr schema, DCHECK_OK(Filter(scan_options_->filter)); } +Result ScannerBuilder::MakeScanNode(compute::ExecPlan* plan) { + ARROW_ASSIGN_OR_RAISE(auto scanner, Finish()); + + ARROW_ASSIGN_OR_RAISE(auto unordered_gen, scanner->ScanBatchesUnorderedAsync()); + + auto schema = scanner->options()->projected_schema; + + auto gen = MakeMappedGenerator( + std::move(unordered_gen), + [schema](const EnumeratedRecordBatch& partial) + -> Result> { + // FIXME the batches are still being fully filtered/projected. Need to add an + // option to skip wrapping with FilterAndProjectScanTask + ARROW_ASSIGN_OR_RAISE( + util::optional batch, + compute::MakeExecBatch(*schema, partial.record_batch.value)); + return batch; + }); + + std::vector output_descr; + for (const auto& field : schema->fields()) { + output_descr.push_back(ValueDescr::Array(field->type())); + } + return MakeSourceNode(plan, "dataset_scan", output_descr, std::move(gen)); +} + namespace { class OneShotScanTask : public ScanTask { public: diff --git a/cpp/src/arrow/dataset/scanner.h b/cpp/src/arrow/dataset/scanner.h index 3ee18c5c049..cc8b7be3ee3 100644 --- a/cpp/src/arrow/dataset/scanner.h +++ b/cpp/src/arrow/dataset/scanner.h @@ -25,6 +25,7 @@ #include #include +#include "arrow/compute/exec/exec_plan.h" #include "arrow/compute/exec/expression.h" #include "arrow/dataset/dataset.h" #include "arrow/dataset/projector.h" @@ -398,6 +399,11 @@ class ARROW_DS_EXPORT ScannerBuilder { /// \brief Return the constructed now-immutable Scanner object Result> Finish(); + /// \brief Construct a source ExecNode which yields batches from a dataset scan. + /// + /// Does not construct associated filter or project nodes + Result MakeScanNode(compute::ExecPlan*); + const std::shared_ptr& schema() const; const std::shared_ptr& projected_schema() const; From 0351ccde1592ea0ffea1b522e87ca0061cc99c6e Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Thu, 27 May 2021 11:53:57 -0400 Subject: [PATCH 06/28] flesh out ScanNode, tag ExecBatches with guarantees --- cpp/src/arrow/compute/exec.h | 4 + cpp/src/arrow/compute/exec/exec_plan.cc | 23 ++- cpp/src/arrow/dataset/dataset_internal.h | 30 ++++ cpp/src/arrow/dataset/scanner.cc | 179 +++++++++++---------- cpp/src/arrow/dataset/scanner.h | 1 - cpp/src/arrow/dataset/scanner_test.cc | 195 +++++++++++++++++++++++ cpp/src/arrow/dataset/test_util.h | 19 --- cpp/src/arrow/result.h | 5 + 8 files changed, 345 insertions(+), 111 deletions(-) diff --git a/cpp/src/arrow/compute/exec.h b/cpp/src/arrow/compute/exec.h index d188d2d246d..49484383e0f 100644 --- a/cpp/src/arrow/compute/exec.h +++ b/cpp/src/arrow/compute/exec.h @@ -28,6 +28,7 @@ #include #include "arrow/array/data.h" +#include "arrow/compute/exec/expression.h" #include "arrow/datum.h" #include "arrow/memory_pool.h" #include "arrow/result.h" @@ -186,6 +187,9 @@ struct ARROW_EXPORT ExecBatch { /// ExecBatch::length is equal to the length of this array. std::shared_ptr selection_vector; + /// A predicate Expression guaranteed to evaluate to true for all rows in this batch. + Expression guarantee = literal(true); + /// The semantic length of the ExecBatch. When the values are all scalars, /// the length should be set to 1, otherwise the length is taken from the /// array values, except when there is a selection vector. When there is a diff --git a/cpp/src/arrow/compute/exec/exec_plan.cc b/cpp/src/arrow/compute/exec/exec_plan.cc index 317884d5490..085eccb6161 100644 --- a/cpp/src/arrow/compute/exec/exec_plan.cc +++ b/cpp/src/arrow/compute/exec/exec_plan.cc @@ -227,6 +227,9 @@ struct SourceNode : ExecNode { Loop([gen, this] { std::unique_lock lock(mutex_); int seq = next_batch_index_++; + if (finished_) { + return Future>::MakeFinished(Break(seq)); + } lock.unlock(); return gen().Then( @@ -298,8 +301,11 @@ struct FilterNode : ExecNode { const char* kind_name() override { return "FilterNode"; } Result DoFilter(const ExecBatch& target) { + ARROW_ASSIGN_OR_RAISE(Expression simplified_filter, + SimplifyWithGuarantee(filter_, target.guarantee)); + // XXX get a non-default exec context - ARROW_ASSIGN_OR_RAISE(Datum mask, ExecuteScalarExpression(filter_, target)); + ARROW_ASSIGN_OR_RAISE(Datum mask, ExecuteScalarExpression(simplified_filter, target)); if (mask.is_scalar()) { const auto& mask_scalar = mask.scalar_as(); @@ -328,6 +334,7 @@ struct FilterNode : ExecNode { return; } + maybe_filtered->guarantee = batch.guarantee; outputs_[0]->InputReceived(this, seq, maybe_filtered.MoveValueUnsafe()); } @@ -381,7 +388,10 @@ struct ProjectNode : ExecNode { // XXX get a non-default exec context std::vector values{exprs_.size()}; for (size_t i = 0; i < exprs_.size(); ++i) { - ARROW_ASSIGN_OR_RAISE(values[i], ExecuteScalarExpression(exprs_[i], target)); + ARROW_ASSIGN_OR_RAISE(Expression simplified_expr, + SimplifyWithGuarantee(exprs_[i], target.guarantee)); + + ARROW_ASSIGN_OR_RAISE(values[i], ExecuteScalarExpression(simplified_expr, target)); } return ExecBatch::Make(std::move(values)); } @@ -389,14 +399,15 @@ struct ProjectNode : ExecNode { void InputReceived(ExecNode* input, int seq, ExecBatch batch) override { DCHECK_EQ(input, inputs_[0]); - auto maybe_filtered = DoProject(std::move(batch)); - if (!maybe_filtered.ok()) { - outputs_[0]->ErrorReceived(this, maybe_filtered.status()); + auto maybe_projected = DoProject(std::move(batch)); + if (!maybe_projected.ok()) { + outputs_[0]->ErrorReceived(this, maybe_projected.status()); inputs_[0]->StopProducing(this); return; } - outputs_[0]->InputReceived(this, seq, maybe_filtered.MoveValueUnsafe()); + maybe_projected->guarantee = batch.guarantee; + outputs_[0]->InputReceived(this, seq, maybe_projected.MoveValueUnsafe()); } void ErrorReceived(ExecNode* input, Status error) override { diff --git a/cpp/src/arrow/dataset/dataset_internal.h b/cpp/src/arrow/dataset/dataset_internal.h index 4336f9c157e..952ad3e83ca 100644 --- a/cpp/src/arrow/dataset/dataset_internal.h +++ b/cpp/src/arrow/dataset/dataset_internal.h @@ -204,5 +204,35 @@ arrow::Result> GetFragmentScanOptions( return internal::checked_pointer_cast(source); } +class FragmentDataset : public Dataset { + public: + FragmentDataset(std::shared_ptr schema, FragmentVector fragments) + : Dataset(std::move(schema)), fragments_(std::move(fragments)) {} + + std::string type_name() const override { return "fragment"; } + + Result> ReplaceSchema( + std::shared_ptr schema) const override { + return std::make_shared(std::move(schema), fragments_); + } + + protected: + Result GetFragmentsImpl(compute::Expression predicate) override { + // TODO(ARROW-12891) Provide subtree pruning for any vector of fragments + FragmentVector fragments; + for (const auto& fragment : fragments_) { + ARROW_ASSIGN_OR_RAISE( + auto simplified_filter, + compute::SimplifyWithGuarantee(predicate, fragment->partition_expression())); + + if (simplified_filter.IsSatisfiable()) { + fragments.push_back(fragment); + } + } + return MakeVectorIterator(std::move(fragments)); + } + FragmentVector fragments_; +}; + } // namespace dataset } // namespace arrow diff --git a/cpp/src/arrow/dataset/scanner.cc b/cpp/src/arrow/dataset/scanner.cc index d86128ac1df..b39eefb29df 100644 --- a/cpp/src/arrow/dataset/scanner.cc +++ b/cpp/src/arrow/dataset/scanner.cc @@ -317,10 +317,6 @@ class ARROW_DS_EXPORT SyncScanner : public Scanner { SyncScanner(std::shared_ptr dataset, std::shared_ptr scan_options) : Scanner(std::move(scan_options)), dataset_(std::move(dataset)) {} - SyncScanner(std::shared_ptr fragment, - std::shared_ptr scan_options) - : Scanner(std::move(scan_options)), fragment_(std::move(fragment)) {} - Result ScanBatches() override; Result Scan() override; Status Scan(std::function visitor) override; @@ -337,8 +333,6 @@ class ARROW_DS_EXPORT SyncScanner : public Scanner { Result ScanInternal(); std::shared_ptr dataset_; - // TODO(ARROW-8065) remove fragment_ after a Dataset is constuctible from fragments - std::shared_ptr fragment_; }; Result SyncScanner::ScanBatches() { @@ -370,10 +364,6 @@ Result SyncScanner::ScanBatchesUnorderedAsync() } Result SyncScanner::GetFragments() { - if (fragment_ != nullptr) { - return MakeVectorIterator(FragmentVector{fragment_}); - } - // Transform Datasets in a flat Iterator. This // iterator is lazily constructed, i.e. Dataset::GetFragments is // not invoked until a Fragment is requested. @@ -442,14 +432,14 @@ class ARROW_DS_EXPORT AsyncScanner : public Scanner, namespace { inline Result DoFilterAndProjectRecordBatchAsync( - const std::shared_ptr& scanner, const EnumeratedRecordBatch& in) { - ARROW_ASSIGN_OR_RAISE(compute::Expression simplified_filter, - SimplifyWithGuarantee(scanner->options()->filter, - in.fragment.value->partition_expression())); + const std::shared_ptr& options, const EnumeratedRecordBatch& in) { + ARROW_ASSIGN_OR_RAISE( + compute::Expression simplified_filter, + SimplifyWithGuarantee(options->filter, in.fragment.value->partition_expression())); - const auto& schema = *scanner->options()->dataset_schema; + const auto& schema = *options->dataset_schema; - compute::ExecContext exec_context{scanner->options()->pool}; + compute::ExecContext exec_context{options->pool}; ARROW_ASSIGN_OR_RAISE(Datum mask, ExecuteScalarExpression(simplified_filter, schema, in.record_batch.value, &exec_context)); @@ -471,7 +461,7 @@ inline Result DoFilterAndProjectRecordBatchAsync( } ARROW_ASSIGN_OR_RAISE(compute::Expression simplified_projection, - SimplifyWithGuarantee(scanner->options()->projection, + SimplifyWithGuarantee(options->projection, in.fragment.value->partition_expression())); ARROW_ASSIGN_OR_RAISE( @@ -484,7 +474,7 @@ inline Result DoFilterAndProjectRecordBatchAsync( ARROW_ASSIGN_OR_RAISE( projected, MakeArrayFromScalar(*projected.scalar(), filtered.record_batch()->num_rows(), - scanner->options()->pool)); + options->pool)); } ARROW_ASSIGN_OR_RAISE(auto out, RecordBatch::FromStructArray(projected.array_as())); @@ -497,17 +487,16 @@ inline Result DoFilterAndProjectRecordBatchAsync( } inline EnumeratedRecordBatchGenerator FilterAndProjectRecordBatchAsync( - const std::shared_ptr& scanner, EnumeratedRecordBatchGenerator rbs) { - auto mapper = [scanner](const EnumeratedRecordBatch& in) { - return DoFilterAndProjectRecordBatchAsync(scanner, in); + const std::shared_ptr& options, EnumeratedRecordBatchGenerator rbs) { + auto mapper = [options](const EnumeratedRecordBatch& in) { + return DoFilterAndProjectRecordBatchAsync(options, in); }; return MakeMappedGenerator(std::move(rbs), mapper); } Result FragmentToBatches( - std::shared_ptr scanner, const Enumerated>& fragment, - const std::shared_ptr& options) { + const std::shared_ptr& options, bool filter_and_project = true) { ARROW_ASSIGN_OR_RAISE(auto batch_gen, fragment.value->ScanBatchesAsync(options)); auto enumerated_batch_gen = MakeEnumeratedGenerator(std::move(batch_gen)); @@ -518,27 +507,35 @@ Result FragmentToBatches( auto combined_gen = MakeMappedGenerator(enumerated_batch_gen, std::move(combine_fn)); - return FilterAndProjectRecordBatchAsync(scanner, std::move(combined_gen)); + if (filter_and_project) { + return FilterAndProjectRecordBatchAsync(options, std::move(combined_gen)); + } + return combined_gen; } Result> FragmentsToBatches( - std::shared_ptr scanner, FragmentGenerator fragment_gen) { + FragmentGenerator fragment_gen, const std::shared_ptr& options, + bool filter_and_project = true) { auto enumerated_fragment_gen = MakeEnumeratedGenerator(std::move(fragment_gen)); - return MakeMappedGenerator( - std::move(enumerated_fragment_gen), - [scanner](const Enumerated>& fragment) { - return FragmentToBatches(scanner, fragment, scanner->options()); - }); + return MakeMappedGenerator(std::move(enumerated_fragment_gen), + [=](const Enumerated>& fragment) { + return FragmentToBatches(fragment, options, + filter_and_project); + }); } Result>>> FragmentsToRowCount( - std::shared_ptr scanner, FragmentGenerator fragment_gen) { + FragmentGenerator fragment_gen, + std::shared_ptr options_with_projection) { // Must use optional to avoid breaking the pipeline on empty batches auto enumerated_fragment_gen = MakeEnumeratedGenerator(std::move(fragment_gen)); - auto options = std::make_shared(*scanner->options()); + + // Drop projection since we only need to count rows + auto options = std::make_shared(*options_with_projection); RETURN_NOT_OK(SetProjection(options.get(), std::vector())); + auto count_fragment_fn = - [scanner, options](const Enumerated>& fragment) + [options](const Enumerated>& fragment) -> Result>> { auto count_fut = fragment.value->CountRows(options->filter, options); return MakeFromFuture( @@ -550,8 +547,7 @@ Result>>> FragmentsToRowCo Future>::MakeFinished(val)); } // Slow path - ARROW_ASSIGN_OR_RAISE(auto batch_gen, - FragmentToBatches(scanner, fragment, options)); + ARROW_ASSIGN_OR_RAISE(auto batch_gen, FragmentToBatches(fragment, options)); auto count_fn = [](const EnumeratedRecordBatch& enumerated) -> util::optional { return enumerated.record_batch.value->num_rows(); @@ -563,6 +559,19 @@ Result>>> FragmentsToRowCo std::move(count_fragment_fn)); } +Result ScanBatchesUnorderedAsyncImpl( + const std::shared_ptr& options, FragmentGenerator fragment_gen, + internal::Executor* cpu_executor, bool filter_and_project = true) { + ARROW_ASSIGN_OR_RAISE( + auto batch_gen_gen, + FragmentsToBatches(std::move(fragment_gen), options, filter_and_project)); + auto batch_gen_gen_readahead = + MakeSerialReadaheadGenerator(std::move(batch_gen_gen), options->fragment_readahead); + auto merged_batch_gen = MakeMergedGenerator(std::move(batch_gen_gen_readahead), + options->fragment_readahead); + return MakeReadaheadGenerator(std::move(merged_batch_gen), options->fragment_readahead); +} + } // namespace Result AsyncScanner::GetFragments() const { @@ -596,16 +605,9 @@ Result AsyncScanner::ScanBatchesUnorderedAsync() Result AsyncScanner::ScanBatchesUnorderedAsync( internal::Executor* cpu_executor) { - auto self = shared_from_this(); ARROW_ASSIGN_OR_RAISE(auto fragment_gen, GetFragments()); - ARROW_ASSIGN_OR_RAISE(auto batch_gen_gen, - FragmentsToBatches(self, std::move(fragment_gen))); - auto batch_gen_gen_readahead = MakeSerialReadaheadGenerator( - std::move(batch_gen_gen), scan_options_->fragment_readahead); - auto merged_batch_gen = MakeMergedGenerator(std::move(batch_gen_gen_readahead), - scan_options_->fragment_readahead); - return MakeReadaheadGenerator(std::move(merged_batch_gen), - scan_options_->fragment_readahead); + return ScanBatchesUnorderedAsyncImpl(scan_options_, std::move(fragment_gen), + cpu_executor); } Result AsyncScanner::ScanBatchesAsync() { @@ -723,10 +725,9 @@ Future> AsyncScanner::ToTableAsync( } Result AsyncScanner::CountRows() { - auto self = shared_from_this(); ARROW_ASSIGN_OR_RAISE(auto fragment_gen, GetFragments()); ARROW_ASSIGN_OR_RAISE(auto count_gen_gen, - FragmentsToRowCount(self, std::move(fragment_gen))); + FragmentsToRowCount(std::move(fragment_gen), scan_options_)); auto count_gen = MakeConcatenatedGenerator(std::move(count_gen_gen)); int64_t total = 0; auto sum_fn = [&total](util::optional count) -> Status { @@ -744,9 +745,7 @@ ScannerBuilder::ScannerBuilder(std::shared_ptr dataset) ScannerBuilder::ScannerBuilder(std::shared_ptr dataset, std::shared_ptr scan_options) - : dataset_(std::move(dataset)), - fragment_(nullptr), - scan_options_(std::move(scan_options)) { + : dataset_(std::move(dataset)), scan_options_(std::move(scan_options)) { scan_options_->dataset_schema = dataset_->schema(); DCHECK_OK(Filter(scan_options_->filter)); } @@ -754,38 +753,9 @@ ScannerBuilder::ScannerBuilder(std::shared_ptr dataset, ScannerBuilder::ScannerBuilder(std::shared_ptr schema, std::shared_ptr fragment, std::shared_ptr scan_options) - : dataset_(nullptr), - fragment_(std::move(fragment)), - scan_options_(std::move(scan_options)) { - scan_options_->dataset_schema = std::move(schema); - DCHECK_OK(Filter(scan_options_->filter)); -} - -Result ScannerBuilder::MakeScanNode(compute::ExecPlan* plan) { - ARROW_ASSIGN_OR_RAISE(auto scanner, Finish()); - - ARROW_ASSIGN_OR_RAISE(auto unordered_gen, scanner->ScanBatchesUnorderedAsync()); - - auto schema = scanner->options()->projected_schema; - - auto gen = MakeMappedGenerator( - std::move(unordered_gen), - [schema](const EnumeratedRecordBatch& partial) - -> Result> { - // FIXME the batches are still being fully filtered/projected. Need to add an - // option to skip wrapping with FilterAndProjectScanTask - ARROW_ASSIGN_OR_RAISE( - util::optional batch, - compute::MakeExecBatch(*schema, partial.record_batch.value)); - return batch; - }); - - std::vector output_descr; - for (const auto& field : schema->fields()) { - output_descr.push_back(ValueDescr::Array(field->type())); - } - return MakeSourceNode(plan, "dataset_scan", output_descr, std::move(gen)); -} + : ScannerBuilder(std::make_shared( + std::move(schema), FragmentVector{std::move(fragment)}), + std::move(scan_options)) {} namespace { class OneShotScanTask : public ScanTask { @@ -913,10 +883,6 @@ Result> ScannerBuilder::Finish() { RETURN_NOT_OK(Project(scan_options_->dataset_schema->field_names())); } - if (dataset_ == nullptr) { - // AsyncScanner does not support this method of running. It may in the future - return std::make_shared(fragment_, scan_options_); - } if (scan_options_->use_async) { return std::make_shared(dataset_, scan_options_); } else { @@ -1134,5 +1100,48 @@ Result SyncScanner::CountRows() { return count; } +Result ScannerBuilder::MakeScanNode(compute::ExecPlan* plan) { + if (!scan_options_->use_async) { + return Status::NotImplemented("ScanNodes without asynchrony"); + } + + // using a generator for speculative forward compatibility with async fragment discovery + ARROW_ASSIGN_OR_RAISE(auto fragments_it, dataset_->GetFragments(scan_options_->filter)); + ARROW_ASSIGN_OR_RAISE(auto fragments_vec, fragments_it.ToVector()); + auto fragments_gen = MakeVectorGenerator(std::move(fragments_vec)); + + ARROW_ASSIGN_OR_RAISE(auto batch_gen, + ScanBatchesUnorderedAsyncImpl( + scan_options_, std::move(fragments_gen), + internal::GetCpuThreadPool(), /*filter_and_project=*/false)); + + const auto& schema = dataset_->schema(); + + auto gen = MakeMappedGenerator( + std::move(batch_gen), + [schema](const EnumeratedRecordBatch& partial) + -> Result> { + ARROW_ASSIGN_OR_RAISE( + util::optional batch, + compute::MakeExecBatch(*schema, partial.record_batch.value)); + + // TODO fragments may be able to attach more guarantees to batches than this, + // for example parquet's row group stats. + batch->guarantee = partial.fragment.value->partition_expression(); + + // tag rows with fragment- and batch-of-origin + batch->values.emplace_back(partial.fragment.index); + batch->values.emplace_back(partial.record_batch.index); + return batch; + }); + + std::vector output_descr; + for (const auto& field : schema->fields()) { + output_descr.push_back(ValueDescr::Array(field->type())); + } + + return MakeSourceNode(plan, "dataset_scan", std::move(output_descr), std::move(gen)); +} + } // namespace dataset } // namespace arrow diff --git a/cpp/src/arrow/dataset/scanner.h b/cpp/src/arrow/dataset/scanner.h index cc8b7be3ee3..d899cd82df0 100644 --- a/cpp/src/arrow/dataset/scanner.h +++ b/cpp/src/arrow/dataset/scanner.h @@ -409,7 +409,6 @@ class ARROW_DS_EXPORT ScannerBuilder { private: std::shared_ptr dataset_; - std::shared_ptr fragment_; std::shared_ptr scan_options_ = std::make_shared(); }; diff --git a/cpp/src/arrow/dataset/scanner_test.cc b/cpp/src/arrow/dataset/scanner_test.cc index 87fc2c902c3..2400d01092f 100644 --- a/cpp/src/arrow/dataset/scanner_test.cc +++ b/cpp/src/arrow/dataset/scanner_test.cc @@ -1087,5 +1087,200 @@ TEST(ScanOptions, TestMaterializedFields) { EXPECT_THAT(opts->MaterializedFields(), ElementsAre("i64", "i32")); } +TEST(ScanNode, Trivial) { + ASSERT_OK_AND_ASSIGN(auto plan, compute::ExecPlan::Make()); + + const auto schema = ::arrow::schema({field("a", int32()), field("b", boolean())}); + RecordBatchVector batches{ + RecordBatchFromJSON(schema, R"([{"a": null, "b": true}, + {"a": 4, "b": false}])"), + RecordBatchFromJSON(schema, R"([{"a": 5, "b": null}, + {"a": 6, "b": false}, + {"a": 7, "b": false}])"), + }; + + auto dataset = std::make_shared(schema, batches); + + ScannerBuilder scanner_builder(dataset); + ASSERT_OK(scanner_builder.UseAsync(true)); + ASSERT_OK_AND_ASSIGN(auto scan, scanner_builder.MakeScanNode(plan.get())); + auto sink_gen = MakeSinkNode(scan, "sink"); + ASSERT_OK(plan->Validate()); + ASSERT_OK(plan->StartProducing()); + + auto got_batches_fut = CollectAsyncGenerator(sink_gen); + ASSERT_OK_AND_ASSIGN(auto got_batches, got_batches_fut.result()); + + ASSERT_EQ(got_batches.size(), batches.size()); + for (size_t i = 0; i < batches.size(); ++i) { + SCOPED_TRACE("Batch " + std::to_string(i)); + const compute::ExecBatch& actual = *got_batches[i]; + const RecordBatch& expected = *batches[i]; + AssertDatumsEqual(expected.GetColumnByName("a"), actual[0], /*verbose=*/true); + AssertDatumsEqual(expected.GetColumnByName("b"), actual[1], /*verbose=*/true); + // InMemoryDataset(RecordBatchVector) produces a fragment wrapping each batch + AssertDatumsEqual(Datum(int(i)), actual[2], /*verbose=*/true); + AssertDatumsEqual(Datum(0), actual[3], /*verbose=*/true); + } +} + +TEST(ScanNode, FilteredOnVirtualColumn) { + ASSERT_OK_AND_ASSIGN(auto plan, compute::ExecPlan::Make()); + + const auto dataset_schema = ::arrow::schema({ + field("a", int32()), + field("b", boolean()), + field("c", int32()), + }); + const auto physical_schema = SchemaFromColumnNames(dataset_schema, {"a", "b"}); + RecordBatchVector batches{ + RecordBatchFromJSON(physical_schema, R"([{"a": null, "b": true}, + {"a": 4, "b": false}])"), + RecordBatchFromJSON(physical_schema, R"([{"a": 5, "b": null}, + {"a": 6, "b": false}, + {"a": 7, "b": false}])"), + }; + + auto dataset = std::make_shared( + dataset_schema, + FragmentVector{ + std::make_shared(physical_schema, batches, + equal(field_ref("c"), literal(23))), + std::make_shared(physical_schema, batches, + equal(field_ref("c"), literal(47))), + }); + + ScannerBuilder scanner_builder(dataset); + ASSERT_OK(scanner_builder.UseAsync(true)); + ASSERT_OK(scanner_builder.Filter(greater(field_ref("c"), literal(30)))); + ASSERT_OK_AND_ASSIGN(auto scan, scanner_builder.MakeScanNode(plan.get())); + auto sink_gen = MakeSinkNode(scan, "sink"); + ASSERT_OK(plan->Validate()); + ASSERT_OK(plan->StartProducing()); + + auto got_batches_fut = CollectAsyncGenerator(sink_gen); + ASSERT_OK_AND_ASSIGN(auto got_batches, got_batches_fut.result()); + + ASSERT_EQ(got_batches.size(), 2); + for (size_t i = 0; i < batches.size(); ++i) { + const compute::ExecBatch& actual = *got_batches[i]; + const RecordBatch& expected = *batches[i]; + AssertDatumsEqual(expected.GetColumnByName("a"), actual[0], /*verbose=*/true); + AssertDatumsEqual(expected.GetColumnByName("b"), actual[1], /*verbose=*/true); + + // Note: placeholder for partition field "c" + AssertDatumsEqual(Datum(std::make_shared()), actual[2], + /*verbose=*/true); + + // Only one fragment in this scan, its index is 0 + AssertDatumsEqual(Datum(0), actual[3], /*verbose=*/true); + AssertDatumsEqual(Datum(int(i)), actual[4], /*verbose=*/true); + + EXPECT_EQ(actual.guarantee, equal(field_ref("c"), literal(47))); + } +} + +TEST(ScanNode, FilteredOnPhysicalColumn) { + ASSERT_OK_AND_ASSIGN(auto plan, compute::ExecPlan::Make()); + + const auto dataset_schema = ::arrow::schema({ + field("a", int32()), + field("b", boolean()), + field("c", int32()), + }); + const auto physical_schema = SchemaFromColumnNames(dataset_schema, {"a", "b"}); + RecordBatchVector batches{ + RecordBatchFromJSON(physical_schema, R"([{"a": null, "b": true}, + {"a": 4, "b": false}])"), + RecordBatchFromJSON(physical_schema, R"([{"a": 5, "b": null}, + {"a": 6, "b": false}, + {"a": 7, "b": false}])"), + }; + + auto dataset = std::make_shared( + dataset_schema, + FragmentVector{ + std::make_shared(physical_schema, batches, + equal(field_ref("c"), literal(23))), + std::make_shared(physical_schema, batches, + equal(field_ref("c"), literal(47))), + }); + + ScannerBuilder scanner_builder(dataset); + ASSERT_OK(scanner_builder.UseAsync(true)); + ASSERT_OK(scanner_builder.Filter(greater(field_ref("a"), literal(4)))); + ASSERT_OK_AND_ASSIGN(auto scan, scanner_builder.MakeScanNode(plan.get())); + auto sink_gen = MakeSinkNode(scan, "sink"); + ASSERT_OK(plan->Validate()); + ASSERT_OK(plan->StartProducing()); + + auto got_batches_fut = CollectAsyncGenerator(sink_gen); + ASSERT_OK_AND_ASSIGN(auto got_batches, got_batches_fut.result()); + + // no filtering is performed by ScanNode: all batches will be yielded whole + ASSERT_EQ(got_batches.size(), batches.size() * 2); + for (size_t i = 0; i < got_batches.size(); ++i) { + const compute::ExecBatch& actual = *got_batches[i]; + const RecordBatch& expected = *batches[i % 2]; + AssertDatumsEqual(expected.GetColumnByName("a"), actual[0], /*verbose=*/true); + AssertDatumsEqual(expected.GetColumnByName("b"), actual[1], /*verbose=*/true); + AssertDatumsEqual(Datum(std::make_shared()), actual[2], + /*verbose=*/true); + + AssertDatumsEqual(Datum(int(i / 2)), actual[3], /*verbose=*/true); + AssertDatumsEqual(Datum(int(i % 2)), actual[4], /*verbose=*/true); + + EXPECT_EQ(actual.guarantee, equal(field_ref("c"), literal(i / 2 ? 47 : 23))) << i; + } +} + +TEST(ScanNode, ProjectPhysicalColumn) { + ASSERT_OK_AND_ASSIGN(auto plan, compute::ExecPlan::Make()); + + const auto dataset_schema = ::arrow::schema({ + field("a", int32()), + field("b", boolean()), + field("c", int32()), + }); + const auto physical_schema = SchemaFromColumnNames(dataset_schema, {"a", "b"}); + RecordBatchVector batches{ + RecordBatchFromJSON(physical_schema, R"([{"a": null, "b": true}, + {"a": 4, "b": false}])"), + RecordBatchFromJSON(physical_schema, R"([{"a": 5, "b": null}, + {"a": 6, "b": false}, + {"a": 7, "b": false}])"), + }; + + auto dataset = std::make_shared( + dataset_schema, + FragmentVector{ + std::make_shared(physical_schema, batches, + equal(field_ref("c"), literal(23))), + std::make_shared(physical_schema, batches, + equal(field_ref("c"), literal(47))), + }); + + ScannerBuilder scanner_builder(dataset); + ASSERT_OK(scanner_builder.UseAsync(true)); + ASSERT_OK_AND_ASSIGN(auto scan, scanner_builder.MakeScanNode(plan.get())); + auto project = compute::MakeProjectNode( + scan, "project", {field_ref("c").Bind(*dataset_schema).ValueOrDie()}); + auto sink_gen = MakeSinkNode(project, "sink"); + ASSERT_OK(plan->Validate()); + ASSERT_OK(plan->StartProducing()); + + auto got_batches_fut = CollectAsyncGenerator(sink_gen); + ASSERT_OK_AND_ASSIGN(auto got_batches, got_batches_fut.result()); + + // no filtering is performed by ScanNode: all batches will be yielded whole + ASSERT_EQ(got_batches.size(), batches.size() * 2); + for (size_t i = 0; i < got_batches.size(); ++i) { + const compute::ExecBatch& actual = *got_batches[i]; + Datum expected(i / 2 ? 47 : 23); + AssertDatumsEqual(expected, actual[0], /*verbose=*/true); + EXPECT_EQ(actual.guarantee, equal(field_ref("c"), literal(expected))) << i; + } +} + } // namespace dataset } // namespace arrow diff --git a/cpp/src/arrow/dataset/test_util.h b/cpp/src/arrow/dataset/test_util.h index e1d85d23342..aab97b9bb49 100644 --- a/cpp/src/arrow/dataset/test_util.h +++ b/cpp/src/arrow/dataset/test_util.h @@ -122,25 +122,6 @@ void EnsureRecordBatchReaderDrained(RecordBatchReader* reader) { EXPECT_EQ(batch, nullptr); } -/// Test dataset that returns one or more fragments. -class FragmentDataset : public Dataset { - public: - FragmentDataset(std::shared_ptr schema, FragmentVector fragments) - : Dataset(std::move(schema)), fragments_(std::move(fragments)) {} - - std::string type_name() const override { return "fragment"; } - - Result> ReplaceSchema(std::shared_ptr) const override { - return Status::NotImplemented(""); - } - - protected: - Result GetFragmentsImpl(compute::Expression predicate) override { - return MakeVectorIterator(fragments_); - } - FragmentVector fragments_; -}; - class DatasetFixtureMixin : public ::testing::Test { public: /// \brief Ensure that record batches found in reader are equals to the diff --git a/cpp/src/arrow/result.h b/cpp/src/arrow/result.h index 6b68d70ef5c..ee79077335e 100644 --- a/cpp/src/arrow/result.h +++ b/cpp/src/arrow/result.h @@ -478,6 +478,11 @@ class ARROW_MUST_USE_TYPE Result : public util::EqualityComparable> { /// /// WARNING: ARROW_ASSIGN_OR_RAISE `std::move`s its right operand. If you have /// an lvalue Result which you *don't* want to move out of cast appropriately. +/// +/// WARNING: ARROW_ASSIGN_OR_RAISE is not a single expression; it will not +/// maintain lifetimes of all temporaries in `rexpr` (e.g. +/// `ARROW_ASSIGN_OR_RAISE(auto x, MakeTemp().GetResultRef());` +/// will most likely segfault)! #define ARROW_ASSIGN_OR_RAISE(lhs, rexpr) \ ARROW_ASSIGN_OR_RAISE_IMPL(ARROW_ASSIGN_OR_RAISE_NAME(_error_or_value, __COUNTER__), \ lhs, rexpr); From e9468cdf83eec39f9fca9e6e742380f0fa32d668 Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Thu, 27 May 2021 15:57:29 -0400 Subject: [PATCH 07/28] add fast path for FieldRef.Name lookup in Schema --- cpp/src/arrow/compute/exec/expression.cc | 35 ++++++++++++++---------- cpp/src/arrow/compute/exec/expression.h | 5 +--- cpp/src/arrow/type.cc | 4 +++ 3 files changed, 26 insertions(+), 18 deletions(-) diff --git a/cpp/src/arrow/compute/exec/expression.cc b/cpp/src/arrow/compute/exec/expression.cc index 19a311a5d39..0e99c614d3f 100644 --- a/cpp/src/arrow/compute/exec/expression.cc +++ b/cpp/src/arrow/compute/exec/expression.cc @@ -378,43 +378,50 @@ Result BindNonRecursive(Expression::Call call, bool insert_implicit_ return Expression(std::move(call)); } -} // namespace - -Result Expression::Bind(ValueDescr in, - compute::ExecContext* exec_context) const { +template +Result BindImpl(Expression expr, const TypeOrSchema& in, + ValueDescr::Shape shape, compute::ExecContext* exec_context) { if (exec_context == nullptr) { compute::ExecContext exec_context; - return Bind(std::move(in), &exec_context); + return BindImpl(std::move(expr), in, shape, &exec_context); } - if (literal()) return *this; + if (expr.literal()) return expr; - if (auto ref = field_ref()) { + if (auto ref = expr.field_ref()) { if (ref->IsNested()) { return Status::NotImplemented("nested field references"); } - ARROW_ASSIGN_OR_RAISE(auto path, ref->FindOne(*in.type)); + ARROW_ASSIGN_OR_RAISE(auto path, ref->FindOne(in)); - auto bound = *parameter(); + auto bound = *expr.parameter(); bound.index = path[0]; - ARROW_ASSIGN_OR_RAISE(auto field, path.Get(*in.type)); + ARROW_ASSIGN_OR_RAISE(auto field, path.Get(in)); bound.descr.type = field->type(); - bound.descr.shape = in.shape; + bound.descr.shape = shape; return Expression{std::move(bound)}; } - auto call = *CallNotNull(*this); + auto call = *CallNotNull(expr); for (auto& argument : call.arguments) { - ARROW_ASSIGN_OR_RAISE(argument, argument.Bind(in, exec_context)); + ARROW_ASSIGN_OR_RAISE(argument, + BindImpl(std::move(argument), in, shape, exec_context)); } return BindNonRecursive(std::move(call), /*insert_implicit_casts=*/true, exec_context); } +} // namespace + +Result Expression::Bind(const ValueDescr& in, + compute::ExecContext* exec_context) const { + return BindImpl(*this, *in.type, in.shape, exec_context); +} + Result Expression::Bind(const Schema& in_schema, compute::ExecContext* exec_context) const { - return Bind(ValueDescr::Array(struct_(in_schema.fields())), exec_context); + return BindImpl(*this, in_schema, ValueDescr::ARRAY, exec_context); } Result MakeExecBatch(const Schema& full_schema, const Datum& partial) { diff --git a/cpp/src/arrow/compute/exec/expression.h b/cpp/src/arrow/compute/exec/expression.h index eb1dcfd3091..4b842ef38aa 100644 --- a/cpp/src/arrow/compute/exec/expression.h +++ b/cpp/src/arrow/compute/exec/expression.h @@ -64,12 +64,9 @@ class ARROW_EXPORT Expression { /// Bind this expression to the given input type, looking up Kernels and field types. /// Some expression simplification may be performed and implicit casts will be inserted. /// Any state necessary for execution will be initialized and returned. - Result Bind(ValueDescr in, ExecContext* = NULLPTR) const; + Result Bind(const ValueDescr& in, ExecContext* = NULLPTR) const; Result Bind(const Schema& in_schema, ExecContext* = NULLPTR) const; - Result BindFlattened(ValueDescr in, ExecContext* = NULLPTR) const; - Result BindFlattened(const Schema& in_schema, ExecContext* = NULLPTR) const; - // XXX someday // Clone all KernelState in this bound expression. If any function referenced by this // expression has mutable KernelState, it is not safe to execute or apply simplification diff --git a/cpp/src/arrow/type.cc b/cpp/src/arrow/type.cc index 65c783ce847..41914f43663 100644 --- a/cpp/src/arrow/type.cc +++ b/cpp/src/arrow/type.cc @@ -1195,6 +1195,10 @@ std::string FieldRef::ToString() const { } std::vector FieldRef::FindAll(const Schema& schema) const { + if (auto name = this->name()) { + return internal::MapVector([](int i) { return FieldPath{i}; }, + schema.GetAllFieldIndices(*name)); + } return FindAll(schema.fields()); } From 57264ee3b8cad5c048f75e08edeb18c400fd8ca5 Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Wed, 2 Jun 2021 12:58:20 -0400 Subject: [PATCH 08/28] remove seq reordering from SinkNode --- cpp/src/arrow/compute/exec/exec_plan.cc | 64 ++++++------------------- cpp/src/arrow/compute/exec/exec_plan.h | 6 ++- cpp/src/arrow/compute/exec/plan_test.cc | 24 ++++++++++ cpp/src/arrow/dataset/scanner_test.cc | 49 +++++++++---------- cpp/src/arrow/util/async_generator.h | 4 +- cpp/src/arrow/util/iterator.h | 6 --- 6 files changed, 69 insertions(+), 84 deletions(-) diff --git a/cpp/src/arrow/compute/exec/exec_plan.cc b/cpp/src/arrow/compute/exec/exec_plan.cc index 085eccb6161..1eef9183155 100644 --- a/cpp/src/arrow/compute/exec/exec_plan.cc +++ b/cpp/src/arrow/compute/exec/exec_plan.cc @@ -347,7 +347,6 @@ struct FilterNode : ExecNode { void InputFinished(ExecNode* input, int seq) override { DCHECK_EQ(input, inputs_[0]); outputs_[0]->InputFinished(this, seq); - inputs_[0]->StopProducing(this); } Status StartProducing() override { @@ -419,7 +418,6 @@ struct ProjectNode : ExecNode { void InputFinished(ExecNode* input, int seq) override { DCHECK_EQ(input, inputs_[0]); outputs_[0]->InputFinished(this, seq); - inputs_[0]->StopProducing(this); } Status StartProducing() override { @@ -450,14 +448,14 @@ ExecNode* MakeProjectNode(ExecNode* input, std::string label, struct SinkNode : ExecNode { SinkNode(ExecNode* input, std::string label, - AsyncGenerator>* generator) + AsyncGenerator>* generator) : ExecNode(input->plan(), std::move(label), {input}, {"collected"}, {}, /*num_outputs=*/0), producer_(MakeProducer(generator)) {} - static PushGenerator>::Producer MakeProducer( - AsyncGenerator>* out_gen) { - PushGenerator> gen; + static PushGenerator>::Producer MakeProducer( + AsyncGenerator>* out_gen) { + PushGenerator> gen; auto out = gen.producer(); *out_gen = std::move(gen); return out; @@ -478,54 +476,27 @@ struct SinkNode : ExecNode { StopProducingUnlocked(); } - void InputReceived(ExecNode* input, int seq_num, ExecBatch exec_batch) override { + void InputReceived(ExecNode* input, int seq_num, ExecBatch batch) override { + DCHECK_EQ(input, inputs_[0]); + std::unique_lock lock(mutex_); if (stopped_) return; - // TODO would be nice to factor this out in a ReorderQueue - if (seq_num <= static_cast(received_batches_.size())) { - received_batches_.resize(seq_num + 1); - emitted_.resize(seq_num + 1, false); - } - received_batches_[seq_num] = std::move(exec_batch); ++num_received_; - - if (seq_num != num_emitted_) { - // Cannot emit yet as there is a hole at `num_emitted_` - DCHECK_GT(seq_num, num_emitted_); - return; - } - if (num_received_ == emit_stop_) { StopProducingUnlocked(); } - // Emit batches in order as far as possible - // First collect these batches, then unlock before producing. - const auto seq_start = seq_num; - while (seq_num < static_cast(emitted_.size()) && !emitted_[seq_num]) { - emitted_[seq_num] = true; - ++seq_num; + if (emit_stop_ != -1) { + DCHECK_LE(seq_num, emit_stop_); } - DCHECK_GT(seq_num, seq_start); - // By moving the values now, we make sure another thread won't emit the same values - // below - std::vector to_emit( - std::make_move_iterator(received_batches_.begin() + seq_start), - std::make_move_iterator(received_batches_.begin() + seq_num)); - lock.unlock(); - for (auto&& batch : to_emit) { - producer_.Push(std::move(batch)); - } - lock.lock(); - DCHECK_EQ(seq_start, num_emitted_); // num_emitted_ wasn't bumped in the meantime - num_emitted_ = seq_num; + producer_.Push(Enumerated{std::move(batch), seq_num, false}); } void ErrorReceived(ExecNode* input, Status error) override { - // XXX do we care about properly sequencing the error? + DCHECK_EQ(input, inputs_[0]); producer_.Push(std::move(error)); std::unique_lock lock(mutex_); StopProducingUnlocked(); @@ -533,11 +504,8 @@ struct SinkNode : ExecNode { void InputFinished(ExecNode* input, int seq_stop) override { std::unique_lock lock(mutex_); - DCHECK_GE(seq_stop, static_cast(received_batches_.size())); - received_batches_.reserve(seq_stop); emit_stop_ = seq_stop; if (emit_stop_ == num_received_) { - DCHECK_EQ(emit_stop_, num_emitted_); StopProducingUnlocked(); } } @@ -552,20 +520,16 @@ struct SinkNode : ExecNode { } std::mutex mutex_; - std::vector received_batches_; - std::vector emitted_; int num_received_ = 0; - int num_emitted_ = 0; int emit_stop_ = -1; bool stopped_ = false; - PushGenerator>::Producer producer_; + PushGenerator>::Producer producer_; }; -AsyncGenerator> MakeSinkNode(ExecNode* input, - std::string label) { - AsyncGenerator> out; +AsyncGenerator> MakeSinkNode(ExecNode* input, std::string label) { + AsyncGenerator> out; (void)input->plan()->EmplaceNode(input, std::move(label), &out); return out; } diff --git a/cpp/src/arrow/compute/exec/exec_plan.h b/cpp/src/arrow/compute/exec/exec_plan.h index ff55c631647..933cc9b50b6 100644 --- a/cpp/src/arrow/compute/exec/exec_plan.h +++ b/cpp/src/arrow/compute/exec/exec_plan.h @@ -233,9 +233,11 @@ ExecNode* MakeSourceNode(ExecPlan*, std::string label, ExecNode::BatchDescr outp AsyncGenerator>); /// \brief Add a sink node which forwards to an AsyncGenerator +/// +/// Emitted batches will not be ordered; instead they will be tagged with the `seq` at +/// which they were received. ARROW_EXPORT -AsyncGenerator> MakeSinkNode(ExecNode* input, - std::string label); +AsyncGenerator> MakeSinkNode(ExecNode* input, std::string label); /// \brief Make a node which excludes some rows from batches passed through it /// diff --git a/cpp/src/arrow/compute/exec/plan_test.cc b/cpp/src/arrow/compute/exec/plan_test.cc index 8668e281200..205c0d8ff46 100644 --- a/cpp/src/arrow/compute/exec/plan_test.cc +++ b/cpp/src/arrow/compute/exec/plan_test.cc @@ -62,6 +62,14 @@ void AssertBatchesEqual(const RecordBatchVector& expected, } } +void AssertBatchesEqual(const RecordBatchVector& expected, + const std::vector& actual) { + ASSERT_EQ(expected.size(), actual.size()); + for (size_t i = 0; i < expected.size(); ++i) { + AssertBatchesEqual(ExecBatch(*expected[i]), actual[i]); + } +} + TEST(ExecPlanConstruction, Empty) { ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); @@ -268,6 +276,22 @@ class TestExecPlanExecution : public ::testing::Test { return CollectAsyncGenerator(gen).result(); } + Result> StartAndCollect( + ExecPlan* plan, AsyncGenerator> gen) { + RETURN_NOT_OK(plan->Validate()); + RETURN_NOT_OK(plan->StartProducing()); + + auto maybe_collected = CollectAsyncGenerator(gen).result(); + ARROW_ASSIGN_OR_RAISE(auto collected, maybe_collected); + + std::sort(collected.begin(), collected.end(), + [](const Enumerated& l, const Enumerated& r) { + return l.index < r.index; + }); + return internal::MapVector( + [](Enumerated batch) { return std::move(batch.value); }, collected); + } + ExecNode* MakeSource(ExecPlan* plan, std::shared_ptr reader, std::shared_ptr schema) { return MakeRecordBatchReaderNode(plan, "source", reader, io_executor_.get()); diff --git a/cpp/src/arrow/dataset/scanner_test.cc b/cpp/src/arrow/dataset/scanner_test.cc index 2400d01092f..e1ae813bee4 100644 --- a/cpp/src/arrow/dataset/scanner_test.cc +++ b/cpp/src/arrow/dataset/scanner_test.cc @@ -34,6 +34,7 @@ #include "arrow/testing/gtest_util.h" #include "arrow/testing/util.h" #include "arrow/util/range.h" +#include "arrow/util/vector.h" using testing::ElementsAre; using testing::IsEmpty; @@ -1087,6 +1088,22 @@ TEST(ScanOptions, TestMaterializedFields) { EXPECT_THAT(opts->MaterializedFields(), ElementsAre("i64", "i32")); } +static Result> StartAndCollect( + compute::ExecPlan* plan, AsyncGenerator> gen) { + RETURN_NOT_OK(plan->Validate()); + RETURN_NOT_OK(plan->StartProducing()); + + auto maybe_collected = CollectAsyncGenerator(gen).result(); + ARROW_ASSIGN_OR_RAISE(auto collected, maybe_collected); + + std::sort(collected.begin(), collected.end(), + [](const Enumerated& l, + const Enumerated& r) { return l.index < r.index; }); + return internal::MapVector( + [](Enumerated batch) { return std::move(batch.value); }, + collected); +} + TEST(ScanNode, Trivial) { ASSERT_OK_AND_ASSIGN(auto plan, compute::ExecPlan::Make()); @@ -1105,16 +1122,12 @@ TEST(ScanNode, Trivial) { ASSERT_OK(scanner_builder.UseAsync(true)); ASSERT_OK_AND_ASSIGN(auto scan, scanner_builder.MakeScanNode(plan.get())); auto sink_gen = MakeSinkNode(scan, "sink"); - ASSERT_OK(plan->Validate()); - ASSERT_OK(plan->StartProducing()); - - auto got_batches_fut = CollectAsyncGenerator(sink_gen); - ASSERT_OK_AND_ASSIGN(auto got_batches, got_batches_fut.result()); + ASSERT_OK_AND_ASSIGN(auto got_batches, StartAndCollect(plan.get(), sink_gen)); ASSERT_EQ(got_batches.size(), batches.size()); for (size_t i = 0; i < batches.size(); ++i) { SCOPED_TRACE("Batch " + std::to_string(i)); - const compute::ExecBatch& actual = *got_batches[i]; + const compute::ExecBatch& actual = got_batches[i]; const RecordBatch& expected = *batches[i]; AssertDatumsEqual(expected.GetColumnByName("a"), actual[0], /*verbose=*/true); AssertDatumsEqual(expected.GetColumnByName("b"), actual[1], /*verbose=*/true); @@ -1155,15 +1168,11 @@ TEST(ScanNode, FilteredOnVirtualColumn) { ASSERT_OK(scanner_builder.Filter(greater(field_ref("c"), literal(30)))); ASSERT_OK_AND_ASSIGN(auto scan, scanner_builder.MakeScanNode(plan.get())); auto sink_gen = MakeSinkNode(scan, "sink"); - ASSERT_OK(plan->Validate()); - ASSERT_OK(plan->StartProducing()); - - auto got_batches_fut = CollectAsyncGenerator(sink_gen); - ASSERT_OK_AND_ASSIGN(auto got_batches, got_batches_fut.result()); + ASSERT_OK_AND_ASSIGN(auto got_batches, StartAndCollect(plan.get(), sink_gen)); ASSERT_EQ(got_batches.size(), 2); for (size_t i = 0; i < batches.size(); ++i) { - const compute::ExecBatch& actual = *got_batches[i]; + const compute::ExecBatch& actual = got_batches[i]; const RecordBatch& expected = *batches[i]; AssertDatumsEqual(expected.GetColumnByName("a"), actual[0], /*verbose=*/true); AssertDatumsEqual(expected.GetColumnByName("b"), actual[1], /*verbose=*/true); @@ -1211,16 +1220,12 @@ TEST(ScanNode, FilteredOnPhysicalColumn) { ASSERT_OK(scanner_builder.Filter(greater(field_ref("a"), literal(4)))); ASSERT_OK_AND_ASSIGN(auto scan, scanner_builder.MakeScanNode(plan.get())); auto sink_gen = MakeSinkNode(scan, "sink"); - ASSERT_OK(plan->Validate()); - ASSERT_OK(plan->StartProducing()); - - auto got_batches_fut = CollectAsyncGenerator(sink_gen); - ASSERT_OK_AND_ASSIGN(auto got_batches, got_batches_fut.result()); + ASSERT_OK_AND_ASSIGN(auto got_batches, StartAndCollect(plan.get(), sink_gen)); // no filtering is performed by ScanNode: all batches will be yielded whole ASSERT_EQ(got_batches.size(), batches.size() * 2); for (size_t i = 0; i < got_batches.size(); ++i) { - const compute::ExecBatch& actual = *got_batches[i]; + const compute::ExecBatch& actual = got_batches[i]; const RecordBatch& expected = *batches[i % 2]; AssertDatumsEqual(expected.GetColumnByName("a"), actual[0], /*verbose=*/true); AssertDatumsEqual(expected.GetColumnByName("b"), actual[1], /*verbose=*/true); @@ -1266,16 +1271,12 @@ TEST(ScanNode, ProjectPhysicalColumn) { auto project = compute::MakeProjectNode( scan, "project", {field_ref("c").Bind(*dataset_schema).ValueOrDie()}); auto sink_gen = MakeSinkNode(project, "sink"); - ASSERT_OK(plan->Validate()); - ASSERT_OK(plan->StartProducing()); - - auto got_batches_fut = CollectAsyncGenerator(sink_gen); - ASSERT_OK_AND_ASSIGN(auto got_batches, got_batches_fut.result()); + ASSERT_OK_AND_ASSIGN(auto got_batches, StartAndCollect(plan.get(), sink_gen)); // no filtering is performed by ScanNode: all batches will be yielded whole ASSERT_EQ(got_batches.size(), batches.size() * 2); for (size_t i = 0; i < got_batches.size(); ++i) { - const compute::ExecBatch& actual = *got_batches[i]; + const compute::ExecBatch& actual = got_batches[i]; Datum expected(i / 2 ? 47 : 23); AssertDatumsEqual(expected, actual[0], /*verbose=*/true); EXPECT_EQ(actual.guarantee, equal(field_ref("c"), literal(expected))) << i; diff --git a/cpp/src/arrow/util/async_generator.h b/cpp/src/arrow/util/async_generator.h index 29635ec5bd2..ca257aaddad 100644 --- a/cpp/src/arrow/util/async_generator.h +++ b/cpp/src/arrow/util/async_generator.h @@ -1114,7 +1114,7 @@ struct Enumerated { template struct IterationTraits> { - static Enumerated End() { return Enumerated{IterationEnd(), -1, false}; } + static Enumerated End() { return Enumerated{T{}, -1, false}; } static bool IsEnd(const Enumerated& val) { return val.index < 0; } }; @@ -1383,7 +1383,7 @@ class BackgroundGenerator { break; } - if (IsIterationEnd(next)) { + if (!next.ok() || IsIterationEnd(*next)) { // Terminal item. Mark finished to true, send this last item, and quit state->finished = true; if (!next.ok()) { diff --git a/cpp/src/arrow/util/iterator.h b/cpp/src/arrow/util/iterator.h index c84cdb21e03..b82021e4b21 100644 --- a/cpp/src/arrow/util/iterator.h +++ b/cpp/src/arrow/util/iterator.h @@ -66,12 +66,6 @@ bool IsIterationEnd(const T& val) { return IterationTraits::IsEnd(val); } -template -bool IsIterationEnd(const Result& maybe_val) { - if (!maybe_val.ok()) return true; - return IsIterationEnd(*maybe_val); -} - template struct IterationTraits> { /// \brief by default when iterating through a sequence of optional, From 1f8e93b8363379978bef90c64a1640add7a193e9 Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Wed, 2 Jun 2021 16:44:38 -0400 Subject: [PATCH 09/28] minor review comments --- cpp/src/arrow/compute/exec/exec_plan.h | 9 +++++++-- cpp/src/arrow/compute/exec/expression.cc | 16 ---------------- cpp/src/arrow/compute/exec/expression.h | 10 +++------- cpp/src/arrow/util/async_generator.h | 12 +----------- cpp/src/arrow/util/future.h | 24 ++++++++++++++++++++++++ 5 files changed, 35 insertions(+), 36 deletions(-) diff --git a/cpp/src/arrow/compute/exec/exec_plan.h b/cpp/src/arrow/compute/exec/exec_plan.h index 933cc9b50b6..d31228fdcf8 100644 --- a/cpp/src/arrow/compute/exec/exec_plan.h +++ b/cpp/src/arrow/compute/exec/exec_plan.h @@ -212,7 +212,7 @@ class ARROW_EXPORT ExecNode { virtual void StopProducing() = 0; protected: - ExecNode(ExecPlan*, std::string label, NodeVector inputs, + ExecNode(ExecPlan* plan, std::string label, NodeVector inputs, std::vector input_labels, BatchDescr output_descr, int num_outputs); @@ -241,12 +241,17 @@ AsyncGenerator> MakeSinkNode(ExecNode* input, std::string /// \brief Make a node which excludes some rows from batches passed through it /// -/// filter Expression must be bound; no field references will be looked up by name +/// The filter Expression will be evaluated against each batch which is pushed to +/// this node. Any rows for which the filter does not evaluate to `true` will be excluded +/// in the batch emitted by this node. Filter Expression must be bound; no field +/// references will be looked up by name ARROW_EXPORT ExecNode* MakeFilterNode(ExecNode* input, std::string label, Expression filter); /// \brief Make a node which executes expressions on input batches, producing new batches. /// +/// Each expression will be evaluated against each batch which is pushed to +/// this node to produce a corresponding output column. /// Expressions must be bound; no field references will be looked up by name ARROW_EXPORT ExecNode* MakeProjectNode(ExecNode* input, std::string label, diff --git a/cpp/src/arrow/compute/exec/expression.cc b/cpp/src/arrow/compute/exec/expression.cc index 0e99c614d3f..7bfe1bddc58 100644 --- a/cpp/src/arrow/compute/exec/expression.cc +++ b/cpp/src/arrow/compute/exec/expression.cc @@ -588,22 +588,6 @@ std::vector FieldsInExpression(const Expression& expr) { return fields; } -std::vector ParametersInExpression(const Expression& expr) { - if (expr.literal()) return {}; - - if (auto parameter = expr.parameter()) { - return {parameter->index}; - } - - std::vector indices; - for (const Expression& arg : CallNotNull(expr)->arguments) { - auto argument_indices = ParametersInExpression(arg); - std::move(argument_indices.begin(), argument_indices.end(), - std::back_inserter(indices)); - } - return indices; -} - bool ExpressionHasFieldRefs(const Expression& expr) { if (expr.literal()) return false; diff --git a/cpp/src/arrow/compute/exec/expression.h b/cpp/src/arrow/compute/exec/expression.h index 4b842ef38aa..1d576a23112 100644 --- a/cpp/src/arrow/compute/exec/expression.h +++ b/cpp/src/arrow/compute/exec/expression.h @@ -161,10 +161,6 @@ Expression call(std::string function, std::vector arguments, ARROW_EXPORT std::vector FieldsInExpression(const Expression&); -/// Assemble parameter indices referenced by an Expression at any depth. -ARROW_EXPORT -std::vector ParametersInExpression(const Expression&); - /// Check if the expression references any fields. ARROW_EXPORT bool ExpressionHasFieldRefs(const Expression&); @@ -215,9 +211,9 @@ Result SimplifyWithGuarantee(Expression, // Execution -/// Ensure that a RecordBatch (which may have missing or incorrectly ordered columns) -/// precisely matches the schema. This is necessary when executing Expressions -/// since we look up fields by index. Missing fields will be replaced with null scalars. +/// Create an ExecBatch suitable for passing to ExecuteScalarExpression() from a +/// RecordBatch which may have missing or incorrectly ordered columns. +/// Missing fields will be replaced with null scalars. ARROW_EXPORT Result MakeExecBatch(const Schema& full_schema, const Datum& partial); diff --git a/cpp/src/arrow/util/async_generator.h b/cpp/src/arrow/util/async_generator.h index ca257aaddad..217d9e5073a 100644 --- a/cpp/src/arrow/util/async_generator.h +++ b/cpp/src/arrow/util/async_generator.h @@ -266,17 +266,7 @@ AsyncGenerator MakeMappedGenerator(AsyncGenerator source_generator, MapFn struct MapCallback { MapFn map_; - Future operator()(const T& val) { return EnsureFuture(map_(val)); } - - Future EnsureFuture(V mapped) { - return Future::MakeFinished(std::move(mapped)); - } - - Future EnsureFuture(Result mapped) { - return Future::MakeFinished(std::move(mapped)); - } - - Future EnsureFuture(Future mapped) { return mapped; } + Future operator()(const T& val) { return ToFuture(map_(val)); } }; return MappingGenerator(std::move(source_generator), MapCallback{std::move(map)}); diff --git a/cpp/src/arrow/util/future.h b/cpp/src/arrow/util/future.h index 741550b3e64..c2e754911eb 100644 --- a/cpp/src/arrow/util/future.h +++ b/cpp/src/arrow/util/future.h @@ -994,4 +994,28 @@ struct EnsureFuture> { using type = Future; }; +template <> +struct EnsureFuture { + using type = Future<>; +}; + +inline Future<> ToFuture(Status status) { + return Future<>::MakeFinished(std::move(status)); +} + +template +Future ToFuture(T value) { + return Future::MakeFinished(std::move(value)); +} + +template +Future ToFuture(Result maybe_value) { + return Future::MakeFinished(std::move(maybe_value)); +} + +template +Future ToFuture(Future fut) { + return std::move(fut); +} + } // namespace arrow From 2e2612a6b1788e11307ccd3ed10eb15064679788 Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Mon, 7 Jun 2021 16:48:19 -0400 Subject: [PATCH 10/28] use compute/type_fwd.h --- cpp/src/arrow/compute/type_fwd.h | 2 ++ cpp/src/arrow/dataset/scanner.h | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/cpp/src/arrow/compute/type_fwd.h b/cpp/src/arrow/compute/type_fwd.h index 8a0d6de7f25..eebc8c1b678 100644 --- a/cpp/src/arrow/compute/type_fwd.h +++ b/cpp/src/arrow/compute/type_fwd.h @@ -41,6 +41,8 @@ struct VectorKernel; struct KernelState; class Expression; +class ExecNode; +class ExecPlan; } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/dataset/scanner.h b/cpp/src/arrow/dataset/scanner.h index d899cd82df0..99075a649fb 100644 --- a/cpp/src/arrow/dataset/scanner.h +++ b/cpp/src/arrow/dataset/scanner.h @@ -25,8 +25,8 @@ #include #include -#include "arrow/compute/exec/exec_plan.h" #include "arrow/compute/exec/expression.h" +#include "arrow/compute/type_fwd.h" #include "arrow/dataset/dataset.h" #include "arrow/dataset/projector.h" #include "arrow/dataset/type_fwd.h" From d7b4534d64d1dd3f2d2feb54c010c62122ddfeb7 Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Tue, 8 Jun 2021 19:30:31 -0400 Subject: [PATCH 11/28] Add (very) basic ExecNode doc --- cpp/src/arrow/compute/exec/doc/exec_node.md | 42 +++++++++++++++++++++ 1 file changed, 42 insertions(+) create mode 100644 cpp/src/arrow/compute/exec/doc/exec_node.md diff --git a/cpp/src/arrow/compute/exec/doc/exec_node.md b/cpp/src/arrow/compute/exec/doc/exec_node.md new file mode 100644 index 00000000000..d0b010c2558 --- /dev/null +++ b/cpp/src/arrow/compute/exec/doc/exec_node.md @@ -0,0 +1,42 @@ + + +# ExecNodes + +`ExecNode`s are intended to implement individual logical operators +in a streaming execution graph. Each node receives batches from +upstream nodes (inputs), processes them in some way, then pushes +results to downstream nodes (outputs). `ExecNode`s are owned and +(to an extent) coordinated by an `ExecPlan`. + +For example: for a simple dataset scan with only a filter and a +projection, we'll have a pretty trivial graph with a scan node +which loads batches from disk and pushes to a filter node. The +filter node excludes some rows based on an `Expression` then +pushes filtered batches to a project node. The project node +materializes new columns based on `Expressions` then pushes those +batches to a table collection node. The table collection node +assembles these batches into a `Table` which is handed off as the +result of the `ExecPlan`. + +Note that the execution graph is orthogonal to parallelism; any +node may push to any other node from any thread. In most cases, +a batch will arrive on a thread from a scan node and will +pass through each node in the graph on that same thread. + From 1019c3704eae26d9936c9383cf5aa748379398a2 Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Wed, 9 Jun 2021 13:17:35 -0400 Subject: [PATCH 12/28] Append to ExecNode doc --- cpp/src/arrow/compute/exec/doc/exec_node.md | 125 ++++++++++++++++++-- 1 file changed, 115 insertions(+), 10 deletions(-) diff --git a/cpp/src/arrow/compute/exec/doc/exec_node.md b/cpp/src/arrow/compute/exec/doc/exec_node.md index d0b010c2558..797cc87d90a 100644 --- a/cpp/src/arrow/compute/exec/doc/exec_node.md +++ b/cpp/src/arrow/compute/exec/doc/exec_node.md @@ -17,7 +17,7 @@ under the License. --> -# ExecNodes +# ExecNodes and logical operators `ExecNode`s are intended to implement individual logical operators in a streaming execution graph. Each node receives batches from @@ -25,18 +25,123 @@ upstream nodes (inputs), processes them in some way, then pushes results to downstream nodes (outputs). `ExecNode`s are owned and (to an extent) coordinated by an `ExecPlan`. -For example: for a simple dataset scan with only a filter and a -projection, we'll have a pretty trivial graph with a scan node -which loads batches from disk and pushes to a filter node. The -filter node excludes some rows based on an `Expression` then +> Terminology: "operator" and "node" are mostly interchangable, like +> "Interface" and "Abstract Base Class" in c++ space. The latter is +> a formal and specific bit of code which implements the abstract +> concept. + +## Types of logical operators + +Each of these will have at least one corresponding concrete +`ExecNode`. Where possible, compatible implementations of a +logical operator will *not* be exposed as independent subclasses +of `ExecNode`. Instead we prefer that they be +be encapsulated internally by a single subclass of `ExecNode` +to permit switching between them during a query. + +- Scan: materializes in-memory batches from storage (e.g. Parquet + files, flight stream, ...) +- Filter: evaluates an `Expression` on each input batch and outputs + a copy with any rows excluded for which the filter did not return + `true`. +- Project: evaluates `Expression`s on each input batch to produce + the columns of an output batch. +- Grouped Aggregate: identify groups based on one or more key columns + in each input batch, then update aggregates corresponding to those + groups. Node that this is a pipeline breaker; it will wait for its + inputs to complete before outputting any batches. +- Union: merge two or more streams of batches into a single stream + of batches. +- Write: write each batch to storage +- ToTable: Collect batches into a `Table` with stable row ordering where + possible. + +#### Not in scope for Arrow 5.0: + +- Join: perform an inner, left, outer, semi, or anti join given some + join predicates. +- Sort: accumulate all input batches into a single table, reorder its + rows by some sorting condition, then stream the sorted table out as + batches +- Top-K: retrieve a limited subset of rows from a table as though it + were in sorted order. + +For example: a dataset scan with only a filter and a +projection will correspond to a fairly trivial graph: + +``` +ScanNode -> FilterNode -> ProjectNode -> ToTableNode +``` + +A scan node loads batches from disk and pushes to a filter node. +The filter node excludes some rows based on an `Expression` then pushes filtered batches to a project node. The project node -materializes new columns based on `Expressions` then pushes those +materializes new columns based on `Expression`s then pushes those batches to a table collection node. The table collection node assembles these batches into a `Table` which is handed off as the result of the `ExecPlan`. -Note that the execution graph is orthogonal to parallelism; any -node may push to any other node from any thread. In most cases, -a batch will arrive on a thread from a scan node and will -pass through each node in the graph on that same thread. +## Parallelism, pipelines + +The execution graph is orthogonal to parallelism; any +node may push to any other node from any thread. A scan node causes +each batch to arrive on a thread after which it will pass through +each node in the example graph above, never leaving that thread +(memory/other resource pressure permitting). + +The example graph above happens to be simple enough that processing +of any batch by any node is independent of other nodes and other +batches; it is a pipeline. Note that there is no explicit `Pipeline` +class- pipelined execution is an emergent property of some sub +graphs. + +Nodes which do not share this property (pipeline breakers) are +responsible for deciding when they have received sufficient input, +when they can start emitting output, etc. For example a `GroupByNode` +will wait for its input to be exhausted before it begins pushing +batches to its own outputs. + +Parallelism is "seeded" by `ScanNode` (or other source nodes)- it +owns a reference to the thread pool on which the graph is executing +and fans out pushing to its outputs across that pool. A subsequent +`ProjectNode` will process the batch immediately after it is handed +off by the `ScanNode`- no explicit scheduling required. +Eventually, individual nodes may internally +parallelize processing of individual batches (for example, if a +`FilterNode`'s filter expression is slow). This decision is also left +up to each `ExecNode` implementation. + +# ExecNode interface and usage + +`ExecNode`s are constructed using one of the available factory +functions, such as `arrow::compute::MakeFilterNode` +or `arrow::dataset::MakeScanNode`. Any inputs to an `ExecNode` +must be provided when the node is constructed, so the first +nodes to be constructed are source nodes with no inputs +such as `ScanNode`. + +The batches yielded by an `ExecNode` always conform precisely +to its output schema. NB: no by-name field lookups or type +checks are performed during execution. The output schema +is usually derived from the output schemas of inputs. For +example a `FilterNode`'s output schema is always identical to +that of its input since batches are only modified by exclusion +of some rows. + +An `ExecNode` will begin producing batches when +`node->StartProducing()` is invoked and will proceed until stopped +with `node->StopProducing()`. Started nodes may not be destroyed +until stopped. `ExecNode`s are not currently restartable. +An `ExecNode` pushes batches to its outputs by passing each batch +to `output->InputReceived()`. It signals exhaustion by invoking +`output->InputFinished()`. + +Error recovery is permitted within a node. For example, if evaluation +of an `Expression` runs out of memory the governing node may +try that evaluation again after some memory has been freed up. +If a node experiences an error from which it cannot recover (for +example an IO error while parsing a CSV file) then it reports this +with `output->ErrorReceived()`. An error which escapes the scope of +a single node should not be considered recoverable (no `FilterNode` +should `try/catch` the IO error above). From b63e585c519f879339172823371302631e218c79 Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Fri, 11 Jun 2021 21:03:48 -0400 Subject: [PATCH 13/28] add Result and Status matchers --- cpp/src/arrow/compute/exec.cc | 33 ++ cpp/src/arrow/compute/exec.h | 7 + cpp/src/arrow/compute/exec/exec_plan.cc | 51 +-- cpp/src/arrow/compute/exec/exec_plan.h | 15 +- cpp/src/arrow/compute/exec/plan_test.cc | 342 +++++++++------------ cpp/src/arrow/dataset/scanner.cc | 1 + cpp/src/arrow/dataset/scanner_test.cc | 124 ++++---- cpp/src/arrow/pretty_print.cc | 88 ++++-- cpp/src/arrow/pretty_print.h | 6 +- cpp/src/arrow/result.h | 2 +- cpp/src/arrow/result_test.cc | 70 +++++ cpp/src/arrow/status.h | 7 +- cpp/src/arrow/status_test.cc | 81 +++++ cpp/src/arrow/testing/gtest_util.h | 121 ++++++++ cpp/src/arrow/util/async_generator_test.cc | 10 +- cpp/src/arrow/util/vector.h | 42 ++- 16 files changed, 674 insertions(+), 326 deletions(-) diff --git a/cpp/src/arrow/compute/exec.cc b/cpp/src/arrow/compute/exec.cc index add6188ab48..78f3d753711 100644 --- a/cpp/src/arrow/compute/exec.cc +++ b/cpp/src/arrow/compute/exec.cc @@ -36,6 +36,7 @@ #include "arrow/compute/registry.h" #include "arrow/compute/util_internal.h" #include "arrow/datum.h" +#include "arrow/pretty_print.h" #include "arrow/record_batch.h" #include "arrow/scalar.h" #include "arrow/status.h" @@ -69,6 +70,38 @@ ExecBatch::ExecBatch(const RecordBatch& batch) std::move(columns.begin(), columns.end(), values.begin()); } +bool ExecBatch::Equals(const ExecBatch& other) const { + return guarantee == other.guarantee && values == other.values; +} + +void PrintTo(const ExecBatch& batch, std::ostream* os) { + *os << "ExecBatch\n"; + + static const std::string indent = " "; + + *os << indent << "# Rows: " << batch.length << "\n"; + if (batch.guarantee != literal(true)) { + *os << indent << "Guarantee: " << batch.guarantee.ToString() << "\n"; + } + + int i = 0; + for (const Datum& value : batch.values) { + *os << indent << "" << i++ << ": "; + + if (value.is_scalar()) { + *os << "Scalar[" << value.scalar()->ToString() << "]\n"; + continue; + } + + auto array = value.make_array(); + PrettyPrintOptions options; + options.skip_new_lines = true; + *os << "Array"; + ARROW_CHECK_OK(PrettyPrint(*array, options, os)); + *os << "\n"; + } +} + ExecBatch ExecBatch::Slice(int64_t offset, int64_t length) const { ExecBatch out = *this; for (auto& value : out.values) { diff --git a/cpp/src/arrow/compute/exec.h b/cpp/src/arrow/compute/exec.h index 49484383e0f..e7015814d2a 100644 --- a/cpp/src/arrow/compute/exec.h +++ b/cpp/src/arrow/compute/exec.h @@ -207,6 +207,8 @@ struct ARROW_EXPORT ExecBatch { return values[i]; } + bool Equals(const ExecBatch& other) const; + /// \brief A convenience for the number of values / arguments. int num_values() const { return static_cast(values.size()); } @@ -221,8 +223,13 @@ struct ARROW_EXPORT ExecBatch { } return result; } + + ARROW_EXPORT friend void PrintTo(const ExecBatch&, std::ostream*); }; +inline bool operator==(const ExecBatch& l, const ExecBatch& r) { return l.Equals(r); } +inline bool operator!=(const ExecBatch& l, const ExecBatch& r) { return !l.Equals(r); } + /// \defgroup compute-call-function One-shot calls to compute functions /// /// @{ diff --git a/cpp/src/arrow/compute/exec/exec_plan.cc b/cpp/src/arrow/compute/exec/exec_plan.cc index 1eef9183155..a30ed2ca352 100644 --- a/cpp/src/arrow/compute/exec/exec_plan.cc +++ b/cpp/src/arrow/compute/exec/exec_plan.cc @@ -25,6 +25,7 @@ #include "arrow/compute/exec/expression.h" #include "arrow/datum.h" #include "arrow/result.h" +#include "arrow/util/async_generator.h" #include "arrow/util/checked_cast.h" #include "arrow/util/logging.h" #include "arrow/util/optional.h" @@ -220,11 +221,8 @@ struct SourceNode : ExecNode { return Status::Invalid("Restarted SourceNode '", label(), "'"); } - auto gen = std::move(generator_); - - /// XXX should we wait on this future anywhere? In StopProducing() maybe? - auto done_fut = - Loop([gen, this] { + finished_fut_ = + Loop([this] { std::unique_lock lock(mutex_); int seq = next_batch_index_++; if (finished_) { @@ -232,7 +230,7 @@ struct SourceNode : ExecNode { } lock.unlock(); - return gen().Then( + return generator_().Then( [=](const util::optional& batch) -> ControlFlow { std::unique_lock lock(mutex_); if (!batch || finished_) { @@ -250,8 +248,8 @@ struct SourceNode : ExecNode { finished_ = true; lock.unlock(); // unless we were already finished, push the error to our output - // XXX is this correct? Is it reasonable for a consumer to ignore errors - // from a finished producer? + // XXX is this correct? Is it reasonable for a consumer to + // ignore errors from a finished producer? outputs_[0]->ErrorReceived(this, error); } return Break(seq); @@ -271,8 +269,12 @@ struct SourceNode : ExecNode { void StopProducing(ExecNode* output) override { DCHECK_EQ(output, outputs_[0]); - std::unique_lock lock(mutex_); - finished_ = true; + { + std::unique_lock lock(mutex_); + finished_ = true; + } + DCHECK(finished_fut_.is_valid()); + finished_fut_.Wait(); } void StopProducing() override { StopProducing(outputs_[0]); } @@ -281,6 +283,7 @@ struct SourceNode : ExecNode { std::mutex mutex_; bool finished_{false}; int next_batch_index_{0}; + Future<> finished_fut_; AsyncGenerator> generator_; }; @@ -448,14 +451,14 @@ ExecNode* MakeProjectNode(ExecNode* input, std::string label, struct SinkNode : ExecNode { SinkNode(ExecNode* input, std::string label, - AsyncGenerator>* generator) + AsyncGenerator>* generator) : ExecNode(input->plan(), std::move(label), {input}, {"collected"}, {}, /*num_outputs=*/0), producer_(MakeProducer(generator)) {} - static PushGenerator>::Producer MakeProducer( - AsyncGenerator>* out_gen) { - PushGenerator> gen; + static PushGenerator>::Producer MakeProducer( + AsyncGenerator>* out_gen) { + PushGenerator> gen; auto out = gen.producer(); *out_gen = std::move(gen); return out; @@ -473,7 +476,7 @@ struct SinkNode : ExecNode { void StopProducing() override { std::unique_lock lock(mutex_); - StopProducingUnlocked(); + InputFinishedUnlocked(); } void InputReceived(ExecNode* input, int seq_num, ExecBatch batch) override { @@ -484,7 +487,7 @@ struct SinkNode : ExecNode { ++num_received_; if (num_received_ == emit_stop_) { - StopProducingUnlocked(); + InputFinishedUnlocked(); } if (emit_stop_ != -1) { @@ -492,30 +495,29 @@ struct SinkNode : ExecNode { } lock.unlock(); - producer_.Push(Enumerated{std::move(batch), seq_num, false}); + producer_.Push(std::move(batch)); } void ErrorReceived(ExecNode* input, Status error) override { DCHECK_EQ(input, inputs_[0]); producer_.Push(std::move(error)); std::unique_lock lock(mutex_); - StopProducingUnlocked(); + InputFinishedUnlocked(); } void InputFinished(ExecNode* input, int seq_stop) override { std::unique_lock lock(mutex_); emit_stop_ = seq_stop; if (emit_stop_ == num_received_) { - StopProducingUnlocked(); + InputFinishedUnlocked(); } } private: - void StopProducingUnlocked() { + void InputFinishedUnlocked() { if (!stopped_) { stopped_ = true; producer_.Close(); - inputs_[0]->StopProducing(this); } } @@ -525,11 +527,12 @@ struct SinkNode : ExecNode { int emit_stop_ = -1; bool stopped_ = false; - PushGenerator>::Producer producer_; + PushGenerator>::Producer producer_; }; -AsyncGenerator> MakeSinkNode(ExecNode* input, std::string label) { - AsyncGenerator> out; +AsyncGenerator> MakeSinkNode(ExecNode* input, + std::string label) { + AsyncGenerator> out; (void)input->plan()->EmplaceNode(input, std::move(label), &out); return out; } diff --git a/cpp/src/arrow/compute/exec/exec_plan.h b/cpp/src/arrow/compute/exec/exec_plan.h index d31228fdcf8..b3e3823c76d 100644 --- a/cpp/src/arrow/compute/exec/exec_plan.h +++ b/cpp/src/arrow/compute/exec/exec_plan.h @@ -17,14 +17,15 @@ #pragma once +#include #include #include #include #include "arrow/compute/type_fwd.h" #include "arrow/type_fwd.h" -#include "arrow/util/async_generator.h" #include "arrow/util/macros.h" +#include "arrow/util/optional.h" #include "arrow/util/visibility.h" // NOTES: @@ -93,8 +94,6 @@ class ARROW_EXPORT ExecNode { const NodeVector& inputs() const { return inputs_; } /// \brief Labels identifying the function of each input. - /// - /// For example, FilterNode accepts "target" and "filter" inputs. const std::vector& input_labels() const { return input_labels_; } /// This node's successors in the exec plan @@ -209,6 +208,8 @@ class ARROW_EXPORT ExecNode { virtual void StopProducing(ExecNode* output) = 0; /// \brief Stop producing definitively + /// + /// XXX maybe this should return a Future<>? virtual void StopProducing() = 0; protected: @@ -228,16 +229,20 @@ class ARROW_EXPORT ExecNode { }; /// \brief Adapt an AsyncGenerator as a source node +/// +/// TODO this should accept an Executor and explicitly handle batches +/// as they are generated on each of the Executor's threads. ARROW_EXPORT ExecNode* MakeSourceNode(ExecPlan*, std::string label, ExecNode::BatchDescr output_descr, - AsyncGenerator>); + std::function>()>); /// \brief Add a sink node which forwards to an AsyncGenerator /// /// Emitted batches will not be ordered; instead they will be tagged with the `seq` at /// which they were received. ARROW_EXPORT -AsyncGenerator> MakeSinkNode(ExecNode* input, std::string label); +std::function>()> MakeSinkNode(ExecNode* input, + std::string label); /// \brief Make a node which excludes some rows from batches passed through it /// diff --git a/cpp/src/arrow/compute/exec/plan_test.cc b/cpp/src/arrow/compute/exec/plan_test.cc index 205c0d8ff46..8cb1c4c1f7d 100644 --- a/cpp/src/arrow/compute/exec/plan_test.cc +++ b/cpp/src/arrow/compute/exec/plan_test.cc @@ -34,40 +34,30 @@ namespace arrow { -using internal::Executor; +using testing::UnorderedElementsAreArray; namespace compute { -void AssertBatchesEqual(const RecordBatchVector& expected, - const RecordBatchVector& actual) { - ASSERT_EQ(expected.size(), actual.size()); - for (size_t i = 0; i < expected.size(); ++i) { - AssertBatchesEqual(*expected[i], *actual[i]); - } -} - -void AssertBatchesEqual(const std::vector>& expected, - const std::vector>& actual) { - ASSERT_EQ(expected.size(), actual.size()); - for (size_t i = 0; i < expected.size(); ++i) { - AssertBatchesEqual(*expected[i], *actual[i]); - } -} - -void AssertBatchesEqual(const RecordBatchVector& expected, - const std::vector>& actual) { - ASSERT_EQ(expected.size(), actual.size()); - for (size_t i = 0; i < expected.size(); ++i) { - AssertBatchesEqual(ExecBatch(*expected[i]), *actual[i]); +ExecBatch ExecBatchFromJSON(const std::vector& descrs, + util::string_view json) { + auto fields = internal::MapVector( + [](const ValueDescr& descr) { return field("", descr.type); }, descrs); + + ExecBatch batch{*RecordBatchFromJSON(schema(std::move(fields)), json)}; + + auto value_it = batch.values.begin(); + for (const auto& descr : descrs) { + if (descr.shape == ValueDescr::SCALAR) { + if (batch.length == 0) { + *value_it = MakeNullScalar(value_it->type()); + } else { + *value_it = value_it->make_array()->GetScalar(0).ValueOrDie(); + } + } + ++value_it; } -} -void AssertBatchesEqual(const RecordBatchVector& expected, - const std::vector& actual) { - ASSERT_EQ(expected.size(), actual.size()); - for (size_t i = 0; i < expected.size(); ++i) { - AssertBatchesEqual(ExecBatch(*expected[i]), actual[i]); - } + return batch; } TEST(ExecPlanConstruction, Empty) { @@ -214,175 +204,151 @@ TEST(ExecPlan, DummyStartProducingError) { ASSERT_THAT(t.stopped, ::testing::ElementsAre("process2", "process3", "sink")); } -// TODO move this to gtest_util.h? - -class SlowRecordBatchReader : public RecordBatchReader { - public: - explicit SlowRecordBatchReader(std::shared_ptr reader) - : reader_(std::move(reader)) {} - - std::shared_ptr schema() const override { return reader_->schema(); } - - Status ReadNext(std::shared_ptr* batch) override { - SleepABit(); - return reader_->ReadNext(batch); +static Result MakeTestSourceNode(ExecPlan* plan, std::string label, + std::vector batches, bool parallel, + bool slow) { + DCHECK_GT(batches.size(), 0); + auto out_descr = batches.back().GetDescriptors(); + + auto opt_batches = internal::MapVector( + [](ExecBatch batch) { return util::make_optional(std::move(batch)); }, + std::move(batches)); + + AsyncGenerator> gen; + + if (parallel) { + // emulate batches completing initial decode-after-scan on a cpu thread + ARROW_ASSIGN_OR_RAISE( + gen, MakeBackgroundGenerator(MakeVectorIterator(std::move(opt_batches)), + internal::GetCpuThreadPool())); + } else { + gen = MakeVectorGenerator(std::move(opt_batches)); } - static Result> Make( - RecordBatchVector batches, std::shared_ptr schema = nullptr) { - ARROW_ASSIGN_OR_RAISE(auto reader, - RecordBatchReader::Make(std::move(batches), std::move(schema))); - return std::make_shared(std::move(reader)); + if (slow) { + gen = MakeMappedGenerator(std::move(gen), [](const util::optional& batch) { + SleepABit(); + return batch; + }); } - protected: - std::shared_ptr reader_; -}; - -static Result MakeSlowRecordBatchGenerator( - RecordBatchVector batches, std::shared_ptr schema) { - // TODO move this into testing/async_generator_util.h? - auto delayed_gen = - MakeMappedGenerator(MakeVectorGenerator(std::move(batches)), - [](const std::shared_ptr& batch) { - return SleepABitAsync().Then([=] { return batch; }); - }); - // Adding readahead implicitly adds parallelism by pulling reentrantly from - // the delayed generator - return MakeReadaheadGenerator(std::move(delayed_gen), /*max_readahead=*/64); + return MakeSourceNode(plan, label, out_descr, std::move(gen)); } -class TestExecPlanExecution : public ::testing::Test { - public: - void SetUp() override { - ASSERT_OK_AND_ASSIGN(io_executor_, internal::ThreadPool::Make(8)); - } +static Result> StartAndCollect( + ExecPlan* plan, AsyncGenerator> gen) { + RETURN_NOT_OK(plan->Validate()); + RETURN_NOT_OK(plan->StartProducing()); - RecordBatchVector MakeRandomBatches(const std::shared_ptr& schema, - int num_batches = 10, int batch_size = 4) { - random::RandomArrayGenerator rng(42); - RecordBatchVector batches; - batches.reserve(num_batches); - for (int i = 0; i < num_batches; ++i) { - batches.push_back(rng.BatchOf(schema->fields(), batch_size)); - } - return batches; - } + auto maybe_collected = CollectAsyncGenerator(gen).result(); + ARROW_ASSIGN_OR_RAISE(auto collected, maybe_collected); - Result>> StartAndCollect( - ExecPlan* plan, AsyncGenerator> gen) { - RETURN_NOT_OK(plan->Validate()); - RETURN_NOT_OK(plan->StartProducing()); - return CollectAsyncGenerator(gen).result(); - } - - Result> StartAndCollect( - ExecPlan* plan, AsyncGenerator> gen) { - RETURN_NOT_OK(plan->Validate()); - RETURN_NOT_OK(plan->StartProducing()); + // RETURN_NOT_OK(plan->StopProducing()); - auto maybe_collected = CollectAsyncGenerator(gen).result(); - ARROW_ASSIGN_OR_RAISE(auto collected, maybe_collected); + return internal::MapVector( + [](util::optional batch) { return std::move(*batch); }, collected); +} - std::sort(collected.begin(), collected.end(), - [](const Enumerated& l, const Enumerated& r) { - return l.index < r.index; - }); - return internal::MapVector( - [](Enumerated batch) { return std::move(batch.value); }, collected); - } +static std::vector MakeBasicBatches() { + return {ExecBatchFromJSON({int32(), boolean()}, "[[null, true], [4, false]]"), + ExecBatchFromJSON({int32(), boolean()}, "[[5, null], [6, false], [7, false]]")}; +} - ExecNode* MakeSource(ExecPlan* plan, std::shared_ptr reader, - std::shared_ptr schema) { - return MakeRecordBatchReaderNode(plan, "source", reader, io_executor_.get()); - } +static std::vector MakeRandomBatches(const std::shared_ptr& schema, + int num_batches = 10, + int batch_size = 4) { + random::RandomArrayGenerator rng(42); + std::vector batches(num_batches); - ExecNode* MakeSource(ExecPlan* plan, RecordBatchGenerator generator, - std::shared_ptr schema) { - return MakeRecordBatchReaderNode(plan, "source", schema, generator, - io_executor_.get()); + for (int i = 0; i < num_batches; ++i) { + batches[i] = ExecBatch(*rng.BatchOf(schema->fields(), batch_size)); + // add a tag scalar to ensure the batches are unique + batches[i].values.emplace_back(i); } + return batches; +} - template - void TestSourceSink(RecordBatchReaderFactory factory) { - auto schema = ::arrow::schema({field("a", int32()), field("b", boolean())}); - RecordBatchVector batches{ - RecordBatchFromJSON(schema, R"([{"a": null, "b": true}, - {"a": 4, "b": false}])"), - RecordBatchFromJSON(schema, R"([{"a": 5, "b": null}, - {"a": 6, "b": false}, - {"a": 7, "b": false}])"), - }; +TEST(ExecPlanExecution, SourceSink) { + for (bool slow : {false, true}) { + SCOPED_TRACE(slow ? "slowed" : "unslowed"); - ASSERT_OK_AND_ASSIGN(auto reader_or_gen, factory(batches, schema)); + for (bool parallel : {false, true}) { + SCOPED_TRACE(parallel ? "parallel" : "single threaded"); - ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); - auto source = MakeSource(plan.get(), reader_or_gen, schema); - auto sink_gen = MakeSinkNode(source, "sink"); + ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); - ASSERT_OK_AND_ASSIGN(auto got_batches, StartAndCollect(plan.get(), sink_gen)); - AssertBatchesEqual(batches, got_batches); - } + auto batches = MakeBasicBatches(); - template - void TestStressSourceSink(int num_batches, RecordBatchReaderFactory factory) { - auto schema = ::arrow::schema({field("a", int32()), field("b", boolean())}); - auto batches = MakeRandomBatches(schema, num_batches); + ASSERT_OK_AND_ASSIGN( + auto source, MakeTestSourceNode(plan.get(), "source", batches, parallel, slow)); - ASSERT_OK_AND_ASSIGN(auto reader_or_gen, factory(batches, schema)); - ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); - auto source = MakeSource(plan.get(), reader_or_gen, schema); - auto sink_gen = MakeSinkNode(source, "sink"); + auto sink_gen = MakeSinkNode(source, "sink"); - ASSERT_OK_AND_ASSIGN(auto got_batches, StartAndCollect(plan.get(), sink_gen)); - AssertBatchesEqual(batches, got_batches); + ASSERT_THAT(StartAndCollect(plan.get(), sink_gen), + ResultWith(UnorderedElementsAreArray(batches))); + } } +} - protected: - std::shared_ptr io_executor_; -}; +TEST(ExecPlanExecution, SourceSinkError) { + ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); -// FIXME Test "collecting" an error + auto batches = MakeBasicBatches(); + auto it = batches.begin(); + AsyncGenerator> gen = + [&]() -> Result> { + if (it == batches.end()) { + return Status::Invalid("Artificial error"); + } + return util::make_optional(*it++); + }; -TEST_F(TestExecPlanExecution, SourceSink) { TestSourceSink(RecordBatchReader::Make); } + auto source = MakeSourceNode(plan.get(), "source", {}, gen); + auto sink_gen = MakeSinkNode(source, "sink"); -TEST_F(TestExecPlanExecution, SlowSourceSink) { - TestSourceSink(SlowRecordBatchReader::Make); + ASSERT_THAT(StartAndCollect(plan.get(), sink_gen), + Raises(StatusCode::Invalid, testing::HasSubstr("Artificial"))); } -TEST_F(TestExecPlanExecution, SlowSourceSinkParallel) { - TestSourceSink(MakeSlowRecordBatchGenerator); -} +TEST(ExecPlanExecution, StressSourceSink) { + for (bool slow : {false, true}) { + SCOPED_TRACE(slow ? "slowed" : "unslowed"); -TEST_F(TestExecPlanExecution, StressSourceSink) { - TestStressSourceSink(/*num_batches=*/200, RecordBatchReader::Make); -} + for (bool parallel : {false, true}) { + SCOPED_TRACE(parallel ? "parallel" : "single threaded"); -TEST_F(TestExecPlanExecution, StressSlowSourceSink) { - // This doesn't create parallelism as the RecordBatchReader is iterated serially. - TestStressSourceSink(/*num_batches=*/30, SlowRecordBatchReader::Make); -} + int num_batches = slow && !parallel ? 30 : 300; + + ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); + + auto batches = MakeRandomBatches( + schema({field("a", int32()), field("b", boolean())}), num_batches); + + ASSERT_OK_AND_ASSIGN( + auto source, MakeTestSourceNode(plan.get(), "source", batches, parallel, slow)); + + auto sink_gen = MakeSinkNode(source, "sink"); -TEST_F(TestExecPlanExecution, StressSlowSourceSinkParallel) { - TestStressSourceSink(/*num_batches=*/300, MakeSlowRecordBatchGenerator); + ASSERT_THAT(StartAndCollect(plan.get(), sink_gen), + ResultWith(UnorderedElementsAreArray(batches))); + } + } } -TEST_F(TestExecPlanExecution, SourceFilterSink) { +TEST(ExecPlanExecution, SourceFilterSink) { ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); - const auto schema = ::arrow::schema({field("a", int32()), field("b", boolean())}); - RecordBatchVector batches{ - RecordBatchFromJSON(schema, R"([{"a": null, "b": true}, - {"a": 4, "b": false}])"), - RecordBatchFromJSON(schema, R"([{"a": 5, "b": null}, - {"a": 6, "b": false}, - {"a": 7, "b": false}])"), - }; + auto batches = MakeBasicBatches(); - ASSERT_OK_AND_ASSIGN(auto reader, RecordBatchReader::Make(std::move(batches), schema)); + ASSERT_OK_AND_ASSIGN(auto source, + MakeTestSourceNode(plan.get(), "source", batches, + /*parallel=*/false, /*slow=*/false)); - auto source = - MakeRecordBatchReaderNode(plan.get(), "source", reader, io_executor_.get()); + const auto schema = ::arrow::schema({ + field("a", int32()), + field("b", boolean()), + field("__tag", int32()), + }); ASSERT_OK_AND_ASSIGN(auto predicate, equal(field_ref("a"), literal(6)).Bind(*schema)); @@ -390,33 +356,25 @@ TEST_F(TestExecPlanExecution, SourceFilterSink) { auto sink_gen = MakeSinkNode(filter, "sink"); - ASSERT_OK_AND_ASSIGN(auto got_batches, StartAndCollect(plan.get(), sink_gen)); - - ASSERT_EQ(got_batches.size(), 2); - AssertBatchesEqual( - { - RecordBatchFromJSON(schema, R"([])"), - RecordBatchFromJSON(schema, R"([{"a": 6, "b": false}])"), - }, - got_batches); + ASSERT_THAT(StartAndCollect(plan.get(), sink_gen), + ResultWith(UnorderedElementsAreArray( + {ExecBatchFromJSON({int32(), boolean()}, "[]"), + ExecBatchFromJSON({int32(), boolean()}, "[[6, false]]")}))); } -TEST_F(TestExecPlanExecution, SourceProjectSink) { +TEST(ExecPlanExecution, SourceProjectSink) { ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); - const auto schema = ::arrow::schema({field("a", int32()), field("b", boolean())}); - RecordBatchVector batches{ - RecordBatchFromJSON(schema, R"([{"a": null, "b": true}, - {"a": 4, "b": false}])"), - RecordBatchFromJSON(schema, R"([{"a": 5, "b": null}, - {"a": 6, "b": false}, - {"a": 7, "b": false}])"), - }; + auto batches = MakeBasicBatches(); - ASSERT_OK_AND_ASSIGN(auto reader, RecordBatchReader::Make(std::move(batches), schema)); + ASSERT_OK_AND_ASSIGN(auto source, + MakeTestSourceNode(plan.get(), "source", batches, + /*parallel=*/false, /*slow=*/false)); - auto source = - MakeRecordBatchReaderNode(plan.get(), "source", reader, io_executor_.get()); + const auto schema = ::arrow::schema({ + field("a", int32()), + field("b", boolean()), + }); std::vector exprs{ not_(field_ref("b")), @@ -430,19 +388,11 @@ TEST_F(TestExecPlanExecution, SourceProjectSink) { auto sink_gen = MakeSinkNode(projection, "sink"); - ASSERT_OK_AND_ASSIGN(auto got_batches, StartAndCollect(plan.get(), sink_gen)); - - auto out_schema = ::arrow::schema({field("!b", boolean()), field("a + 1", int32())}); - ASSERT_EQ(got_batches.size(), 2); - AssertBatchesEqual( - { - RecordBatchFromJSON(out_schema, R"([{"!b": false, "a + 1": null}, - {"!b": true, "a + 1": 5}])"), - RecordBatchFromJSON(out_schema, R"([{"!b": null, "a + 1": 6}, - {"!b": true, "a + 1": 7}, - {"!b": true, "a + 1": 8}])"), - }, - got_batches); + ASSERT_THAT(StartAndCollect(plan.get(), sink_gen), + ResultWith(UnorderedElementsAreArray( + {ExecBatchFromJSON({boolean(), int32()}, "[[false, null], [true, 5]]"), + ExecBatchFromJSON({boolean(), int32()}, + "[[null, 6], [true, 7], [true, 8]]")}))); } } // namespace compute diff --git a/cpp/src/arrow/dataset/scanner.cc b/cpp/src/arrow/dataset/scanner.cc index b39eefb29df..a2e70c0bfe8 100644 --- a/cpp/src/arrow/dataset/scanner.cc +++ b/cpp/src/arrow/dataset/scanner.cc @@ -27,6 +27,7 @@ #include "arrow/compute/api_scalar.h" #include "arrow/compute/api_vector.h" #include "arrow/compute/cast.h" +#include "arrow/compute/exec/exec_plan.h" #include "arrow/dataset/dataset.h" #include "arrow/dataset/dataset_internal.h" #include "arrow/dataset/scanner_internal.h" diff --git a/cpp/src/arrow/dataset/scanner_test.cc b/cpp/src/arrow/dataset/scanner_test.cc index e1ae813bee4..8805eace9a5 100644 --- a/cpp/src/arrow/dataset/scanner_test.cc +++ b/cpp/src/arrow/dataset/scanner_test.cc @@ -18,6 +18,7 @@ #include "arrow/dataset/scanner.h" #include +#include #include @@ -25,6 +26,7 @@ #include "arrow/compute/api_scalar.h" #include "arrow/compute/api_vector.h" #include "arrow/compute/cast.h" +#include "arrow/compute/exec/exec_plan.h" #include "arrow/dataset/scanner_internal.h" #include "arrow/dataset/test_util.h" #include "arrow/record_batch.h" @@ -38,6 +40,7 @@ using testing::ElementsAre; using testing::IsEmpty; +using testing::UnorderedElementsAreArray; namespace arrow { namespace dataset { @@ -923,7 +926,7 @@ TEST_F(TestReordering, ScanBatchesUnordered) { struct BatchConsumer { explicit BatchConsumer(EnumeratedRecordBatchGenerator generator) - : generator(generator), next() {} + : generator(std::move(generator)), next() {} void AssertCanConsume() { if (!next.is_valid()) { @@ -1089,18 +1092,17 @@ TEST(ScanOptions, TestMaterializedFields) { } static Result> StartAndCollect( - compute::ExecPlan* plan, AsyncGenerator> gen) { + compute::ExecPlan* plan, AsyncGenerator> gen) { RETURN_NOT_OK(plan->Validate()); RETURN_NOT_OK(plan->StartProducing()); auto maybe_collected = CollectAsyncGenerator(gen).result(); ARROW_ASSIGN_OR_RAISE(auto collected, maybe_collected); - std::sort(collected.begin(), collected.end(), - [](const Enumerated& l, - const Enumerated& r) { return l.index < r.index; }); + // RETURN_NOT_OK(plan->StopProducing()); + return internal::MapVector( - [](Enumerated batch) { return std::move(batch.value); }, + [](util::optional batch) { return std::move(*batch); }, collected); } @@ -1122,19 +1124,19 @@ TEST(ScanNode, Trivial) { ASSERT_OK(scanner_builder.UseAsync(true)); ASSERT_OK_AND_ASSIGN(auto scan, scanner_builder.MakeScanNode(plan.get())); auto sink_gen = MakeSinkNode(scan, "sink"); - ASSERT_OK_AND_ASSIGN(auto got_batches, StartAndCollect(plan.get(), sink_gen)); - - ASSERT_EQ(got_batches.size(), batches.size()); - for (size_t i = 0; i < batches.size(); ++i) { - SCOPED_TRACE("Batch " + std::to_string(i)); - const compute::ExecBatch& actual = got_batches[i]; - const RecordBatch& expected = *batches[i]; - AssertDatumsEqual(expected.GetColumnByName("a"), actual[0], /*verbose=*/true); - AssertDatumsEqual(expected.GetColumnByName("b"), actual[1], /*verbose=*/true); - // InMemoryDataset(RecordBatchVector) produces a fragment wrapping each batch - AssertDatumsEqual(Datum(int(i)), actual[2], /*verbose=*/true); - AssertDatumsEqual(Datum(0), actual[3], /*verbose=*/true); + + std::vector expected; + // InMemoryDataset(RecordBatchVector) produces a fragment wrapping each batch + const int batch_index = 0; + int fragment_index = 0; + for (const auto& batch : batches) { + expected.emplace_back(*batch); + expected.back().values.emplace_back(fragment_index++); + expected.back().values.emplace_back(batch_index); } + + ASSERT_THAT(StartAndCollect(plan.get(), sink_gen), + ResultWith(UnorderedElementsAreArray(expected))); } TEST(ScanNode, FilteredOnVirtualColumn) { @@ -1168,28 +1170,28 @@ TEST(ScanNode, FilteredOnVirtualColumn) { ASSERT_OK(scanner_builder.Filter(greater(field_ref("c"), literal(30)))); ASSERT_OK_AND_ASSIGN(auto scan, scanner_builder.MakeScanNode(plan.get())); auto sink_gen = MakeSinkNode(scan, "sink"); - ASSERT_OK_AND_ASSIGN(auto got_batches, StartAndCollect(plan.get(), sink_gen)); - ASSERT_EQ(got_batches.size(), 2); - for (size_t i = 0; i < batches.size(); ++i) { - const compute::ExecBatch& actual = got_batches[i]; - const RecordBatch& expected = *batches[i]; - AssertDatumsEqual(expected.GetColumnByName("a"), actual[0], /*verbose=*/true); - AssertDatumsEqual(expected.GetColumnByName("b"), actual[1], /*verbose=*/true); + std::vector expected; + const int fragment_index = 0; // only the second fragment will make it past the filter, + // and its index in the scan is 0 + int batch_index = 0; + for (const auto& batch : batches) { + expected.emplace_back(*batch); - // Note: placeholder for partition field "c" - AssertDatumsEqual(Datum(std::make_shared()), actual[2], - /*verbose=*/true); + expected.back().guarantee = equal(field_ref("c"), literal(47)); - // Only one fragment in this scan, its index is 0 - AssertDatumsEqual(Datum(0), actual[3], /*verbose=*/true); - AssertDatumsEqual(Datum(int(i)), actual[4], /*verbose=*/true); + // Note: placeholder for partition field "c" + expected.back().values.emplace_back(std::make_shared()); - EXPECT_EQ(actual.guarantee, equal(field_ref("c"), literal(47))); + expected.back().values.emplace_back(fragment_index); + expected.back().values.emplace_back(batch_index++); } + + ASSERT_THAT(StartAndCollect(plan.get(), sink_gen), + ResultWith(UnorderedElementsAreArray(expected))); } -TEST(ScanNode, FilteredOnPhysicalColumn) { +TEST(ScanNode, DeferredFilterOnPhysicalColumn) { ASSERT_OK_AND_ASSIGN(auto plan, compute::ExecPlan::Make()); const auto dataset_schema = ::arrow::schema({ @@ -1220,26 +1222,29 @@ TEST(ScanNode, FilteredOnPhysicalColumn) { ASSERT_OK(scanner_builder.Filter(greater(field_ref("a"), literal(4)))); ASSERT_OK_AND_ASSIGN(auto scan, scanner_builder.MakeScanNode(plan.get())); auto sink_gen = MakeSinkNode(scan, "sink"); - ASSERT_OK_AND_ASSIGN(auto got_batches, StartAndCollect(plan.get(), sink_gen)); // no filtering is performed by ScanNode: all batches will be yielded whole - ASSERT_EQ(got_batches.size(), batches.size() * 2); - for (size_t i = 0; i < got_batches.size(); ++i) { - const compute::ExecBatch& actual = got_batches[i]; - const RecordBatch& expected = *batches[i % 2]; - AssertDatumsEqual(expected.GetColumnByName("a"), actual[0], /*verbose=*/true); - AssertDatumsEqual(expected.GetColumnByName("b"), actual[1], /*verbose=*/true); - AssertDatumsEqual(Datum(std::make_shared()), actual[2], - /*verbose=*/true); - - AssertDatumsEqual(Datum(int(i / 2)), actual[3], /*verbose=*/true); - AssertDatumsEqual(Datum(int(i % 2)), actual[4], /*verbose=*/true); - - EXPECT_EQ(actual.guarantee, equal(field_ref("c"), literal(i / 2 ? 47 : 23))) << i; + std::vector expected; + for (int fragment_index = 0; fragment_index < 2; ++fragment_index) { + for (int batch_index = 0; batch_index < 2; ++batch_index) { + expected.emplace_back(*batches[batch_index]); + + expected.back().guarantee = + equal(field_ref("c"), literal(fragment_index == 0 ? 23 : 47)); + + // Note: placeholder for partition field "c" + expected.back().values.emplace_back(std::make_shared()); + + expected.back().values.emplace_back(fragment_index); + expected.back().values.emplace_back(batch_index); + } } + + ASSERT_THAT(StartAndCollect(plan.get(), sink_gen), + ResultWith(UnorderedElementsAreArray(expected))); } -TEST(ScanNode, ProjectPhysicalColumn) { +TEST(ScanNode, MaterializationOfVirtualColumn) { ASSERT_OK_AND_ASSIGN(auto plan, compute::ExecPlan::Make()); const auto dataset_schema = ::arrow::schema({ @@ -1271,16 +1276,23 @@ TEST(ScanNode, ProjectPhysicalColumn) { auto project = compute::MakeProjectNode( scan, "project", {field_ref("c").Bind(*dataset_schema).ValueOrDie()}); auto sink_gen = MakeSinkNode(project, "sink"); - ASSERT_OK_AND_ASSIGN(auto got_batches, StartAndCollect(plan.get(), sink_gen)); - // no filtering is performed by ScanNode: all batches will be yielded whole - ASSERT_EQ(got_batches.size(), batches.size() * 2); - for (size_t i = 0; i < got_batches.size(); ++i) { - const compute::ExecBatch& actual = got_batches[i]; - Datum expected(i / 2 ? 47 : 23); - AssertDatumsEqual(expected, actual[0], /*verbose=*/true); - EXPECT_EQ(actual.guarantee, equal(field_ref("c"), literal(expected))) << i; + std::vector expected; + for (int fragment_index = 0; fragment_index < 2; ++fragment_index) { + for (int batch_index = 0; batch_index < 2; ++batch_index) { + auto c_value = fragment_index == 0 ? 23 : 47; + + expected.push_back(compute::ExecBatch{{}, batches[batch_index]->num_rows()}); + + // NB: ProjectNode overwrites "c" placeholder with value from guarantee + expected.back().values.emplace_back(c_value); + + expected.back().guarantee = equal(field_ref("c"), literal(c_value)); + } } + + ASSERT_THAT(StartAndCollect(plan.get(), sink_gen), + ResultWith(UnorderedElementsAreArray(expected))); } } // namespace dataset diff --git a/cpp/src/arrow/pretty_print.cc b/cpp/src/arrow/pretty_print.cc index 8c2ac376d1e..8d1c16e0ed6 100644 --- a/cpp/src/arrow/pretty_print.cc +++ b/cpp/src/arrow/pretty_print.cc @@ -69,10 +69,12 @@ class PrettyPrinter { }; void PrettyPrinter::OpenArray(const Array& array) { - Indent(); + if (!options_.skip_new_lines) { + Indent(); + } (*sink_) << "["; if (array.length() > 0) { - (*sink_) << "\n"; + Newline(); indent_ += options_.indent_size; } } @@ -103,7 +105,6 @@ void PrettyPrinter::Newline() { return; } (*sink_) << "\n"; - Indent(); } void PrettyPrinter::Indent() { @@ -124,11 +125,15 @@ class ArrayPrinter : public PrettyPrinter { if (skip_comma) { skip_comma = false; } else { - (*sink_) << ",\n"; + (*sink_) << ","; + Newline(); + } + if (!options_.skip_new_lines) { + Indent(); } - Indent(); if ((i >= options_.window) && (i < (array.length() - options_.window))) { - (*sink_) << "...\n"; + (*sink_) << "..."; + Newline(); i = array.length() - options_.window - 1; skip_comma = true; } else if (array.IsNull(i)) { @@ -137,7 +142,7 @@ class ArrayPrinter : public PrettyPrinter { func(i); } } - (*sink_) << "\n"; + Newline(); } Status WriteDataValues(const BooleanArray& array) { @@ -239,11 +244,13 @@ class ArrayPrinter : public PrettyPrinter { if (skip_comma) { skip_comma = false; } else { - (*sink_) << ",\n"; + (*sink_) << ","; + Newline(); } if ((i >= options_.window) && (i < (array.length() - options_.window))) { Indent(); - (*sink_) << "...\n"; + (*sink_) << "..."; + Newline(); i = array.length() - options_.window - 1; skip_comma = true; } else if (array.IsNull(i)) { @@ -252,10 +259,11 @@ class ArrayPrinter : public PrettyPrinter { } else { std::shared_ptr slice = array.values()->Slice(array.value_offset(i), array.value_length(i)); - RETURN_NOT_OK(PrettyPrint(*slice, {indent_, options_.window}, sink_)); + RETURN_NOT_OK( + PrettyPrint(*slice, PrettyPrintOptions{indent_, options_.window}, sink_)); } } - (*sink_) << "\n"; + Newline(); return Status::OK(); } @@ -265,28 +273,36 @@ class ArrayPrinter : public PrettyPrinter { if (skip_comma) { skip_comma = false; } else { - (*sink_) << ",\n"; + (*sink_) << ","; + Newline(); } - if ((i >= options_.window) && (i < (array.length() - options_.window))) { + + if (!options_.skip_new_lines) { Indent(); - (*sink_) << "...\n"; + } + + if ((i >= options_.window) && (i < (array.length() - options_.window))) { + (*sink_) << "..."; + Newline(); i = array.length() - options_.window - 1; skip_comma = true; } else if (array.IsNull(i)) { - Indent(); (*sink_) << options_.null_rep; } else { - Indent(); - (*sink_) << "keys:\n"; + (*sink_) << "keys:"; + Newline(); auto keys_slice = array.keys()->Slice(array.value_offset(i), array.value_length(i)); - RETURN_NOT_OK(PrettyPrint(*keys_slice, {indent_, options_.window}, sink_)); - (*sink_) << "\n"; + RETURN_NOT_OK(PrettyPrint(*keys_slice, + PrettyPrintOptions{indent_, options_.window}, sink_)); + Newline(); Indent(); - (*sink_) << "values:\n"; + (*sink_) << "values:"; + Newline(); auto values_slice = array.items()->Slice(array.value_offset(i), array.value_length(i)); - RETURN_NOT_OK(PrettyPrint(*values_slice, {indent_, options_.window}, sink_)); + RETURN_NOT_OK(PrettyPrint(*values_slice, + PrettyPrintOptions{indent_, options_.window}, sink_)); } } (*sink_) << "\n"; @@ -325,6 +341,7 @@ class ArrayPrinter : public PrettyPrinter { int64_t length) { for (size_t i = 0; i < fields.size(); ++i) { Newline(); + Indent(); std::stringstream ss; ss << "-- child " << i << " type: " << fields[i]->type()->ToString() << "\n"; Write(ss.str()); @@ -352,12 +369,14 @@ class ArrayPrinter : public PrettyPrinter { RETURN_NOT_OK(WriteValidityBitmap(array)); Newline(); + Indent(); Write("-- type_ids: "); UInt8Array type_codes(array.length(), array.type_codes(), nullptr, 0, array.offset()); RETURN_NOT_OK(PrettyPrint(type_codes, indent_ + options_.indent_size, sink_)); if (array.mode() == UnionMode::DENSE) { Newline(); + Indent(); Write("-- value_offsets: "); Int32Array value_offsets( array.length(), checked_cast(array).value_offsets(), @@ -376,11 +395,13 @@ class ArrayPrinter : public PrettyPrinter { Status Visit(const DictionaryArray& array) { Newline(); + Indent(); Write("-- dictionary:\n"); RETURN_NOT_OK( PrettyPrint(*array.dictionary(), indent_ + options_.indent_size, sink_)); Newline(); + Indent(); Write("-- indices:\n"); return PrettyPrint(*array.indices(), indent_ + options_.indent_size, sink_); } @@ -431,6 +452,7 @@ Status ArrayPrinter::WriteValidityBitmap(const Array& array) { if (array.null_count() > 0) { Newline(); + Indent(); BooleanArray is_valid(array.length(), array.null_bitmap(), nullptr, 0, array.offset()); return PrettyPrint(is_valid, indent_ + options_.indent_size, sink_); @@ -470,19 +492,28 @@ Status PrettyPrint(const ChunkedArray& chunked_arr, const PrettyPrintOptions& op for (int i = 0; i < indent; ++i) { (*sink) << " "; } - (*sink) << "[\n"; + (*sink) << "["; + if (!options.skip_new_lines) { + *sink << "\n"; + } bool skip_comma = true; for (int i = 0; i < num_chunks; ++i) { if (skip_comma) { skip_comma = false; } else { - (*sink) << ",\n"; + (*sink) << ","; + if (!options.skip_new_lines) { + *sink << "\n"; + } } if ((i >= window) && (i < (num_chunks - window))) { for (int i = 0; i < indent; ++i) { (*sink) << " "; } - (*sink) << "...\n"; + (*sink) << "..."; + if (!options.skip_new_lines) { + *sink << "\n"; + } i = num_chunks - window - 1; skip_comma = true; } else { @@ -492,7 +523,9 @@ Status PrettyPrint(const ChunkedArray& chunked_arr, const PrettyPrintOptions& op RETURN_NOT_OK(printer.Print(*chunked_arr.chunk(i))); } } - (*sink) << "\n"; + if (!options.skip_new_lines) { + *sink << "\n"; + } for (int i = 0; i < indent; ++i) { (*sink) << " "; @@ -572,6 +605,7 @@ class SchemaPrinter : public PrettyPrinter { void PrintVerboseMetadata(const KeyValueMetadata& metadata) { for (int64_t i = 0; i < metadata.size(); ++i) { Newline(); + Indent(); Write(metadata.key(i) + ": '" + metadata.value(i) + "'"); } } @@ -579,6 +613,7 @@ class SchemaPrinter : public PrettyPrinter { void PrintTruncatedMetadata(const KeyValueMetadata& metadata) { for (int64_t i = 0; i < metadata.size(); ++i) { Newline(); + Indent(); size_t size = metadata.value(i).size(); size_t truncated_size = std::max(10, 70 - metadata.key(i).size() - indent_); if (size <= truncated_size) { @@ -594,6 +629,7 @@ class SchemaPrinter : public PrettyPrinter { void PrintMetadata(const std::string& metadata_type, const KeyValueMetadata& metadata) { if (metadata.size() > 0) { Newline(); + Indent(); Write(metadata_type); if (options_.truncate_metadata) { PrintTruncatedMetadata(metadata); @@ -607,6 +643,7 @@ class SchemaPrinter : public PrettyPrinter { for (int i = 0; i < schema_.num_fields(); ++i) { if (i > 0) { Newline(); + Indent(); } else { Indent(); } @@ -631,6 +668,7 @@ Status SchemaPrinter::PrintType(const DataType& type, bool nullable) { } for (int i = 0; i < type.num_fields(); ++i) { Newline(); + Indent(); std::stringstream ss; ss << "child " << i << ", "; diff --git a/cpp/src/arrow/pretty_print.h b/cpp/src/arrow/pretty_print.h index 9d2c72c7186..1bc086a6889 100644 --- a/cpp/src/arrow/pretty_print.h +++ b/cpp/src/arrow/pretty_print.h @@ -19,6 +19,7 @@ #include #include +#include #include "arrow/util/visibility.h" @@ -34,13 +35,14 @@ class Table; struct PrettyPrintOptions { PrettyPrintOptions() = default; - PrettyPrintOptions(int indent_arg, int window_arg = 10, int indent_size_arg = 2, + PrettyPrintOptions(int indent_arg, // NOLINT runtime/explicit + int window_arg = 10, int indent_size_arg = 2, std::string null_rep_arg = "null", bool skip_new_lines_arg = false, bool truncate_metadata_arg = true) : indent(indent_arg), indent_size(indent_size_arg), window(window_arg), - null_rep(null_rep_arg), + null_rep(std::move(null_rep_arg)), skip_new_lines(skip_new_lines_arg), truncate_metadata(truncate_metadata_arg) {} diff --git a/cpp/src/arrow/result.h b/cpp/src/arrow/result.h index ee79077335e..cb7437cd242 100644 --- a/cpp/src/arrow/result.h +++ b/cpp/src/arrow/result.h @@ -490,7 +490,7 @@ class ARROW_MUST_USE_TYPE Result : public util::EqualityComparable> { namespace internal { template -inline Status GenericToStatus(const Result& res) { +inline const Status& GenericToStatus(const Result& res) { return res.status(); } diff --git a/cpp/src/arrow/result_test.cc b/cpp/src/arrow/result_test.cc index b71af9d8531..b814e3f3ea1 100644 --- a/cpp/src/arrow/result_test.cc +++ b/cpp/src/arrow/result_test.cc @@ -26,6 +26,7 @@ #include #include "arrow/testing/gtest_compat.h" +#include "arrow/testing/gtest_util.h" namespace arrow { @@ -724,5 +725,74 @@ TEST(ResultTest, ViewAsStatus) { EXPECT_EQ(ViewAsStatus(&err), &err.status()); } +TEST(ResultTest, MatcherExamples) { + EXPECT_THAT(Result(Status::Invalid("arbitrary error")), + Raises(StatusCode::Invalid)); + + EXPECT_THAT(Result(Status::Invalid("arbitrary error")), + Raises(StatusCode::Invalid, testing::HasSubstr("arbitrary"))); + + // message doesn't match, so no match + EXPECT_THAT( + Result(Status::Invalid("arbitrary error")), + testing::Not(Raises(StatusCode::Invalid, testing::HasSubstr("reasonable")))); + + // different error code, so no match + EXPECT_THAT(Result(Status::TypeError("arbitrary error")), + testing::Not(Raises(StatusCode::Invalid))); + + // not an error, so no match + EXPECT_THAT(Result(333), testing::Not(Raises(StatusCode::Invalid))); + + EXPECT_THAT(Result("hello world"), + ResultWith(testing::HasSubstr("hello"))); + + EXPECT_THAT(Result(Status::Invalid("XXX")), + testing::Not(ResultWith(testing::HasSubstr("hello")))); + + // holds a value, but that value doesn't match the given pattern + EXPECT_THAT(Result("foo bar"), + testing::Not(ResultWith(testing::HasSubstr("hello")))); +} + +TEST(ResultTest, MatcherDescriptions) { + testing::Matcher> matcher = ResultWith(testing::HasSubstr("hello")); + + { + std::stringstream ss; + matcher.DescribeTo(&ss); + EXPECT_THAT(ss.str(), testing::StrEq("value has substring \"hello\"")); + } + + { + std::stringstream ss; + matcher.DescribeNegationTo(&ss); + EXPECT_THAT(ss.str(), testing::StrEq("value has no substring \"hello\"")); + } +} + +TEST(ResultTest, MatcherExplanations) { + testing::Matcher> matcher = ResultWith(testing::HasSubstr("hello")); + + { + testing::StringMatchResultListener listener; + EXPECT_TRUE(matcher.MatchAndExplain(Result("hello world"), &listener)); + EXPECT_THAT(listener.str(), testing::StrEq("whose value \"hello world\" matches")); + } + + { + testing::StringMatchResultListener listener; + EXPECT_FALSE(matcher.MatchAndExplain(Result("foo bar"), &listener)); + EXPECT_THAT(listener.str(), testing::StrEq("whose value \"foo bar\" doesn't match")); + } + + { + testing::StringMatchResultListener listener; + EXPECT_FALSE(matcher.MatchAndExplain(Status::TypeError("XXX"), &listener)); + EXPECT_THAT(listener.str(), + testing::StrEq("whose error \"Type error: XXX\" doesn't match")); + } +} + } // namespace } // namespace arrow diff --git a/cpp/src/arrow/status.h b/cpp/src/arrow/status.h index 43879e6c6a3..056d60d6f32 100644 --- a/cpp/src/arrow/status.h +++ b/cpp/src/arrow/status.h @@ -312,7 +312,10 @@ class ARROW_MUST_USE_TYPE ARROW_EXPORT Status : public util::EqualityComparable< StatusCode code() const { return ok() ? StatusCode::OK : state_->code; } /// \brief Return the specific error message attached to this status. - std::string message() const { return ok() ? "" : state_->msg; } + const std::string& message() const { + static const std::string no_message = ""; + return ok() ? no_message : state_->msg; + } /// \brief Return the status detail attached to this message. const std::shared_ptr& detail() const { @@ -440,7 +443,7 @@ namespace internal { // Extract Status from Status or Result // Useful for the status check macros such as RETURN_NOT_OK. -inline Status GenericToStatus(const Status& st) { return st; } +inline const Status& GenericToStatus(const Status& st) { return st; } inline Status GenericToStatus(Status&& st) { return std::move(st); } } // namespace internal diff --git a/cpp/src/arrow/status_test.cc b/cpp/src/arrow/status_test.cc index fc5a7ec45cf..b4e5c288d1b 100644 --- a/cpp/src/arrow/status_test.cc +++ b/cpp/src/arrow/status_test.cc @@ -17,9 +17,11 @@ #include +#include #include #include "arrow/status.h" +#include "arrow/testing/gtest_util.h" namespace arrow { @@ -114,6 +116,85 @@ TEST(StatusTest, TestEquality) { ASSERT_NE(Status::Invalid("error"), Status::Invalid("other error")); } +TEST(StatusTest, MatcherExamples) { + EXPECT_THAT(Status::Invalid("arbitrary error"), Raises(StatusCode::Invalid)); + + EXPECT_THAT(Status::Invalid("arbitrary error"), + Raises(StatusCode::Invalid, testing::HasSubstr("arbitrary"))); + + // message doesn't match, so no match + EXPECT_THAT( + Status::Invalid("arbitrary error"), + testing::Not(Raises(StatusCode::Invalid, testing::HasSubstr("reasonable")))); + + // different error code, so no match + EXPECT_THAT(Status::TypeError("arbitrary error"), + testing::Not(Raises(StatusCode::Invalid))); + + // not an error, so no match + EXPECT_THAT(Status::OK(), testing::Not(Raises(StatusCode::Invalid))); +} + +TEST(StatusTest, MatcherDescriptions) { + testing::Matcher matcher = Raises(StatusCode::Invalid); + + { + std::stringstream ss; + matcher.DescribeTo(&ss); + EXPECT_THAT(ss.str(), testing::StrEq("raises StatusCode::Invalid")); + } + + { + std::stringstream ss; + matcher.DescribeNegationTo(&ss); + EXPECT_THAT(ss.str(), testing::StrEq("does not raise StatusCode::Invalid")); + } +} + +TEST(StatusTest, MessageMatcherDescriptions) { + testing::Matcher matcher = + Raises(StatusCode::Invalid, testing::HasSubstr("arbitrary")); + + { + std::stringstream ss; + matcher.DescribeTo(&ss); + EXPECT_THAT( + ss.str(), + testing::StrEq( + "raises StatusCode::Invalid and message has substring \"arbitrary\"")); + } + + { + std::stringstream ss; + matcher.DescribeNegationTo(&ss); + EXPECT_THAT(ss.str(), testing::StrEq("does not raise StatusCode::Invalid or message " + "has no substring \"arbitrary\"")); + } +} + +TEST(StatusTest, MatcherExplanations) { + testing::Matcher matcher = Raises(StatusCode::Invalid); + + { + testing::StringMatchResultListener listener; + EXPECT_TRUE(matcher.MatchAndExplain(Status::Invalid("XXX"), &listener)); + EXPECT_THAT(listener.str(), testing::StrEq("whose value \"Invalid: XXX\" matches")); + } + + { + testing::StringMatchResultListener listener; + EXPECT_FALSE(matcher.MatchAndExplain(Status::OK(), &listener)); + EXPECT_THAT(listener.str(), testing::StrEq("whose value \"OK\" doesn't match")); + } + + { + testing::StringMatchResultListener listener; + EXPECT_FALSE(matcher.MatchAndExplain(Status::TypeError("XXX"), &listener)); + EXPECT_THAT(listener.str(), + testing::StrEq("whose value \"Type error: XXX\" doesn't match")); + } +} + TEST(StatusTest, TestDetailEquality) { const auto status_with_detail = arrow::Status(StatusCode::IOError, "", std::make_shared()); diff --git a/cpp/src/arrow/testing/gtest_util.h b/cpp/src/arrow/testing/gtest_util.h index 591745151da..499f6adf4a6 100644 --- a/cpp/src/arrow/testing/gtest_util.h +++ b/cpp/src/arrow/testing/gtest_util.h @@ -28,6 +28,7 @@ #include #include +#include #include #include "arrow/array/builder_binary.h" @@ -284,6 +285,126 @@ ARROW_TESTING_EXPORT void AssertZeroPadded(const Array& array); ARROW_TESTING_EXPORT void TestInitialized(const ArrayData& array); ARROW_TESTING_EXPORT void TestInitialized(const Array& array); +template +class ResultMatcher { + public: + explicit ResultMatcher(ValueMatcher value_matcher) + : value_matcher_(std::move(value_matcher)) {} + + template ::type::ValueType> + operator testing::Matcher() const { // NOLINT runtime/explicit + struct Impl : testing::MatcherInterface { + explicit Impl(const ValueMatcher& value_matcher) + : value_matcher_(testing::MatcherCast(value_matcher)) {} + + void DescribeTo(::std::ostream* os) const override { + *os << "value "; + value_matcher_.DescribeTo(os); + } + + void DescribeNegationTo(::std::ostream* os) const override { + *os << "value "; + value_matcher_.DescribeNegationTo(os); + } + + bool MatchAndExplain(const Res& maybe_value, + testing::MatchResultListener* listener) const override { + if (!maybe_value.ok()) { + *listener << "whose error " + << testing::PrintToString(maybe_value.status().ToString()) + << " doesn't match"; + return false; + } + const ValueType& value = maybe_value.ValueOrDie(); + testing::StringMatchResultListener value_listener; + const bool match = value_matcher_.MatchAndExplain(value, &value_listener); + *listener << "whose value " << testing::PrintToString(value) + << (match ? " matches" : " doesn't match"); + testing::internal::PrintIfNotEmpty(value_listener.str(), listener->stream()); + return match; + } + + const testing::Matcher value_matcher_; + }; + + return testing::Matcher(new Impl(value_matcher_)); + } + + private: + const ValueMatcher value_matcher_; +}; + +class StatusMatcher { + public: + explicit StatusMatcher(StatusCode code, + util::optional> message_matcher) + : code_(code), message_matcher_(std::move(message_matcher)) {} + + template + operator testing::Matcher() const { // NOLINT runtime/explicit + struct Impl : testing::MatcherInterface { + explicit Impl(StatusCode code, + util::optional> message_matcher) + : code_(code), message_matcher_(std::move(message_matcher)) {} + + void DescribeTo(::std::ostream* os) const override { + *os << "raises StatusCode::" << Status::CodeAsString(code_); + if (message_matcher_) { + *os << " and message "; + message_matcher_->DescribeTo(os); + } + } + + void DescribeNegationTo(::std::ostream* os) const override { + *os << "does not raise StatusCode::" << Status::CodeAsString(code_); + if (message_matcher_) { + *os << " or message "; + message_matcher_->DescribeNegationTo(os); + } + } + + bool MatchAndExplain(const ResultOrStatus& result_or_status, + testing::MatchResultListener* listener) const override { + const Status& status = internal::GenericToStatus(result_or_status); + testing::StringMatchResultListener value_listener; + + bool match = status.code() == code_; + if (message_matcher_) { + match = match && + message_matcher_->MatchAndExplain(status.message(), &value_listener); + } + + *listener << "whose value " << testing::PrintToString(status.ToString()) + << (match ? " matches" : " doesn't match"); + testing::internal::PrintIfNotEmpty(value_listener.str(), listener->stream()); + return match; + } + + const StatusCode code_; + const util::optional> message_matcher_; + }; + + return testing::Matcher(new Impl(code_, message_matcher_)); + } + + const StatusCode code_; + const util::optional> message_matcher_; +}; + +template +ResultMatcher ResultWith(const ValueMatcher& value_matcher) { + return ResultMatcher(value_matcher); +} + +inline StatusMatcher Raises(StatusCode code) { + return StatusMatcher(code, util::nullopt); +} + +template +StatusMatcher Raises(StatusCode code, const MessageMatcher& message_matcher) { + return StatusMatcher(code, testing::MatcherCast(message_matcher)); +} + template void FinishAndCheckPadding(BuilderType* builder, std::shared_ptr* out) { ASSERT_OK_AND_ASSIGN(*out, builder->Finish()); diff --git a/cpp/src/arrow/util/async_generator_test.cc b/cpp/src/arrow/util/async_generator_test.cc index 61efb0043be..29c8d73ab6c 100644 --- a/cpp/src/arrow/util/async_generator_test.cc +++ b/cpp/src/arrow/util/async_generator_test.cc @@ -21,6 +21,7 @@ #include #include #include +#include #include "arrow/testing/future_util.h" #include "arrow/testing/gtest_util.h" @@ -88,8 +89,7 @@ std::function()> BackgroundAsyncVectorIt( auto slow_iterator = PossiblySlowVectorIt(v, sleep); EXPECT_OK_AND_ASSIGN( auto background, - MakeBackgroundGenerator(std::move(slow_iterator), - internal::GetCpuThreadPool(), max_q, q_restart)); + MakeBackgroundGenerator(std::move(slow_iterator), pool, max_q, q_restart)); return MakeTransferredGenerator(background, pool); } @@ -106,8 +106,7 @@ std::function()> NewBackgroundAsyncVectorIt(std::vector }); EXPECT_OK_AND_ASSIGN(auto background, - MakeBackgroundGenerator(std::move(slow_iterator), - internal::GetCpuThreadPool())); + MakeBackgroundGenerator(std::move(slow_iterator), pool)); return MakeTransferredGenerator(background, pool); } @@ -176,7 +175,8 @@ class ReentrantChecker { template class ReentrantCheckerGuard { public: - explicit ReentrantCheckerGuard(ReentrantChecker checker) : checker_(checker) {} + explicit ReentrantCheckerGuard(ReentrantChecker checker) + : checker_(std::move(checker)) {} ARROW_DISALLOW_COPY_AND_ASSIGN(ReentrantCheckerGuard); ReentrantCheckerGuard(ReentrantCheckerGuard&& other) : checker_(other.checker_) { diff --git a/cpp/src/arrow/util/vector.h b/cpp/src/arrow/util/vector.h index 3ef0074aa9d..041bdb424a7 100644 --- a/cpp/src/arrow/util/vector.h +++ b/cpp/src/arrow/util/vector.h @@ -84,27 +84,49 @@ std::vector FilterVector(std::vector values, Predicate&& predicate) { return values; } -/// \brief Like MapVector, but where the function can fail. -template , - typename To = typename internal::call_traits::return_type::ValueType> -Result> MaybeMapVector(Fn&& map, const std::vector& src) { +template ()(std::declval()))> +std::vector MapVector(Fn&& map, const std::vector& source) { std::vector out; - out.reserve(src.size()); - ARROW_RETURN_NOT_OK(MaybeTransform(src.begin(), src.end(), std::back_inserter(out), - std::forward(map))); - return std::move(out); + out.reserve(source.size()); + std::transform(source.begin(), source.end(), std::back_inserter(out), + std::forward(map)); + return out; } template ()(std::declval()))> -std::vector MapVector(Fn&& map, const std::vector& source) { +std::vector MapVector(Fn&& map, std::vector&& source) { std::vector out; out.reserve(source.size()); - std::transform(source.begin(), source.end(), std::back_inserter(out), + std::transform(std::make_move_iterator(source.begin()), + std::make_move_iterator(source.end()), std::back_inserter(out), std::forward(map)); return out; } +/// \brief Like MapVector, but where the function can fail. +template , + typename To = typename internal::call_traits::return_type::ValueType> +Result> MaybeMapVector(Fn&& map, const std::vector& source) { + std::vector out; + out.reserve(source.size()); + ARROW_RETURN_NOT_OK(MaybeTransform(source.begin(), source.end(), + std::back_inserter(out), std::forward(map))); + return std::move(out); +} + +template , + typename To = typename internal::call_traits::return_type::ValueType> +Result> MaybeMapVector(Fn&& map, std::vector&& source) { + std::vector out; + out.reserve(source.size()); + ARROW_RETURN_NOT_OK(MaybeTransform(std::make_move_iterator(source.begin()), + std::make_move_iterator(source.end()), + std::back_inserter(out), std::forward(map))); + return std::move(out); +} + template std::vector FlattenVectors(const std::vector>& vecs) { std::size_t sum = 0; From d0c9eac78d9327c52d41bff70ff52784d6f5c39f Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Mon, 14 Jun 2021 11:40:37 -0400 Subject: [PATCH 14/28] replace output_descr with output_schema for named fields --- cpp/src/arrow/compute/exec/exec_plan.cc | 58 ++++++----- cpp/src/arrow/compute/exec/exec_plan.h | 20 ++-- cpp/src/arrow/compute/exec/plan_test.cc | 132 ++++++++++++------------ cpp/src/arrow/compute/exec/test_util.cc | 49 +-------- cpp/src/arrow/compute/exec/test_util.h | 18 ---- cpp/src/arrow/dataset/scanner.cc | 7 +- cpp/src/arrow/dataset/scanner_test.cc | 4 +- 7 files changed, 120 insertions(+), 168 deletions(-) diff --git a/cpp/src/arrow/compute/exec/exec_plan.cc b/cpp/src/arrow/compute/exec/exec_plan.cc index a30ed2ca352..5f857b97441 100644 --- a/cpp/src/arrow/compute/exec/exec_plan.cc +++ b/cpp/src/arrow/compute/exec/exec_plan.cc @@ -167,13 +167,13 @@ Status ExecPlan::Validate() { return ToDerived(this)->Validate(); } Status ExecPlan::StartProducing() { return ToDerived(this)->StartProducing(); } ExecNode::ExecNode(ExecPlan* plan, std::string label, NodeVector inputs, - std::vector input_labels, BatchDescr output_descr, - int num_outputs) + std::vector input_labels, + std::shared_ptr output_schema, int num_outputs) : plan_(plan), label_(std::move(label)), inputs_(std::move(inputs)), input_labels_(std::move(input_labels)), - output_descr_(std::move(output_descr)), + output_schema_(std::move(output_schema)), num_outputs_(num_outputs) { for (auto input : inputs_) { input->outputs_.push_back(this); @@ -203,9 +203,9 @@ Status ExecNode::Validate() const { } struct SourceNode : ExecNode { - SourceNode(ExecPlan* plan, std::string label, ExecNode::BatchDescr output_descr, + SourceNode(ExecPlan* plan, std::string label, std::shared_ptr output_schema, AsyncGenerator> generator) - : ExecNode(plan, std::move(label), {}, {}, std::move(output_descr), + : ExecNode(plan, std::move(label), {}, {}, std::move(output_schema), /*num_outputs=*/1), generator_(std::move(generator)) {} @@ -288,16 +288,16 @@ struct SourceNode : ExecNode { }; ExecNode* MakeSourceNode(ExecPlan* plan, std::string label, - ExecNode::BatchDescr output_descr, + std::shared_ptr output_schema, AsyncGenerator> generator) { - return plan->EmplaceNode(plan, std::move(label), std::move(output_descr), + return plan->EmplaceNode(plan, std::move(label), std::move(output_schema), std::move(generator)); } struct FilterNode : ExecNode { FilterNode(ExecNode* input, std::string label, Expression filter) : ExecNode(input->plan(), std::move(label), {input}, {"target"}, - /*output_descr=*/{input->output_descr()}, + /*output_schema=*/input->output_schema(), /*num_outputs=*/1), filter_(std::move(filter)) {} @@ -352,10 +352,7 @@ struct FilterNode : ExecNode { outputs_[0]->InputFinished(this, seq); } - Status StartProducing() override { - // XXX validate inputs_[0]->output_descr() against filter_ - return Status::OK(); - } + Status StartProducing() override { return Status::OK(); } void PauseProducing(ExecNode* output) override {} @@ -372,15 +369,24 @@ struct FilterNode : ExecNode { Expression filter_; }; -ExecNode* MakeFilterNode(ExecNode* input, std::string label, Expression filter) { +Result MakeFilterNode(ExecNode* input, std::string label, Expression filter) { + ARROW_ASSIGN_OR_RAISE(filter, filter.Bind(*input->output_schema())); + + if (filter.type()->id() != Type::BOOL) { + return Status::TypeError("Filter expression must evaluate to bool, but ", + filter.ToString(), " evaluates to ", + filter.type()->ToString()); + } + return input->plan()->EmplaceNode(input, std::move(label), std::move(filter)); } struct ProjectNode : ExecNode { - ProjectNode(ExecNode* input, std::string label, std::vector exprs) + ProjectNode(ExecNode* input, std::string label, std::shared_ptr output_schema, + std::vector exprs) : ExecNode(input->plan(), std::move(label), {input}, {"target"}, - /*output_descr=*/{input->output_descr()}, + /*output_schema=*/std::move(output_schema), /*num_outputs=*/1), exprs_(std::move(exprs)) {} @@ -423,10 +429,7 @@ struct ProjectNode : ExecNode { outputs_[0]->InputFinished(this, seq); } - Status StartProducing() override { - // XXX validate inputs_[0]->output_descr() against filter_ - return Status::OK(); - } + Status StartProducing() override { return Status::OK(); } void PauseProducing(ExecNode* output) override {} @@ -443,10 +446,19 @@ struct ProjectNode : ExecNode { std::vector exprs_; }; -ExecNode* MakeProjectNode(ExecNode* input, std::string label, - std::vector exprs) { - return input->plan()->EmplaceNode(input, std::move(label), - std::move(exprs)); +Result MakeProjectNode(ExecNode* input, std::string label, + std::vector exprs) { + FieldVector fields(exprs.size()); + + int i = 0; + for (auto& expr : exprs) { + ARROW_ASSIGN_OR_RAISE(expr, expr.Bind(*input->output_schema())); + fields[i] = field(expr.ToString(), expr.type()); + ++i; + } + + return input->plan()->EmplaceNode( + input, std::move(label), schema(std::move(fields)), std::move(exprs)); } struct SinkNode : ExecNode { diff --git a/cpp/src/arrow/compute/exec/exec_plan.h b/cpp/src/arrow/compute/exec/exec_plan.h index b3e3823c76d..44da87d2c7d 100644 --- a/cpp/src/arrow/compute/exec/exec_plan.h +++ b/cpp/src/arrow/compute/exec/exec_plan.h @@ -80,7 +80,6 @@ class ARROW_EXPORT ExecPlan : public std::enable_shared_from_this { class ARROW_EXPORT ExecNode { public: using NodeVector = std::vector; - using BatchDescr = std::vector; virtual ~ExecNode() = default; @@ -100,7 +99,7 @@ class ARROW_EXPORT ExecNode { const NodeVector& outputs() const { return outputs_; } /// The datatypes for batches produced by this node - const BatchDescr& output_descr() const { return output_descr_; } + const std::shared_ptr& output_schema() const { return output_schema_; } /// This node's exec plan ExecPlan* plan() { return plan_; } @@ -214,7 +213,7 @@ class ARROW_EXPORT ExecNode { protected: ExecNode(ExecPlan* plan, std::string label, NodeVector inputs, - std::vector input_labels, BatchDescr output_descr, + std::vector input_labels, std::shared_ptr output_schema, int num_outputs); ExecPlan* plan_; @@ -223,7 +222,7 @@ class ARROW_EXPORT ExecNode { NodeVector inputs_; std::vector input_labels_; - BatchDescr output_descr_; + std::shared_ptr output_schema_; int num_outputs_; NodeVector outputs_; }; @@ -233,7 +232,8 @@ class ARROW_EXPORT ExecNode { /// TODO this should accept an Executor and explicitly handle batches /// as they are generated on each of the Executor's threads. ARROW_EXPORT -ExecNode* MakeSourceNode(ExecPlan*, std::string label, ExecNode::BatchDescr output_descr, +ExecNode* MakeSourceNode(ExecPlan*, std::string label, + std::shared_ptr output_schema, std::function>()>); /// \brief Add a sink node which forwards to an AsyncGenerator @@ -248,19 +248,17 @@ std::function>()> MakeSinkNode(ExecNode* input, /// /// The filter Expression will be evaluated against each batch which is pushed to /// this node. Any rows for which the filter does not evaluate to `true` will be excluded -/// in the batch emitted by this node. Filter Expression must be bound; no field -/// references will be looked up by name +/// in the batch emitted by this node. ARROW_EXPORT -ExecNode* MakeFilterNode(ExecNode* input, std::string label, Expression filter); +Result MakeFilterNode(ExecNode* input, std::string label, Expression filter); /// \brief Make a node which executes expressions on input batches, producing new batches. /// /// Each expression will be evaluated against each batch which is pushed to /// this node to produce a corresponding output column. -/// Expressions must be bound; no field references will be looked up by name ARROW_EXPORT -ExecNode* MakeProjectNode(ExecNode* input, std::string label, - std::vector exprs); +Result MakeProjectNode(ExecNode* input, std::string label, + std::vector exprs); } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/exec/plan_test.cc b/cpp/src/arrow/compute/exec/plan_test.cc index 8cb1c4c1f7d..dc5845977da 100644 --- a/cpp/src/arrow/compute/exec/plan_test.cc +++ b/cpp/src/arrow/compute/exec/plan_test.cc @@ -32,10 +32,12 @@ #include "arrow/util/thread_pool.h" #include "arrow/util/vector.h" -namespace arrow { - +using testing::ElementsAre; +using testing::HasSubstr; using testing::UnorderedElementsAreArray; +namespace arrow { + namespace compute { ExecBatch ExecBatchFromJSON(const std::vector& descrs, @@ -70,8 +72,8 @@ TEST(ExecPlanConstruction, SingleNode) { ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); auto node = MakeDummyNode(plan.get(), "dummy", /*inputs=*/{}, /*num_outputs=*/0); ASSERT_OK(plan->Validate()); - ASSERT_THAT(plan->sources(), ::testing::ElementsAre(node)); - ASSERT_THAT(plan->sinks(), ::testing::ElementsAre(node)); + ASSERT_THAT(plan->sources(), ElementsAre(node)); + ASSERT_THAT(plan->sinks(), ElementsAre(node)); ASSERT_OK_AND_ASSIGN(plan, ExecPlan::Make()); node = MakeDummyNode(plan.get(), "dummy", /*inputs=*/{}, /*num_outputs=*/1); @@ -85,8 +87,8 @@ TEST(ExecPlanConstruction, SourceSink) { auto sink = MakeDummyNode(plan.get(), "sink", /*inputs=*/{source}, /*num_outputs=*/0); ASSERT_OK(plan->Validate()); - EXPECT_THAT(plan->sources(), ::testing::ElementsAre(source)); - EXPECT_THAT(plan->sinks(), ::testing::ElementsAre(sink)); + EXPECT_THAT(plan->sources(), ElementsAre(source)); + EXPECT_THAT(plan->sinks(), ElementsAre(sink)); } TEST(ExecPlanConstruction, MultipleNode) { @@ -109,8 +111,8 @@ TEST(ExecPlanConstruction, MultipleNode) { auto sink = MakeDummyNode(plan.get(), "sink", /*inputs=*/{process3}, /*num_outputs=*/0); ASSERT_OK(plan->Validate()); - ASSERT_THAT(plan->sources(), ::testing::ElementsAre(source1, source2)); - ASSERT_THAT(plan->sinks(), ::testing::ElementsAre(sink)); + ASSERT_THAT(plan->sources(), ElementsAre(source1, source2)); + ASSERT_THAT(plan->sinks(), ElementsAre(sink)); } struct StartStopTracker { @@ -160,8 +162,8 @@ TEST(ExecPlan, DummyStartProducing) { ASSERT_OK(plan->StartProducing()); // Note that any correct reverse topological order may do - ASSERT_THAT(t.started, ::testing::ElementsAre("sink", "process3", "process2", - "process1", "source2", "source1")); + ASSERT_THAT(t.started, ElementsAre("sink", "process3", "process2", "process1", + "source2", "source1")); ASSERT_EQ(t.stopped.size(), 0); } @@ -198,21 +200,26 @@ TEST(ExecPlan, DummyStartProducingError) { // `process1` raises IOError ASSERT_RAISES(IOError, plan->StartProducing()); - ASSERT_THAT(t.started, - ::testing::ElementsAre("sink", "process3", "process2", "process1")); + ASSERT_THAT(t.started, ElementsAre("sink", "process3", "process2", "process1")); // Nodes that started successfully were stopped in reverse order - ASSERT_THAT(t.stopped, ::testing::ElementsAre("process2", "process3", "sink")); + ASSERT_THAT(t.stopped, ElementsAre("process2", "process3", "sink")); } -static Result MakeTestSourceNode(ExecPlan* plan, std::string label, - std::vector batches, bool parallel, - bool slow) { - DCHECK_GT(batches.size(), 0); - auto out_descr = batches.back().GetDescriptors(); +namespace { + +struct BatchesWithSchema { + std::vector batches; + std::shared_ptr schema; +}; + +Result MakeTestSourceNode(ExecPlan* plan, std::string label, + BatchesWithSchema batches_with_schema, bool parallel, + bool slow) { + DCHECK_GT(batches_with_schema.batches.size(), 0); auto opt_batches = internal::MapVector( [](ExecBatch batch) { return util::make_optional(std::move(batch)); }, - std::move(batches)); + std::move(batches_with_schema.batches)); AsyncGenerator> gen; @@ -232,10 +239,11 @@ static Result MakeTestSourceNode(ExecPlan* plan, std::string label, }); } - return MakeSourceNode(plan, label, out_descr, std::move(gen)); + return MakeSourceNode(plan, label, std::move(batches_with_schema.schema), + std::move(gen)); } -static Result> StartAndCollect( +Result> StartAndCollect( ExecPlan* plan, AsyncGenerator> gen) { RETURN_NOT_OK(plan->Validate()); RETURN_NOT_OK(plan->StartProducing()); @@ -249,24 +257,30 @@ static Result> StartAndCollect( [](util::optional batch) { return std::move(*batch); }, collected); } -static std::vector MakeBasicBatches() { - return {ExecBatchFromJSON({int32(), boolean()}, "[[null, true], [4, false]]"), - ExecBatchFromJSON({int32(), boolean()}, "[[5, null], [6, false], [7, false]]")}; +BatchesWithSchema MakeBasicBatches() { + BatchesWithSchema out; + out.batches = { + ExecBatchFromJSON({int32(), boolean()}, "[[null, true], [4, false]]"), + ExecBatchFromJSON({int32(), boolean()}, "[[5, null], [6, false], [7, false]]")}; + out.schema = schema({field("i32", int32()), field("bool", boolean())}); + return out; } -static std::vector MakeRandomBatches(const std::shared_ptr& schema, - int num_batches = 10, - int batch_size = 4) { +BatchesWithSchema MakeRandomBatches(const std::shared_ptr& schema, + int num_batches = 10, int batch_size = 4) { + BatchesWithSchema out; + random::RandomArrayGenerator rng(42); - std::vector batches(num_batches); + out.batches.resize(num_batches); for (int i = 0; i < num_batches; ++i) { - batches[i] = ExecBatch(*rng.BatchOf(schema->fields(), batch_size)); + out.batches[i] = ExecBatch(*rng.BatchOf(schema->fields(), batch_size)); // add a tag scalar to ensure the batches are unique - batches[i].values.emplace_back(i); + out.batches[i].values.emplace_back(i); } - return batches; + return out; } +} // namespace TEST(ExecPlanExecution, SourceSink) { for (bool slow : {false, true}) { @@ -277,15 +291,15 @@ TEST(ExecPlanExecution, SourceSink) { ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); - auto batches = MakeBasicBatches(); + auto basic_data = MakeBasicBatches(); - ASSERT_OK_AND_ASSIGN( - auto source, MakeTestSourceNode(plan.get(), "source", batches, parallel, slow)); + ASSERT_OK_AND_ASSIGN(auto source, MakeTestSourceNode(plan.get(), "source", + basic_data, parallel, slow)); auto sink_gen = MakeSinkNode(source, "sink"); ASSERT_THAT(StartAndCollect(plan.get(), sink_gen), - ResultWith(UnorderedElementsAreArray(batches))); + ResultWith(UnorderedElementsAreArray(basic_data.batches))); } } } @@ -293,11 +307,11 @@ TEST(ExecPlanExecution, SourceSink) { TEST(ExecPlanExecution, SourceSinkError) { ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); - auto batches = MakeBasicBatches(); - auto it = batches.begin(); + auto basic_data = MakeBasicBatches(); + auto it = basic_data.batches.begin(); AsyncGenerator> gen = [&]() -> Result> { - if (it == batches.end()) { + if (it == basic_data.batches.end()) { return Status::Invalid("Artificial error"); } return util::make_optional(*it++); @@ -307,7 +321,7 @@ TEST(ExecPlanExecution, SourceSinkError) { auto sink_gen = MakeSinkNode(source, "sink"); ASSERT_THAT(StartAndCollect(plan.get(), sink_gen), - Raises(StatusCode::Invalid, testing::HasSubstr("Artificial"))); + Raises(StatusCode::Invalid, HasSubstr("Artificial"))); } TEST(ExecPlanExecution, StressSourceSink) { @@ -321,16 +335,16 @@ TEST(ExecPlanExecution, StressSourceSink) { ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); - auto batches = MakeRandomBatches( + auto random_data = MakeRandomBatches( schema({field("a", int32()), field("b", boolean())}), num_batches); - ASSERT_OK_AND_ASSIGN( - auto source, MakeTestSourceNode(plan.get(), "source", batches, parallel, slow)); + ASSERT_OK_AND_ASSIGN(auto source, MakeTestSourceNode(plan.get(), "source", + random_data, parallel, slow)); auto sink_gen = MakeSinkNode(source, "sink"); ASSERT_THAT(StartAndCollect(plan.get(), sink_gen), - ResultWith(UnorderedElementsAreArray(batches))); + ResultWith(UnorderedElementsAreArray(random_data.batches))); } } } @@ -338,21 +352,16 @@ TEST(ExecPlanExecution, StressSourceSink) { TEST(ExecPlanExecution, SourceFilterSink) { ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); - auto batches = MakeBasicBatches(); + auto basic_data = MakeBasicBatches(); ASSERT_OK_AND_ASSIGN(auto source, - MakeTestSourceNode(plan.get(), "source", batches, + MakeTestSourceNode(plan.get(), "source", basic_data, /*parallel=*/false, /*slow=*/false)); - const auto schema = ::arrow::schema({ - field("a", int32()), - field("b", boolean()), - field("__tag", int32()), - }); + ASSERT_OK_AND_ASSIGN(auto predicate, + equal(field_ref("i32"), literal(6)).Bind(*basic_data.schema)); - ASSERT_OK_AND_ASSIGN(auto predicate, equal(field_ref("a"), literal(6)).Bind(*schema)); - - auto filter = MakeFilterNode(source, "filter", predicate); + ASSERT_OK_AND_ASSIGN(auto filter, MakeFilterNode(source, "filter", predicate)); auto sink_gen = MakeSinkNode(filter, "sink"); @@ -365,26 +374,21 @@ TEST(ExecPlanExecution, SourceFilterSink) { TEST(ExecPlanExecution, SourceProjectSink) { ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); - auto batches = MakeBasicBatches(); + auto basic_data = MakeBasicBatches(); ASSERT_OK_AND_ASSIGN(auto source, - MakeTestSourceNode(plan.get(), "source", batches, + MakeTestSourceNode(plan.get(), "source", basic_data, /*parallel=*/false, /*slow=*/false)); - const auto schema = ::arrow::schema({ - field("a", int32()), - field("b", boolean()), - }); - std::vector exprs{ - not_(field_ref("b")), - call("add", {field_ref("a"), literal(1)}), + not_(field_ref("bool")), + call("add", {field_ref("i32"), literal(1)}), }; for (auto& expr : exprs) { - ASSERT_OK_AND_ASSIGN(expr, expr.Bind(*schema)); + ASSERT_OK_AND_ASSIGN(expr, expr.Bind(*basic_data.schema)); } - auto projection = MakeProjectNode(source, "project", exprs); + ASSERT_OK_AND_ASSIGN(auto projection, MakeProjectNode(source, "project", exprs)); auto sink_gen = MakeSinkNode(projection, "sink"); diff --git a/cpp/src/arrow/compute/exec/test_util.cc b/cpp/src/arrow/compute/exec/test_util.cc index fda39a66d28..e5bf61a1808 100644 --- a/cpp/src/arrow/compute/exec/test_util.cc +++ b/cpp/src/arrow/compute/exec/test_util.cc @@ -45,27 +45,13 @@ namespace arrow { using internal::Executor; namespace compute { - -void AssertBatchesEqual(const ExecBatch& expected, const ExecBatch& actual) { - ASSERT_THAT(actual.values, testing::ElementsAreArray(expected.values)); -} - namespace { -// TODO expose this as `static ValueDescr::FromSchemaColumns`? -std::vector DescrFromSchemaColumns(const Schema& schema) { - std::vector descr(schema.num_fields()); - std::transform(schema.fields().begin(), schema.fields().end(), descr.begin(), - [](const std::shared_ptr& field) { - return ValueDescr::Array(field->type()); - }); - return descr; -} - struct DummyNode : ExecNode { DummyNode(ExecPlan* plan, std::string label, NodeVector inputs, int num_outputs, StartProducingFunc start_producing, StopProducingFunc stop_producing) - : ExecNode(plan, std::move(label), std::move(inputs), {}, descr(), num_outputs), + : ExecNode(plan, std::move(label), std::move(inputs), {}, dummy_schema(), + num_outputs), start_producing_(std::move(start_producing)), stop_producing_(std::move(stop_producing)) { input_labels_.resize(inputs_.size()); @@ -123,42 +109,17 @@ struct DummyNode : ExecNode { ASSERT_NE(std::find(outputs_.begin(), outputs_.end(), output), outputs_.end()); } - BatchDescr descr() const { return std::vector{ValueDescr(null())}; } + std::shared_ptr dummy_schema() const { + return schema({field("dummy", null())}); + } StartProducingFunc start_producing_; StopProducingFunc stop_producing_; bool started_ = false; }; -AsyncGenerator> Wrap(RecordBatchGenerator gen, - ::arrow::internal::Executor* io_executor) { - return MakeMappedGenerator( - MakeTransferredGenerator(std::move(gen), io_executor), - [](const std::shared_ptr& batch) -> util::optional { - return ExecBatch(*batch); - }); -} - } // namespace -ExecNode* MakeRecordBatchReaderNode(ExecPlan* plan, std::string label, - const std::shared_ptr& schema, - RecordBatchGenerator generator, - ::arrow::internal::Executor* io_executor) { - return MakeSourceNode(plan, std::move(label), DescrFromSchemaColumns(*schema), - Wrap(std::move(generator), io_executor)); -} - -ExecNode* MakeRecordBatchReaderNode(ExecPlan* plan, std::string label, - const std::shared_ptr& reader, - Executor* io_executor) { - auto gen = - MakeBackgroundGenerator(MakeIteratorFromReader(reader), io_executor).ValueOrDie(); - - return MakeRecordBatchReaderNode(plan, std::move(label), reader->schema(), - std::move(gen), io_executor); -} - ExecNode* MakeDummyNode(ExecPlan* plan, std::string label, std::vector inputs, int num_outputs, StartProducingFunc start_producing, StopProducingFunc stop_producing) { diff --git a/cpp/src/arrow/compute/exec/test_util.h b/cpp/src/arrow/compute/exec/test_util.h index 543df257353..60423548614 100644 --- a/cpp/src/arrow/compute/exec/test_util.h +++ b/cpp/src/arrow/compute/exec/test_util.h @@ -40,23 +40,5 @@ ARROW_TESTING_EXPORT ExecNode* MakeDummyNode(ExecPlan* plan, std::string label, std::vector inputs, int num_outputs, StartProducingFunc = {}, StopProducingFunc = {}); -using RecordBatchGenerator = AsyncGenerator>; - -// Make a source node (no inputs) that produces record batches by reading in the -// background from a RecordBatchReader. -ARROW_TESTING_EXPORT -ExecNode* MakeRecordBatchReaderNode(ExecPlan* plan, std::string label, - const std::shared_ptr& reader, - ::arrow::internal::Executor* io_executor); - -ARROW_TESTING_EXPORT -ExecNode* MakeRecordBatchReaderNode(ExecPlan* plan, std::string label, - const std::shared_ptr& schema, - RecordBatchGenerator generator, - ::arrow::internal::Executor* io_executor); - -ARROW_TESTING_EXPORT void AssertBatchesEqual(const ExecBatch& expected, - const ExecBatch& actual); - } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/dataset/scanner.cc b/cpp/src/arrow/dataset/scanner.cc index a2e70c0bfe8..a622595c8f5 100644 --- a/cpp/src/arrow/dataset/scanner.cc +++ b/cpp/src/arrow/dataset/scanner.cc @@ -1136,12 +1136,7 @@ Result ScannerBuilder::MakeScanNode(compute::ExecPlan* plan) return batch; }); - std::vector output_descr; - for (const auto& field : schema->fields()) { - output_descr.push_back(ValueDescr::Array(field->type())); - } - - return MakeSourceNode(plan, "dataset_scan", std::move(output_descr), std::move(gen)); + return MakeSourceNode(plan, "dataset_scan", schema, std::move(gen)); } } // namespace dataset diff --git a/cpp/src/arrow/dataset/scanner_test.cc b/cpp/src/arrow/dataset/scanner_test.cc index 8805eace9a5..f527e86f57a 100644 --- a/cpp/src/arrow/dataset/scanner_test.cc +++ b/cpp/src/arrow/dataset/scanner_test.cc @@ -1273,8 +1273,8 @@ TEST(ScanNode, MaterializationOfVirtualColumn) { ScannerBuilder scanner_builder(dataset); ASSERT_OK(scanner_builder.UseAsync(true)); ASSERT_OK_AND_ASSIGN(auto scan, scanner_builder.MakeScanNode(plan.get())); - auto project = compute::MakeProjectNode( - scan, "project", {field_ref("c").Bind(*dataset_schema).ValueOrDie()}); + ASSERT_OK_AND_ASSIGN(auto project, + compute::MakeProjectNode(scan, "project", {field_ref("c")})); auto sink_gen = MakeSinkNode(project, "sink"); std::vector expected; From 86cfce50982c3c6d0c043caba6fae006ab338699 Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Mon, 14 Jun 2021 18:08:02 -0400 Subject: [PATCH 15/28] repair r/src/dataset.cpp --- r/src/dataset.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/r/src/dataset.cpp b/r/src/dataset.cpp index 24c1a1343ea..7bb1e639e05 100644 --- a/r/src/dataset.cpp +++ b/r/src/dataset.cpp @@ -70,9 +70,9 @@ const char* r6_class_name::get( // [[dataset::export]] std::shared_ptr dataset___Dataset__NewScan( const std::shared_ptr& ds) { - auto options = std::make_shared(); - options->pool = gc_memory_pool(); - return ValueOrStop(ds->NewScan(std::move(options))); + auto builder = ValueOrStop(ds->NewScan()); + StopIfNotOk(builder->Pool(gc_memory_pool())); + return builder; } // [[dataset::export]] From 55df44d19b4ea0d44db4c7c51a166dac3c0ce366 Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Tue, 15 Jun 2021 16:09:40 -0400 Subject: [PATCH 16/28] add accessor to check for thread membership --- cpp/src/arrow/util/thread_pool.cc | 9 ++++++++- cpp/src/arrow/util/thread_pool.h | 6 ++++++ cpp/src/arrow/util/thread_pool_test.cc | 16 ++++++++++++++++ 3 files changed, 30 insertions(+), 1 deletion(-) diff --git a/cpp/src/arrow/util/thread_pool.cc b/cpp/src/arrow/util/thread_pool.cc index 672839b67d5..758295d01ed 100644 --- a/cpp/src/arrow/util/thread_pool.cc +++ b/cpp/src/arrow/util/thread_pool.cc @@ -321,13 +321,20 @@ void ThreadPool::CollectFinishedWorkersUnlocked() { state_->finished_workers_.clear(); } +thread_local ThreadPool* current_thread_pool_ = nullptr; + +bool ThreadPool::OwnsThisThread() { return current_thread_pool_ == this; } + void ThreadPool::LaunchWorkersUnlocked(int threads) { std::shared_ptr state = sp_state_; for (int i = 0; i < threads; i++) { state_->workers_.emplace_back(); auto it = --(state_->workers_.end()); - *it = std::thread([state, it] { WorkerLoop(state, it); }); + *it = std::thread([this, state, it] { + current_thread_pool_ = this; + WorkerLoop(state, it); + }); } } diff --git a/cpp/src/arrow/util/thread_pool.h b/cpp/src/arrow/util/thread_pool.h index d012aa02010..febbc997852 100644 --- a/cpp/src/arrow/util/thread_pool.h +++ b/cpp/src/arrow/util/thread_pool.h @@ -179,6 +179,10 @@ class ARROW_EXPORT Executor { // concurrently). This may be an approximate number. virtual int GetCapacity() = 0; + // Return true if the thread from which this function is called is owned by this + // Executor. Returns false if this Executor does not support this property. + virtual bool OwnsThisThread() { return false; } + protected: ARROW_DISALLOW_COPY_AND_ASSIGN(Executor); @@ -298,6 +302,8 @@ class ARROW_EXPORT ThreadPool : public Executor { // match this value. int GetCapacity() override; + bool OwnsThisThread() override; + // Return the number of tasks either running or in the queue. int GetNumTasks(); diff --git a/cpp/src/arrow/util/thread_pool_test.cc b/cpp/src/arrow/util/thread_pool_test.cc index 2cfb4c62613..92f6a8ac00a 100644 --- a/cpp/src/arrow/util/thread_pool_test.cc +++ b/cpp/src/arrow/util/thread_pool_test.cc @@ -395,6 +395,22 @@ TEST_F(TestThreadPool, StressSpawn) { SpawnAdds(pool.get(), 1000, task_add); } +TEST_F(TestThreadPool, OwnsCurrentThread) { + auto pool = this->MakeThreadPool(30); + std::atomic one_failed{false}; + + for (int i = 0; i < 1000; ++i) { + ASSERT_OK(pool->Spawn([&] { + if (pool->OwnsThisThread()) return; + + one_failed = true; + })); + } + + ASSERT_OK(pool->Shutdown()); + ASSERT_FALSE(one_failed); +} + TEST_F(TestThreadPool, StressSpawnThreaded) { auto pool = this->MakeThreadPool(30); SpawnAddsThreaded(pool.get(), 20, 100, task_add); From 9410b1036df01d047b970229fce27b4544a3f68f Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Tue, 15 Jun 2021 18:19:32 -0400 Subject: [PATCH 17/28] add support for Future<> to ResultWith, Raises --- cpp/src/arrow/testing/gtest_util.h | 48 ++++++++++++++++++++++++------ cpp/src/arrow/util/future_test.cc | 40 +++++++++++++++++++++++++ 2 files changed, 79 insertions(+), 9 deletions(-) diff --git a/cpp/src/arrow/testing/gtest_util.h b/cpp/src/arrow/testing/gtest_util.h index 499f6adf4a6..b1a4151d6c5 100644 --- a/cpp/src/arrow/testing/gtest_util.h +++ b/cpp/src/arrow/testing/gtest_util.h @@ -291,7 +291,8 @@ class ResultMatcher { explicit ResultMatcher(ValueMatcher value_matcher) : value_matcher_(std::move(value_matcher)) {} - template ::type::ValueType> + template ::type::ValueType> operator testing::Matcher() const { // NOLINT runtime/explicit struct Impl : testing::MatcherInterface { explicit Impl(const ValueMatcher& value_matcher) @@ -309,13 +310,13 @@ class ResultMatcher { bool MatchAndExplain(const Res& maybe_value, testing::MatchResultListener* listener) const override { - if (!maybe_value.ok()) { + if (!maybe_value.status().ok()) { *listener << "whose error " << testing::PrintToString(maybe_value.status().ToString()) << " doesn't match"; return false; } - const ValueType& value = maybe_value.ValueOrDie(); + const ValueType& value = GetValue(maybe_value); testing::StringMatchResultListener value_listener; const bool match = value_matcher_.MatchAndExplain(value, &value_listener); *listener << "whose value " << testing::PrintToString(value) @@ -331,6 +332,16 @@ class ResultMatcher { } private: + template + static const T& GetValue(const Result& maybe_value) { + return maybe_value.ValueOrDie(); + } + + template + static const T& GetValue(const Future& value_fut) { + return GetValue(value_fut.result()); + } + const ValueMatcher value_matcher_; }; @@ -340,9 +351,9 @@ class StatusMatcher { util::optional> message_matcher) : code_(code), message_matcher_(std::move(message_matcher)) {} - template - operator testing::Matcher() const { // NOLINT runtime/explicit - struct Impl : testing::MatcherInterface { + template + operator testing::Matcher() const { // NOLINT runtime/explicit + struct Impl : testing::MatcherInterface { explicit Impl(StatusCode code, util::optional> message_matcher) : code_(code), message_matcher_(std::move(message_matcher)) {} @@ -363,9 +374,9 @@ class StatusMatcher { } } - bool MatchAndExplain(const ResultOrStatus& result_or_status, + bool MatchAndExplain(const Res& maybe_value, testing::MatchResultListener* listener) const override { - const Status& status = internal::GenericToStatus(result_or_status); + const Status& status = GetStatus(maybe_value); testing::StringMatchResultListener value_listener; bool match = status.code() == code_; @@ -384,22 +395,41 @@ class StatusMatcher { const util::optional> message_matcher_; }; - return testing::Matcher(new Impl(code_, message_matcher_)); + return testing::Matcher(new Impl(code_, message_matcher_)); + } + + private: + static const Status& GetStatus(const Status& status) { return status; } + + template + static const Status& GetStatus(const Result& maybe_value) { + return maybe_value.status(); + } + + template + static const Status& GetStatus(const Future& value_fut) { + return value_fut.status(); } const StatusCode code_; const util::optional> message_matcher_; }; +// Returns a matcher that matches the value of a successful Result or Future. +// (Future will be waited upon to acquire its result for matching.) template ResultMatcher ResultWith(const ValueMatcher& value_matcher) { return ResultMatcher(value_matcher); } +// Returns a matcher that matches the StatusCode of a Status, Result, or Future. +// (Future will be waited upon to acquire its result for matching.) inline StatusMatcher Raises(StatusCode code) { return StatusMatcher(code, util::nullopt); } +// Returns a matcher that matches the StatusCode and message of a Status, Result, or +// Future. (Future will be waited upon to acquire its result for matching.) template StatusMatcher Raises(StatusCode code, const MessageMatcher& message_matcher) { return StatusMatcher(code, testing::MatcherCast(message_matcher)); diff --git a/cpp/src/arrow/util/future_test.cc b/cpp/src/arrow/util/future_test.cc index 33796a05bb1..773ac09e359 100644 --- a/cpp/src/arrow/util/future_test.cc +++ b/cpp/src/arrow/util/future_test.cc @@ -1704,5 +1704,45 @@ TEST(FnOnceTest, MoveOnlyDataType) { ASSERT_EQ(i0.moves, 0); ASSERT_EQ(i1.moves, 0); } + +TEST(FutureTest, MatcherExamples) { + EXPECT_THAT(Future::MakeFinished(Status::Invalid("arbitrary error")), + Raises(StatusCode::Invalid)); + + EXPECT_THAT(Future::MakeFinished(Status::Invalid("arbitrary error")), + Raises(StatusCode::Invalid, testing::HasSubstr("arbitrary"))); + + // message doesn't match, so no match + EXPECT_THAT( + Future::MakeFinished(Status::Invalid("arbitrary error")), + testing::Not(Raises(StatusCode::Invalid, testing::HasSubstr("reasonable")))); + + // different error code, so no match + EXPECT_THAT(Future::MakeFinished(Status::TypeError("arbitrary error")), + testing::Not(Raises(StatusCode::Invalid))); + + // not an error, so no match + EXPECT_THAT(Future::MakeFinished(333), testing::Not(Raises(StatusCode::Invalid))); + + EXPECT_THAT(Future::MakeFinished("hello world"), + ResultWith(testing::HasSubstr("hello"))); + + // Matcher waits on Futures + auto string_fut = Future::Make(); + auto finisher = std::thread([&] { + SleepABit(); + string_fut.MarkFinished("hello world"); + }); + EXPECT_THAT(string_fut, ResultWith(testing::HasSubstr("hello"))); + finisher.join(); + + EXPECT_THAT(Future::MakeFinished(Status::Invalid("XXX")), + testing::Not(ResultWith(testing::HasSubstr("hello")))); + + // holds a value, but that value doesn't match the given pattern + EXPECT_THAT(Future::MakeFinished("foo bar"), + testing::Not(ResultWith(testing::HasSubstr("hello")))); +} + } // namespace internal } // namespace arrow From 1de1a1471807e14ba3224b518beb9194cbe62b69 Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Wed, 16 Jun 2021 16:49:29 -0400 Subject: [PATCH 18/28] replaced ScanBatchesUnorderedAsync but it hangs --- cpp/src/arrow/compute/exec/exec_plan.cc | 15 +- cpp/src/arrow/compute/exec/exec_plan.h | 4 + cpp/src/arrow/dataset/file_csv.cc | 4 +- cpp/src/arrow/dataset/file_parquet_test.cc | 27 +++ cpp/src/arrow/dataset/scanner.cc | 82 ++++++- cpp/src/arrow/dataset/scanner.h | 14 +- cpp/src/arrow/dataset/scanner_test.cc | 256 ++++++++++----------- cpp/src/arrow/dataset/test_util.h | 23 +- 8 files changed, 260 insertions(+), 165 deletions(-) diff --git a/cpp/src/arrow/compute/exec/exec_plan.cc b/cpp/src/arrow/compute/exec/exec_plan.cc index 5f857b97441..7438af78b8f 100644 --- a/cpp/src/arrow/compute/exec/exec_plan.cc +++ b/cpp/src/arrow/compute/exec/exec_plan.cc @@ -239,6 +239,10 @@ struct SourceNode : ExecNode { } lock.unlock(); + // TODO check if we are on the desired Executor and transfer if not. + // This can happen for in-memory scans where batches didn't require + // any CPU work to decode. Otherwise, parsing etc should have already + // been placed us on the thread pool outputs_[0]->InputReceived(this, seq, *batch); return Continue(); }, @@ -273,7 +277,6 @@ struct SourceNode : ExecNode { std::unique_lock lock(mutex_); finished_ = true; } - DCHECK(finished_fut_.is_valid()); finished_fut_.Wait(); } @@ -283,7 +286,7 @@ struct SourceNode : ExecNode { std::mutex mutex_; bool finished_{false}; int next_batch_index_{0}; - Future<> finished_fut_; + Future<> finished_fut_ = Future<>::MakeFinished(); AsyncGenerator> generator_; }; @@ -370,7 +373,9 @@ struct FilterNode : ExecNode { }; Result MakeFilterNode(ExecNode* input, std::string label, Expression filter) { - ARROW_ASSIGN_OR_RAISE(filter, filter.Bind(*input->output_schema())); + if (!filter.IsBound()) { + ARROW_ASSIGN_OR_RAISE(filter, filter.Bind(*input->output_schema())); + } if (filter.type()->id() != Type::BOOL) { return Status::TypeError("Filter expression must evaluate to bool, but ", @@ -452,7 +457,9 @@ Result MakeProjectNode(ExecNode* input, std::string label, int i = 0; for (auto& expr : exprs) { - ARROW_ASSIGN_OR_RAISE(expr, expr.Bind(*input->output_schema())); + if (!expr.IsBound()) { + ARROW_ASSIGN_OR_RAISE(expr, expr.Bind(*input->output_schema())); + } fields[i] = field(expr.ToString(), expr.type()); ++i; } diff --git a/cpp/src/arrow/compute/exec/exec_plan.h b/cpp/src/arrow/compute/exec/exec_plan.h index 44da87d2c7d..42f9e6527c0 100644 --- a/cpp/src/arrow/compute/exec/exec_plan.h +++ b/cpp/src/arrow/compute/exec/exec_plan.h @@ -249,6 +249,8 @@ std::function>()> MakeSinkNode(ExecNode* input, /// The filter Expression will be evaluated against each batch which is pushed to /// this node. Any rows for which the filter does not evaluate to `true` will be excluded /// in the batch emitted by this node. +/// +/// If the filter is not already bound, it will be bound against the input's schema. ARROW_EXPORT Result MakeFilterNode(ExecNode* input, std::string label, Expression filter); @@ -256,6 +258,8 @@ Result MakeFilterNode(ExecNode* input, std::string label, Expression /// /// Each expression will be evaluated against each batch which is pushed to /// this node to produce a corresponding output column. +/// +/// If exprs are not already bound, they will be bound against the input's schema. ARROW_EXPORT Result MakeProjectNode(ExecNode* input, std::string label, std::vector exprs); diff --git a/cpp/src/arrow/dataset/file_csv.cc b/cpp/src/arrow/dataset/file_csv.cc index fd96fe8f50e..3f42ab44a39 100644 --- a/cpp/src/arrow/dataset/file_csv.cc +++ b/cpp/src/arrow/dataset/file_csv.cc @@ -162,8 +162,8 @@ static inline Future> OpenReaderAsync( })); return reader_fut.Then( // Adds the filename to the error - [](const std::shared_ptr& maybe_reader) - -> Result> { return maybe_reader; }, + [](const std::shared_ptr& reader) + -> Result> { return reader; }, [source](const Status& err) -> Result> { return err.WithMessage("Could not open CSV input source '", source.path(), "': ", err); diff --git a/cpp/src/arrow/dataset/file_parquet_test.cc b/cpp/src/arrow/dataset/file_parquet_test.cc index 04c86b1f16f..ffa64e8ec10 100644 --- a/cpp/src/arrow/dataset/file_parquet_test.cc +++ b/cpp/src/arrow/dataset/file_parquet_test.cc @@ -25,6 +25,7 @@ #include "arrow/dataset/scanner_internal.h" #include "arrow/dataset/test_util.h" #include "arrow/io/memory.h" +#include "arrow/io/util_internal.h" #include "arrow/record_batch.h" #include "arrow/table.h" #include "arrow/testing/gtest_util.h" @@ -283,6 +284,32 @@ TEST_F(TestParquetFileFormat, CountRowsPredicatePushdown) { } } +TEST_F(TestParquetFileFormat, MultithreadedScan) { + constexpr int64_t kNumRowGroups = 16; + + // See PredicatePushdown test below for a description of the generated data + auto reader = ArithmeticDatasetFixture::GetRecordBatchReader(kNumRowGroups); + auto source = GetFileSource(reader.get()); + auto options = std::make_shared(); + + auto fragment = MakeFragment(*source); + + FragmentDataset dataset(ArithmeticDatasetFixture::schema(), {fragment}); + ScannerBuilder builder({&dataset, [](...) {}}); + + ASSERT_OK(builder.UseAsync(true)); + ASSERT_OK(builder.UseThreads(true)); + ASSERT_OK(builder.Project({call("add", {field_ref("i64"), literal(3)})}, {""})); + ASSERT_OK_AND_ASSIGN(auto scanner, builder.Finish()); + + ASSERT_OK_AND_ASSIGN(auto gen, scanner->ScanBatchesUnorderedAsync()); + + auto collect_fut = CollectAsyncGenerator(gen); + ASSERT_OK_AND_ASSIGN(auto batches, collect_fut.result()); + + ASSERT_EQ(batches.size(), kNumRowGroups); +} + class TestParquetFileSystemDataset : public WriteFileSystemDatasetMixin, public testing::Test { public: diff --git a/cpp/src/arrow/dataset/scanner.cc b/cpp/src/arrow/dataset/scanner.cc index a622595c8f5..7c455fd89f2 100644 --- a/cpp/src/arrow/dataset/scanner.cc +++ b/cpp/src/arrow/dataset/scanner.cc @@ -606,6 +606,62 @@ Result AsyncScanner::ScanBatchesUnorderedAsync() Result AsyncScanner::ScanBatchesUnorderedAsync( internal::Executor* cpu_executor) { + if (false) { + // causing multithreaded scans to hang + ARROW_ASSIGN_OR_RAISE(auto plan, compute::ExecPlan::Make()); + + ARROW_ASSIGN_OR_RAISE(auto scan, MakeScanNode(plan.get(), dataset_, scan_options_)); + + ARROW_ASSIGN_OR_RAISE(auto filter, + compute::MakeFilterNode(scan, "filter", scan_options_->filter)); + + auto exprs = scan_options_->projection.call()->arguments; + exprs.push_back(compute::field_ref("__fragment_index")); + exprs.push_back(compute::field_ref("__batch_index")); + exprs.push_back(compute::field_ref("__last_in_fragment")); + ARROW_ASSIGN_OR_RAISE(auto project, + compute::MakeProjectNode(filter, "project", exprs)); + + AsyncGenerator> sink_gen = + compute::MakeSinkNode(project, "sink"); + auto scan_options = scan_options_; + + RETURN_NOT_OK(plan->StartProducing()); + + return MakeMappedGenerator( + sink_gen, + [plan, scan_options](const util::optional& batch) + -> Result { + int num_fields = scan_options->projected_schema->num_fields(); + + ArrayVector columns(num_fields); + for (size_t i = 0; i < columns.size(); ++i) { + const Datum& value = batch->values[i]; + if (value.is_array()) { + columns[i] = value.make_array(); + continue; + } + ARROW_ASSIGN_OR_RAISE( + columns[i], + MakeArrayFromScalar(*value.scalar(), batch->length, scan_options->pool)); + } + + EnumeratedRecordBatch out; + out.fragment.value = nullptr; // hope nobody needed this... + out.fragment.index = batch->values[num_fields].scalar_as().value; + out.fragment.last = false; // ignored during reordering + + out.record_batch.value = RecordBatch::Make(scan_options->projected_schema, + batch->length, std::move(columns)); + out.record_batch.index = + batch->values[num_fields + 1].scalar_as().value; + out.record_batch.last = + batch->values[num_fields + 2].scalar_as().value; + + return out; + }); + } + ARROW_ASSIGN_OR_RAISE(auto fragment_gen, GetFragments()); return ScanBatchesUnorderedAsyncImpl(scan_options_, std::move(fragment_gen), cpu_executor); @@ -1101,30 +1157,32 @@ Result SyncScanner::CountRows() { return count; } -Result ScannerBuilder::MakeScanNode(compute::ExecPlan* plan) { - if (!scan_options_->use_async) { +Result MakeScanNode(compute::ExecPlan* plan, + std::shared_ptr dataset, + std::shared_ptr scan_options) { + if (!scan_options->use_async) { return Status::NotImplemented("ScanNodes without asynchrony"); } // using a generator for speculative forward compatibility with async fragment discovery - ARROW_ASSIGN_OR_RAISE(auto fragments_it, dataset_->GetFragments(scan_options_->filter)); + ARROW_ASSIGN_OR_RAISE(scan_options->filter, + scan_options->filter.Bind(*dataset->schema())); + ARROW_ASSIGN_OR_RAISE(auto fragments_it, dataset->GetFragments(scan_options->filter)); ARROW_ASSIGN_OR_RAISE(auto fragments_vec, fragments_it.ToVector()); auto fragments_gen = MakeVectorGenerator(std::move(fragments_vec)); ARROW_ASSIGN_OR_RAISE(auto batch_gen, ScanBatchesUnorderedAsyncImpl( - scan_options_, std::move(fragments_gen), + scan_options, std::move(fragments_gen), internal::GetCpuThreadPool(), /*filter_and_project=*/false)); - const auto& schema = dataset_->schema(); - auto gen = MakeMappedGenerator( std::move(batch_gen), - [schema](const EnumeratedRecordBatch& partial) + [dataset](const EnumeratedRecordBatch& partial) -> Result> { ARROW_ASSIGN_OR_RAISE( util::optional batch, - compute::MakeExecBatch(*schema, partial.record_batch.value)); + compute::MakeExecBatch(*dataset->schema(), partial.record_batch.value)); // TODO fragments may be able to attach more guarantees to batches than this, // for example parquet's row group stats. @@ -1133,10 +1191,16 @@ Result ScannerBuilder::MakeScanNode(compute::ExecPlan* plan) // tag rows with fragment- and batch-of-origin batch->values.emplace_back(partial.fragment.index); batch->values.emplace_back(partial.record_batch.index); + batch->values.emplace_back(partial.record_batch.last); return batch; }); - return MakeSourceNode(plan, "dataset_scan", schema, std::move(gen)); + auto augmented_fields = dataset->schema()->fields(); + augmented_fields.push_back(field("__fragment_index", int32())); + augmented_fields.push_back(field("__batch_index", int32())); + augmented_fields.push_back(field("__last_in_fragment", boolean())); + return compute::MakeSourceNode(plan, "dataset_scan", + schema(std::move(augmented_fields)), std::move(gen)); } } // namespace dataset diff --git a/cpp/src/arrow/dataset/scanner.h b/cpp/src/arrow/dataset/scanner.h index 99075a649fb..c803cde1978 100644 --- a/cpp/src/arrow/dataset/scanner.h +++ b/cpp/src/arrow/dataset/scanner.h @@ -210,7 +210,7 @@ struct IterationTraits { IterationEnd>>()}; } static bool IsEnd(const dataset::EnumeratedRecordBatch& val) { - return val.fragment.value == NULLPTR; + return IsIterationEnd(val.fragment); } }; @@ -399,11 +399,6 @@ class ARROW_DS_EXPORT ScannerBuilder { /// \brief Return the constructed now-immutable Scanner object Result> Finish(); - /// \brief Construct a source ExecNode which yields batches from a dataset scan. - /// - /// Does not construct associated filter or project nodes - Result MakeScanNode(compute::ExecPlan*); - const std::shared_ptr& schema() const; const std::shared_ptr& projected_schema() const; @@ -412,6 +407,13 @@ class ARROW_DS_EXPORT ScannerBuilder { std::shared_ptr scan_options_ = std::make_shared(); }; +/// \brief Construct a source ExecNode which yields batches from a dataset scan. +/// +/// Does not construct associated filter or project nodes +ARROW_DS_EXPORT Result MakeScanNode(compute::ExecPlan*, + std::shared_ptr, + std::shared_ptr); + /// @} /// \brief A trivial ScanTask that yields the RecordBatch of an array. diff --git a/cpp/src/arrow/dataset/scanner_test.cc b/cpp/src/arrow/dataset/scanner_test.cc index f527e86f57a..4a6e89beaad 100644 --- a/cpp/src/arrow/dataset/scanner_test.cc +++ b/cpp/src/arrow/dataset/scanner_test.cc @@ -1091,6 +1091,8 @@ TEST(ScanOptions, TestMaterializedFields) { EXPECT_THAT(opts->MaterializedFields(), ElementsAre("i64", "i32")); } +namespace { + static Result> StartAndCollect( compute::ExecPlan* plan, AsyncGenerator> gen) { RETURN_NOT_OK(plan->Validate()); @@ -1106,49 +1108,25 @@ static Result> StartAndCollect( collected); } -TEST(ScanNode, Trivial) { - ASSERT_OK_AND_ASSIGN(auto plan, compute::ExecPlan::Make()); - - const auto schema = ::arrow::schema({field("a", int32()), field("b", boolean())}); - RecordBatchVector batches{ - RecordBatchFromJSON(schema, R"([{"a": null, "b": true}, - {"a": 4, "b": false}])"), - RecordBatchFromJSON(schema, R"([{"a": 5, "b": null}, - {"a": 6, "b": false}, - {"a": 7, "b": false}])"), - }; - - auto dataset = std::make_shared(schema, batches); - - ScannerBuilder scanner_builder(dataset); - ASSERT_OK(scanner_builder.UseAsync(true)); - ASSERT_OK_AND_ASSIGN(auto scan, scanner_builder.MakeScanNode(plan.get())); - auto sink_gen = MakeSinkNode(scan, "sink"); - - std::vector expected; - // InMemoryDataset(RecordBatchVector) produces a fragment wrapping each batch - const int batch_index = 0; - int fragment_index = 0; - for (const auto& batch : batches) { - expected.emplace_back(*batch); - expected.back().values.emplace_back(fragment_index++); - expected.back().values.emplace_back(batch_index); - } - - ASSERT_THAT(StartAndCollect(plan.get(), sink_gen), - ResultWith(UnorderedElementsAreArray(expected))); -} - -TEST(ScanNode, FilteredOnVirtualColumn) { - ASSERT_OK_AND_ASSIGN(auto plan, compute::ExecPlan::Make()); +struct DatasetAndBatches { + std::shared_ptr dataset; + std::vector batches; +}; +DatasetAndBatches MakeBasicDataset() { const auto dataset_schema = ::arrow::schema({ field("a", int32()), field("b", boolean()), field("c", int32()), }); + const auto physical_schema = SchemaFromColumnNames(dataset_schema, {"a", "b"}); - RecordBatchVector batches{ + + RecordBatchVector record_batches{ + RecordBatchFromJSON(physical_schema, R"([{"a": 1, "b": null}, + {"a": 2, "b": true}])"), + RecordBatchFromJSON(physical_schema, R"([{"a": null, "b": true}, + {"a": 3, "b": false}])"), RecordBatchFromJSON(physical_schema, R"([{"a": null, "b": true}, {"a": 4, "b": false}])"), RecordBatchFromJSON(physical_schema, R"([{"a": 5, "b": null}, @@ -1159,136 +1137,152 @@ TEST(ScanNode, FilteredOnVirtualColumn) { auto dataset = std::make_shared( dataset_schema, FragmentVector{ - std::make_shared(physical_schema, batches, - equal(field_ref("c"), literal(23))), - std::make_shared(physical_schema, batches, - equal(field_ref("c"), literal(47))), + std::make_shared( + physical_schema, RecordBatchVector{record_batches[0], record_batches[1]}, + equal(field_ref("c"), literal(23))), + std::make_shared( + physical_schema, RecordBatchVector{record_batches[2], record_batches[3]}, + equal(field_ref("c"), literal(47))), }); - ScannerBuilder scanner_builder(dataset); - ASSERT_OK(scanner_builder.UseAsync(true)); - ASSERT_OK(scanner_builder.Filter(greater(field_ref("c"), literal(30)))); - ASSERT_OK_AND_ASSIGN(auto scan, scanner_builder.MakeScanNode(plan.get())); - auto sink_gen = MakeSinkNode(scan, "sink"); + std::vector batches; + + auto batch_it = record_batches.begin(); + for (int fragment_index = 0; fragment_index < 2; ++fragment_index) { + for (int batch_index = 0; batch_index < 2; ++batch_index) { + const auto& batch = *batch_it++; + + // the scanned ExecBatches will begin with physical columns + batches.emplace_back(*batch); - std::vector expected; - const int fragment_index = 0; // only the second fragment will make it past the filter, - // and its index in the scan is 0 - int batch_index = 0; - for (const auto& batch : batches) { - expected.emplace_back(*batch); + // a placeholder will be inserted for partition field "c" + batches.back().values.emplace_back(std::make_shared()); - expected.back().guarantee = equal(field_ref("c"), literal(47)); + // scanned batches will be augmented with fragment and batch indices + batches.back().values.emplace_back(fragment_index); + batches.back().values.emplace_back(batch_index); - // Note: placeholder for partition field "c" - expected.back().values.emplace_back(std::make_shared()); + // ... and with the last-in-fragment flag + batches.back().values.emplace_back(batch_index == 1); - expected.back().values.emplace_back(fragment_index); - expected.back().values.emplace_back(batch_index++); + // each batch carries a guarantee inherited from its Fragment's partition expression + batches.back().guarantee = + equal(field_ref("c"), literal(fragment_index == 0 ? 23 : 47)); + } } + return {dataset, batches}; +} +} // namespace + +TEST(ScanNode, Schema) { + ASSERT_OK_AND_ASSIGN(auto plan, compute::ExecPlan::Make()); + + auto basic = MakeBasicDataset(); + + auto options = std::make_shared(); + options->use_async = true; + + ASSERT_OK_AND_ASSIGN(auto scan, MakeScanNode(plan.get(), basic.dataset, options)); + + auto fields = basic.dataset->schema()->fields(); + fields.push_back(field("__fragment_index", int32())); + fields.push_back(field("__batch_index", int32())); + fields.push_back(field("__last_in_fragment", boolean())); + AssertSchemaEqual(Schema(fields), *scan->output_schema()); +} + +TEST(ScanNode, Trivial) { + ASSERT_OK_AND_ASSIGN(auto plan, compute::ExecPlan::Make()); + + auto basic = MakeBasicDataset(); + + auto options = std::make_shared(); + options->use_async = true; + + ASSERT_OK_AND_ASSIGN(auto scan, MakeScanNode(plan.get(), basic.dataset, options)); + auto sink_gen = MakeSinkNode(scan, "sink"); + + // trivial scan: the batches are returned unmodified + auto expected = basic.batches; ASSERT_THAT(StartAndCollect(plan.get(), sink_gen), ResultWith(UnorderedElementsAreArray(expected))); } -TEST(ScanNode, DeferredFilterOnPhysicalColumn) { +TEST(ScanNode, FilteredOnVirtualColumn) { ASSERT_OK_AND_ASSIGN(auto plan, compute::ExecPlan::Make()); - const auto dataset_schema = ::arrow::schema({ - field("a", int32()), - field("b", boolean()), - field("c", int32()), - }); - const auto physical_schema = SchemaFromColumnNames(dataset_schema, {"a", "b"}); - RecordBatchVector batches{ - RecordBatchFromJSON(physical_schema, R"([{"a": null, "b": true}, - {"a": 4, "b": false}])"), - RecordBatchFromJSON(physical_schema, R"([{"a": 5, "b": null}, - {"a": 6, "b": false}, - {"a": 7, "b": false}])"), - }; + auto basic = MakeBasicDataset(); - auto dataset = std::make_shared( - dataset_schema, - FragmentVector{ - std::make_shared(physical_schema, batches, - equal(field_ref("c"), literal(23))), - std::make_shared(physical_schema, batches, - equal(field_ref("c"), literal(47))), - }); + auto options = std::make_shared(); + options->use_async = true; + options->filter = less(field_ref("c"), literal(30)); + + ASSERT_OK_AND_ASSIGN(auto scan, MakeScanNode(plan.get(), basic.dataset, options)); - ScannerBuilder scanner_builder(dataset); - ASSERT_OK(scanner_builder.UseAsync(true)); - ASSERT_OK(scanner_builder.Filter(greater(field_ref("a"), literal(4)))); - ASSERT_OK_AND_ASSIGN(auto scan, scanner_builder.MakeScanNode(plan.get())); auto sink_gen = MakeSinkNode(scan, "sink"); - // no filtering is performed by ScanNode: all batches will be yielded whole - std::vector expected; - for (int fragment_index = 0; fragment_index < 2; ++fragment_index) { - for (int batch_index = 0; batch_index < 2; ++batch_index) { - expected.emplace_back(*batches[batch_index]); + auto expected = basic.batches; - expected.back().guarantee = - equal(field_ref("c"), literal(fragment_index == 0 ? 23 : 47)); + // only the first fragment will make it past the filter + expected.pop_back(); + expected.pop_back(); - // Note: placeholder for partition field "c" - expected.back().values.emplace_back(std::make_shared()); + ASSERT_THAT(StartAndCollect(plan.get(), sink_gen), + ResultWith(UnorderedElementsAreArray(expected))); +} - expected.back().values.emplace_back(fragment_index); - expected.back().values.emplace_back(batch_index); - } - } +TEST(ScanNode, DeferredFilterOnPhysicalColumn) { + ASSERT_OK_AND_ASSIGN(auto plan, compute::ExecPlan::Make()); + + auto basic = MakeBasicDataset(); + + auto options = std::make_shared(); + options->use_async = true; + options->filter = greater(field_ref("a"), literal(4)); + + ASSERT_OK_AND_ASSIGN(auto scan, MakeScanNode(plan.get(), basic.dataset, options)); + + auto sink_gen = MakeSinkNode(scan, "sink"); + + // No post filtering is performed by ScanNode: all batches will be yielded whole. + // To filter out rows from individual batches, construct a FilterNode. + auto expected = basic.batches; ASSERT_THAT(StartAndCollect(plan.get(), sink_gen), ResultWith(UnorderedElementsAreArray(expected))); } +TEST(ScanNode, ProjectionPushdown) { + // ensure non-projected columns are dropped +} + TEST(ScanNode, MaterializationOfVirtualColumn) { ASSERT_OK_AND_ASSIGN(auto plan, compute::ExecPlan::Make()); - const auto dataset_schema = ::arrow::schema({ - field("a", int32()), - field("b", boolean()), - field("c", int32()), - }); - const auto physical_schema = SchemaFromColumnNames(dataset_schema, {"a", "b"}); - RecordBatchVector batches{ - RecordBatchFromJSON(physical_schema, R"([{"a": null, "b": true}, - {"a": 4, "b": false}])"), - RecordBatchFromJSON(physical_schema, R"([{"a": 5, "b": null}, - {"a": 6, "b": false}, - {"a": 7, "b": false}])"), - }; + auto basic = MakeBasicDataset(); - auto dataset = std::make_shared( - dataset_schema, - FragmentVector{ - std::make_shared(physical_schema, batches, - equal(field_ref("c"), literal(23))), - std::make_shared(physical_schema, batches, - equal(field_ref("c"), literal(47))), - }); + auto options = std::make_shared(); + options->use_async = true; + options->filter = greater(field_ref("a"), literal(4)); - ScannerBuilder scanner_builder(dataset); - ASSERT_OK(scanner_builder.UseAsync(true)); - ASSERT_OK_AND_ASSIGN(auto scan, scanner_builder.MakeScanNode(plan.get())); - ASSERT_OK_AND_ASSIGN(auto project, - compute::MakeProjectNode(scan, "project", {field_ref("c")})); - auto sink_gen = MakeSinkNode(project, "sink"); + ASSERT_OK_AND_ASSIGN(auto scan, MakeScanNode(plan.get(), basic.dataset, options)); - std::vector expected; - for (int fragment_index = 0; fragment_index < 2; ++fragment_index) { - for (int batch_index = 0; batch_index < 2; ++batch_index) { - auto c_value = fragment_index == 0 ? 23 : 47; + ASSERT_OK_AND_ASSIGN( + auto project, + compute::MakeProjectNode( + scan, "project", + {field_ref("a"), field_ref("b"), field_ref("c"), field_ref("__fragment_index"), + field_ref("__batch_index"), field_ref("__last_in_fragment")})); - expected.push_back(compute::ExecBatch{{}, batches[batch_index]->num_rows()}); + auto sink_gen = MakeSinkNode(project, "sink"); - // NB: ProjectNode overwrites "c" placeholder with value from guarantee - expected.back().values.emplace_back(c_value); + auto expected = basic.batches; - expected.back().guarantee = equal(field_ref("c"), literal(c_value)); - } + for (auto& batch : expected) { + // ProjectNode overwrites "c" placeholder with non-null drawn from guarantee + const auto& value = *batch.guarantee.call()->arguments[1].literal(); + batch.values[project->output_schema()->GetFieldIndex("c")] = value; } ASSERT_THAT(StartAndCollect(plan.get(), sink_gen), diff --git a/cpp/src/arrow/dataset/test_util.h b/cpp/src/arrow/dataset/test_util.h index aab97b9bb49..201fc7e55b2 100644 --- a/cpp/src/arrow/dataset/test_util.h +++ b/cpp/src/arrow/dataset/test_util.h @@ -890,13 +890,10 @@ struct ArithmeticDatasetFixture { static std::shared_ptr schema() { return ::arrow::schema({ field("i64", int64()), - // ARROW-1644: Parquet can't write complex level - // field("struct", struct_({ - // // ARROW-2587: Parquet can't write struct with more - // // than one field. - // // field("i32", int32()), - // field("str", utf8()), - // })), + field("struct", struct_({ + field("i32", int32()), + field("str", utf8()), + })), field("u8", uint8()), field("list", list(int32())), field("bool", boolean()), @@ -913,12 +910,12 @@ struct ArithmeticDatasetFixture { ss << "{"; ss << "\"i64\": " << n << ", "; - // ss << "\"struct\": {"; - // { - // // ss << "\"i32\": " << n_i32 << ", "; - // ss << "\"str\": \"" << std::to_string(n) << "\""; - // } - // ss << "}, "; + ss << "\"struct\": {"; + { + ss << "\"i32\": " << n_i32 << ", "; + ss << R"("str": ")" << std::to_string(n) << "\""; + } + ss << "}, "; ss << "\"u8\": " << static_cast(n) << ", "; ss << "\"list\": [" << n_i32 << ", " << n_i32 << "], "; ss << "\"bool\": " << (static_cast(n % 2) ? "true" : "false"); From ceac80bc0250f528fe58c08d8b5644bff1eb1608 Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Thu, 17 Jun 2021 15:46:24 -0400 Subject: [PATCH 19/28] gcc4.8: more explicit construction --- cpp/src/arrow/compute/exec/expression.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/arrow/compute/exec/expression.cc b/cpp/src/arrow/compute/exec/expression.cc index 7bfe1bddc58..043b7d9ecba 100644 --- a/cpp/src/arrow/compute/exec/expression.cc +++ b/cpp/src/arrow/compute/exec/expression.cc @@ -53,7 +53,7 @@ Expression::Expression(Parameter parameter) Expression literal(Datum lit) { return Expression(std::move(lit)); } Expression field_ref(FieldRef ref) { - return Expression(Expression::Parameter{std::move(ref), {}, {}}); + return Expression(Expression::Parameter{std::move(ref), ValueDescr{}, -1}); } Expression call(std::string function, std::vector arguments, From 84c2182e0c7dbde80cb6c0f48da6d3f2927bf86c Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Thu, 17 Jun 2021 15:47:23 -0400 Subject: [PATCH 20/28] paranoid reversion in async_generator.h --- cpp/src/arrow/util/async_generator.h | 125 ++++++++++++++------------- 1 file changed, 65 insertions(+), 60 deletions(-) diff --git a/cpp/src/arrow/util/async_generator.h b/cpp/src/arrow/util/async_generator.h index 217d9e5073a..1ac10ad7ce8 100644 --- a/cpp/src/arrow/util/async_generator.h +++ b/cpp/src/arrow/util/async_generator.h @@ -282,28 +282,30 @@ class SequencingGenerator { std::move(is_next), std::move(initial_value))) {} Future operator()() { - auto guard = state_->mutex.Lock(); - // We can send a result immediately if the top of the queue is either an - // error or the next item - if (!state_->queue.empty() && - (!state_->queue.top().ok() || - state_->is_next(state_->previous_value, *state_->queue.top()))) { - auto result = std::move(state_->queue.top()); - if (result.ok()) { - state_->previous_value = *result; + { + auto guard = state_->mutex.Lock(); + // We can send a result immediately if the top of the queue is either an + // error or the next item + if (!state_->queue.empty() && + (!state_->queue.top().ok() || + state_->is_next(state_->previous_value, *state_->queue.top()))) { + auto result = std::move(state_->queue.top()); + if (result.ok()) { + state_->previous_value = *result; + } + state_->queue.pop(); + return Future::MakeFinished(result); } - state_->queue.pop(); - return Future::MakeFinished(result); - } - if (state_->finished) { - return AsyncGeneratorEnd(); + if (state_->finished) { + return AsyncGeneratorEnd(); + } + // The next item is not in the queue so we will need to wait + auto new_waiting_fut = Future::Make(); + state_->waiting_future = new_waiting_fut; + guard.Unlock(); + state_->source().AddCallback(Callback{state_}); + return new_waiting_fut; } - // The next item is not in the queue so we will need to wait - auto new_waiting_fut = Future::Make(); - state_->waiting_future = new_waiting_fut; - guard.Unlock(); - state_->source().AddCallback(Callback{state_}); - return new_waiting_fut; } private: @@ -337,8 +339,11 @@ class SequencingGenerator { util::Mutex mutex; }; - struct Callback { - void operator()(const Result& result) { + class Callback { + public: + explicit Callback(std::shared_ptr state) : state_(std::move(state)) {} + + void operator()(const Result result) { Future to_deliver; bool finished; { @@ -381,6 +386,7 @@ class SequencingGenerator { } } + private: const std::shared_ptr state_; }; @@ -1104,7 +1110,7 @@ struct Enumerated { template struct IterationTraits> { - static Enumerated End() { return Enumerated{T{}, -1, false}; } + static Enumerated End() { return Enumerated{IterationEnd(), -1, false}; } static bool IsEnd(const Enumerated& val) { return val.index < 0; } }; @@ -1118,17 +1124,17 @@ class EnumeratingGenerator { Future> operator()() { if (state_->finished) { return AsyncGeneratorEnd>(); + } else { + auto state = state_; + return state->source().Then([state](const T& next) { + auto finished = IsIterationEnd(next); + auto prev = Enumerated{state->prev_value, state->prev_index, finished}; + state->prev_value = next; + state->prev_index++; + state->finished = finished; + return prev; + }); } - - auto state = state_; - return state->source().Then([state](const T& next) { - auto finished = IsIterationEnd(next); - auto prev = Enumerated{state->prev_value, state->prev_index, finished}; - state->prev_value = next; - state->prev_index++; - state->finished = finished; - return prev; - }); } private: @@ -1215,27 +1221,27 @@ class BackgroundGenerator { Future operator()() { auto guard = state_->mutex.Lock(); - if (!state_->queue.empty()) { + Future waiting_future; + if (state_->queue.empty()) { + if (state_->finished) { + return AsyncGeneratorEnd(); + } else { + waiting_future = Future::Make(); + state_->waiting_future = waiting_future; + } + } else { auto next = Future::MakeFinished(std::move(state_->queue.front())); state_->queue.pop(); - if (state_->NeedsRestart()) { return state_->RestartTask(state_, std::move(guard), std::move(next)); } return next; } - - if (state_->finished) { - return AsyncGeneratorEnd(); - } - - state_->waiting_future = Future::Make(); - // This should only trigger the very first time this method is called if (state_->NeedsRestart()) { - return state_->RestartTask(state_, std::move(guard), *state_->waiting_future); + return state_->RestartTask(state_, std::move(guard), std::move(waiting_future)); } - return *state_->waiting_future; + return waiting_future; } protected: @@ -1266,23 +1272,22 @@ class BackgroundGenerator { // task_finished future for it state->task_finished = Future<>::Make(); state->reading = true; - auto spawn_status = io_executor->Spawn( [state]() { BackgroundGenerator::WorkerTask(std::move(state)); }); - if (spawn_status.ok()) return; - - // If we can't spawn a new task then send an error to the consumer (either via a - // waiting future or the queue) and mark ourselves finished - state->finished = true; - state->task_finished = Future<>(); - if (waiting_future.has_value()) { - auto to_deliver = std::move(waiting_future.value()); - waiting_future.reset(); - guard.Unlock(); - to_deliver.MarkFinished(spawn_status); - } else { - ClearQueue(); - queue.push(spawn_status); + if (!spawn_status.ok()) { + // If we can't spawn a new task then send an error to the consumer (either via a + // waiting future or the queue) and mark ourselves finished + state->finished = true; + state->task_finished = Future<>(); + if (waiting_future.has_value()) { + auto to_deliver = std::move(waiting_future.value()); + waiting_future.reset(); + guard.Unlock(); + to_deliver.MarkFinished(spawn_status); + } else { + ClearQueue(); + queue.push(spawn_status); + } } } @@ -1400,7 +1405,7 @@ class BackgroundGenerator { // callbacks off of this thread so we can continue looping. Still, best not to // rely on that if (waiting_future.is_valid()) { - waiting_future.MarkFinished(std::move(next)); + waiting_future.MarkFinished(next); } } // Once we've sent our last item we can notify any waiters that we are done and so From aaaa353b5cf95146573e18994cc3bf2c62fec695 Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Thu, 17 Jun 2021 15:47:47 -0400 Subject: [PATCH 21/28] move new scan path into a unit test for now --- cpp/src/arrow/dataset/scanner.cc | 56 --------------- cpp/src/arrow/dataset/scanner_test.cc | 100 +++++++++++++++++++++++++- 2 files changed, 99 insertions(+), 57 deletions(-) diff --git a/cpp/src/arrow/dataset/scanner.cc b/cpp/src/arrow/dataset/scanner.cc index 7c455fd89f2..fc8cb5e669d 100644 --- a/cpp/src/arrow/dataset/scanner.cc +++ b/cpp/src/arrow/dataset/scanner.cc @@ -606,62 +606,6 @@ Result AsyncScanner::ScanBatchesUnorderedAsync() Result AsyncScanner::ScanBatchesUnorderedAsync( internal::Executor* cpu_executor) { - if (false) { - // causing multithreaded scans to hang - ARROW_ASSIGN_OR_RAISE(auto plan, compute::ExecPlan::Make()); - - ARROW_ASSIGN_OR_RAISE(auto scan, MakeScanNode(plan.get(), dataset_, scan_options_)); - - ARROW_ASSIGN_OR_RAISE(auto filter, - compute::MakeFilterNode(scan, "filter", scan_options_->filter)); - - auto exprs = scan_options_->projection.call()->arguments; - exprs.push_back(compute::field_ref("__fragment_index")); - exprs.push_back(compute::field_ref("__batch_index")); - exprs.push_back(compute::field_ref("__last_in_fragment")); - ARROW_ASSIGN_OR_RAISE(auto project, - compute::MakeProjectNode(filter, "project", exprs)); - - AsyncGenerator> sink_gen = - compute::MakeSinkNode(project, "sink"); - auto scan_options = scan_options_; - - RETURN_NOT_OK(plan->StartProducing()); - - return MakeMappedGenerator( - sink_gen, - [plan, scan_options](const util::optional& batch) - -> Result { - int num_fields = scan_options->projected_schema->num_fields(); - - ArrayVector columns(num_fields); - for (size_t i = 0; i < columns.size(); ++i) { - const Datum& value = batch->values[i]; - if (value.is_array()) { - columns[i] = value.make_array(); - continue; - } - ARROW_ASSIGN_OR_RAISE( - columns[i], - MakeArrayFromScalar(*value.scalar(), batch->length, scan_options->pool)); - } - - EnumeratedRecordBatch out; - out.fragment.value = nullptr; // hope nobody needed this... - out.fragment.index = batch->values[num_fields].scalar_as().value; - out.fragment.last = false; // ignored during reordering - - out.record_batch.value = RecordBatch::Make(scan_options->projected_schema, - batch->length, std::move(columns)); - out.record_batch.index = - batch->values[num_fields + 1].scalar_as().value; - out.record_batch.last = - batch->values[num_fields + 2].scalar_as().value; - - return out; - }); - } - ARROW_ASSIGN_OR_RAISE(auto fragment_gen, GetFragments()); return ScanBatchesUnorderedAsyncImpl(scan_options_, std::move(fragment_gen), cpu_executor); diff --git a/cpp/src/arrow/dataset/scanner_test.cc b/cpp/src/arrow/dataset/scanner_test.cc index 4a6e89beaad..ef129db2c9a 100644 --- a/cpp/src/arrow/dataset/scanner_test.cc +++ b/cpp/src/arrow/dataset/scanner_test.cc @@ -1264,7 +1264,6 @@ TEST(ScanNode, MaterializationOfVirtualColumn) { auto options = std::make_shared(); options->use_async = true; - options->filter = greater(field_ref("a"), literal(4)); ASSERT_OK_AND_ASSIGN(auto scan, MakeScanNode(plan.get(), basic.dataset, options)); @@ -1289,5 +1288,104 @@ TEST(ScanNode, MaterializationOfVirtualColumn) { ResultWith(UnorderedElementsAreArray(expected))); } +TEST(ScanNode, CompareToScanner) { + ASSERT_OK_AND_ASSIGN(auto plan, compute::ExecPlan::Make()); + + auto basic = MakeBasicDataset(); + + ScannerBuilder builder(basic.dataset); + ASSERT_OK(builder.UseAsync(true)); + ASSERT_OK(builder.UseThreads(true)); + ASSERT_OK(builder.Filter(greater(field_ref("c"), literal(30)))); + ASSERT_OK(builder.Project( + {field_ref("c"), call("multiply", {field_ref("a"), literal(2)})}, {"c", "a * 2"})); + ASSERT_OK_AND_ASSIGN(auto scanner, builder.Finish()); + + ASSERT_OK_AND_ASSIGN(auto fragments_it, + basic.dataset->GetFragments(scanner->options()->filter)); + ASSERT_OK_AND_ASSIGN(auto fragments, fragments_it.ToVector()); + + auto options = scanner->options(); + + ASSERT_OK_AND_ASSIGN(auto scan, MakeScanNode(plan.get(), basic.dataset, options)); + + ASSERT_OK_AND_ASSIGN(auto filter, + compute::MakeFilterNode(scan, "filter", options->filter)); + + auto exprs = options->projection.call()->arguments; + exprs.push_back(compute::field_ref("__fragment_index")); + exprs.push_back(compute::field_ref("__batch_index")); + exprs.push_back(compute::field_ref("__last_in_fragment")); + ASSERT_OK_AND_ASSIGN(auto project, compute::MakeProjectNode(filter, "project", exprs)); + + AsyncGenerator> sink_gen = + compute::MakeSinkNode(project, "sink"); + + ASSERT_OK(plan->StartProducing()); + + auto from_plan = + CollectAsyncGenerator( + MakeMappedGenerator( + sink_gen, + [&](const util::optional& batch) + -> Result { + int num_fields = options->projected_schema->num_fields(); + + ArrayVector columns(num_fields); + for (size_t i = 0; i < columns.size(); ++i) { + const Datum& value = batch->values[i]; + if (value.is_array()) { + columns[i] = value.make_array(); + continue; + } + ARROW_ASSIGN_OR_RAISE( + columns[i], + MakeArrayFromScalar(*value.scalar(), batch->length, options->pool)); + } + + EnumeratedRecordBatch out; + out.fragment.index = + batch->values[num_fields].scalar_as().value; + out.fragment.value = fragments[out.fragment.index]; + out.fragment.last = false; // ignored during reordering + + out.record_batch.index = + batch->values[num_fields + 1].scalar_as().value; + out.record_batch.value = RecordBatch::Make( + options->projected_schema, batch->length, std::move(columns)); + out.record_batch.last = + batch->values[num_fields + 2].scalar_as().value; + + return out; + })) + .result(); + + ASSERT_OK_AND_ASSIGN(auto from_scanner_gen, scanner->ScanBatchesUnorderedAsync()); + auto from_scanner = CollectAsyncGenerator(from_scanner_gen).result(); + + auto less = [](const EnumeratedRecordBatch& l, const EnumeratedRecordBatch& r) { + if (l.fragment.index < r.fragment.index) return true; + return l.record_batch.index < r.record_batch.index; + }; + + ASSERT_OK(from_plan); + std::sort(from_plan->begin(), from_plan->end(), less); + + ASSERT_OK(from_scanner); + std::sort(from_scanner->begin(), from_scanner->end(), less); + + ASSERT_EQ(from_plan->size(), from_scanner->size()); + for (size_t i = 0; i < from_plan->size(); ++i) { + const auto& p = from_plan->at(i); + const auto& s = from_scanner->at(i); + SCOPED_TRACE(i); + ASSERT_EQ(p.fragment.index, s.fragment.index); + ASSERT_EQ(p.fragment.value, s.fragment.value); + ASSERT_EQ(p.record_batch.last, s.record_batch.last); + ASSERT_EQ(p.record_batch.index, s.record_batch.index); + AssertBatchesEqual(*p.record_batch.value, *s.record_batch.value); + } +} + } // namespace dataset } // namespace arrow From 1da177586ad5ebd0cc1106dd93160126b5723458 Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Fri, 18 Jun 2021 20:37:24 -0400 Subject: [PATCH 22/28] reduce #includes in expression.h --- cpp/src/arrow/compute/exec/expression.cc | 19 +++++++++---------- cpp/src/arrow/compute/exec/expression.h | 9 ++++----- .../arrow/compute/exec/expression_internal.h | 4 ++++ cpp/src/arrow/compute/exec/expression_test.cc | 14 +++++++------- cpp/src/arrow/dataset/partition.cc | 3 ++- 5 files changed, 26 insertions(+), 23 deletions(-) diff --git a/cpp/src/arrow/compute/exec/expression.cc b/cpp/src/arrow/compute/exec/expression.cc index 043b7d9ecba..24816783851 100644 --- a/cpp/src/arrow/compute/exec/expression.cc +++ b/cpp/src/arrow/compute/exec/expression.cc @@ -716,17 +716,16 @@ Status ExtractKnownFieldValuesImpl( } // namespace -Result> ExtractKnownFieldValues( +Result ExtractKnownFieldValues( const Expression& guaranteed_true_predicate) { auto conjunction_members = GuaranteeConjunctionMembers(guaranteed_true_predicate); - std::unordered_map known_values; - RETURN_NOT_OK(ExtractKnownFieldValuesImpl(&conjunction_members, &known_values)); + KnownFieldValues known_values; + RETURN_NOT_OK(ExtractKnownFieldValuesImpl(&conjunction_members, &known_values.map)); return known_values; } -Result ReplaceFieldsWithKnownValues( - const std::unordered_map& known_values, - Expression expr) { +Result ReplaceFieldsWithKnownValues(const KnownFieldValues& known_values, + Expression expr) { if (!expr.IsBound()) { return Status::Invalid( "ReplaceFieldsWithKnownValues called on an unbound Expression"); @@ -736,8 +735,8 @@ Result ReplaceFieldsWithKnownValues( std::move(expr), [&known_values](Expression expr) -> Result { if (auto ref = expr.field_ref()) { - auto it = known_values.find(*ref); - if (it != known_values.end()) { + auto it = known_values.map.find(*ref); + if (it != known_values.map.end()) { Datum lit = it->second; if (lit.descr() == expr.descr()) return literal(std::move(lit)); // type mismatch, try casting the known value to the correct type @@ -939,8 +938,8 @@ Result SimplifyWithGuarantee(Expression expr, const Expression& guaranteed_true_predicate) { auto conjunction_members = GuaranteeConjunctionMembers(guaranteed_true_predicate); - std::unordered_map known_values; - RETURN_NOT_OK(ExtractKnownFieldValuesImpl(&conjunction_members, &known_values)); + KnownFieldValues known_values; + RETURN_NOT_OK(ExtractKnownFieldValuesImpl(&conjunction_members, &known_values.map)); ARROW_ASSIGN_OR_RAISE(expr, ReplaceFieldsWithKnownValues(known_values, std::move(expr))); diff --git a/cpp/src/arrow/compute/exec/expression.h b/cpp/src/arrow/compute/exec/expression.h index 1d576a23112..d06a923bb32 100644 --- a/cpp/src/arrow/compute/exec/expression.h +++ b/cpp/src/arrow/compute/exec/expression.h @@ -19,10 +19,8 @@ #pragma once -#include #include #include -#include #include #include @@ -166,8 +164,9 @@ ARROW_EXPORT bool ExpressionHasFieldRefs(const Expression&); /// Assemble a mapping from field references to known values. +struct ARROW_EXPORT KnownFieldValues; ARROW_EXPORT -Result> ExtractKnownFieldValues( +Result ExtractKnownFieldValues( const Expression& guaranteed_true_predicate); /// \defgroup expression-passes Functions for modification of Expressions @@ -196,8 +195,8 @@ Result FoldConstants(Expression); /// Simplify Expressions by replacing with known values of the fields which it references. ARROW_EXPORT -Result ReplaceFieldsWithKnownValues( - const std::unordered_map& known_values, Expression); +Result ReplaceFieldsWithKnownValues(const KnownFieldValues& known_values, + Expression); /// Simplify an expression by replacing subexpressions based on a guarantee: /// a boolean expression which is guaranteed to evaluate to `true`. For example, this is diff --git a/cpp/src/arrow/compute/exec/expression_internal.h b/cpp/src/arrow/compute/exec/expression_internal.h index b9165a5f0c2..51d242e8d66 100644 --- a/cpp/src/arrow/compute/exec/expression_internal.h +++ b/cpp/src/arrow/compute/exec/expression_internal.h @@ -34,6 +34,10 @@ using internal::checked_cast; namespace compute { +struct KnownFieldValues { + std::unordered_map map; +}; + inline const Expression::Call* CallNotNull(const Expression& expr) { auto call = expr.call(); DCHECK_NE(call, nullptr); diff --git a/cpp/src/arrow/compute/exec/expression_test.cc b/cpp/src/arrow/compute/exec/expression_test.cc index c4c2dd1c951..86909f4eb64 100644 --- a/cpp/src/arrow/compute/exec/expression_test.cc +++ b/cpp/src/arrow/compute/exec/expression_test.cc @@ -830,7 +830,7 @@ TEST(Expression, ExtractKnownFieldValues) { void operator()(Expression guarantee, std::unordered_map expected) { ASSERT_OK_AND_ASSIGN(auto actual, ExtractKnownFieldValues(guarantee)); - EXPECT_THAT(actual, UnorderedElementsAreArray(expected)) + EXPECT_THAT(actual.map, UnorderedElementsAreArray(expected)) << " guarantee: " << guarantee.ToString(); } } ExpectKnown; @@ -882,8 +882,8 @@ TEST(Expression, ReplaceFieldsWithKnownValues) { Expression unbound_expected) { ASSERT_OK_AND_ASSIGN(expr, expr.Bind(*kBoringSchema)); ASSERT_OK_AND_ASSIGN(auto expected, unbound_expected.Bind(*kBoringSchema)); - ASSERT_OK_AND_ASSIGN(auto replaced, - ReplaceFieldsWithKnownValues(known_values, expr)); + ASSERT_OK_AND_ASSIGN(auto replaced, ReplaceFieldsWithKnownValues( + KnownFieldValues{known_values}, expr)); EXPECT_EQ(replaced, expected); ExpectIdenticalIfUnchanged(replaced, expr); @@ -943,13 +943,13 @@ TEST(Expression, ReplaceFieldsWithKnownValues) { Datum dict_i32{ DictionaryScalar::Make(MakeScalar(0), ArrayFromJSON(int32(), R"([3])"))}; // Unsupported cast dictionary(int32(), int32()) -> dictionary(int32(), utf8()) - ASSERT_RAISES(NotImplemented, - ReplaceFieldsWithKnownValues({{"dict_str", dict_i32}}, expr)); + ASSERT_RAISES(NotImplemented, ReplaceFieldsWithKnownValues( + KnownFieldValues{{{"dict_str", dict_i32}}}, expr)); // Unsupported cast dictionary(int8(), utf8()) -> dictionary(int32(), utf8()) dict_str = Datum{ DictionaryScalar::Make(MakeScalar(0), ArrayFromJSON(utf8(), R"(["a"])"))}; - ASSERT_RAISES(NotImplemented, - ReplaceFieldsWithKnownValues({{"dict_str", dict_str}}, expr)); + ASSERT_RAISES(NotImplemented, ReplaceFieldsWithKnownValues( + KnownFieldValues{{{"dict_str", dict_str}}}, expr)); } struct { diff --git a/cpp/src/arrow/dataset/partition.cc b/cpp/src/arrow/dataset/partition.cc index 5c390b6b487..1ec47e3cee1 100644 --- a/cpp/src/arrow/dataset/partition.cc +++ b/cpp/src/arrow/dataset/partition.cc @@ -30,6 +30,7 @@ #include "arrow/compute/api_scalar.h" #include "arrow/compute/api_vector.h" #include "arrow/compute/cast.h" +#include "arrow/compute/exec/expression_internal.h" #include "arrow/dataset/dataset_internal.h" #include "arrow/filesystem/path_util.h" #include "arrow/scalar.h" @@ -252,7 +253,7 @@ Result KeyValuePartitioning::Format(const compute::Expression& expr ScalarVector values{static_cast(schema_->num_fields()), nullptr}; ARROW_ASSIGN_OR_RAISE(auto known_values, ExtractKnownFieldValues(expr)); - for (const auto& ref_value : known_values) { + for (const auto& ref_value : known_values.map) { if (!ref_value.second.is_scalar()) { return Status::Invalid("non-scalar partition key ", ref_value.second.ToString()); } From f50a89c4e6ba1775d9559ed41725cdd068423a3f Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Mon, 21 Jun 2021 12:31:19 -0400 Subject: [PATCH 23/28] repair python binding with inlined KnownFieldValues def --- python/pyarrow/_dataset.pyx | 2 +- python/pyarrow/includes/libarrow_dataset.pxd | 22 +++++++++++++++++++- 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/python/pyarrow/_dataset.pyx b/python/pyarrow/_dataset.pyx index bd93da9cb18..e7e8341c9d4 100644 --- a/python/pyarrow/_dataset.pyx +++ b/python/pyarrow/_dataset.pyx @@ -3001,7 +3001,7 @@ def _get_partition_keys(Expression partition_expression): pair[CFieldRef, CDatum] ref_val out = {} - for ref_val in GetResultValue(CExtractKnownFieldValues(expr)): + for ref_val in GetResultValue(CExtractKnownFieldValues(expr)).map: assert ref_val.first.name() != nullptr assert ref_val.second.kind() == DatumType_SCALAR val = pyarrow_wrap_scalar(ref_val.second.scalar()) diff --git a/python/pyarrow/includes/libarrow_dataset.pxd b/python/pyarrow/includes/libarrow_dataset.pxd index 8cab5536647..f9349f3a642 100644 --- a/python/pyarrow/includes/libarrow_dataset.pxd +++ b/python/pyarrow/includes/libarrow_dataset.pxd @@ -32,6 +32,26 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil: pass +cdef extern from * namespace "arrow::compute": + # inlined from expression_internal.h to avoid + # proliferation of #include + """ + #include + + #include "arrow/type.h" + #include "arrow/datum.h" + + namespace arrow { + namespace compute { + struct KnownFieldValues { + std::unordered_map map; + }; + } // namespace compute + } // namespace arrow + """ + cdef struct CKnownFieldValues "arrow::compute::KnownFieldValues": + unordered_map[CFieldRef, CDatum, CFieldRefHash] map + cdef extern from "arrow/compute/exec/expression.h" \ namespace "arrow::compute" nogil: @@ -57,7 +77,7 @@ cdef extern from "arrow/compute/exec/expression.h" \ cdef CResult[CExpression] CDeserializeExpression \ "arrow::compute::Deserialize"(shared_ptr[CBuffer]) - cdef CResult[unordered_map[CFieldRef, CDatum, CFieldRefHash]] \ + cdef CResult[CKnownFieldValues] \ CExtractKnownFieldValues "arrow::compute::ExtractKnownFieldValues"( const CExpression& partition_expression) From 256c29c45850c3443bcf0d3135636ffe0ed3ef5e Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Mon, 28 Jun 2021 17:08:00 -0400 Subject: [PATCH 24/28] review comments --- cpp/src/arrow/compute/exec/expression.cc | 3 +- cpp/src/arrow/compute/exec/plan_test.cc | 28 +--- cpp/src/arrow/compute/exec/test_util.cc | 23 +++ cpp/src/arrow/compute/exec/test_util.h | 9 +- cpp/src/arrow/dataset/scanner_test.cc | 1 + cpp/src/arrow/result_test.cc | 1 + cpp/src/arrow/status_test.cc | 1 + cpp/src/arrow/testing/gtest_util.h | 151 ------------------- cpp/src/arrow/testing/matchers.h | 177 +++++++++++++++++++++++ cpp/src/arrow/util/future.h | 25 +--- cpp/src/arrow/util/future_test.cc | 1 + cpp/src/arrow/util/thread_pool_test.cc | 1 + 12 files changed, 221 insertions(+), 200 deletions(-) create mode 100644 cpp/src/arrow/testing/matchers.h diff --git a/cpp/src/arrow/compute/exec/expression.cc b/cpp/src/arrow/compute/exec/expression.cc index 24816783851..ac8c79db1ba 100644 --- a/cpp/src/arrow/compute/exec/expression.cc +++ b/cpp/src/arrow/compute/exec/expression.cc @@ -29,6 +29,7 @@ #include "arrow/ipc/reader.h" #include "arrow/ipc/writer.h" #include "arrow/util/atomic_shared_ptr.h" +#include "arrow/util/hash_util.h" #include "arrow/util/key_value_metadata.h" #include "arrow/util/logging.h" #include "arrow/util/optional.h" @@ -65,7 +66,7 @@ Expression call(std::string function, std::vector arguments, call.hash = std::hash{}(call.function_name); for (const auto& arg : call.arguments) { - call.hash ^= arg.hash(); + arrow::internal::hash_combine(call.hash, arg.hash()); } return Expression(std::move(call)); } diff --git a/cpp/src/arrow/compute/exec/plan_test.cc b/cpp/src/arrow/compute/exec/plan_test.cc index dc5845977da..16e97f593b8 100644 --- a/cpp/src/arrow/compute/exec/plan_test.cc +++ b/cpp/src/arrow/compute/exec/plan_test.cc @@ -15,11 +15,11 @@ // specific language governing permissions and limitations // under the License. -#include - #include #include +#include + #include "arrow/compute/exec.h" #include "arrow/compute/exec/exec_plan.h" #include "arrow/compute/exec/expression.h" @@ -27,7 +27,9 @@ #include "arrow/record_batch.h" #include "arrow/testing/future_util.h" #include "arrow/testing/gtest_util.h" +#include "arrow/testing/matchers.h" #include "arrow/testing/random.h" +#include "arrow/util/async_generator.h" #include "arrow/util/logging.h" #include "arrow/util/thread_pool.h" #include "arrow/util/vector.h" @@ -40,28 +42,6 @@ namespace arrow { namespace compute { -ExecBatch ExecBatchFromJSON(const std::vector& descrs, - util::string_view json) { - auto fields = internal::MapVector( - [](const ValueDescr& descr) { return field("", descr.type); }, descrs); - - ExecBatch batch{*RecordBatchFromJSON(schema(std::move(fields)), json)}; - - auto value_it = batch.values.begin(); - for (const auto& descr : descrs) { - if (descr.shape == ValueDescr::SCALAR) { - if (batch.length == 0) { - *value_it = MakeNullScalar(value_it->type()); - } else { - *value_it = value_it->make_array()->GetScalar(0).ValueOrDie(); - } - } - ++value_it; - } - - return batch; -} - TEST(ExecPlanConstruction, Empty) { ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); diff --git a/cpp/src/arrow/compute/exec/test_util.cc b/cpp/src/arrow/compute/exec/test_util.cc index e5bf61a1808..6fbfa2a430c 100644 --- a/cpp/src/arrow/compute/exec/test_util.cc +++ b/cpp/src/arrow/compute/exec/test_util.cc @@ -39,6 +39,7 @@ #include "arrow/util/iterator.h" #include "arrow/util/logging.h" #include "arrow/util/optional.h" +#include "arrow/util/vector.h" namespace arrow { @@ -128,5 +129,27 @@ ExecNode* MakeDummyNode(ExecPlan* plan, std::string label, std::vector& descrs, + util::string_view json) { + auto fields = internal::MapVector( + [](const ValueDescr& descr) { return field("", descr.type); }, descrs); + + ExecBatch batch{*RecordBatchFromJSON(schema(std::move(fields)), json)}; + + auto value_it = batch.values.begin(); + for (const auto& descr : descrs) { + if (descr.shape == ValueDescr::SCALAR) { + if (batch.length == 0) { + *value_it = MakeNullScalar(value_it->type()); + } else { + *value_it = value_it->make_array()->GetScalar(0).ValueOrDie(); + } + } + ++value_it; + } + + return batch; +} + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/exec/test_util.h b/cpp/src/arrow/compute/exec/test_util.h index 60423548614..faa395bab78 100644 --- a/cpp/src/arrow/compute/exec/test_util.h +++ b/cpp/src/arrow/compute/exec/test_util.h @@ -18,16 +18,13 @@ #pragma once #include -#include #include #include +#include "arrow/compute/exec.h" #include "arrow/compute/exec/exec_plan.h" -#include "arrow/record_batch.h" #include "arrow/testing/visibility.h" -#include "arrow/util/async_generator.h" #include "arrow/util/string_view.h" -#include "arrow/util/type_fwd.h" namespace arrow { namespace compute { @@ -40,5 +37,9 @@ ARROW_TESTING_EXPORT ExecNode* MakeDummyNode(ExecPlan* plan, std::string label, std::vector inputs, int num_outputs, StartProducingFunc = {}, StopProducingFunc = {}); +ARROW_TESTING_EXPORT +ExecBatch ExecBatchFromJSON(const std::vector& descrs, + util::string_view json); + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/dataset/scanner_test.cc b/cpp/src/arrow/dataset/scanner_test.cc index ef129db2c9a..98b218512a4 100644 --- a/cpp/src/arrow/dataset/scanner_test.cc +++ b/cpp/src/arrow/dataset/scanner_test.cc @@ -34,6 +34,7 @@ #include "arrow/testing/future_util.h" #include "arrow/testing/generator.h" #include "arrow/testing/gtest_util.h" +#include "arrow/testing/matchers.h" #include "arrow/testing/util.h" #include "arrow/util/range.h" #include "arrow/util/vector.h" diff --git a/cpp/src/arrow/result_test.cc b/cpp/src/arrow/result_test.cc index b814e3f3ea1..cb645bc7402 100644 --- a/cpp/src/arrow/result_test.cc +++ b/cpp/src/arrow/result_test.cc @@ -27,6 +27,7 @@ #include "arrow/testing/gtest_compat.h" #include "arrow/testing/gtest_util.h" +#include "arrow/testing/matchers.h" namespace arrow { diff --git a/cpp/src/arrow/status_test.cc b/cpp/src/arrow/status_test.cc index b4e5c288d1b..10a79d9b990 100644 --- a/cpp/src/arrow/status_test.cc +++ b/cpp/src/arrow/status_test.cc @@ -22,6 +22,7 @@ #include "arrow/status.h" #include "arrow/testing/gtest_util.h" +#include "arrow/testing/matchers.h" namespace arrow { diff --git a/cpp/src/arrow/testing/gtest_util.h b/cpp/src/arrow/testing/gtest_util.h index b1a4151d6c5..591745151da 100644 --- a/cpp/src/arrow/testing/gtest_util.h +++ b/cpp/src/arrow/testing/gtest_util.h @@ -28,7 +28,6 @@ #include #include -#include #include #include "arrow/array/builder_binary.h" @@ -285,156 +284,6 @@ ARROW_TESTING_EXPORT void AssertZeroPadded(const Array& array); ARROW_TESTING_EXPORT void TestInitialized(const ArrayData& array); ARROW_TESTING_EXPORT void TestInitialized(const Array& array); -template -class ResultMatcher { - public: - explicit ResultMatcher(ValueMatcher value_matcher) - : value_matcher_(std::move(value_matcher)) {} - - template ::type::ValueType> - operator testing::Matcher() const { // NOLINT runtime/explicit - struct Impl : testing::MatcherInterface { - explicit Impl(const ValueMatcher& value_matcher) - : value_matcher_(testing::MatcherCast(value_matcher)) {} - - void DescribeTo(::std::ostream* os) const override { - *os << "value "; - value_matcher_.DescribeTo(os); - } - - void DescribeNegationTo(::std::ostream* os) const override { - *os << "value "; - value_matcher_.DescribeNegationTo(os); - } - - bool MatchAndExplain(const Res& maybe_value, - testing::MatchResultListener* listener) const override { - if (!maybe_value.status().ok()) { - *listener << "whose error " - << testing::PrintToString(maybe_value.status().ToString()) - << " doesn't match"; - return false; - } - const ValueType& value = GetValue(maybe_value); - testing::StringMatchResultListener value_listener; - const bool match = value_matcher_.MatchAndExplain(value, &value_listener); - *listener << "whose value " << testing::PrintToString(value) - << (match ? " matches" : " doesn't match"); - testing::internal::PrintIfNotEmpty(value_listener.str(), listener->stream()); - return match; - } - - const testing::Matcher value_matcher_; - }; - - return testing::Matcher(new Impl(value_matcher_)); - } - - private: - template - static const T& GetValue(const Result& maybe_value) { - return maybe_value.ValueOrDie(); - } - - template - static const T& GetValue(const Future& value_fut) { - return GetValue(value_fut.result()); - } - - const ValueMatcher value_matcher_; -}; - -class StatusMatcher { - public: - explicit StatusMatcher(StatusCode code, - util::optional> message_matcher) - : code_(code), message_matcher_(std::move(message_matcher)) {} - - template - operator testing::Matcher() const { // NOLINT runtime/explicit - struct Impl : testing::MatcherInterface { - explicit Impl(StatusCode code, - util::optional> message_matcher) - : code_(code), message_matcher_(std::move(message_matcher)) {} - - void DescribeTo(::std::ostream* os) const override { - *os << "raises StatusCode::" << Status::CodeAsString(code_); - if (message_matcher_) { - *os << " and message "; - message_matcher_->DescribeTo(os); - } - } - - void DescribeNegationTo(::std::ostream* os) const override { - *os << "does not raise StatusCode::" << Status::CodeAsString(code_); - if (message_matcher_) { - *os << " or message "; - message_matcher_->DescribeNegationTo(os); - } - } - - bool MatchAndExplain(const Res& maybe_value, - testing::MatchResultListener* listener) const override { - const Status& status = GetStatus(maybe_value); - testing::StringMatchResultListener value_listener; - - bool match = status.code() == code_; - if (message_matcher_) { - match = match && - message_matcher_->MatchAndExplain(status.message(), &value_listener); - } - - *listener << "whose value " << testing::PrintToString(status.ToString()) - << (match ? " matches" : " doesn't match"); - testing::internal::PrintIfNotEmpty(value_listener.str(), listener->stream()); - return match; - } - - const StatusCode code_; - const util::optional> message_matcher_; - }; - - return testing::Matcher(new Impl(code_, message_matcher_)); - } - - private: - static const Status& GetStatus(const Status& status) { return status; } - - template - static const Status& GetStatus(const Result& maybe_value) { - return maybe_value.status(); - } - - template - static const Status& GetStatus(const Future& value_fut) { - return value_fut.status(); - } - - const StatusCode code_; - const util::optional> message_matcher_; -}; - -// Returns a matcher that matches the value of a successful Result or Future. -// (Future will be waited upon to acquire its result for matching.) -template -ResultMatcher ResultWith(const ValueMatcher& value_matcher) { - return ResultMatcher(value_matcher); -} - -// Returns a matcher that matches the StatusCode of a Status, Result, or Future. -// (Future will be waited upon to acquire its result for matching.) -inline StatusMatcher Raises(StatusCode code) { - return StatusMatcher(code, util::nullopt); -} - -// Returns a matcher that matches the StatusCode and message of a Status, Result, or -// Future. (Future will be waited upon to acquire its result for matching.) -template -StatusMatcher Raises(StatusCode code, const MessageMatcher& message_matcher) { - return StatusMatcher(code, testing::MatcherCast(message_matcher)); -} - template void FinishAndCheckPadding(BuilderType* builder, std::shared_ptr* out) { ASSERT_OK_AND_ASSIGN(*out, builder->Finish()); diff --git a/cpp/src/arrow/testing/matchers.h b/cpp/src/arrow/testing/matchers.h new file mode 100644 index 00000000000..246f321e8fa --- /dev/null +++ b/cpp/src/arrow/testing/matchers.h @@ -0,0 +1,177 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include + +#include "arrow/result.h" +#include "arrow/status.h" + +namespace arrow { + +template +class ResultMatcher { + public: + explicit ResultMatcher(ValueMatcher value_matcher) + : value_matcher_(std::move(value_matcher)) {} + + template ::type::ValueType> + operator testing::Matcher() const { // NOLINT runtime/explicit + struct Impl : testing::MatcherInterface { + explicit Impl(const ValueMatcher& value_matcher) + : value_matcher_(testing::MatcherCast(value_matcher)) {} + + void DescribeTo(::std::ostream* os) const override { + *os << "value "; + value_matcher_.DescribeTo(os); + } + + void DescribeNegationTo(::std::ostream* os) const override { + *os << "value "; + value_matcher_.DescribeNegationTo(os); + } + + bool MatchAndExplain(const Res& maybe_value, + testing::MatchResultListener* listener) const override { + if (!maybe_value.status().ok()) { + *listener << "whose error " + << testing::PrintToString(maybe_value.status().ToString()) + << " doesn't match"; + return false; + } + const ValueType& value = GetValue(maybe_value); + testing::StringMatchResultListener value_listener; + const bool match = value_matcher_.MatchAndExplain(value, &value_listener); + *listener << "whose value " << testing::PrintToString(value) + << (match ? " matches" : " doesn't match"); + testing::internal::PrintIfNotEmpty(value_listener.str(), listener->stream()); + return match; + } + + const testing::Matcher value_matcher_; + }; + + return testing::Matcher(new Impl(value_matcher_)); + } + + private: + template + static const T& GetValue(const Result& maybe_value) { + return maybe_value.ValueOrDie(); + } + + template + static const T& GetValue(const Future& value_fut) { + return GetValue(value_fut.result()); + } + + const ValueMatcher value_matcher_; +}; + +class StatusMatcher { + public: + explicit StatusMatcher(StatusCode code, + util::optional> message_matcher) + : code_(code), message_matcher_(std::move(message_matcher)) {} + + template + operator testing::Matcher() const { // NOLINT runtime/explicit + struct Impl : testing::MatcherInterface { + explicit Impl(StatusCode code, + util::optional> message_matcher) + : code_(code), message_matcher_(std::move(message_matcher)) {} + + void DescribeTo(::std::ostream* os) const override { + *os << "raises StatusCode::" << Status::CodeAsString(code_); + if (message_matcher_) { + *os << " and message "; + message_matcher_->DescribeTo(os); + } + } + + void DescribeNegationTo(::std::ostream* os) const override { + *os << "does not raise StatusCode::" << Status::CodeAsString(code_); + if (message_matcher_) { + *os << " or message "; + message_matcher_->DescribeNegationTo(os); + } + } + + bool MatchAndExplain(const Res& maybe_value, + testing::MatchResultListener* listener) const override { + const Status& status = GetStatus(maybe_value); + testing::StringMatchResultListener value_listener; + + bool match = status.code() == code_; + if (message_matcher_) { + match = match && + message_matcher_->MatchAndExplain(status.message(), &value_listener); + } + + *listener << "whose value " << testing::PrintToString(status.ToString()) + << (match ? " matches" : " doesn't match"); + testing::internal::PrintIfNotEmpty(value_listener.str(), listener->stream()); + return match; + } + + const StatusCode code_; + const util::optional> message_matcher_; + }; + + return testing::Matcher(new Impl(code_, message_matcher_)); + } + + private: + static const Status& GetStatus(const Status& status) { return status; } + + template + static const Status& GetStatus(const Result& maybe_value) { + return maybe_value.status(); + } + + template + static const Status& GetStatus(const Future& value_fut) { + return value_fut.status(); + } + + const StatusCode code_; + const util::optional> message_matcher_; +}; + +// Returns a matcher that matches the value of a successful Result or Future. +// (Future will be waited upon to acquire its result for matching.) +template +ResultMatcher ResultWith(const ValueMatcher& value_matcher) { + return ResultMatcher(value_matcher); +} + +// Returns a matcher that matches the StatusCode of a Status, Result, or Future. +// (Future will be waited upon to acquire its result for matching.) +inline StatusMatcher Raises(StatusCode code) { + return StatusMatcher(code, util::nullopt); +} + +// Returns a matcher that matches the StatusCode and message of a Status, Result, or +// Future. (Future will be waited upon to acquire its result for matching.) +template +StatusMatcher Raises(StatusCode code, const MessageMatcher& message_matcher) { + return StatusMatcher(code, testing::MatcherCast(message_matcher)); +} + +} // namespace arrow diff --git a/cpp/src/arrow/util/future.h b/cpp/src/arrow/util/future.h index c2e754911eb..c7c5ba802f9 100644 --- a/cpp/src/arrow/util/future.h +++ b/cpp/src/arrow/util/future.h @@ -979,26 +979,6 @@ Future Loop(Iterate iterate) { return break_fut; } -template -struct EnsureFuture { - using type = Future; -}; - -template -struct EnsureFuture> { - using type = Future; -}; - -template -struct EnsureFuture> { - using type = Future; -}; - -template <> -struct EnsureFuture { - using type = Future<>; -}; - inline Future<> ToFuture(Status status) { return Future<>::MakeFinished(std::move(status)); } @@ -1018,4 +998,9 @@ Future ToFuture(Future fut) { return std::move(fut); } +template +struct EnsureFuture { + using type = decltype(ToFuture(std::declval())); +}; + } // namespace arrow diff --git a/cpp/src/arrow/util/future_test.cc b/cpp/src/arrow/util/future_test.cc index 773ac09e359..b25d77c48cd 100644 --- a/cpp/src/arrow/util/future_test.cc +++ b/cpp/src/arrow/util/future_test.cc @@ -36,6 +36,7 @@ #include "arrow/testing/executor_util.h" #include "arrow/testing/future_util.h" #include "arrow/testing/gtest_util.h" +#include "arrow/testing/matchers.h" #include "arrow/util/logging.h" #include "arrow/util/thread_pool.h" diff --git a/cpp/src/arrow/util/thread_pool_test.cc b/cpp/src/arrow/util/thread_pool_test.cc index 92f6a8ac00a..399c755a8f9 100644 --- a/cpp/src/arrow/util/thread_pool_test.cc +++ b/cpp/src/arrow/util/thread_pool_test.cc @@ -408,6 +408,7 @@ TEST_F(TestThreadPool, OwnsCurrentThread) { } ASSERT_OK(pool->Shutdown()); + ASSERT_FALSE(pool->OwnsThisThread()); ASSERT_FALSE(one_failed); } From 02c7eb481efc86a2b57cbc626554b4acde581135 Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Tue, 29 Jun 2021 12:07:57 -0400 Subject: [PATCH 25/28] consistent end signaling for Enumerated --- cpp/src/arrow/dataset/scanner.cc | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/cpp/src/arrow/dataset/scanner.cc b/cpp/src/arrow/dataset/scanner.cc index fc8cb5e669d..58e96fdc113 100644 --- a/cpp/src/arrow/dataset/scanner.cc +++ b/cpp/src/arrow/dataset/scanner.cc @@ -618,13 +618,17 @@ Result AsyncScanner::ScanBatchesAsync() { Result AsyncScanner::ScanBatchesAsync( internal::Executor* cpu_executor) { ARROW_ASSIGN_OR_RAISE(auto unordered, ScanBatchesUnorderedAsync(cpu_executor)); - auto left_after_right = [](const EnumeratedRecordBatch& left, - const EnumeratedRecordBatch& right) { + // We need an initial value sentinel, so we use one with fragment.index < 0 + auto is_before_any = [](const EnumeratedRecordBatch& batch) { + return batch.fragment.index < 0; + }; + auto left_after_right = [&is_before_any](const EnumeratedRecordBatch& left, + const EnumeratedRecordBatch& right) { // Before any comes first - if (left.fragment.value == nullptr) { + if (is_before_any(left)) { return false; } - if (right.fragment.value == nullptr) { + if (is_before_any(right)) { return true; } // Compare batches if fragment is the same @@ -634,10 +638,10 @@ Result AsyncScanner::ScanBatchesAsync( // Otherwise compare fragment return left.fragment.index > right.fragment.index; }; - auto is_next = [](const EnumeratedRecordBatch& prev, - const EnumeratedRecordBatch& next) { + auto is_next = [is_before_any](const EnumeratedRecordBatch& prev, + const EnumeratedRecordBatch& next) { // Only true if next is the first batch - if (prev.fragment.value == nullptr) { + if (is_before_any(prev)) { return next.fragment.index == 0 && next.record_batch.index == 0; } // If same fragment, compare batch index From 5133e7eed5b51650901c925db6a79c04ffbc9599 Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Wed, 30 Jun 2021 12:24:55 -0400 Subject: [PATCH 26/28] transfer from background thread --- cpp/src/arrow/compute/exec/plan_test.cc | 3 +++ 1 file changed, 3 insertions(+) diff --git a/cpp/src/arrow/compute/exec/plan_test.cc b/cpp/src/arrow/compute/exec/plan_test.cc index 16e97f593b8..32de67a6670 100644 --- a/cpp/src/arrow/compute/exec/plan_test.cc +++ b/cpp/src/arrow/compute/exec/plan_test.cc @@ -208,6 +208,9 @@ Result MakeTestSourceNode(ExecPlan* plan, std::string label, ARROW_ASSIGN_OR_RAISE( gen, MakeBackgroundGenerator(MakeVectorIterator(std::move(opt_batches)), internal::GetCpuThreadPool())); + + // ensure that callbacks are not executed immediately on a background thread + gen = MakeTransferredGenerator(std::move(gen), internal::GetCpuThreadPool()); } else { gen = MakeVectorGenerator(std::move(opt_batches)); } From e91ef9fda5bfd0704b792e787b75e00d3066bdd5 Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Wed, 30 Jun 2021 13:46:29 -0400 Subject: [PATCH 27/28] ensure that plans are stopped before they are destroyed --- cpp/src/arrow/compute/exec/exec_plan.cc | 83 ++++++++++++++----------- cpp/src/arrow/compute/exec/exec_plan.h | 2 +- cpp/src/arrow/compute/exec/plan_test.cc | 2 +- cpp/src/arrow/dataset/scanner_test.cc | 2 +- 4 files changed, 49 insertions(+), 40 deletions(-) diff --git a/cpp/src/arrow/compute/exec/exec_plan.cc b/cpp/src/arrow/compute/exec/exec_plan.cc index 7438af78b8f..2dcbfb24724 100644 --- a/cpp/src/arrow/compute/exec/exec_plan.cc +++ b/cpp/src/arrow/compute/exec/exec_plan.cc @@ -41,7 +41,11 @@ namespace { struct ExecPlanImpl : public ExecPlan { ExecPlanImpl() = default; - ~ExecPlanImpl() override = default; + ~ExecPlanImpl() override { + if (started_ && !stopped_) { + StopProducing(); + } + } ExecNode* AddNode(std::unique_ptr node) { if (node->num_inputs() == 0) { @@ -65,70 +69,73 @@ struct ExecPlanImpl : public ExecPlan { } Status StartProducing() { - ARROW_ASSIGN_OR_RAISE(auto sorted_nodes, ReverseTopoSort()); - Status st; - auto it = sorted_nodes.begin(); - while (it != sorted_nodes.end() && st.ok()) { - st &= (*it++)->StartProducing(); + if (started_) { + return Status::Invalid("restarted ExecPlan"); } - if (!st.ok()) { + started_ = true; + + // producers precede consumers + sorted_nodes_ = TopoSort(); + + for (size_t i = 0, rev_i = sorted_nodes_.size() - 1; i < sorted_nodes_.size(); + ++i, --rev_i) { + auto st = sorted_nodes_[rev_i]->StartProducing(); + if (st.ok()) continue; + // Stop nodes that successfully started, in reverse order - // (`it` now points after the node that failed starting, so need to rewind) - --it; - while (it != sorted_nodes.begin()) { - (*--it)->StopProducing(); + for (; rev_i < sorted_nodes_.size(); ++rev_i) { + sorted_nodes_[rev_i]->StopProducing(); } + return st; + } + return Status::OK(); + } + + void StopProducing() { + DCHECK(started_) << "stopped an ExecPlan which never started"; + stopped_ = true; + + for (const auto& node : sorted_nodes_) { + node->StopProducing(); } - return st; } - Result ReverseTopoSort() { - struct TopoSort { + NodeVector TopoSort() { + struct Impl { const std::vector>& nodes; std::unordered_set visited; NodeVector sorted; - explicit TopoSort(const std::vector>& nodes) - : nodes(nodes) { + explicit Impl(const std::vector>& nodes) : nodes(nodes) { visited.reserve(nodes.size()); - sorted.reserve(nodes.size()); - } + sorted.resize(nodes.size()); - Status Sort() { for (const auto& node : nodes) { - RETURN_NOT_OK(Visit(node.get())); + Visit(node.get()); } - DCHECK_EQ(sorted.size(), nodes.size()); + DCHECK_EQ(visited.size(), nodes.size()); - return Status::OK(); } - Status Visit(ExecNode* node) { - if (visited.count(node) != 0) { - return Status::OK(); - } + void Visit(ExecNode* node) { + if (visited.count(node) != 0) return; for (auto input : node->inputs()) { // Ensure that producers are inserted before this consumer - RETURN_NOT_OK(Visit(input)); + Visit(input); } + sorted[visited.size()] = node; visited.insert(node); - sorted.push_back(node); - return Status::OK(); - } - - NodeVector Reverse() { - std::reverse(sorted.begin(), sorted.end()); - return std::move(sorted); } - } topo_sort(nodes_); + }; - RETURN_NOT_OK(topo_sort.Sort()); - return topo_sort.Reverse(); + return std::move(Impl{nodes_}.sorted); } + bool started_ = false, stopped_ = false; std::vector> nodes_; + NodeVector sorted_nodes_; NodeVector sources_, sinks_; }; @@ -166,6 +173,8 @@ Status ExecPlan::Validate() { return ToDerived(this)->Validate(); } Status ExecPlan::StartProducing() { return ToDerived(this)->StartProducing(); } +void ExecPlan::StopProducing() { ToDerived(this)->StopProducing(); } + ExecNode::ExecNode(ExecPlan* plan, std::string label, NodeVector inputs, std::vector input_labels, std::shared_ptr output_schema, int num_outputs) diff --git a/cpp/src/arrow/compute/exec/exec_plan.h b/cpp/src/arrow/compute/exec/exec_plan.h index 42f9e6527c0..21a757af5a1 100644 --- a/cpp/src/arrow/compute/exec/exec_plan.h +++ b/cpp/src/arrow/compute/exec/exec_plan.h @@ -71,7 +71,7 @@ class ARROW_EXPORT ExecPlan : public std::enable_shared_from_this { /// is started before all of its inputs. Status StartProducing(); - // XXX should we also have `void StopProducing()`? + void StopProducing(); protected: ExecPlan() = default; diff --git a/cpp/src/arrow/compute/exec/plan_test.cc b/cpp/src/arrow/compute/exec/plan_test.cc index 32de67a6670..75b71f97535 100644 --- a/cpp/src/arrow/compute/exec/plan_test.cc +++ b/cpp/src/arrow/compute/exec/plan_test.cc @@ -234,7 +234,7 @@ Result> StartAndCollect( auto maybe_collected = CollectAsyncGenerator(gen).result(); ARROW_ASSIGN_OR_RAISE(auto collected, maybe_collected); - // RETURN_NOT_OK(plan->StopProducing()); + plan->StopProducing(); return internal::MapVector( [](util::optional batch) { return std::move(*batch); }, collected); diff --git a/cpp/src/arrow/dataset/scanner_test.cc b/cpp/src/arrow/dataset/scanner_test.cc index 98b218512a4..bed276b1bff 100644 --- a/cpp/src/arrow/dataset/scanner_test.cc +++ b/cpp/src/arrow/dataset/scanner_test.cc @@ -1102,7 +1102,7 @@ static Result> StartAndCollect( auto maybe_collected = CollectAsyncGenerator(gen).result(); ARROW_ASSIGN_OR_RAISE(auto collected, maybe_collected); - // RETURN_NOT_OK(plan->StopProducing()); + plan->StopProducing(); return internal::MapVector( [](util::optional batch) { return std::move(*batch); }, From 046e057e9f48384db73ff5f7eefec235ddb1a5b5 Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Thu, 1 Jul 2021 08:14:18 -0400 Subject: [PATCH 28/28] move hash into Expression::Expression(Call) to ensure it's always initialized --- cpp/src/arrow/compute/exec/expression.cc | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/cpp/src/arrow/compute/exec/expression.cc b/cpp/src/arrow/compute/exec/expression.cc index ac8c79db1ba..022584d5b39 100644 --- a/cpp/src/arrow/compute/exec/expression.cc +++ b/cpp/src/arrow/compute/exec/expression.cc @@ -43,7 +43,13 @@ using internal::checked_pointer_cast; namespace compute { -Expression::Expression(Call call) : impl_(std::make_shared(std::move(call))) {} +Expression::Expression(Call call) { + call.hash = std::hash{}(call.function_name); + for (const auto& arg : call.arguments) { + arrow::internal::hash_combine(call.hash, arg.hash()); + } + impl_ = std::make_shared(std::move(call)); +} Expression::Expression(Datum literal) : impl_(std::make_shared(std::move(literal))) {} @@ -63,11 +69,6 @@ Expression call(std::string function, std::vector arguments, call.function_name = std::move(function); call.arguments = std::move(arguments); call.options = std::move(options); - - call.hash = std::hash{}(call.function_name); - for (const auto& arg : call.arguments) { - arrow::internal::hash_combine(call.hash, arg.hash()); - } return Expression(std::move(call)); }