From f3165784b7b98b06e4859b02b43dc3665fe9c01e Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Tue, 13 Aug 2019 14:26:07 -0400 Subject: [PATCH 01/28] add basic filter expressions --- cpp/src/arrow/dataset/CMakeLists.txt | 33 +-- cpp/src/arrow/dataset/api.h | 1 + cpp/src/arrow/dataset/filter.cc | 375 +++++++++++++++++++++++++++ cpp/src/arrow/dataset/filter.h | 187 ++++++++++++- cpp/src/arrow/dataset/filter_test.cc | 46 ++++ 5 files changed, 624 insertions(+), 18 deletions(-) create mode 100644 cpp/src/arrow/dataset/filter.cc create mode 100644 cpp/src/arrow/dataset/filter_test.cc diff --git a/cpp/src/arrow/dataset/CMakeLists.txt b/cpp/src/arrow/dataset/CMakeLists.txt index 9f102ac1601..6f94297e272 100644 --- a/cpp/src/arrow/dataset/CMakeLists.txt +++ b/cpp/src/arrow/dataset/CMakeLists.txt @@ -23,7 +23,7 @@ arrow_install_all_headers("arrow/dataset") # pkg-config support arrow_add_pkg_config("arrow-dataset") -set(ARROW_DATASET_SRCS dataset.cc file_base.cc scanner.cc) +set(ARROW_DATASET_SRCS dataset.cc file_base.cc filter.cc scanner.cc) set(ARROW_DATASET_LINK_STATIC arrow_static) set(ARROW_DATASET_LINK_SHARED arrow_shared) @@ -79,22 +79,23 @@ function(ADD_ARROW_DATASET_TEST REL_TEST_NAME) set(LABELS "arrow_dataset") endif() - if(NOT WIN32) - add_arrow_test(${REL_TEST_NAME} - EXTRA_LINK_LIBS - ${ARROW_DATASET_TEST_LINK_LIBS} - PREFIX - ${PREFIX} - LABELS - ${LABELS} - ${ARG_UNPARSED_ARGUMENTS}) - endif() + add_arrow_test(${REL_TEST_NAME} + EXTRA_LINK_LIBS + ${ARROW_DATASET_TEST_LINK_LIBS} + PREFIX + ${PREFIX} + LABELS + ${LABELS} + ${ARG_UNPARSED_ARGUMENTS}) endfunction() -add_arrow_dataset_test(dataset_test) -add_arrow_dataset_test(file_test) -add_arrow_dataset_test(scanner_test) +if(NOT WIN32) + add_arrow_dataset_test(dataset_test) + add_arrow_dataset_test(file_test) + add_arrow_dataset_test(filter_test) + add_arrow_dataset_test(scanner_test) -if(ARROW_PARQUET) - add_arrow_dataset_test(file_parquet_test) + if(ARROW_PARQUET) + add_arrow_dataset_test(file_parquet_test) + endif() endif() diff --git a/cpp/src/arrow/dataset/api.h b/cpp/src/arrow/dataset/api.h index 18622f3a448..f9e49f20994 100644 --- a/cpp/src/arrow/dataset/api.h +++ b/cpp/src/arrow/dataset/api.h @@ -22,4 +22,5 @@ #include "arrow/dataset/file_base.h" #include "arrow/dataset/file_csv.h" #include "arrow/dataset/file_feather.h" +#include "arrow/dataset/filter.h" #include "arrow/dataset/scanner.h" diff --git a/cpp/src/arrow/dataset/filter.cc b/cpp/src/arrow/dataset/filter.cc new file mode 100644 index 00000000000..581fe574964 --- /dev/null +++ b/cpp/src/arrow/dataset/filter.cc @@ -0,0 +1,375 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/dataset/filter.h" + +#include + +#include "arrow/buffer.h" +#include "arrow/util/logging.h" +#include "arrow/visitor_inline.h" + +namespace arrow { +namespace dataset { + +using internal::checked_cast; + +std::shared_ptr ScalarExpression::Make(std::string value) { + return std::make_shared( + std::make_shared(Buffer::FromString(std::move(value)))); +} + +// return a pair +std::pair IsBoolean(const Expression& e) { + if (e.type() == ExpressionType::SCALAR) { + auto value = checked_cast(e).value(); + if (value->type->id() == Type::BOOL) { + // FIXME(bkietz) null scalars do what? + return {true, checked_cast(*value).value}; + } + } + return {false, false}; +} + +Result> AssumeIfOperator(const std::shared_ptr& e, + const Expression& given) { + if (!e->IsOperatorExpression()) { + return e; + } + return checked_cast(*e).Assume(given); +} + +struct CompareVisitor { + Status Visit(const BooleanType&) { + result_ = checked_cast(lhs_).value - + checked_cast(rhs_).value; + return Status::OK(); + } + + Status Visit(const Int64Type&) { + result_ = checked_cast(lhs_).value - + checked_cast(rhs_).value; + return Status::OK(); + } + + Status Visit(const DoubleType&) { + double result = checked_cast(lhs_).value - + checked_cast(rhs_).value; + result_ = result < 0.0 ? -1 : result > 0.0 ? +1 : 0; + return Status::OK(); + } + + Status Visit(const StringType&) { + auto lhs = checked_cast(lhs_).value; + auto rhs = checked_cast(rhs_).value; + result_ = std::memcmp(lhs->data(), rhs->data(), std::min(lhs->size(), rhs->size())); + if (result_ == 0) { + result_ = lhs->size() - rhs->size(); + } + return Status::OK(); + } + + Status Visit(const DataType&) { + return Status::NotImplemented("comparison of scalars of type ", *lhs_.type); + } + + int64_t result_; + const Scalar& lhs_; + const Scalar& rhs_; +}; + +Result Compare(const Scalar& lhs, const Scalar& rhs) { + CompareVisitor vis{0, lhs, rhs}; + RETURN_NOT_OK(VisitTypeInline(*lhs.type, &vis)); + return vis.result_; +} + +Result> AssumeComparison(const OperatorExpression& e, + const OperatorExpression& given) { + auto e_rhs = checked_cast(*e.operands()[1]).value(); + // TODO(bkietz) allow the RHS of given to be FROM_STRING + auto given_rhs = checked_cast(*given.operands()[1]).value(); + + ARROW_ASSIGN_OR_RAISE(auto cmp, Compare(*e_rhs, *given_rhs)); + + static auto always = ScalarExpression::Make(true); + static auto never = ScalarExpression::Make(false); + auto unsimplified = MakeShared(e); + + if (cmp == 0) { + // the rhs of the comparisons are equal + switch (e.type()) { + case ExpressionType::EQUAL: + switch (given.type()) { + case ExpressionType::NOT_EQUAL: + case ExpressionType::GREATER: + case ExpressionType::LESS: + return never; + case ExpressionType::EQUAL: + case ExpressionType::GREATER_EQUAL: + case ExpressionType::LESS_EQUAL: + return always; + default: + return unsimplified; + } + case ExpressionType::NOT_EQUAL: + switch (given.type()) { + case ExpressionType::EQUAL: + return never; + case ExpressionType::NOT_EQUAL: + case ExpressionType::GREATER: + case ExpressionType::LESS: + return always; + default: + return unsimplified; + } + case ExpressionType::GREATER: + switch (given.type()) { + case ExpressionType::EQUAL: + case ExpressionType::LESS_EQUAL: + case ExpressionType::LESS: + return never; + case ExpressionType::GREATER: + return always; + default: + return unsimplified; + } + case ExpressionType::GREATER_EQUAL: + switch (given.type()) { + case ExpressionType::LESS: + return never; + case ExpressionType::EQUAL: + case ExpressionType::GREATER: + case ExpressionType::GREATER_EQUAL: + return always; + default: + return unsimplified; + } + case ExpressionType::LESS: + switch (given.type()) { + case ExpressionType::EQUAL: + case ExpressionType::GREATER: + case ExpressionType::GREATER_EQUAL: + return never; + case ExpressionType::LESS: + return always; + default: + return unsimplified; + } + case ExpressionType::LESS_EQUAL: + switch (given.type()) { + case ExpressionType::GREATER: + return never; + case ExpressionType::EQUAL: + case ExpressionType::LESS: + case ExpressionType::LESS_EQUAL: + return always; + default: + return unsimplified; + } + default: + return unsimplified; + } + } else if (cmp > 0) { + // the rhs of e is greater than that of given + switch (e.type()) { + case ExpressionType::EQUAL: + case ExpressionType::GREATER: + case ExpressionType::GREATER_EQUAL: + switch (given.type()) { + case ExpressionType::EQUAL: + case ExpressionType::LESS: + case ExpressionType::LESS_EQUAL: + return never; + default: + return unsimplified; + } + case ExpressionType::NOT_EQUAL: + case ExpressionType::LESS: + case ExpressionType::LESS_EQUAL: + switch (given.type()) { + case ExpressionType::EQUAL: + case ExpressionType::LESS: + case ExpressionType::LESS_EQUAL: + return always; + default: + return unsimplified; + } + default: + return unsimplified; + } + } else { + // the rhs of e is less than that of given + switch (e.type()) { + case ExpressionType::EQUAL: + case ExpressionType::LESS: + case ExpressionType::LESS_EQUAL: + switch (given.type()) { + case ExpressionType::EQUAL: + case ExpressionType::GREATER: + case ExpressionType::GREATER_EQUAL: + return never; + default: + return unsimplified; + } + case ExpressionType::NOT_EQUAL: + case ExpressionType::GREATER: + case ExpressionType::GREATER_EQUAL: + switch (given.type()) { + case ExpressionType::EQUAL: + case ExpressionType::GREATER: + case ExpressionType::GREATER_EQUAL: + return always; + default: + return unsimplified; + } + default: + return unsimplified; + } + } + + return unsimplified; +} + +std::string GetName(const Expression& e) { + DCHECK_EQ(e.type(), ExpressionType::FIELD); + return checked_cast(e).name(); +} + +Result> OperatorExpression::Assume( + const Expression& given) const { + auto unsimplified = MakeShared(*this); + + if (IsComparisonExpression()) { + if (!given.IsOperatorExpression()) { + return unsimplified; + } + + const auto& given_op = checked_cast(given); + + if (given.IsComparisonExpression()) { + // Both this and given are simple comparisons. If they constrain + // the same field, try to simplify this assuming given + DCHECK_EQ(operands_.size(), 2); + DCHECK_EQ(given_op.operands_.size(), 2); + if (GetName(*operands_[0]) != GetName(*given_op.operands_[0])) { + return unsimplified; + } + return AssumeComparison(*this, given_op); + } + + // must be NOT, AND, OR- decompose given + switch (given.type()) { + case ExpressionType::NOT: { + return unsimplified; + } + + case ExpressionType::OR: { + bool simplify_to_always = true; + bool simplify_to_never = true; + for (auto operand : given_op.operands_) { + ARROW_ASSIGN_OR_RAISE(auto simplified, Assume(*operand)); + auto isbool_value = IsBoolean(*simplified); + if (!isbool_value.first) { + return unsimplified; + } + + if (isbool_value.second) { + simplify_to_never = false; + } else { + simplify_to_always = false; + } + } + + if (simplify_to_always) { + return ScalarExpression::Make(true); + } + + if (simplify_to_never) { + return ScalarExpression::Make(false); + } + + return unsimplified; + } + + case ExpressionType::AND: { + std::shared_ptr simplified = unsimplified; + for (auto operand : given_op.operands_) { + auto isbool_value = IsBoolean(*simplified); + if (isbool_value.first) { + break; + } + DCHECK(simplified->IsOperatorExpression()); + const auto& simplified_op = + checked_cast(*simplified); + ARROW_ASSIGN_OR_RAISE(simplified, simplified_op.Assume(*operand)); + } + return simplified; + } + + default: + DCHECK(false); + } + } + + switch (type_) { + case ExpressionType::NOT: { + DCHECK_EQ(operands_.size(), 1); + ARROW_ASSIGN_OR_RAISE(auto operand, AssumeIfOperator(operands_[0], given)); + auto isbool_value = IsBoolean(*operand); + if (isbool_value.first) { + return ScalarExpression::Make(!isbool_value.second); + } + return std::make_shared( + ExpressionType::NOT, std::vector>{operand}); + } + + case ExpressionType::OR: + case ExpressionType::AND: { + // if any of the operands matches trivial_condition, we can return a trivial + // expression: + // anything OR true => true + // anything AND false => false + bool trivial_condition = type_ == ExpressionType::OR; + + std::vector> operands; + for (auto operand : operands_) { + ARROW_ASSIGN_OR_RAISE(operand, AssumeIfOperator(operand, given)); + + auto isbool_value = IsBoolean(*operand); + if (isbool_value.first) { + if (isbool_value.second == trivial_condition) { + return ScalarExpression::Make(trivial_condition); + } + continue; + } + + operands.push_back(operand); + } + + return std::make_shared(type_, std::move(operands)); + } + + default: + DCHECK(false); + } + + return unsimplified; +} + +} // namespace dataset +} // namespace arrow diff --git a/cpp/src/arrow/dataset/filter.h b/cpp/src/arrow/dataset/filter.h index a727b1ce4b8..70460209558 100644 --- a/cpp/src/arrow/dataset/filter.h +++ b/cpp/src/arrow/dataset/filter.h @@ -20,12 +20,13 @@ #include #include "arrow/dataset/visibility.h" +#include "arrow/result.h" +#include "arrow/scalar.h" namespace arrow { namespace dataset { -class ARROW_DS_EXPORT Filter { - public: +struct FilterType { enum type { /// Simple boolean predicate consisting of comparisons and boolean /// logic (AND, OR, NOT) involving Schema fields @@ -36,5 +37,187 @@ class ARROW_DS_EXPORT Filter { }; }; +class ARROW_DS_EXPORT Filter { + public: + explicit Filter(FilterType::type type) : type_(type) {} + + virtual ~Filter() = 0; + + FilterType::type type() const { return type_; } + + private: + FilterType::type type_; +}; + +class Expression; + +class ARROW_DS_EXPORT ExpressionFilter : public Filter { + public: + explicit ExpressionFilter(const std::shared_ptr& expression) + : Filter(FilterType::EXPRESSION), expression_(std::move(expression)) {} + + const std::shared_ptr& expression() const { return expression_; } + + private: + std::shared_ptr expression_; +}; + +struct ExpressionType { + enum type { + FIELD, + SCALAR, + FROM_STRING, + + NOT, + AND, + OR, + + EQUAL, + NOT_EQUAL, + GREATER, + GREATER_EQUAL, + LESS, + LESS_EQUAL, + }; +}; + +template +std::shared_ptr::type> MakeShared(T&& t) { + return std::make_shared::type>(std::forward(t)); +} + +class ARROW_DS_EXPORT Expression { + public: + explicit Expression(ExpressionType::type type) : type_(type) {} + + virtual ~Expression() = default; + + ExpressionType::type type() const { return type_; } + + bool IsOperatorExpression() const { + return static_cast(type_) >= static_cast(ExpressionType::NOT) && + static_cast(type_) <= static_cast(ExpressionType::LESS_EQUAL); + } + + bool IsComparisonExpression() const { + return static_cast(type_) >= static_cast(ExpressionType::EQUAL) && + static_cast(type_) <= static_cast(ExpressionType::LESS_EQUAL); + } + + protected: + ExpressionType::type type_; +}; + +class ARROW_DS_EXPORT OperatorExpression : public Expression { + public: + OperatorExpression(ExpressionType::type type, + std::vector> operands) + : Expression(type), operands_(std::move(operands)) {} + + const std::vector>& operands() const { return operands_; } + + OperatorExpression operator and(const OperatorExpression& other) const { + return OperatorExpression(ExpressionType::AND, + {MakeShared(*this), MakeShared(other)}); + } + + OperatorExpression operator or(const OperatorExpression& other) const { + return OperatorExpression(ExpressionType::OR, {MakeShared(*this), MakeShared(other)}); + } + + OperatorExpression operator not() const { + return OperatorExpression(ExpressionType::NOT, {MakeShared(*this)}); + } + + Result> Assume(const Expression& given) const; + + private: + std::vector> operands_; +}; + +class ARROW_DS_EXPORT ScalarExpression : public Expression { + public: + explicit ScalarExpression(const std::shared_ptr& value) + : Expression(ExpressionType::SCALAR), value_(std::move(value)) {} + + const std::shared_ptr& value() const { return value_; } + + static std::shared_ptr Make(bool value) { + return std::make_shared(std::make_shared(value)); + } + + template + static typename std::enable_if::value, + std::shared_ptr>::type + Make(T value) { + return std::make_shared(std::make_shared(value)); + } + + template + static typename std::enable_if::value, + std::shared_ptr>::type + Make(T value) { + return std::make_shared(std::make_shared(value)); + } + + static std::shared_ptr Make(std::string value); + + private: + std::shared_ptr value_; +}; + +class ARROW_DS_EXPORT FieldReferenceExpression : public Expression { + public: + explicit FieldReferenceExpression(std::string name) + : Expression(ExpressionType::FIELD), name_(std::move(name)) {} + + std::string name() const { return name_; } + + template + OperatorExpression operator==(T&& value) const { + return Comparison(ExpressionType::EQUAL, std::forward(value)); + } + + template + OperatorExpression operator!=(T&& value) const { + return Comparison(ExpressionType::NOT_EQUAL, std::forward(value)); + } + + template + OperatorExpression operator>(T&& value) const { + return Comparison(ExpressionType::GREATER, std::forward(value)); + } + + template + OperatorExpression operator>=(T&& value) const { + return Comparison(ExpressionType::GREATER_EQUAL, std::forward(value)); + } + + template + OperatorExpression operator<(T&& value) const { + return Comparison(ExpressionType::LESS, std::forward(value)); + } + + template + OperatorExpression operator<=(T&& value) const { + return Comparison(ExpressionType::LESS_EQUAL, std::forward(value)); + } + + private: + template + OperatorExpression Comparison(ExpressionType::type type, T&& value) const { + return OperatorExpression( + type, {MakeShared(*this), ScalarExpression::Make(std::forward(value))}); + } + + std::string name_; +}; + +inline namespace string_literals { +FieldReferenceExpression operator""_(const char* name, size_t name_length) { + return FieldReferenceExpression({name, name_length}); +} +} // namespace string_literals + } // namespace dataset } // namespace arrow diff --git a/cpp/src/arrow/dataset/filter_test.cc b/cpp/src/arrow/dataset/filter_test.cc new file mode 100644 index 00000000000..9673becdb07 --- /dev/null +++ b/cpp/src/arrow/dataset/filter_test.cc @@ -0,0 +1,46 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include +#include +#include + +#include +#include + +#include "arrow/status.h" +#include "arrow/testing/gtest_util.h" + +#include "arrow/dataset/api.h" + +namespace arrow { +namespace dataset { + +TEST(Expressions, Basics) { + using namespace string_literals; + + auto simplified = ("b"_ == 3).Assume("b"_ > 5 and "b"_ < 10); + ASSERT_OK(simplified.status()); + ASSERT_EQ(simplified.ValueOrDie()->type(), ExpressionType::SCALAR); + auto value = internal::checked_cast(**simplified).value(); + ASSERT_EQ(value->type->id(), Type::BOOL); + ASSERT_FALSE(internal::checked_cast(*value).value); +} + +} // namespace dataset +} // namespace arrow From 6ccd636c7ba89f1cce32d1f2532de245baad2699 Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Tue, 13 Aug 2019 16:59:49 -0400 Subject: [PATCH 02/28] add execution of filter expressions using compute kernels --- cpp/src/arrow/dataset/filter.cc | 151 +++++++++++++++++++++++++++ cpp/src/arrow/dataset/filter.h | 32 +++++- cpp/src/arrow/dataset/filter_test.cc | 62 +++++++++-- cpp/src/arrow/record_batch.cc | 5 + cpp/src/arrow/record_batch.h | 3 + 5 files changed, 243 insertions(+), 10 deletions(-) diff --git a/cpp/src/arrow/dataset/filter.cc b/cpp/src/arrow/dataset/filter.cc index 581fe574964..df6d76a71ab 100644 --- a/cpp/src/arrow/dataset/filter.cc +++ b/cpp/src/arrow/dataset/filter.cc @@ -20,6 +20,9 @@ #include #include "arrow/buffer.h" +#include "arrow/compute/kernels/boolean.h" +#include "arrow/compute/kernels/compare.h" +#include "arrow/record_batch.h" #include "arrow/util/logging.h" #include "arrow/visitor_inline.h" @@ -28,11 +31,69 @@ namespace dataset { using internal::checked_cast; +Status ExpressionFilter::Execute(compute::FunctionContext* ctx, const RecordBatch& batch, + std::shared_ptr* filter) const { + using arrow::compute::Datum; + if (!expression_->IsOperatorExpression()) { + return Status::Invalid("can't execute expression ", expression_->ToString()); + } + + const auto& op = checked_cast(*expression_).operands(); + + if (expression_->IsComparisonExpression()) { + const auto& lhs = checked_cast(*op[0]); + const auto& rhs = checked_cast(*op[1]); + using arrow::compute::CompareOperator; + arrow::compute::CompareOptions opts(static_cast( + static_cast(CompareOperator::EQUAL) + static_cast(expression_->type()) - + static_cast(ExpressionType::EQUAL))); + Datum out; + RETURN_NOT_OK(arrow::compute::Compare(ctx, Datum(batch.GetColumnByName(lhs.name())), + Datum(rhs.value()), opts, &out)); + *filter = internal::checked_pointer_cast(out.make_array()); + return Status::OK(); + } + + if (expression_->type() == ExpressionType::NOT) { + std::shared_ptr to_invert; + RETURN_NOT_OK(ExpressionFilter(op[0]).Execute(ctx, batch, &to_invert)); + Datum out; + RETURN_NOT_OK(arrow::compute::Invert(ctx, Datum(to_invert), &out)); + *filter = internal::checked_pointer_cast(out.make_array()); + return Status::OK(); + } + + DCHECK(expression_->type() == ExpressionType::OR || + expression_->type() == ExpressionType::AND); + + std::shared_ptr next; + RETURN_NOT_OK(ExpressionFilter(op[0]).Execute(ctx, batch, &next)); + Datum acc(next); + + for (size_t i_next = 1; i_next < op.size(); ++i_next) { + RETURN_NOT_OK(ExpressionFilter(op[i_next]).Execute(ctx, batch, &next)); + + if (expression_->type() == ExpressionType::OR) { + RETURN_NOT_OK(arrow::compute::Or(ctx, Datum(acc), Datum(next), &acc)); + } else { + RETURN_NOT_OK(arrow::compute::And(ctx, Datum(acc), Datum(next), &acc)); + } + } + + *filter = internal::checked_pointer_cast(acc.make_array()); + return Status::OK(); +} + std::shared_ptr ScalarExpression::Make(std::string value) { return std::make_shared( std::make_shared(Buffer::FromString(std::move(value)))); } +std::shared_ptr ScalarExpression::Make(const char* value) { + return std::make_shared( + std::make_shared(Buffer::Wrap(value, std::strlen(value)))); +} + // return a pair std::pair IsBoolean(const Expression& e) { if (e.type() == ExpressionType::SCALAR) { @@ -100,6 +161,7 @@ Result Compare(const Scalar& lhs, const Scalar& rhs) { Result> AssumeComparison(const OperatorExpression& e, const OperatorExpression& given) { + // TODO(bkietz) allow the RHS of e to be FIELD auto e_rhs = checked_cast(*e.operands()[1]).value(); // TODO(bkietz) allow the RHS of given to be FROM_STRING auto given_rhs = checked_cast(*given.operands()[1]).value(); @@ -371,5 +433,94 @@ Result> OperatorExpression::Assume( return unsimplified; } +std::string FieldReferenceExpression::ToString() const { + return std::string("field(") + name_ + ")"; +} + +std::string OperatorName(ExpressionType::type type) { + switch (type) { + case ExpressionType::AND: + return "AND"; + case ExpressionType::OR: + return "OR"; + case ExpressionType::NOT: + return "NOT"; + case ExpressionType::EQUAL: + return "EQUAL"; + case ExpressionType::NOT_EQUAL: + return "NOT_EQUAL"; + case ExpressionType::LESS: + return "LESS"; + case ExpressionType::LESS_EQUAL: + return "LESS_EQUAL"; + case ExpressionType::GREATER: + return "GREATER"; + case ExpressionType::GREATER_EQUAL: + return "GREATER_EQUAL"; + default: + DCHECK(false); + } + return ""; +} + +std::string OperatorExpression::ToString() const { + auto out = OperatorName(type_) + "("; + bool comma = false; + for (const auto& operand : operands_) { + if (comma) { + out += ", "; + } else { + comma = true; + } + out += operand->ToString(); + } + out += ")"; + return out; +} + +std::string ScalarExpression::ToString() const { + return "scalar<" + value_->type->ToString() + ">(TODO)"; +} + +bool Expression::Equals(const Expression& other) const { + if (type_ != other.type()) { + return false; + } + + // FIXME(bkietz) create FromStringExpression + DCHECK_NE(type_, ExpressionType::FROM_STRING); + + switch (type_) { + case ExpressionType::FIELD: + return checked_cast(*this).name() == + checked_cast(other).name(); + + case ExpressionType::SCALAR: { + auto this_value = checked_cast(*this).value(); + auto other_value = checked_cast(other).value(); + return this_value->Equals(other_value); + } + + default: { + DCHECK(IsOperatorExpression()); + const auto& this_op = checked_cast(*this).operands(); + const auto& other_op = checked_cast(other).operands(); + if (this_op.size() != other_op.size()) { + return false; + } + + for (size_t i = 0; i < this_op.size(); ++i) { + if (!this_op[i]->Equals(*other_op[i])) { + return false; + } + } + + return true; + } + } + + return true; +} + } // namespace dataset } // namespace arrow diff --git a/cpp/src/arrow/dataset/filter.h b/cpp/src/arrow/dataset/filter.h index 70460209558..69dfe5c5d79 100644 --- a/cpp/src/arrow/dataset/filter.h +++ b/cpp/src/arrow/dataset/filter.h @@ -24,6 +24,11 @@ #include "arrow/scalar.h" namespace arrow { + +namespace compute { +class FunctionContext; +} + namespace dataset { struct FilterType { @@ -41,10 +46,13 @@ class ARROW_DS_EXPORT Filter { public: explicit Filter(FilterType::type type) : type_(type) {} - virtual ~Filter() = 0; + virtual ~Filter() = default; FilterType::type type() const { return type_; } + virtual Status Execute(compute::FunctionContext* ctx, const RecordBatch& batch, + std::shared_ptr* filter) const = 0; + private: FilterType::type type_; }; @@ -58,6 +66,9 @@ class ARROW_DS_EXPORT ExpressionFilter : public Filter { const std::shared_ptr& expression() const { return expression_; } + Status Execute(compute::FunctionContext* ctx, const RecordBatch& batch, + std::shared_ptr* filter) const override; + private: std::shared_ptr expression_; }; @@ -92,6 +103,17 @@ class ARROW_DS_EXPORT Expression { virtual ~Expression() = default; + bool Equals(const Expression& other) const; + + bool Equals(const std::shared_ptr& other) const { + if (other == NULLPTR) { + return false; + } + return Equals(*other); + } + + virtual std::string ToString() const = 0; + ExpressionType::type type() const { return type_; } bool IsOperatorExpression() const { @@ -116,6 +138,8 @@ class ARROW_DS_EXPORT OperatorExpression : public Expression { const std::vector>& operands() const { return operands_; } + virtual std::string ToString() const override; + OperatorExpression operator and(const OperatorExpression& other) const { return OperatorExpression(ExpressionType::AND, {MakeShared(*this), MakeShared(other)}); @@ -142,6 +166,8 @@ class ARROW_DS_EXPORT ScalarExpression : public Expression { const std::shared_ptr& value() const { return value_; } + virtual std::string ToString() const override; + static std::shared_ptr Make(bool value) { return std::make_shared(std::make_shared(value)); } @@ -162,6 +188,8 @@ class ARROW_DS_EXPORT ScalarExpression : public Expression { static std::shared_ptr Make(std::string value); + static std::shared_ptr Make(const char* value); + private: std::shared_ptr value_; }; @@ -173,6 +201,8 @@ class ARROW_DS_EXPORT FieldReferenceExpression : public Expression { std::string name() const { return name_; } + virtual std::string ToString() const override; + template OperatorExpression operator==(T&& value) const { return Comparison(ExpressionType::EQUAL, std::forward(value)); diff --git a/cpp/src/arrow/dataset/filter_test.cc b/cpp/src/arrow/dataset/filter_test.cc index 9673becdb07..3432ac2319a 100644 --- a/cpp/src/arrow/dataset/filter_test.cc +++ b/cpp/src/arrow/dataset/filter_test.cc @@ -23,23 +23,67 @@ #include #include +#include "arrow/compute/api.h" +#include "arrow/dataset/api.h" +#include "arrow/record_batch.h" #include "arrow/status.h" #include "arrow/testing/gtest_util.h" -#include "arrow/dataset/api.h" - namespace arrow { namespace dataset { -TEST(Expressions, Basics) { +class ExpressionsTest : public ::testing::Test { + public: + void AssertSimplifiesTo(OperatorExpression expr, const Expression& given, + const Expression& expected) { + auto simplified = expr.Assume(given); + ASSERT_OK(simplified.status()); + if (!simplified.ValueOrDie()->Equals(expected)) { + FAIL() << " simplification of: " << expr.ToString() << std::endl + << " given: " << expr.ToString() << std::endl + << " expected: " << expected.ToString() << std::endl + << " was: " << simplified.ValueOrDie()->ToString(); + } + } + + std::shared_ptr always = ScalarExpression::Make(true); + std::shared_ptr never = ScalarExpression::Make(false); +}; + +TEST_F(ExpressionsTest, Basics) { using namespace string_literals; - auto simplified = ("b"_ == 3).Assume("b"_ > 5 and "b"_ < 10); - ASSERT_OK(simplified.status()); - ASSERT_EQ(simplified.ValueOrDie()->type(), ExpressionType::SCALAR); - auto value = internal::checked_cast(**simplified).value(); - ASSERT_EQ(value->type->id(), Type::BOOL); - ASSERT_FALSE(internal::checked_cast(*value).value); + ASSERT_TRUE("a"_.Equals("a"_)); + ASSERT_FALSE("a"_.Equals("b"_)); + + ASSERT_TRUE(("b"_ == 3).Equals("b"_ == 3)); + ASSERT_FALSE(("b"_ == 3).Equals("b"_ < 3)); + ASSERT_FALSE(("b"_ == 3).Equals("b"_)); + + AssertSimplifiesTo("b"_ == 3, "b"_ > 5 and "b"_ < 10, *never); + AssertSimplifiesTo("b"_ > 3, "b"_ > 5 and "b"_ < 10, *always); +} + +TEST(FilterTest, Basics) { + using namespace string_literals; + + auto batch = RecordBatch::FromStructArray( + ArrayFromJSON(struct_({field("a", int64()), field("b", float64())}), R"([ + {"a": 0, "b": -0.1}, + {"a": 0, "b": 0.3}, + {"a": 1, "b": 0.2}, + {"a": 2, "b": -0.1}, + {"a": 0, "b": 0.1}, + {"a": 0, "b": 1.0} + ])")); + + arrow::compute::FunctionContext ctx; + std::shared_ptr filter; + ASSERT_OK(ExpressionFilter(MakeShared("a"_ == 0 and "b"_ > 0.0 and "b"_ < 1.0)) + .Execute(&ctx, *batch, &filter)); + + auto expected_filter = ArrayFromJSON(boolean(), "[0, 1, 0, 0, 1, 0]"); + ASSERT_ARRAYS_EQUAL(*expected_filter, *filter); } } // namespace dataset diff --git a/cpp/src/arrow/record_batch.cc b/cpp/src/arrow/record_batch.cc index 7ae865af5f0..881f57453ec 100644 --- a/cpp/src/arrow/record_batch.cc +++ b/cpp/src/arrow/record_batch.cc @@ -197,6 +197,11 @@ std::shared_ptr RecordBatch::Make( DCHECK_EQ(schema->num_fields(), static_cast(columns.size())); return std::make_shared(schema, num_rows, columns); } +std::shared_ptr RecordBatch::FromStructArray( + const std::shared_ptr& array) { + return Make(arrow::schema(array->type()->children()), array->length(), + array->data()->child_data); +} const std::string& RecordBatch::column_name(int i) const { return schema_->field(i)->name(); diff --git a/cpp/src/arrow/record_batch.h b/cpp/src/arrow/record_batch.h index 6c67f363c64..0b419bad7db 100644 --- a/cpp/src/arrow/record_batch.h +++ b/cpp/src/arrow/record_batch.h @@ -71,6 +71,9 @@ class ARROW_EXPORT RecordBatch { const std::shared_ptr& schema, int64_t num_rows, const std::vector>& columns); + static std::shared_ptr FromStructArray( + const std::shared_ptr& array); + /// \brief Determine if two record batches are exactly equal /// \return true if batches are equal bool Equals(const RecordBatch& other) const; From 9282284a03a9b4dc3d77ecf3e63ea64c9b9352b5 Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Wed, 14 Aug 2019 10:03:11 -0400 Subject: [PATCH 03/28] simplify filter testing --- cpp/src/arrow/dataset/filter_test.cc | 50 ++++++++++++++++++---------- cpp/src/arrow/record_batch.cc | 12 +++++-- cpp/src/arrow/record_batch.h | 4 +-- 3 files changed, 43 insertions(+), 23 deletions(-) diff --git a/cpp/src/arrow/dataset/filter_test.cc b/cpp/src/arrow/dataset/filter_test.cc index 3432ac2319a..68fe5f5a059 100644 --- a/cpp/src/arrow/dataset/filter_test.cc +++ b/cpp/src/arrow/dataset/filter_test.cc @@ -64,26 +64,40 @@ TEST_F(ExpressionsTest, Basics) { AssertSimplifiesTo("b"_ > 3, "b"_ > 5 and "b"_ < 10, *always); } -TEST(FilterTest, Basics) { +class FilterTest : public ::testing::Test { + public: + void AssertFilter(OperatorExpression expr, std::vector> fields, + std::string batch_json) { + // expected filter result is in the "in" field + fields.push_back(field("in", boolean())); + + auto batch_array = ArrayFromJSON(struct_(std::move(fields)), std::move(batch_json)); + std::shared_ptr batch; + ASSERT_OK(RecordBatch::FromStructArray(batch_array, &batch)); + + auto expected_filter = batch->GetColumnByName("in"); + + std::shared_ptr filter; + ASSERT_OK(ExpressionFilter(MakeShared(expr)).Execute(&ctx_, *batch, &filter)); + + ASSERT_ARRAYS_EQUAL(*expected_filter, *filter); + } + + arrow::compute::FunctionContext ctx_; +}; + +TEST_F(FilterTest, Basics) { using namespace string_literals; - auto batch = RecordBatch::FromStructArray( - ArrayFromJSON(struct_({field("a", int64()), field("b", float64())}), R"([ - {"a": 0, "b": -0.1}, - {"a": 0, "b": 0.3}, - {"a": 1, "b": 0.2}, - {"a": 2, "b": -0.1}, - {"a": 0, "b": 0.1}, - {"a": 0, "b": 1.0} - ])")); - - arrow::compute::FunctionContext ctx; - std::shared_ptr filter; - ASSERT_OK(ExpressionFilter(MakeShared("a"_ == 0 and "b"_ > 0.0 and "b"_ < 1.0)) - .Execute(&ctx, *batch, &filter)); - - auto expected_filter = ArrayFromJSON(boolean(), "[0, 1, 0, 0, 1, 0]"); - ASSERT_ARRAYS_EQUAL(*expected_filter, *filter); + AssertFilter("a"_ == 0 and "b"_ > 0.0 and "b"_ < 1.0, + {field("a", int64()), field("b", float64())}, R"([ + {"a": 0, "b": -0.1, "in": 0}, + {"a": 0, "b": 0.3, "in": 1}, + {"a": 1, "b": 0.2, "in": 0}, + {"a": 2, "b": -0.1, "in": 0}, + {"a": 0, "b": 0.1, "in": 1}, + {"a": 0, "b": 1.0, "in": 0} + ])"); } } // namespace dataset diff --git a/cpp/src/arrow/record_batch.cc b/cpp/src/arrow/record_batch.cc index 881f57453ec..1e4ad45a3d2 100644 --- a/cpp/src/arrow/record_batch.cc +++ b/cpp/src/arrow/record_batch.cc @@ -197,10 +197,16 @@ std::shared_ptr RecordBatch::Make( DCHECK_EQ(schema->num_fields(), static_cast(columns.size())); return std::make_shared(schema, num_rows, columns); } -std::shared_ptr RecordBatch::FromStructArray( - const std::shared_ptr& array) { - return Make(arrow::schema(array->type()->children()), array->length(), + +Status RecordBatch::FromStructArray(const std::shared_ptr& array, + std::shared_ptr* out) { + if (array->type_id() != Type::STRUCT) { + return Status::Invalid("Cannot construct record batch from array of type ", + *array->type()); + } + *out = Make(arrow::schema(array->type()->children()), array->length(), array->data()->child_data); + return Status::OK(); } const std::string& RecordBatch::column_name(int i) const { diff --git a/cpp/src/arrow/record_batch.h b/cpp/src/arrow/record_batch.h index 0b419bad7db..c3b003edf68 100644 --- a/cpp/src/arrow/record_batch.h +++ b/cpp/src/arrow/record_batch.h @@ -71,8 +71,8 @@ class ARROW_EXPORT RecordBatch { const std::shared_ptr& schema, int64_t num_rows, const std::vector>& columns); - static std::shared_ptr FromStructArray( - const std::shared_ptr& array); + static Status FromStructArray(const std::shared_ptr& array, + std::shared_ptr* out); /// \brief Determine if two record batches are exactly equal /// \return true if batches are equal From a67b330a1c1d1d56e7c2b1b42431fb670e50bbae Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Wed, 14 Aug 2019 13:04:14 -0400 Subject: [PATCH 04/28] add comments, more tests, simplify operator overloads --- cpp/src/arrow/dataset/filter.cc | 74 ++++++++++++-- cpp/src/arrow/dataset/filter.h | 141 +++++++++++++-------------- cpp/src/arrow/dataset/filter_test.cc | 27 ++++- 3 files changed, 161 insertions(+), 81 deletions(-) diff --git a/cpp/src/arrow/dataset/filter.cc b/cpp/src/arrow/dataset/filter.cc index df6d76a71ab..feb2e38f0b1 100644 --- a/cpp/src/arrow/dataset/filter.cc +++ b/cpp/src/arrow/dataset/filter.cc @@ -163,14 +163,13 @@ Result> AssumeComparison(const OperatorExpression& e const OperatorExpression& given) { // TODO(bkietz) allow the RHS of e to be FIELD auto e_rhs = checked_cast(*e.operands()[1]).value(); - // TODO(bkietz) allow the RHS of given to be FROM_STRING auto given_rhs = checked_cast(*given.operands()[1]).value(); ARROW_ASSIGN_OR_RAISE(auto cmp, Compare(*e_rhs, *given_rhs)); static auto always = ScalarExpression::Make(true); static auto never = ScalarExpression::Make(false); - auto unsimplified = MakeShared(e); + auto unsimplified = e.Copy(); if (cmp == 0) { // the rhs of the comparisons are equal @@ -314,7 +313,7 @@ std::string GetName(const Expression& e) { Result> OperatorExpression::Assume( const Expression& given) const { - auto unsimplified = MakeShared(*this); + auto unsimplified = Copy(); if (IsComparisonExpression()) { if (!given.IsOperatorExpression()) { @@ -487,9 +486,6 @@ bool Expression::Equals(const Expression& other) const { return false; } - // FIXME(bkietz) create FromStringExpression - DCHECK_NE(type_, ExpressionType::FROM_STRING); - switch (type_) { case ExpressionType::FIELD: return checked_cast(*this).name() == @@ -522,5 +518,71 @@ bool Expression::Equals(const Expression& other) const { return true; } +bool Expression::Equals(const std::shared_ptr& other) const { + if (other == NULLPTR) { + return false; + } + return Equals(*other); +} + +bool Expression::IsOperatorExpression() const { + return static_cast(type_) >= static_cast(ExpressionType::NOT) && + static_cast(type_) <= static_cast(ExpressionType::LESS_EQUAL); +} + +bool Expression::IsComparisonExpression() const { + return static_cast(type_) >= static_cast(ExpressionType::EQUAL) && + static_cast(type_) <= static_cast(ExpressionType::LESS_EQUAL); +} + +std::shared_ptr OperatorExpression::Copy() const { + return std::make_shared(*this); +} + +std::shared_ptr FieldReferenceExpression::Copy() const { + return std::make_shared(*this); +} + +std::shared_ptr ScalarExpression::Copy() const { + return std::make_shared(*this); +} + +// flatten chains of and/or to a single OperatorExpression +OperatorExpression MaybeCombine(ExpressionType::type type, const OperatorExpression& lhs, + const OperatorExpression& rhs) { + if (lhs.type() != type && rhs.type() != type) { + return OperatorExpression(type, {lhs.Copy(), rhs.Copy()}); + } + std::vector> operands; + if (lhs.type() == type) { + operands = lhs.operands(); + if (rhs.type() == type) { + for (auto operand : rhs.operands()) { + operands.emplace_back(std::move(operand)); + } + } else { + operands.emplace_back(rhs.Copy()); + } + } else { + operands = rhs.operands(); + operands.emplace(operands.begin(), lhs.Copy()); + } + return OperatorExpression(type, std::move(operands)); +} + +OperatorExpression operator and(const OperatorExpression& lhs, + const OperatorExpression& rhs) { + return MaybeCombine(ExpressionType::AND, lhs, rhs); +} + +OperatorExpression operator or(const OperatorExpression& lhs, + const OperatorExpression& rhs) { + return MaybeCombine(ExpressionType::OR, lhs, rhs); +} + +OperatorExpression operator not(const OperatorExpression& rhs) { + return OperatorExpression(ExpressionType::NOT, {rhs.Copy()}); +} + } // namespace dataset } // namespace arrow diff --git a/cpp/src/arrow/dataset/filter.h b/cpp/src/arrow/dataset/filter.h index 69dfe5c5d79..51faa023793 100644 --- a/cpp/src/arrow/dataset/filter.h +++ b/cpp/src/arrow/dataset/filter.h @@ -37,7 +37,7 @@ struct FilterType { /// logic (AND, OR, NOT) involving Schema fields EXPRESSION, - /// + /// Non decomposable filter; must be evaluated against every record batch GENERIC }; }; @@ -50,6 +50,8 @@ class ARROW_DS_EXPORT Filter { FilterType::type type() const { return type_; } + /// Evaluate this filter producing a boolean array which encodes whether each row + /// satisfies the filter virtual Status Execute(compute::FunctionContext* ctx, const RecordBatch& batch, std::shared_ptr* filter) const = 0; @@ -59,6 +61,8 @@ class ARROW_DS_EXPORT Filter { class Expression; +/// Filter subclass encapsulating a simple boolean predicate consisting of comparisons and +/// boolean logic (AND, OR, NOT) involving Schema fields class ARROW_DS_EXPORT ExpressionFilter : public Filter { public: explicit ExpressionFilter(const std::shared_ptr& expression) @@ -77,7 +81,6 @@ struct ExpressionType { enum type { FIELD, SCALAR, - FROM_STRING, NOT, AND, @@ -92,74 +95,81 @@ struct ExpressionType { }; }; -template -std::shared_ptr::type> MakeShared(T&& t) { - return std::make_shared::type>(std::forward(t)); -} - +/// Represents an expression tree. The expression can be evaluated against a +/// RecordBatch via ExpressionFilter class ARROW_DS_EXPORT Expression { public: explicit Expression(ExpressionType::type type) : type_(type) {} virtual ~Expression() = default; + /// Returns true iff the expressions are identical; does not check for equivalence. + /// For example, (A and B) is not equal to (B and A) nor is (A and not A) equal to + /// (false). bool Equals(const Expression& other) const; - bool Equals(const std::shared_ptr& other) const { - if (other == NULLPTR) { - return false; - } - return Equals(*other); - } + bool Equals(const std::shared_ptr& other) const; + /// Validate this expression for execution against a schema. This will check that all + /// reference fields are present and all subexpressions are executable. Returns a copy + /// of this expression with schema information incorporated: + /// - Scalars are cast to other data types if necessary to ensure comparisons are + /// between data of identical type + // virtual Result> Validate(const Schema&) const; + + /// returns a debug string representing this expression virtual std::string ToString() const = 0; ExpressionType::type type() const { return type_; } - bool IsOperatorExpression() const { - return static_cast(type_) >= static_cast(ExpressionType::NOT) && - static_cast(type_) <= static_cast(ExpressionType::LESS_EQUAL); - } + /// If true, this Expression may be safely cast to OperatorExpression + bool IsOperatorExpression() const; - bool IsComparisonExpression() const { - return static_cast(type_) >= static_cast(ExpressionType::EQUAL) && - static_cast(type_) <= static_cast(ExpressionType::LESS_EQUAL); - } + /// If true, this Expression may be safely cast to OperatorExpression + /// and there will be exactly two operands representing the left and right hand sides of + /// a comparison + bool IsComparisonExpression() const; + + /// Copy this expression into a shared pointer. + virtual std::shared_ptr Copy() const = 0; protected: ExpressionType::type type_; }; -class ARROW_DS_EXPORT OperatorExpression : public Expression { +/// Represents an compound expression; for example comparison between a field and a scalar +/// or a union of other expressions +class ARROW_DS_EXPORT OperatorExpression final : public Expression { public: OperatorExpression(ExpressionType::type type, std::vector> operands) : Expression(type), operands_(std::move(operands)) {} + /// Return a simplified form of this expression given some known conditions. + /// For example, (a > 3).Assume(a == 5) == (true). This can be used to do less work + /// in ExpressionFilter when partition conditions guarantee some of this expression. + /// In the example above, *no* filtering need be done on record batches in the partition + /// since (a == 5). + Result> Assume(const Expression& given) const; + const std::vector>& operands() const { return operands_; } virtual std::string ToString() const override; - OperatorExpression operator and(const OperatorExpression& other) const { - return OperatorExpression(ExpressionType::AND, - {MakeShared(*this), MakeShared(other)}); - } - - OperatorExpression operator or(const OperatorExpression& other) const { - return OperatorExpression(ExpressionType::OR, {MakeShared(*this), MakeShared(other)}); - } - - OperatorExpression operator not() const { - return OperatorExpression(ExpressionType::NOT, {MakeShared(*this)}); - } - - Result> Assume(const Expression& given) const; + std::shared_ptr Copy() const override; private: std::vector> operands_; }; -class ARROW_DS_EXPORT ScalarExpression : public Expression { +OperatorExpression operator and(const OperatorExpression& lhs, + const OperatorExpression& rhs); +OperatorExpression operator or(const OperatorExpression& lhs, + const OperatorExpression& rhs); +OperatorExpression operator not(const OperatorExpression& rhs); + +/// Represents a scalar value; thin wrapper around arrow::Scalar +class ARROW_DS_EXPORT ScalarExpression final : public Expression { public: explicit ScalarExpression(const std::shared_ptr& value) : Expression(ExpressionType::SCALAR), value_(std::move(value)) {} @@ -190,11 +200,15 @@ class ARROW_DS_EXPORT ScalarExpression : public Expression { static std::shared_ptr Make(const char* value); + std::shared_ptr Copy() const override; + private: std::shared_ptr value_; }; -class ARROW_DS_EXPORT FieldReferenceExpression : public Expression { +/// Represents a reference to a field. Stores only the field's name (type and other +/// information is known only when a Schema is provided) +class ARROW_DS_EXPORT FieldReferenceExpression final : public Expression { public: explicit FieldReferenceExpression(std::string name) : Expression(ExpressionType::FIELD), name_(std::move(name)) {} @@ -203,46 +217,27 @@ class ARROW_DS_EXPORT FieldReferenceExpression : public Expression { virtual std::string ToString() const override; - template - OperatorExpression operator==(T&& value) const { - return Comparison(ExpressionType::EQUAL, std::forward(value)); - } - - template - OperatorExpression operator!=(T&& value) const { - return Comparison(ExpressionType::NOT_EQUAL, std::forward(value)); - } - - template - OperatorExpression operator>(T&& value) const { - return Comparison(ExpressionType::GREATER, std::forward(value)); - } - - template - OperatorExpression operator>=(T&& value) const { - return Comparison(ExpressionType::GREATER_EQUAL, std::forward(value)); - } - - template - OperatorExpression operator<(T&& value) const { - return Comparison(ExpressionType::LESS, std::forward(value)); - } - - template - OperatorExpression operator<=(T&& value) const { - return Comparison(ExpressionType::LESS_EQUAL, std::forward(value)); - } + std::shared_ptr Copy() const override; private: - template - OperatorExpression Comparison(ExpressionType::type type, T&& value) const { - return OperatorExpression( - type, {MakeShared(*this), ScalarExpression::Make(std::forward(value))}); - } - std::string name_; }; +#define COMPARISON_FACTORY(NAME, OP) \ + template \ + OperatorExpression operator OP(const FieldReferenceExpression& lhs, T&& rhs) { \ + return OperatorExpression( \ + ExpressionType::NAME, \ + {lhs.Copy(), ScalarExpression::Make(std::forward(rhs))}); \ + } +COMPARISON_FACTORY(EQUAL, ==) +COMPARISON_FACTORY(NOT_EQUAL, !=) +COMPARISON_FACTORY(GREATER, >) +COMPARISON_FACTORY(GREATER_EQUAL, >=) +COMPARISON_FACTORY(LESS, <) +COMPARISON_FACTORY(LESS_EQUAL, <=) +#undef COMPARISON_FACTORY + inline namespace string_literals { FieldReferenceExpression operator""_(const char* name, size_t name_length) { return FieldReferenceExpression({name, name_length}); diff --git a/cpp/src/arrow/dataset/filter_test.cc b/cpp/src/arrow/dataset/filter_test.cc index 68fe5f5a059..62b84e289c4 100644 --- a/cpp/src/arrow/dataset/filter_test.cc +++ b/cpp/src/arrow/dataset/filter_test.cc @@ -46,11 +46,23 @@ class ExpressionsTest : public ::testing::Test { } } + template + void AssertOperandsAre(OperatorExpression expr, ExpressionType::type type, + T... expected_operands) { + ASSERT_EQ(expr.type(), type); + ASSERT_EQ(expr.operands().size(), sizeof...(T)); + std::shared_ptr expected_operand_ptrs[] = {expected_operands.Copy()...}; + + for (size_t i = 0; i < sizeof...(T); ++i) { + ASSERT_TRUE(expr.operands()[i]->Equals(expected_operand_ptrs[i])); + } + } + std::shared_ptr always = ScalarExpression::Make(true); std::shared_ptr never = ScalarExpression::Make(false); }; -TEST_F(ExpressionsTest, Basics) { +TEST_F(ExpressionsTest, Equality) { using namespace string_literals; ASSERT_TRUE("a"_.Equals("a"_)); @@ -60,6 +72,17 @@ TEST_F(ExpressionsTest, Basics) { ASSERT_FALSE(("b"_ == 3).Equals("b"_ < 3)); ASSERT_FALSE(("b"_ == 3).Equals("b"_)); + // ordering matters + ASSERT_FALSE(("b"_ > 2 and "b"_ < 3).Equals("b"_ < 3 and "b"_ > 2)); +} + +TEST_F(ExpressionsTest, Simplification) { + using namespace string_literals; + + // chained "and" expressions are flattened + auto multi_and = "b"_ > 5 and "b"_ < 10 and "b"_ != 7; + AssertOperandsAre(multi_and, ExpressionType::AND, "b"_ > 5, "b"_ < 10, "b"_ != 7); + AssertSimplifiesTo("b"_ == 3, "b"_ > 5 and "b"_ < 10, *never); AssertSimplifiesTo("b"_ > 3, "b"_ > 5 and "b"_ < 10, *always); } @@ -78,7 +101,7 @@ class FilterTest : public ::testing::Test { auto expected_filter = batch->GetColumnByName("in"); std::shared_ptr filter; - ASSERT_OK(ExpressionFilter(MakeShared(expr)).Execute(&ctx_, *batch, &filter)); + ASSERT_OK(ExpressionFilter(expr.Copy()).Execute(&ctx_, *batch, &filter)); ASSERT_ARRAYS_EQUAL(*expected_filter, *filter); } From 06b25becb4f064e9eab6e4e256bff0730bb95906 Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Wed, 14 Aug 2019 15:35:32 -0400 Subject: [PATCH 05/28] move expression evaluation to a free function --- cpp/src/arrow/dataset/filter.cc | 133 +++++++++++++++++---------- cpp/src/arrow/dataset/filter.h | 22 +++-- cpp/src/arrow/dataset/filter_test.cc | 16 +++- cpp/src/arrow/scalar.h | 13 +++ 4 files changed, 125 insertions(+), 59 deletions(-) diff --git a/cpp/src/arrow/dataset/filter.cc b/cpp/src/arrow/dataset/filter.cc index feb2e38f0b1..3723ca1f10f 100644 --- a/cpp/src/arrow/dataset/filter.cc +++ b/cpp/src/arrow/dataset/filter.cc @@ -31,21 +31,22 @@ namespace dataset { using internal::checked_cast; -Status ExpressionFilter::Execute(compute::FunctionContext* ctx, const RecordBatch& batch, - std::shared_ptr* filter) const { +Status EvaluateExpression(compute::FunctionContext* ctx, const Expression& condition, + const RecordBatch& batch, + std::shared_ptr* filter) { using arrow::compute::Datum; - if (!expression_->IsOperatorExpression()) { - return Status::Invalid("can't execute expression ", expression_->ToString()); + if (!condition.IsOperatorExpression()) { + return Status::Invalid("can't execute condition ", condition.ToString()); } - const auto& op = checked_cast(*expression_).operands(); + const auto& op = checked_cast(condition).operands(); - if (expression_->IsComparisonExpression()) { + if (condition.IsComparisonExpression()) { const auto& lhs = checked_cast(*op[0]); const auto& rhs = checked_cast(*op[1]); using arrow::compute::CompareOperator; arrow::compute::CompareOptions opts(static_cast( - static_cast(CompareOperator::EQUAL) + static_cast(expression_->type()) - + static_cast(CompareOperator::EQUAL) + static_cast(condition.type()) - static_cast(ExpressionType::EQUAL))); Datum out; RETURN_NOT_OK(arrow::compute::Compare(ctx, Datum(batch.GetColumnByName(lhs.name())), @@ -54,26 +55,26 @@ Status ExpressionFilter::Execute(compute::FunctionContext* ctx, const RecordBatc return Status::OK(); } - if (expression_->type() == ExpressionType::NOT) { + if (condition.type() == ExpressionType::NOT) { std::shared_ptr to_invert; - RETURN_NOT_OK(ExpressionFilter(op[0]).Execute(ctx, batch, &to_invert)); + RETURN_NOT_OK(EvaluateExpression(ctx, *op[0], batch, &to_invert)); Datum out; RETURN_NOT_OK(arrow::compute::Invert(ctx, Datum(to_invert), &out)); *filter = internal::checked_pointer_cast(out.make_array()); return Status::OK(); } - DCHECK(expression_->type() == ExpressionType::OR || - expression_->type() == ExpressionType::AND); + DCHECK(condition.type() == ExpressionType::OR || + condition.type() == ExpressionType::AND); std::shared_ptr next; - RETURN_NOT_OK(ExpressionFilter(op[0]).Execute(ctx, batch, &next)); + RETURN_NOT_OK(EvaluateExpression(ctx, *op[0], batch, &next)); Datum acc(next); for (size_t i_next = 1; i_next < op.size(); ++i_next) { - RETURN_NOT_OK(ExpressionFilter(op[i_next]).Execute(ctx, batch, &next)); + RETURN_NOT_OK(EvaluateExpression(ctx, *op[i_next], batch, &next)); - if (expression_->type() == ExpressionType::OR) { + if (condition.type() == ExpressionType::OR) { RETURN_NOT_OK(arrow::compute::Or(ctx, Datum(acc), Datum(next), &acc)); } else { RETURN_NOT_OK(arrow::compute::And(ctx, Datum(acc), Datum(next), &acc)); @@ -94,18 +95,6 @@ std::shared_ptr ScalarExpression::Make(const char* value) { std::make_shared(Buffer::Wrap(value, std::strlen(value)))); } -// return a pair -std::pair IsBoolean(const Expression& e) { - if (e.type() == ExpressionType::SCALAR) { - auto value = checked_cast(e).value(); - if (value->type->id() == Type::BOOL) { - // FIXME(bkietz) null scalars do what? - return {true, checked_cast(*value).value}; - } - } - return {false, false}; -} - Result> AssumeIfOperator(const std::shared_ptr& e, const Expression& given) { if (!e->IsOperatorExpression()) { @@ -115,6 +104,7 @@ Result> AssumeIfOperator(const std::shared_ptr(lhs_).value - checked_cast(rhs_).value; @@ -181,8 +171,6 @@ Result> AssumeComparison(const OperatorExpression& e case ExpressionType::LESS: return never; case ExpressionType::EQUAL: - case ExpressionType::GREATER_EQUAL: - case ExpressionType::LESS_EQUAL: return always; default: return unsimplified; @@ -335,21 +323,23 @@ Result> OperatorExpression::Assume( // must be NOT, AND, OR- decompose given switch (given.type()) { - case ExpressionType::NOT: { + case ExpressionType::NOT: return unsimplified; - } case ExpressionType::OR: { bool simplify_to_always = true; bool simplify_to_never = true; for (auto operand : given_op.operands_) { ARROW_ASSIGN_OR_RAISE(auto simplified, Assume(*operand)); - auto isbool_value = IsBoolean(*simplified); - if (!isbool_value.first) { + BooleanScalar scalar; + if (!simplified->IsBooleanScalar(&scalar)) { return unsimplified; } - if (isbool_value.second) { + // an expression should never simplify to null + DCHECK(scalar.is_valid); + + if (scalar.value == true) { simplify_to_never = false; } else { simplify_to_always = false; @@ -370,10 +360,10 @@ Result> OperatorExpression::Assume( case ExpressionType::AND: { std::shared_ptr simplified = unsimplified; for (auto operand : given_op.operands_) { - auto isbool_value = IsBoolean(*simplified); - if (isbool_value.first) { + if (simplified->IsBooleanScalar()) { break; } + DCHECK(simplified->IsOperatorExpression()); const auto& simplified_op = checked_cast(*simplified); @@ -391,12 +381,16 @@ Result> OperatorExpression::Assume( case ExpressionType::NOT: { DCHECK_EQ(operands_.size(), 1); ARROW_ASSIGN_OR_RAISE(auto operand, AssumeIfOperator(operands_[0], given)); - auto isbool_value = IsBoolean(*operand); - if (isbool_value.first) { - return ScalarExpression::Make(!isbool_value.second); + + BooleanScalar scalar; + if (!operand->IsBooleanScalar(&scalar)) { + return unsimplified; } - return std::make_shared( - ExpressionType::NOT, std::vector>{operand}); + + // an expression should never simplify to null + DCHECK(scalar.is_valid); + + return ScalarExpression::Make(!scalar.value); } case ExpressionType::OR: @@ -411,15 +405,23 @@ Result> OperatorExpression::Assume( for (auto operand : operands_) { ARROW_ASSIGN_OR_RAISE(operand, AssumeIfOperator(operand, given)); - auto isbool_value = IsBoolean(*operand); - if (isbool_value.first) { - if (isbool_value.second == trivial_condition) { - return ScalarExpression::Make(trivial_condition); - } + BooleanScalar scalar; + if (!operand->IsBooleanScalar(&scalar)) { + operands.push_back(operand); continue; } - operands.push_back(operand); + if (scalar.value == trivial_condition) { + return ScalarExpression::Make(trivial_condition); + } + } + + if (operands.size() == 1) { + return operands[0]; + } + + if (operands.size() == 0) { + return ScalarExpression::Make(!trivial_condition); } return std::make_shared(type_, std::move(operands)); @@ -478,7 +480,26 @@ std::string OperatorExpression::ToString() const { } std::string ScalarExpression::ToString() const { - return "scalar<" + value_->type->ToString() + ">(TODO)"; + std::string value; + switch (value_->type->id()) { + case Type::BOOL: + value = checked_cast(*value_).value ? "true" : "false"; + break; + case Type::INT64: + value = std::to_string(checked_cast(*value_).value); + break; + case Type::DOUBLE: + value = std::to_string(checked_cast(*value_).value); + break; + case Type::STRING: + value = checked_cast(*value_).value->ToString(); + break; + default: + value = "TODO"; + break; + } + + return "scalar<" + value_->type->ToString() + ">(" + value + ")"; } bool Expression::Equals(const Expression& other) const { @@ -535,6 +556,24 @@ bool Expression::IsComparisonExpression() const { static_cast(type_) <= static_cast(ExpressionType::LESS_EQUAL); } +bool Expression::IsBooleanScalar(BooleanScalar* out) const { + if (type_ != ExpressionType::SCALAR) { + return false; + } + + auto scalar = checked_cast(*this).value(); + if (scalar->type->id() != Type::BOOL) { + return false; + } + + if (out) { + out->type = boolean(); + out->is_valid = scalar->is_valid; + out->value = checked_cast(*scalar).value; + } + return true; +} + std::shared_ptr OperatorExpression::Copy() const { return std::make_shared(*this); } diff --git a/cpp/src/arrow/dataset/filter.h b/cpp/src/arrow/dataset/filter.h index 51faa023793..1f3b5073eb6 100644 --- a/cpp/src/arrow/dataset/filter.h +++ b/cpp/src/arrow/dataset/filter.h @@ -50,11 +50,6 @@ class ARROW_DS_EXPORT Filter { FilterType::type type() const { return type_; } - /// Evaluate this filter producing a boolean array which encodes whether each row - /// satisfies the filter - virtual Status Execute(compute::FunctionContext* ctx, const RecordBatch& batch, - std::shared_ptr* filter) const = 0; - private: FilterType::type type_; }; @@ -70,13 +65,16 @@ class ARROW_DS_EXPORT ExpressionFilter : public Filter { const std::shared_ptr& expression() const { return expression_; } - Status Execute(compute::FunctionContext* ctx, const RecordBatch& batch, - std::shared_ptr* filter) const override; - private: std::shared_ptr expression_; }; +/// Evaluate an expression producing a boolean array which encodes whether each row +/// satisfies the condition +Status EvaluateExpression(compute::FunctionContext* ctx, const Expression& condition, + const RecordBatch& batch, + std::shared_ptr* filter); + struct ExpressionType { enum type { FIELD, @@ -130,6 +128,10 @@ class ARROW_DS_EXPORT Expression { /// a comparison bool IsComparisonExpression() const; + /// If true, this Expression is a ScalarExpression wrapping a boolean Scalar. Its value + /// may be retrieved at the same time + bool IsBooleanScalar(BooleanScalar* value = NULLPTR) const; + /// Copy this expression into a shared pointer. virtual std::shared_ptr Copy() const = 0; @@ -200,6 +202,10 @@ class ARROW_DS_EXPORT ScalarExpression final : public Expression { static std::shared_ptr Make(const char* value); + static std::shared_ptr MakeNull() { + return std::make_shared(std::make_shared()); + } + std::shared_ptr Copy() const override; private: diff --git a/cpp/src/arrow/dataset/filter_test.cc b/cpp/src/arrow/dataset/filter_test.cc index 62b84e289c4..08999b08a24 100644 --- a/cpp/src/arrow/dataset/filter_test.cc +++ b/cpp/src/arrow/dataset/filter_test.cc @@ -40,7 +40,7 @@ class ExpressionsTest : public ::testing::Test { ASSERT_OK(simplified.status()); if (!simplified.ValueOrDie()->Equals(expected)) { FAIL() << " simplification of: " << expr.ToString() << std::endl - << " given: " << expr.ToString() << std::endl + << " given: " << given.ToString() << std::endl << " expected: " << expected.ToString() << std::endl << " was: " << simplified.ValueOrDie()->ToString(); } @@ -83,8 +83,16 @@ TEST_F(ExpressionsTest, Simplification) { auto multi_and = "b"_ > 5 and "b"_ < 10 and "b"_ != 7; AssertOperandsAre(multi_and, ExpressionType::AND, "b"_ > 5, "b"_ < 10, "b"_ != 7); - AssertSimplifiesTo("b"_ == 3, "b"_ > 5 and "b"_ < 10, *never); - AssertSimplifiesTo("b"_ > 3, "b"_ > 5 and "b"_ < 10, *always); + AssertSimplifiesTo("b"_ > 5 and "b"_ < 10, "b"_ == 3, *never); + AssertSimplifiesTo("b"_ > 5 and "b"_ < 10, "b"_ == 6, *always); + AssertSimplifiesTo("b"_ > 5, "b"_ == 3 or "b"_ == 6, "b"_ > 5); + AssertSimplifiesTo("b"_ > 7, "b"_ == 3 or "b"_ == 6, *never); + AssertSimplifiesTo("b"_ > 5 and "b"_ < 10, "b"_ > 6 and "b"_ < 13, "b"_ < 10); + + AssertSimplifiesTo("b"_ == 3 or "b"_ == 4, "b"_ > 6, *never); + AssertSimplifiesTo("b"_ == 3 or "b"_ == 4, "b"_ == 3, *always); + AssertSimplifiesTo("b"_ == 3 or "b"_ == 4, "b"_ > 3, "b"_ == 4); + AssertSimplifiesTo("b"_ == 3 or "b"_ == 4, "b"_ >= 3, "b"_ == 3 or "b"_ == 4); } class FilterTest : public ::testing::Test { @@ -101,7 +109,7 @@ class FilterTest : public ::testing::Test { auto expected_filter = batch->GetColumnByName("in"); std::shared_ptr filter; - ASSERT_OK(ExpressionFilter(expr.Copy()).Execute(&ctx_, *batch, &filter)); + ASSERT_OK(EvaluateExpression(&ctx_, expr, *batch, &filter)); ASSERT_ARRAYS_EQUAL(*expected_filter, *filter); } diff --git a/cpp/src/arrow/scalar.h b/cpp/src/arrow/scalar.h index 09192896ba3..977acefcf32 100644 --- a/cpp/src/arrow/scalar.h +++ b/cpp/src/arrow/scalar.h @@ -76,6 +76,7 @@ struct ARROW_EXPORT BooleanScalar : public internal::PrimitiveScalar { bool value; explicit BooleanScalar(bool value, bool is_valid = true) : internal::PrimitiveScalar{boolean(), is_valid}, value(value) {} + BooleanScalar() : BooleanScalar(false, false) {} }; template @@ -86,6 +87,8 @@ struct NumericScalar : public internal::PrimitiveScalar { explicit NumericScalar(T value, bool is_valid = true) : NumericScalar(value, TypeTraits::type_singleton(), is_valid) {} + NumericScalar() : NumericScalar(0, false) {} + protected: explicit NumericScalar(T value, const std::shared_ptr& type, bool is_valid) : internal::PrimitiveScalar{type, is_valid}, value(value) {} @@ -99,12 +102,16 @@ struct BaseBinaryScalar : public Scalar { BaseBinaryScalar(const std::shared_ptr& value, const std::shared_ptr& type, bool is_valid = true) : Scalar{type, is_valid}, value(value) {} + + static std::shared_ptr Empty(); }; struct ARROW_EXPORT BinaryScalar : public BaseBinaryScalar { explicit BinaryScalar(const std::shared_ptr& value, bool is_valid = true) : BaseBinaryScalar(value, binary(), is_valid) {} + BinaryScalar() : BinaryScalar(NULLPTR, false) {} + protected: using BaseBinaryScalar::BaseBinaryScalar; }; @@ -112,12 +119,16 @@ struct ARROW_EXPORT BinaryScalar : public BaseBinaryScalar { struct ARROW_EXPORT StringScalar : public BinaryScalar { explicit StringScalar(const std::shared_ptr& value, bool is_valid = true) : BinaryScalar(value, utf8(), is_valid) {} + + StringScalar() : StringScalar(NULLPTR, false) {} }; struct ARROW_EXPORT LargeBinaryScalar : public BaseBinaryScalar { explicit LargeBinaryScalar(const std::shared_ptr& value, bool is_valid = true) : BaseBinaryScalar(value, large_binary(), is_valid) {} + LargeBinaryScalar() : LargeBinaryScalar(NULLPTR, false) {} + protected: using BaseBinaryScalar::BaseBinaryScalar; }; @@ -125,6 +136,8 @@ struct ARROW_EXPORT LargeBinaryScalar : public BaseBinaryScalar struct ARROW_EXPORT LargeStringScalar : public LargeBinaryScalar { explicit LargeStringScalar(const std::shared_ptr& value, bool is_valid = true) : LargeBinaryScalar(value, utf8(), is_valid) {} + + LargeStringScalar() : LargeStringScalar(NULLPTR, false) {} }; struct ARROW_EXPORT FixedSizeBinaryScalar : public BinaryScalar { From 3499fd2ec70aa1b24a4def3104611fa924ff2428 Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Wed, 14 Aug 2019 17:42:03 -0400 Subject: [PATCH 06/28] add an expression simplification test --- cpp/src/arrow/dataset/filter_test.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/cpp/src/arrow/dataset/filter_test.cc b/cpp/src/arrow/dataset/filter_test.cc index 08999b08a24..99675d30233 100644 --- a/cpp/src/arrow/dataset/filter_test.cc +++ b/cpp/src/arrow/dataset/filter_test.cc @@ -93,6 +93,8 @@ TEST_F(ExpressionsTest, Simplification) { AssertSimplifiesTo("b"_ == 3 or "b"_ == 4, "b"_ == 3, *always); AssertSimplifiesTo("b"_ == 3 or "b"_ == 4, "b"_ > 3, "b"_ == 4); AssertSimplifiesTo("b"_ == 3 or "b"_ == 4, "b"_ >= 3, "b"_ == 3 or "b"_ == 4); + + AssertSimplifiesTo("a"_ == 3 or "b"_ == 4, "a"_ == 0, "b"_ == 4); } class FilterTest : public ::testing::Test { From 3d355ff659ef2638c41b51cc90e6f0624042716a Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Thu, 15 Aug 2019 10:12:08 -0400 Subject: [PATCH 07/28] break up assumption logic, add function expression factories --- cpp/src/arrow/dataset/filter.cc | 400 +++++++++++++++++++++----------- cpp/src/arrow/dataset/filter.h | 65 ++++-- 2 files changed, 313 insertions(+), 152 deletions(-) diff --git a/cpp/src/arrow/dataset/filter.cc b/cpp/src/arrow/dataset/filter.cc index 3723ca1f10f..f39f011ef66 100644 --- a/cpp/src/arrow/dataset/filter.cc +++ b/cpp/src/arrow/dataset/filter.cc @@ -95,38 +95,42 @@ std::shared_ptr ScalarExpression::Make(const char* value) { std::make_shared(Buffer::Wrap(value, std::strlen(value)))); } -Result> AssumeIfOperator(const std::shared_ptr& e, - const Expression& given) { - if (!e->IsOperatorExpression()) { - return e; +struct CompareVisitor { + Status Visit(const NullType&) { + result_ = false; + return Status::OK(); } - return checked_cast(*e).Assume(given); -} -struct CompareVisitor { - // FIXME(bkietz) incomplete Status Visit(const BooleanType&) { result_ = checked_cast(lhs_).value - checked_cast(rhs_).value; return Status::OK(); } - Status Visit(const Int64Type&) { - result_ = checked_cast(lhs_).value - - checked_cast(rhs_).value; + template + using ScalarType = typename TypeTraits::ScalarType; + + template + typename std::enable_if::value, Status>::type + Visit(const T&) { + result_ = static_cast(checked_cast&>(lhs_).value - + checked_cast&>(rhs_).value); return Status::OK(); } - Status Visit(const DoubleType&) { - double result = checked_cast(lhs_).value - - checked_cast(rhs_).value; - result_ = result < 0.0 ? -1 : result > 0.0 ? +1 : 0; + template + enable_if_floating_point Visit(const T&) { + auto delta = checked_cast&>(lhs_).value - + checked_cast&>(rhs_).value; + constexpr decltype(delta) zero = 0; + result_ = delta < zero ? -1 : delta > zero ? +1 : 0; return Status::OK(); } - Status Visit(const StringType&) { - auto lhs = checked_cast(lhs_).value; - auto rhs = checked_cast(rhs_).value; + template + enable_if_binary_like Visit(const T&) { + auto lhs = checked_cast&>(lhs_).value; + auto rhs = checked_cast&>(rhs_).value; result_ = std::memcmp(lhs->data(), rhs->data(), std::min(lhs->size(), rhs->size())); if (result_ == 0) { result_ = lhs->size() - rhs->size(); @@ -134,6 +138,19 @@ struct CompareVisitor { return Status::OK(); } + Status Visit(const Decimal128Type&) { + auto lhs = checked_cast(lhs_).value; + auto rhs = checked_cast(rhs_).value; + result_ = (lhs - rhs).Sign(); + return Status::OK(); + } + + // explicit because both integral and floating point conditions match half float + Status Visit(const HalfFloatType&) { + // TODO(bkietz) whenever we vendor a float16, this can be implemented + return Status::NotImplemented("comparison of scalars of type ", *lhs_.type); + } + Status Visit(const DataType&) { return Status::NotImplemented("comparison of scalars of type ", *lhs_.type); } @@ -143,25 +160,110 @@ struct CompareVisitor { const Scalar& rhs_; }; -Result Compare(const Scalar& lhs, const Scalar& rhs) { +// Compare two scalars +// lhs < rhs => return < 0 +// lhs == rhs => return == 0 +// lhs > rhs => return > 0 +// if either is null, return is null +Result Compare(const Scalar& lhs, const Scalar& rhs) { + if (!lhs.type->Equals(*rhs.type)) { + return Status::TypeError("cannot compare a scalars with differing type: ", *lhs.type, + " vs ", *rhs.type); + } + if (!lhs.is_valid || !rhs.is_valid) { + return Int64Scalar(); + } CompareVisitor vis{0, lhs, rhs}; RETURN_NOT_OK(VisitTypeInline(*lhs.type, &vis)); - return vis.result_; + return Int64Scalar(vis.result_); +} + +Result> Invert(const Expression& op) { + auto make_opposite = [&op](ExpressionType::type opposite_type) { + return std::make_shared( + opposite_type, checked_cast(op).operands()); + }; + + switch (op.type()) { + case ExpressionType::NOT: + return checked_cast(op).operands()[0]; + + case ExpressionType::AND: + case ExpressionType::OR: { + std::vector> operands; + for (auto operand : checked_cast(op).operands()) { + ARROW_ASSIGN_OR_RAISE(auto inverted_operand, Invert(*operand)); + operands.push_back(inverted_operand); + } + + auto opposite_type = + op.type() == ExpressionType::AND ? ExpressionType::OR : ExpressionType::AND; + return std::make_shared(opposite_type, std::move(operands)); + } + + case ExpressionType::EQUAL: + return make_opposite(ExpressionType::NOT_EQUAL); + + case ExpressionType::NOT_EQUAL: + return make_opposite(ExpressionType::EQUAL); + + case ExpressionType::LESS: + return make_opposite(ExpressionType::GREATER_EQUAL); + + case ExpressionType::LESS_EQUAL: + return make_opposite(ExpressionType::GREATER); + + case ExpressionType::GREATER: + return make_opposite(ExpressionType::LESS_EQUAL); + + case ExpressionType::GREATER_EQUAL: + return make_opposite(ExpressionType::LESS); + + default: + return Status::NotImplemented("can't invert this expression"); + } + + return op.Copy(); +} + +// If e can be cast to OperatorExpression try to simplify it against given. +// Otherwise pass e through unchanged. +Result> AssumeIfOperator(const std::shared_ptr& e, + const Expression& given) { + if (!e->IsOperatorExpression()) { + return e; + } + return checked_cast(*e).Assume(given); } -Result> AssumeComparison(const OperatorExpression& e, - const OperatorExpression& given) { - // TODO(bkietz) allow the RHS of e to be FIELD +// Try to simplify one comparison against another comparison. +// For example, +// (x > 3) is a subset of (x > 2), so (x > 2).Assume(x > 3) == (true) +// (x < 0) is disjoint with (x > 2), so (x > 2).Assume(x < 0) == (false) +// If simplification to (true) or (false) is not possible, pass e through unchanged. +Result> AssumeComparisonComparison( + const OperatorExpression& e, const OperatorExpression& given) { + if (e.operands()[1]->type() != ExpressionType::SCALAR || + given.operands()[1]->type() != ExpressionType::SCALAR) { + // TODO(bkietz) allow the RHS of e to be FIELD + return Status::Invalid("right hand side of comparison must be a scalar"); + } + auto e_rhs = checked_cast(*e.operands()[1]).value(); auto given_rhs = checked_cast(*given.operands()[1]).value(); ARROW_ASSIGN_OR_RAISE(auto cmp, Compare(*e_rhs, *given_rhs)); + if (!cmp.is_valid) { + // the RHS of e or given was null + return ScalarExpression::MakeNull(); + } + static auto always = ScalarExpression::Make(true); static auto never = ScalarExpression::Make(false); auto unsimplified = e.Copy(); - if (cmp == 0) { + if (cmp.value == 0) { // the rhs of the comparisons are equal switch (e.type()) { case ExpressionType::EQUAL: @@ -233,7 +335,7 @@ Result> AssumeComparison(const OperatorExpression& e default: return unsimplified; } - } else if (cmp > 0) { + } else if (cmp.value > 0) { // the rhs of e is greater than that of given switch (e.type()) { case ExpressionType::EQUAL: @@ -294,144 +396,159 @@ Result> AssumeComparison(const OperatorExpression& e return unsimplified; } -std::string GetName(const Expression& e) { - DCHECK_EQ(e.type(), ExpressionType::FIELD); - return checked_cast(e).name(); -} - -Result> OperatorExpression::Assume( - const Expression& given) const { - auto unsimplified = Copy(); +// Try to simplify a comparison against a compound expression. +// The operands of the compound expression must be examined individually. +Result> AssumeComparisonCompound( + const OperatorExpression& e, const OperatorExpression& given) { + auto unsimplified = e.Copy(); - if (IsComparisonExpression()) { - if (!given.IsOperatorExpression()) { - return unsimplified; + switch (given.type()) { + case ExpressionType::NOT: { + ARROW_ASSIGN_OR_RAISE(auto inverted, Invert(*given.operands()[0])); + return e.Assume(*inverted); } - const auto& given_op = checked_cast(given); + case ExpressionType::OR: { + bool simplify_to_always = true; + bool simplify_to_never = true; + for (auto operand : given.operands()) { + ARROW_ASSIGN_OR_RAISE(auto simplified, e.Assume(*operand)); + BooleanScalar scalar; + if (!simplified->IsBooleanScalar(&scalar)) { + return unsimplified; + } - if (given.IsComparisonExpression()) { - // Both this and given are simple comparisons. If they constrain - // the same field, try to simplify this assuming given - DCHECK_EQ(operands_.size(), 2); - DCHECK_EQ(given_op.operands_.size(), 2); - if (GetName(*operands_[0]) != GetName(*given_op.operands_[0])) { - return unsimplified; + // an expression should never simplify to null + DCHECK(scalar.is_valid); + + if (scalar.value == true) { + simplify_to_never = false; + } else { + simplify_to_always = false; + } } - return AssumeComparison(*this, given_op); - } - // must be NOT, AND, OR- decompose given - switch (given.type()) { - case ExpressionType::NOT: - return unsimplified; + if (simplify_to_always) { + return ScalarExpression::Make(true); + } - case ExpressionType::OR: { - bool simplify_to_always = true; - bool simplify_to_never = true; - for (auto operand : given_op.operands_) { - ARROW_ASSIGN_OR_RAISE(auto simplified, Assume(*operand)); - BooleanScalar scalar; - if (!simplified->IsBooleanScalar(&scalar)) { - return unsimplified; - } + if (simplify_to_never) { + return ScalarExpression::Make(false); + } - // an expression should never simplify to null - DCHECK(scalar.is_valid); + return unsimplified; + } - if (scalar.value == true) { - simplify_to_never = false; - } else { - simplify_to_always = false; - } + case ExpressionType::AND: { + std::shared_ptr simplified = unsimplified; + for (auto operand : given.operands()) { + if (simplified->IsBooleanScalar()) { + break; } - if (simplify_to_always) { - return ScalarExpression::Make(true); - } + DCHECK(simplified->IsOperatorExpression()); + const auto& simplified_op = checked_cast(*simplified); + ARROW_ASSIGN_OR_RAISE(simplified, simplified_op.Assume(*operand)); + } + return simplified; + } - if (simplify_to_never) { - return ScalarExpression::Make(false); - } + default: + DCHECK(false); + } - return unsimplified; - } + return unsimplified; +} - case ExpressionType::AND: { - std::shared_ptr simplified = unsimplified; - for (auto operand : given_op.operands_) { - if (simplified->IsBooleanScalar()) { - break; - } - - DCHECK(simplified->IsOperatorExpression()); - const auto& simplified_op = - checked_cast(*simplified); - ARROW_ASSIGN_OR_RAISE(simplified, simplified_op.Assume(*operand)); - } - return simplified; - } +Result> AssumeCompound(const OperatorExpression& e, + const Expression& given) { + auto unsimplified = e.Copy(); - default: - DCHECK(false); + if (e.type() == ExpressionType::NOT) { + DCHECK_EQ(e.operands().size(), 1); + ARROW_ASSIGN_OR_RAISE(auto operand, AssumeIfOperator(e.operands()[0], given)); + + if (operand->IsNullScalar()) { + return operand; + } + + BooleanScalar scalar; + if (!operand->IsBooleanScalar(&scalar)) { + return unsimplified; } + + return ScalarExpression::Make(!scalar.value); } - switch (type_) { - case ExpressionType::NOT: { - DCHECK_EQ(operands_.size(), 1); - ARROW_ASSIGN_OR_RAISE(auto operand, AssumeIfOperator(operands_[0], given)); + DCHECK(e.type() == ExpressionType::OR || e.type() == ExpressionType::AND); - BooleanScalar scalar; - if (!operand->IsBooleanScalar(&scalar)) { - return unsimplified; - } + // if any of the operands matches trivial_condition, we can return a trivial + // expression: + // anything OR true => true + // anything AND false => false + bool trivial_condition = e.type() == ExpressionType::OR; - // an expression should never simplify to null - DCHECK(scalar.is_valid); + std::vector> operands; + for (auto operand : e.operands()) { + ARROW_ASSIGN_OR_RAISE(operand, AssumeIfOperator(operand, given)); - return ScalarExpression::Make(!scalar.value); + if (operand->IsNullScalar()) { + return operand; } - case ExpressionType::OR: - case ExpressionType::AND: { - // if any of the operands matches trivial_condition, we can return a trivial - // expression: - // anything OR true => true - // anything AND false => false - bool trivial_condition = type_ == ExpressionType::OR; + BooleanScalar scalar; + if (!operand->IsBooleanScalar(&scalar)) { + operands.push_back(operand); + continue; + } - std::vector> operands; - for (auto operand : operands_) { - ARROW_ASSIGN_OR_RAISE(operand, AssumeIfOperator(operand, given)); + if (scalar.value == trivial_condition) { + return ScalarExpression::Make(trivial_condition); + } + } - BooleanScalar scalar; - if (!operand->IsBooleanScalar(&scalar)) { - operands.push_back(operand); - continue; - } + if (operands.size() == 1) { + return operands[0]; + } - if (scalar.value == trivial_condition) { - return ScalarExpression::Make(trivial_condition); - } - } + if (operands.size() == 0) { + return ScalarExpression::Make(!trivial_condition); + } - if (operands.size() == 1) { - return operands[0]; - } + return std::make_shared(e.type(), std::move(operands)); +} - if (operands.size() == 0) { - return ScalarExpression::Make(!trivial_condition); - } +Result> OperatorExpression::Assume( + const Expression& given) const { + auto unsimplified = Copy(); - return std::make_shared(type_, std::move(operands)); + if (IsComparisonExpression()) { + if (!given.IsOperatorExpression()) { + return unsimplified; } - default: - DCHECK(false); + const auto& given_op = checked_cast(given); + + if (given.IsComparisonExpression()) { + // Both this and given are simple comparisons. If they constrain + // the same field, try to simplify this assuming given + DCHECK_EQ(operands_.size(), 2); + DCHECK_EQ(given_op.operands_.size(), 2); + auto get_name = [](const Expression& e) { + DCHECK_EQ(e.type(), ExpressionType::FIELD); + return checked_cast(e).name(); + }; + if (get_name(*operands_[0]) != get_name(*given_op.operands_[0])) { + return unsimplified; + } + return AssumeComparisonComparison(*this, given_op); + } + + // must be NOT, AND, OR- decompose given + return AssumeComparisonCompound(*this, given_op); } - return unsimplified; + return AssumeCompound(*this, given); } std::string FieldReferenceExpression::ToString() const { @@ -556,6 +673,14 @@ bool Expression::IsComparisonExpression() const { static_cast(type_) <= static_cast(ExpressionType::LESS_EQUAL); } +bool Expression::IsNullScalar() const { + if (type_ != ExpressionType::SCALAR) { + return false; + } + + return !checked_cast(*this).value()->is_valid; +} + bool Expression::IsBooleanScalar(BooleanScalar* out) const { if (type_ != ExpressionType::SCALAR) { return false; @@ -586,6 +711,21 @@ std::shared_ptr ScalarExpression::Copy() const { return std::make_shared(*this); } +std::shared_ptr and_( + std::vector> operands) { + return std::make_shared(ExpressionType::AND, std::move(operands)); +} + +std::shared_ptr or_( + std::vector> operands) { + return std::make_shared(ExpressionType::OR, std::move(operands)); +} + +std::shared_ptr not_(std::shared_ptr operand) { + return std::make_shared( + ExpressionType::NOT, std::vector>{std::move(operand)}); +} + // flatten chains of and/or to a single OperatorExpression OperatorExpression MaybeCombine(ExpressionType::type type, const OperatorExpression& lhs, const OperatorExpression& rhs) { diff --git a/cpp/src/arrow/dataset/filter.h b/cpp/src/arrow/dataset/filter.h index 1f3b5073eb6..5232b98b21e 100644 --- a/cpp/src/arrow/dataset/filter.h +++ b/cpp/src/arrow/dataset/filter.h @@ -113,7 +113,7 @@ class ARROW_DS_EXPORT Expression { /// of this expression with schema information incorporated: /// - Scalars are cast to other data types if necessary to ensure comparisons are /// between data of identical type - // virtual Result> Validate(const Schema&) const; + // virtual Result> Validate(const Schema& schema) const = 0; /// returns a debug string representing this expression virtual std::string ToString() const = 0; @@ -128,6 +128,9 @@ class ARROW_DS_EXPORT Expression { /// a comparison bool IsComparisonExpression() const; + /// If true, this Expression is a ScalarExpression wrapping a null Scalar. + bool IsNullScalar() const; + /// If true, this Expression is a ScalarExpression wrapping a boolean Scalar. Its value /// may be retrieved at the same time bool IsBooleanScalar(BooleanScalar* value = NULLPTR) const; @@ -156,7 +159,9 @@ class ARROW_DS_EXPORT OperatorExpression final : public Expression { const std::vector>& operands() const { return operands_; } - virtual std::string ToString() const override; + std::string ToString() const override; + + // Result> Validate(const Schema& schema) const override; std::shared_ptr Copy() const override; @@ -164,11 +169,17 @@ class ARROW_DS_EXPORT OperatorExpression final : public Expression { std::vector> operands_; }; -OperatorExpression operator and(const OperatorExpression& lhs, - const OperatorExpression& rhs); -OperatorExpression operator or(const OperatorExpression& lhs, - const OperatorExpression& rhs); -OperatorExpression operator not(const OperatorExpression& rhs); +ARROW_DS_EXPORT std::shared_ptr and_( + std::vector> operands); +ARROW_DS_EXPORT OperatorExpression operator and(const OperatorExpression& lhs, + const OperatorExpression& rhs); +ARROW_DS_EXPORT std::shared_ptr or_( + std::vector> operands); +ARROW_DS_EXPORT OperatorExpression operator or(const OperatorExpression& lhs, + const OperatorExpression& rhs); +ARROW_DS_EXPORT std::shared_ptr not_( + std::shared_ptr operand); +ARROW_DS_EXPORT OperatorExpression operator not(const OperatorExpression& rhs); /// Represents a scalar value; thin wrapper around arrow::Scalar class ARROW_DS_EXPORT ScalarExpression final : public Expression { @@ -178,7 +189,9 @@ class ARROW_DS_EXPORT ScalarExpression final : public Expression { const std::shared_ptr& value() const { return value_; } - virtual std::string ToString() const override; + std::string ToString() const override; + + // Result> Validate(const Schema& schema) const override; static std::shared_ptr Make(bool value) { return std::make_shared(std::make_shared(value)); @@ -221,7 +234,9 @@ class ARROW_DS_EXPORT FieldReferenceExpression final : public Expression { std::string name() const { return name_; } - virtual std::string ToString() const override; + std::string ToString() const override; + + // Result> Validate(const Schema& schema) const override; std::shared_ptr Copy() const override; @@ -229,23 +244,29 @@ class ARROW_DS_EXPORT FieldReferenceExpression final : public Expression { std::string name_; }; -#define COMPARISON_FACTORY(NAME, OP) \ - template \ - OperatorExpression operator OP(const FieldReferenceExpression& lhs, T&& rhs) { \ - return OperatorExpression( \ - ExpressionType::NAME, \ - {lhs.Copy(), ScalarExpression::Make(std::forward(rhs))}); \ +#define COMPARISON_FACTORY(NAME, FACTORY_NAME, OP) \ + template \ + OperatorExpression FACTORY_NAME(const FieldReferenceExpression& lhs, T&& rhs) { \ + return OperatorExpression( \ + ExpressionType::NAME, \ + {lhs.Copy(), ScalarExpression::Make(std::forward(rhs))}); \ + } \ + \ + template \ + OperatorExpression operator OP(const FieldReferenceExpression& lhs, T&& rhs) { \ + return FACTORY_NAME(lhs, std::forward(rhs)); \ } -COMPARISON_FACTORY(EQUAL, ==) -COMPARISON_FACTORY(NOT_EQUAL, !=) -COMPARISON_FACTORY(GREATER, >) -COMPARISON_FACTORY(GREATER_EQUAL, >=) -COMPARISON_FACTORY(LESS, <) -COMPARISON_FACTORY(LESS_EQUAL, <=) +COMPARISON_FACTORY(EQUAL, equal, ==) +COMPARISON_FACTORY(NOT_EQUAL, not_equal, !=) +COMPARISON_FACTORY(GREATER, greater, >) +COMPARISON_FACTORY(GREATER_EQUAL, greater_equal, >=) +COMPARISON_FACTORY(LESS, less, <) +COMPARISON_FACTORY(LESS_EQUAL, less_equal, <=) #undef COMPARISON_FACTORY inline namespace string_literals { -FieldReferenceExpression operator""_(const char* name, size_t name_length) { +ARROW_DS_EXPORT FieldReferenceExpression operator""_(const char* name, + size_t name_length) { return FieldReferenceExpression({name, name_length}); } } // namespace string_literals From 55aa452309f028b5023603cf5ba7ec74b7c2e478 Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Thu, 15 Aug 2019 10:34:13 -0400 Subject: [PATCH 08/28] implement more robust null handling --- cpp/src/arrow/dataset/filter.cc | 24 +++++++++++++++++++----- cpp/src/arrow/dataset/filter.h | 24 +++++++++++++----------- cpp/src/arrow/dataset/filter_test.cc | 26 ++++++++++++++++++++++---- 3 files changed, 54 insertions(+), 20 deletions(-) diff --git a/cpp/src/arrow/dataset/filter.cc b/cpp/src/arrow/dataset/filter.cc index f39f011ef66..a2327b67399 100644 --- a/cpp/src/arrow/dataset/filter.cc +++ b/cpp/src/arrow/dataset/filter.cc @@ -166,13 +166,13 @@ struct CompareVisitor { // lhs > rhs => return > 0 // if either is null, return is null Result Compare(const Scalar& lhs, const Scalar& rhs) { - if (!lhs.type->Equals(*rhs.type)) { - return Status::TypeError("cannot compare a scalars with differing type: ", *lhs.type, - " vs ", *rhs.type); - } if (!lhs.is_valid || !rhs.is_valid) { return Int64Scalar(); } + if (!lhs.type->Equals(*rhs.type)) { + return Status::TypeError("cannot compare scalars with differing type: ", *lhs.type, + " vs ", *rhs.type); + } CompareVisitor vis{0, lhs, rhs}; RETURN_NOT_OK(VisitTypeInline(*lhs.type, &vis)); return Int64Scalar(vis.result_); @@ -487,6 +487,7 @@ Result> AssumeCompound(const OperatorExpression& e, // anything OR true => true // anything AND false => false bool trivial_condition = e.type() == ExpressionType::OR; + bool simplify_to_trivial = false; std::vector> operands; for (auto operand : e.operands()) { @@ -496,6 +497,10 @@ Result> AssumeCompound(const OperatorExpression& e, return operand; } + if (simplify_to_trivial) { + continue; + } + BooleanScalar scalar; if (!operand->IsBooleanScalar(&scalar)) { operands.push_back(operand); @@ -503,10 +508,15 @@ Result> AssumeCompound(const OperatorExpression& e, } if (scalar.value == trivial_condition) { - return ScalarExpression::Make(trivial_condition); + simplify_to_trivial = true; + continue; } } + if (simplify_to_trivial) { + return ScalarExpression::Make(trivial_condition); + } + if (operands.size() == 1) { return operands[0]; } @@ -597,6 +607,10 @@ std::string OperatorExpression::ToString() const { } std::string ScalarExpression::ToString() const { + if (!value_->is_valid) { + return "scalar<" + value_->type->ToString() + ", null>()"; + } + std::string value; switch (value_->type->id()) { case Type::BOOL: diff --git a/cpp/src/arrow/dataset/filter.h b/cpp/src/arrow/dataset/filter.h index 5232b98b21e..fdecce49d73 100644 --- a/cpp/src/arrow/dataset/filter.h +++ b/cpp/src/arrow/dataset/filter.h @@ -244,17 +244,19 @@ class ARROW_DS_EXPORT FieldReferenceExpression final : public Expression { std::string name_; }; -#define COMPARISON_FACTORY(NAME, FACTORY_NAME, OP) \ - template \ - OperatorExpression FACTORY_NAME(const FieldReferenceExpression& lhs, T&& rhs) { \ - return OperatorExpression( \ - ExpressionType::NAME, \ - {lhs.Copy(), ScalarExpression::Make(std::forward(rhs))}); \ - } \ - \ - template \ - OperatorExpression operator OP(const FieldReferenceExpression& lhs, T&& rhs) { \ - return FACTORY_NAME(lhs, std::forward(rhs)); \ +#define COMPARISON_FACTORY(NAME, FACTORY_NAME, OP) \ + ARROW_DS_EXPORT std::shared_ptr FACTORY_NAME( \ + const FieldReferenceExpression& lhs, const ScalarExpression& rhs) { \ + return std::make_shared( \ + ExpressionType::NAME, \ + std::vector>{lhs.Copy(), rhs.Copy()}); \ + } \ + \ + template \ + OperatorExpression operator OP(const FieldReferenceExpression& lhs, T&& rhs) { \ + return OperatorExpression( \ + ExpressionType::NAME, \ + {lhs.Copy(), ScalarExpression::Make(std::forward(rhs))}); \ } COMPARISON_FACTORY(EQUAL, equal, ==) COMPARISON_FACTORY(NOT_EQUAL, not_equal, !=) diff --git a/cpp/src/arrow/dataset/filter_test.cc b/cpp/src/arrow/dataset/filter_test.cc index 99675d30233..98e73b8897b 100644 --- a/cpp/src/arrow/dataset/filter_test.cc +++ b/cpp/src/arrow/dataset/filter_test.cc @@ -76,7 +76,7 @@ TEST_F(ExpressionsTest, Equality) { ASSERT_FALSE(("b"_ > 2 and "b"_ < 3).Equals("b"_ < 3 and "b"_ > 2)); } -TEST_F(ExpressionsTest, Simplification) { +TEST_F(ExpressionsTest, SimplificationOfCompoundQuery) { using namespace string_literals; // chained "and" expressions are flattened @@ -85,18 +85,36 @@ TEST_F(ExpressionsTest, Simplification) { AssertSimplifiesTo("b"_ > 5 and "b"_ < 10, "b"_ == 3, *never); AssertSimplifiesTo("b"_ > 5 and "b"_ < 10, "b"_ == 6, *always); - AssertSimplifiesTo("b"_ > 5, "b"_ == 3 or "b"_ == 6, "b"_ > 5); - AssertSimplifiesTo("b"_ > 7, "b"_ == 3 or "b"_ == 6, *never); - AssertSimplifiesTo("b"_ > 5 and "b"_ < 10, "b"_ > 6 and "b"_ < 13, "b"_ < 10); AssertSimplifiesTo("b"_ == 3 or "b"_ == 4, "b"_ > 6, *never); AssertSimplifiesTo("b"_ == 3 or "b"_ == 4, "b"_ == 3, *always); AssertSimplifiesTo("b"_ == 3 or "b"_ == 4, "b"_ > 3, "b"_ == 4); AssertSimplifiesTo("b"_ == 3 or "b"_ == 4, "b"_ >= 3, "b"_ == 3 or "b"_ == 4); + AssertSimplifiesTo("b"_ == 4, "a"_ == 0, "b"_ == 4); + AssertSimplifiesTo("a"_ == 3 or "b"_ == 4, "a"_ == 0, "b"_ == 4); } +TEST_F(ExpressionsTest, SimplificationAgainstCompoundCondition) { + using namespace string_literals; + + AssertSimplifiesTo("b"_ > 5, "b"_ == 3 or "b"_ == 6, "b"_ > 5); + AssertSimplifiesTo("b"_ > 7, "b"_ == 3 or "b"_ == 6, *never); + AssertSimplifiesTo("b"_ > 5 and "b"_ < 10, "b"_ > 6 and "b"_ < 13, "b"_ < 10); +} + +TEST_F(ExpressionsTest, SimplificationToNull) { + using namespace string_literals; + + auto null = ScalarExpression::MakeNull(); + + AssertSimplifiesTo(*equal("b"_, *null), "b"_ == 3, *null); + AssertSimplifiesTo(*not_equal("b"_, *null), "b"_ == 3, *null); + AssertSimplifiesTo(*not_equal("b"_, *null) and "b"_ > 3, "b"_ == 3, *null); + AssertSimplifiesTo("b"_ > 3 and *not_equal("b"_, *null), "b"_ == 3, *null); +} + class FilterTest : public ::testing::Test { public: void AssertFilter(OperatorExpression expr, std::vector> fields, From 60d9e088f65bf4e23242f6aa66d61fdc1aad4aad Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Thu, 15 Aug 2019 11:10:20 -0400 Subject: [PATCH 09/28] fix factory linkage, factory fns deal in shared_ptrs exclusively --- cpp/src/arrow/dataset/filter.h | 80 ++++++++++++++++------------ cpp/src/arrow/dataset/filter_test.cc | 8 +-- cpp/src/arrow/dataset/type_fwd.h | 12 +++++ 3 files changed, 62 insertions(+), 38 deletions(-) diff --git a/cpp/src/arrow/dataset/filter.h b/cpp/src/arrow/dataset/filter.h index fdecce49d73..650c54cad47 100644 --- a/cpp/src/arrow/dataset/filter.h +++ b/cpp/src/arrow/dataset/filter.h @@ -19,16 +19,12 @@ #include +#include "arrow/dataset/type_fwd.h" #include "arrow/dataset/visibility.h" #include "arrow/result.h" #include "arrow/scalar.h" namespace arrow { - -namespace compute { -class FunctionContext; -} - namespace dataset { struct FilterType { @@ -54,10 +50,8 @@ class ARROW_DS_EXPORT Filter { FilterType::type type_; }; -class Expression; - -/// Filter subclass encapsulating a simple boolean predicate consisting of comparisons and -/// boolean logic (AND, OR, NOT) involving Schema fields +/// Filter subclass encapsulating a simple boolean predicate consisting of comparisons +/// and boolean logic (AND, OR, NOT) involving Schema fields class ARROW_DS_EXPORT ExpressionFilter : public Filter { public: explicit ExpressionFilter(const std::shared_ptr& expression) @@ -113,7 +107,8 @@ class ARROW_DS_EXPORT Expression { /// of this expression with schema information incorporated: /// - Scalars are cast to other data types if necessary to ensure comparisons are /// between data of identical type - // virtual Result> Validate(const Schema& schema) const = 0; + // virtual Result> Validate(const Schema& schema) const = + // 0; /// returns a debug string representing this expression virtual std::string ToString() const = 0; @@ -124,15 +119,15 @@ class ARROW_DS_EXPORT Expression { bool IsOperatorExpression() const; /// If true, this Expression may be safely cast to OperatorExpression - /// and there will be exactly two operands representing the left and right hand sides of - /// a comparison + /// and there will be exactly two operands representing the left and right hand sides + /// of a comparison bool IsComparisonExpression() const; /// If true, this Expression is a ScalarExpression wrapping a null Scalar. bool IsNullScalar() const; - /// If true, this Expression is a ScalarExpression wrapping a boolean Scalar. Its value - /// may be retrieved at the same time + /// If true, this Expression is a ScalarExpression wrapping a boolean Scalar. Its + /// value may be retrieved at the same time bool IsBooleanScalar(BooleanScalar* value = NULLPTR) const; /// Copy this expression into a shared pointer. @@ -142,8 +137,8 @@ class ARROW_DS_EXPORT Expression { ExpressionType::type type_; }; -/// Represents an compound expression; for example comparison between a field and a scalar -/// or a union of other expressions +/// Represents an compound expression; for example comparison between a field and a +/// scalar or a union of other expressions class ARROW_DS_EXPORT OperatorExpression final : public Expression { public: OperatorExpression(ExpressionType::type type, @@ -153,8 +148,8 @@ class ARROW_DS_EXPORT OperatorExpression final : public Expression { /// Return a simplified form of this expression given some known conditions. /// For example, (a > 3).Assume(a == 5) == (true). This can be used to do less work /// in ExpressionFilter when partition conditions guarantee some of this expression. - /// In the example above, *no* filtering need be done on record batches in the partition - /// since (a == 5). + /// In the example above, *no* filtering need be done on record batches in the + /// partition since (a == 5). Result> Assume(const Expression& given) const; const std::vector>& operands() const { return operands_; } @@ -171,14 +166,19 @@ class ARROW_DS_EXPORT OperatorExpression final : public Expression { ARROW_DS_EXPORT std::shared_ptr and_( std::vector> operands); + ARROW_DS_EXPORT OperatorExpression operator and(const OperatorExpression& lhs, const OperatorExpression& rhs); + ARROW_DS_EXPORT std::shared_ptr or_( std::vector> operands); + ARROW_DS_EXPORT OperatorExpression operator or(const OperatorExpression& lhs, const OperatorExpression& rhs); + ARROW_DS_EXPORT std::shared_ptr not_( std::shared_ptr operand); + ARROW_DS_EXPORT OperatorExpression operator not(const OperatorExpression& rhs); /// Represents a scalar value; thin wrapper around arrow::Scalar @@ -215,8 +215,12 @@ class ARROW_DS_EXPORT ScalarExpression final : public Expression { static std::shared_ptr Make(const char* value); + static std::shared_ptr Make(std::shared_ptr value) { + return std::make_shared(std::move(value)); + } + static std::shared_ptr MakeNull() { - return std::make_shared(std::make_shared()); + return std::make_shared(std::make_shared()); } std::shared_ptr Copy() const override; @@ -225,6 +229,11 @@ class ARROW_DS_EXPORT ScalarExpression final : public Expression { std::shared_ptr value_; }; +template +std::shared_ptr scalar(T&& value) { + return ScalarExpression::Make(std::forward(value)); +} + /// Represents a reference to a field. Stores only the field's name (type and other /// information is known only when a Schema is provided) class ARROW_DS_EXPORT FieldReferenceExpression final : public Expression { @@ -244,19 +253,23 @@ class ARROW_DS_EXPORT FieldReferenceExpression final : public Expression { std::string name_; }; -#define COMPARISON_FACTORY(NAME, FACTORY_NAME, OP) \ - ARROW_DS_EXPORT std::shared_ptr FACTORY_NAME( \ - const FieldReferenceExpression& lhs, const ScalarExpression& rhs) { \ - return std::make_shared( \ - ExpressionType::NAME, \ - std::vector>{lhs.Copy(), rhs.Copy()}); \ - } \ - \ - template \ - OperatorExpression operator OP(const FieldReferenceExpression& lhs, T&& rhs) { \ - return OperatorExpression( \ - ExpressionType::NAME, \ - {lhs.Copy(), ScalarExpression::Make(std::forward(rhs))}); \ +inline std::shared_ptr fieldRef(std::string name) { + return std::make_shared(std::move(name)); +} + +#define COMPARISON_FACTORY(NAME, FACTORY_NAME, OP) \ + inline std::shared_ptr FACTORY_NAME( \ + const std::shared_ptr& lhs, \ + const std::shared_ptr& rhs) { \ + return std::make_shared( \ + ExpressionType::NAME, std::vector>{lhs, rhs}); \ + } \ + \ + template \ + OperatorExpression operator OP(const FieldReferenceExpression& lhs, T&& rhs) { \ + return OperatorExpression( \ + ExpressionType::NAME, \ + {lhs.Copy(), ScalarExpression::Make(std::forward(rhs))}); \ } COMPARISON_FACTORY(EQUAL, equal, ==) COMPARISON_FACTORY(NOT_EQUAL, not_equal, !=) @@ -267,8 +280,7 @@ COMPARISON_FACTORY(LESS_EQUAL, less_equal, <=) #undef COMPARISON_FACTORY inline namespace string_literals { -ARROW_DS_EXPORT FieldReferenceExpression operator""_(const char* name, - size_t name_length) { +inline FieldReferenceExpression operator""_(const char* name, size_t name_length) { return FieldReferenceExpression({name, name_length}); } } // namespace string_literals diff --git a/cpp/src/arrow/dataset/filter_test.cc b/cpp/src/arrow/dataset/filter_test.cc index 98e73b8897b..87739e93185 100644 --- a/cpp/src/arrow/dataset/filter_test.cc +++ b/cpp/src/arrow/dataset/filter_test.cc @@ -109,10 +109,10 @@ TEST_F(ExpressionsTest, SimplificationToNull) { auto null = ScalarExpression::MakeNull(); - AssertSimplifiesTo(*equal("b"_, *null), "b"_ == 3, *null); - AssertSimplifiesTo(*not_equal("b"_, *null), "b"_ == 3, *null); - AssertSimplifiesTo(*not_equal("b"_, *null) and "b"_ > 3, "b"_ == 3, *null); - AssertSimplifiesTo("b"_ > 3 and *not_equal("b"_, *null), "b"_ == 3, *null); + AssertSimplifiesTo(*equal(fieldRef("b"), null), "b"_ == 3, *null); + AssertSimplifiesTo(*not_equal(fieldRef("b"), null), "b"_ == 3, *null); + AssertSimplifiesTo(*not_equal(fieldRef("b"), null) and "b"_ > 3, "b"_ == 3, *null); + AssertSimplifiesTo("b"_ > 3 and *not_equal(fieldRef("b"), null), "b"_ == 3, *null); } class FilterTest : public ::testing::Test { diff --git a/cpp/src/arrow/dataset/type_fwd.h b/cpp/src/arrow/dataset/type_fwd.h index 8e3824625ed..a7d79cd501b 100644 --- a/cpp/src/arrow/dataset/type_fwd.h +++ b/cpp/src/arrow/dataset/type_fwd.h @@ -25,6 +25,12 @@ namespace arrow { +namespace compute { + +class FunctionContext; + +} // namespace compute + namespace fs { class FileSystem; @@ -50,6 +56,12 @@ class FileWriteOptions; class Filter; using FilterVector = std::vector>; +class Expression; +class OperatorExpression; +class ScalarExpression; +class FieldReferenceExpression; +using ExpressionVector = std::vector>; + class Partition; class PartitionKey; class PartitionScheme; From 523186d880c5d33039f3f73770a87810dc30d42a Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Tue, 20 Aug 2019 21:14:57 -0400 Subject: [PATCH 10/28] add support for evaluation of trivial expressions, tests --- cpp/src/arrow/dataset/filter.cc | 66 ++++++++++++++------ cpp/src/arrow/dataset/filter.h | 86 +++++++++++++------------- cpp/src/arrow/dataset/filter_test.cc | 91 ++++++++++++++++++++++------ 3 files changed, 163 insertions(+), 80 deletions(-) diff --git a/cpp/src/arrow/dataset/filter.cc b/cpp/src/arrow/dataset/filter.cc index a2327b67399..818b9b17be8 100644 --- a/cpp/src/arrow/dataset/filter.cc +++ b/cpp/src/arrow/dataset/filter.cc @@ -19,7 +19,9 @@ #include +#include "arrow/buffer-builder.h" #include "arrow/buffer.h" +#include "arrow/compute/context.h" #include "arrow/compute/kernels/boolean.h" #include "arrow/compute/kernels/compare.h" #include "arrow/record_batch.h" @@ -29,12 +31,37 @@ namespace arrow { namespace dataset { +using arrow::compute::Datum; using internal::checked_cast; +using internal::checked_pointer_cast; + +Status TrivialMask(MemoryPool* pool, const BooleanScalar& value, const RecordBatch& batch, + std::shared_ptr* mask) { + if (!value.is_valid) { + std::shared_ptr mask_array; + RETURN_NOT_OK(MakeArrayOfNull(boolean(), batch.num_rows(), &mask_array)); + + *mask = checked_pointer_cast(mask_array); + return Status::OK(); + } + + TypedBufferBuilder builder; + RETURN_NOT_OK(builder.Append(batch.num_rows(), value.value)); + + std::shared_ptr values; + RETURN_NOT_OK(builder.Finish(&values)); + + *mask = std::make_shared(batch.num_rows(), values); + return Status::OK(); +} Status EvaluateExpression(compute::FunctionContext* ctx, const Expression& condition, - const RecordBatch& batch, - std::shared_ptr* filter) { - using arrow::compute::Datum; + const RecordBatch& batch, std::shared_ptr* mask) { + BooleanScalar value; + if (condition.IsNullScalar() || condition.IsBooleanScalar(&value)) { + return TrivialMask(ctx->memory_pool(), value, batch, mask); + } + if (!condition.IsOperatorExpression()) { return Status::Invalid("can't execute condition ", condition.ToString()); } @@ -44,14 +71,19 @@ Status EvaluateExpression(compute::FunctionContext* ctx, const Expression& condi if (condition.IsComparisonExpression()) { const auto& lhs = checked_cast(*op[0]); const auto& rhs = checked_cast(*op[1]); + Datum out; + auto lhs_array = batch.GetColumnByName(lhs.name()); + if (lhs_array == nullptr) { + // comparing a field absent from batch: return nulls + return TrivialMask(ctx->memory_pool(), BooleanScalar(), batch, mask); + } using arrow::compute::CompareOperator; arrow::compute::CompareOptions opts(static_cast( static_cast(CompareOperator::EQUAL) + static_cast(condition.type()) - static_cast(ExpressionType::EQUAL))); - Datum out; - RETURN_NOT_OK(arrow::compute::Compare(ctx, Datum(batch.GetColumnByName(lhs.name())), - Datum(rhs.value()), opts, &out)); - *filter = internal::checked_pointer_cast(out.make_array()); + RETURN_NOT_OK( + arrow::compute::Compare(ctx, Datum(lhs_array), Datum(rhs.value()), opts, &out)); + *mask = checked_pointer_cast(out.make_array()); return Status::OK(); } @@ -60,7 +92,7 @@ Status EvaluateExpression(compute::FunctionContext* ctx, const Expression& condi RETURN_NOT_OK(EvaluateExpression(ctx, *op[0], batch, &to_invert)); Datum out; RETURN_NOT_OK(arrow::compute::Invert(ctx, Datum(to_invert), &out)); - *filter = internal::checked_pointer_cast(out.make_array()); + *mask = checked_pointer_cast(out.make_array()); return Status::OK(); } @@ -81,7 +113,7 @@ Status EvaluateExpression(compute::FunctionContext* ctx, const Expression& condi } } - *filter = internal::checked_pointer_cast(acc.make_array()); + *mask = checked_pointer_cast(acc.make_array()); return Status::OK(); } @@ -190,7 +222,7 @@ Result> Invert(const Expression& op) { case ExpressionType::AND: case ExpressionType::OR: { - std::vector> operands; + ExpressionVector operands; for (auto operand : checked_cast(op).operands()) { ARROW_ASSIGN_OR_RAISE(auto inverted_operand, Invert(*operand)); operands.push_back(inverted_operand); @@ -489,7 +521,7 @@ Result> AssumeCompound(const OperatorExpression& e, bool trivial_condition = e.type() == ExpressionType::OR; bool simplify_to_trivial = false; - std::vector> operands; + ExpressionVector operands; for (auto operand : e.operands()) { ARROW_ASSIGN_OR_RAISE(operand, AssumeIfOperator(operand, given)); @@ -725,19 +757,17 @@ std::shared_ptr ScalarExpression::Copy() const { return std::make_shared(*this); } -std::shared_ptr and_( - std::vector> operands) { +std::shared_ptr and_(ExpressionVector operands) { return std::make_shared(ExpressionType::AND, std::move(operands)); } -std::shared_ptr or_( - std::vector> operands) { +std::shared_ptr or_(ExpressionVector operands) { return std::make_shared(ExpressionType::OR, std::move(operands)); } std::shared_ptr not_(std::shared_ptr operand) { - return std::make_shared( - ExpressionType::NOT, std::vector>{std::move(operand)}); + return std::make_shared(ExpressionType::NOT, + ExpressionVector{std::move(operand)}); } // flatten chains of and/or to a single OperatorExpression @@ -746,7 +776,7 @@ OperatorExpression MaybeCombine(ExpressionType::type type, const OperatorExpress if (lhs.type() != type && rhs.type() != type) { return OperatorExpression(type, {lhs.Copy(), rhs.Copy()}); } - std::vector> operands; + ExpressionVector operands; if (lhs.type() == type) { operands = lhs.operands(); if (rhs.type() == type) { diff --git a/cpp/src/arrow/dataset/filter.h b/cpp/src/arrow/dataset/filter.h index 650c54cad47..e1e4d27e746 100644 --- a/cpp/src/arrow/dataset/filter.h +++ b/cpp/src/arrow/dataset/filter.h @@ -66,8 +66,7 @@ class ARROW_DS_EXPORT ExpressionFilter : public Filter { /// Evaluate an expression producing a boolean array which encodes whether each row /// satisfies the condition Status EvaluateExpression(compute::FunctionContext* ctx, const Expression& condition, - const RecordBatch& batch, - std::shared_ptr* filter); + const RecordBatch& batch, std::shared_ptr* mask); struct ExpressionType { enum type { @@ -137,12 +136,11 @@ class ARROW_DS_EXPORT Expression { ExpressionType::type type_; }; -/// Represents an compound expression; for example comparison between a field and a +/// Represents a compound expression; for example comparison between a field and a /// scalar or a union of other expressions class ARROW_DS_EXPORT OperatorExpression final : public Expression { public: - OperatorExpression(ExpressionType::type type, - std::vector> operands) + OperatorExpression(ExpressionType::type type, ExpressionVector operands) : Expression(type), operands_(std::move(operands)) {} /// Return a simplified form of this expression given some known conditions. @@ -152,7 +150,7 @@ class ARROW_DS_EXPORT OperatorExpression final : public Expression { /// partition since (a == 5). Result> Assume(const Expression& given) const; - const std::vector>& operands() const { return operands_; } + const ExpressionVector& operands() const { return operands_; } std::string ToString() const override; @@ -161,26 +159,9 @@ class ARROW_DS_EXPORT OperatorExpression final : public Expression { std::shared_ptr Copy() const override; private: - std::vector> operands_; + ExpressionVector operands_; }; -ARROW_DS_EXPORT std::shared_ptr and_( - std::vector> operands); - -ARROW_DS_EXPORT OperatorExpression operator and(const OperatorExpression& lhs, - const OperatorExpression& rhs); - -ARROW_DS_EXPORT std::shared_ptr or_( - std::vector> operands); - -ARROW_DS_EXPORT OperatorExpression operator or(const OperatorExpression& lhs, - const OperatorExpression& rhs); - -ARROW_DS_EXPORT std::shared_ptr not_( - std::shared_ptr operand); - -ARROW_DS_EXPORT OperatorExpression operator not(const OperatorExpression& rhs); - /// Represents a scalar value; thin wrapper around arrow::Scalar class ARROW_DS_EXPORT ScalarExpression final : public Expression { public: @@ -229,11 +210,6 @@ class ARROW_DS_EXPORT ScalarExpression final : public Expression { std::shared_ptr value_; }; -template -std::shared_ptr scalar(T&& value) { - return ScalarExpression::Make(std::forward(value)); -} - /// Represents a reference to a field. Stores only the field's name (type and other /// information is known only when a Schema is provided) class ARROW_DS_EXPORT FieldReferenceExpression final : public Expression { @@ -253,23 +229,34 @@ class ARROW_DS_EXPORT FieldReferenceExpression final : public Expression { std::string name_; }; -inline std::shared_ptr fieldRef(std::string name) { - return std::make_shared(std::move(name)); -} +ARROW_DS_EXPORT std::shared_ptr and_(ExpressionVector operands); -#define COMPARISON_FACTORY(NAME, FACTORY_NAME, OP) \ - inline std::shared_ptr FACTORY_NAME( \ - const std::shared_ptr& lhs, \ - const std::shared_ptr& rhs) { \ - return std::make_shared( \ - ExpressionType::NAME, std::vector>{lhs, rhs}); \ - } \ - \ - template \ - OperatorExpression operator OP(const FieldReferenceExpression& lhs, T&& rhs) { \ - return OperatorExpression( \ - ExpressionType::NAME, \ - {lhs.Copy(), ScalarExpression::Make(std::forward(rhs))}); \ +ARROW_DS_EXPORT OperatorExpression operator and(const OperatorExpression& lhs, + const OperatorExpression& rhs); + +ARROW_DS_EXPORT std::shared_ptr or_(ExpressionVector operands); + +ARROW_DS_EXPORT OperatorExpression operator or(const OperatorExpression& lhs, + const OperatorExpression& rhs); + +ARROW_DS_EXPORT std::shared_ptr not_( + std::shared_ptr operand); + +ARROW_DS_EXPORT OperatorExpression operator not(const OperatorExpression& rhs); + +#define COMPARISON_FACTORY(NAME, FACTORY_NAME, OP) \ + inline std::shared_ptr FACTORY_NAME( \ + const std::shared_ptr& lhs, \ + const std::shared_ptr& rhs) { \ + return std::make_shared(ExpressionType::NAME, \ + ExpressionVector{lhs, rhs}); \ + } \ + \ + template \ + OperatorExpression operator OP(const FieldReferenceExpression& lhs, T&& rhs) { \ + return OperatorExpression( \ + ExpressionType::NAME, \ + {lhs.Copy(), ScalarExpression::Make(std::forward(rhs))}); \ } COMPARISON_FACTORY(EQUAL, equal, ==) COMPARISON_FACTORY(NOT_EQUAL, not_equal, !=) @@ -279,6 +266,15 @@ COMPARISON_FACTORY(LESS, less, <) COMPARISON_FACTORY(LESS_EQUAL, less_equal, <=) #undef COMPARISON_FACTORY +template +auto scalar(T&& value) -> decltype(ScalarExpression::Make(std::forward(value))) { + return ScalarExpression::Make(std::forward(value)); +} + +inline std::shared_ptr fieldRef(std::string name) { + return std::make_shared(std::move(name)); +} + inline namespace string_literals { inline FieldReferenceExpression operator""_(const char* name, size_t name_length) { return FieldReferenceExpression({name, name_length}); diff --git a/cpp/src/arrow/dataset/filter_test.cc b/cpp/src/arrow/dataset/filter_test.cc index 87739e93185..0f15313f6c3 100644 --- a/cpp/src/arrow/dataset/filter_test.cc +++ b/cpp/src/arrow/dataset/filter_test.cc @@ -32,6 +32,9 @@ namespace arrow { namespace dataset { +using namespace string_literals; +using internal::checked_pointer_cast; + class ExpressionsTest : public ::testing::Test { public: void AssertSimplifiesTo(OperatorExpression expr, const Expression& given, @@ -63,8 +66,6 @@ class ExpressionsTest : public ::testing::Test { }; TEST_F(ExpressionsTest, Equality) { - using namespace string_literals; - ASSERT_TRUE("a"_.Equals("a"_)); ASSERT_FALSE("a"_.Equals("b"_)); @@ -77,8 +78,6 @@ TEST_F(ExpressionsTest, Equality) { } TEST_F(ExpressionsTest, SimplificationOfCompoundQuery) { - using namespace string_literals; - // chained "and" expressions are flattened auto multi_and = "b"_ > 5 and "b"_ < 10 and "b"_ != 7; AssertOperandsAre(multi_and, ExpressionType::AND, "b"_ > 5, "b"_ < 10, "b"_ != 7); @@ -97,16 +96,12 @@ TEST_F(ExpressionsTest, SimplificationOfCompoundQuery) { } TEST_F(ExpressionsTest, SimplificationAgainstCompoundCondition) { - using namespace string_literals; - AssertSimplifiesTo("b"_ > 5, "b"_ == 3 or "b"_ == 6, "b"_ > 5); AssertSimplifiesTo("b"_ > 7, "b"_ == 3 or "b"_ == 6, *never); AssertSimplifiesTo("b"_ > 5 and "b"_ < 10, "b"_ > 6 and "b"_ < 13, "b"_ < 10); } TEST_F(ExpressionsTest, SimplificationToNull) { - using namespace string_literals; - auto null = ScalarExpression::MakeNull(); AssertSimplifiesTo(*equal(fieldRef("b"), null), "b"_ == 3, *null); @@ -117,29 +112,66 @@ TEST_F(ExpressionsTest, SimplificationToNull) { class FilterTest : public ::testing::Test { public: - void AssertFilter(OperatorExpression expr, std::vector> fields, - std::string batch_json) { + Status DoFilter(const Expression& expr, std::vector> fields, + std::string batch_json, std::shared_ptr* mask, + std::shared_ptr* expected_mask) { // expected filter result is in the "in" field fields.push_back(field("in", boolean())); auto batch_array = ArrayFromJSON(struct_(std::move(fields)), std::move(batch_json)); std::shared_ptr batch; - ASSERT_OK(RecordBatch::FromStructArray(batch_array, &batch)); + RETURN_NOT_OK(RecordBatch::FromStructArray(batch_array, &batch)); - auto expected_filter = batch->GetColumnByName("in"); + *expected_mask = checked_pointer_cast(batch->GetColumnByName("in")); - std::shared_ptr filter; - ASSERT_OK(EvaluateExpression(&ctx_, expr, *batch, &filter)); + return EvaluateExpression(&ctx_, expr, *batch, mask); + } - ASSERT_ARRAYS_EQUAL(*expected_filter, *filter); + void AssertFilter(const Expression& expr, std::vector> fields, + std::string batch_json) { + std::shared_ptr mask, expected_mask; + ASSERT_OK( + DoFilter(expr, std::move(fields), std::move(batch_json), &mask, &expected_mask)); + ASSERT_ARRAYS_EQUAL(*expected_mask, *mask); } arrow::compute::FunctionContext ctx_; }; -TEST_F(FilterTest, Basics) { - using namespace string_literals; +TEST_F(FilterTest, Trivial) { + AssertFilter(*scalar(true), {field("a", int64()), field("b", float64())}, R"([ + {"a": 0, "b": -0.1, "in": 1}, + {"a": 0, "b": 0.3, "in": 1}, + {"a": 1, "b": 0.2, "in": 1}, + {"a": 2, "b": -0.1, "in": 1}, + {"a": 0, "b": 0.1, "in": 1}, + {"a": 0, "b": null, "in": 1}, + {"a": 0, "b": 1.0, "in": 1} + ])"); + + AssertFilter(*scalar(false), {field("a", int64()), field("b", float64())}, R"([ + {"a": 0, "b": -0.1, "in": 0}, + {"a": 0, "b": 0.3, "in": 0}, + {"a": 1, "b": 0.2, "in": 0}, + {"a": 2, "b": -0.1, "in": 0}, + {"a": 0, "b": 0.1, "in": 0}, + {"a": 0, "b": null, "in": 0}, + {"a": 0, "b": 1.0, "in": 0} + ])"); + + AssertFilter(*ScalarExpression::MakeNull(), + {field("a", int64()), field("b", float64())}, R"([ + {"a": 0, "b": -0.1, "in": null}, + {"a": 0, "b": 0.3, "in": null}, + {"a": 1, "b": 0.2, "in": null}, + {"a": 2, "b": -0.1, "in": null}, + {"a": 0, "b": 0.1, "in": null}, + {"a": 0, "b": null, "in": null}, + {"a": 0, "b": 1.0, "in": null} + ])"); +} +TEST_F(FilterTest, Basics) { AssertFilter("a"_ == 0 and "b"_ > 0.0 and "b"_ < 1.0, {field("a", int64()), field("b", float64())}, R"([ {"a": 0, "b": -0.1, "in": 0}, @@ -147,9 +179,34 @@ TEST_F(FilterTest, Basics) { {"a": 1, "b": 0.2, "in": 0}, {"a": 2, "b": -0.1, "in": 0}, {"a": 0, "b": 0.1, "in": 1}, + {"a": 0, "b": null, "in": null}, + {"a": 0, "b": 1.0, "in": 0} + ])"); + + AssertFilter("a"_ != 0 and "b"_ > 0.1, {field("a", int64()), field("b", float64())}, + R"([ + {"a": 0, "b": -0.1, "in": 0}, + {"a": 0, "b": 0.3, "in": 0}, + {"a": 1, "b": 0.2, "in": 1}, + {"a": 2, "b": -0.1, "in": 0}, + {"a": 0, "b": 0.1, "in": 0}, + {"a": 0, "b": null, "in": null}, {"a": 0, "b": 1.0, "in": 0} ])"); } +TEST_F(FilterTest, ConditionOnAbsentColumn) { + AssertFilter("a"_ == 0 and "b"_ > 0.0 and "b"_ < 1.0 and "absent"_ == 0, + {field("a", int64()), field("b", float64())}, R"([ + {"a": 0, "b": -0.1, "in": null}, + {"a": 0, "b": 0.3, "in": null}, + {"a": 1, "b": 0.2, "in": null}, + {"a": 2, "b": -0.1, "in": null}, + {"a": 0, "b": 0.1, "in": null}, + {"a": 0, "b": null, "in": null}, + {"a": 0, "b": 1.0, "in": null} + ])"); +} + } // namespace dataset } // namespace arrow From 02d94a6f1f3bc9f2ebf6f29855105ac3d71bcf65 Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Tue, 20 Aug 2019 21:30:57 -0400 Subject: [PATCH 11/28] use explicit enumeration of comparison results --- cpp/src/arrow/dataset/filter.cc | 87 +++++++++++++++++---------------- 1 file changed, 44 insertions(+), 43 deletions(-) diff --git a/cpp/src/arrow/dataset/filter.cc b/cpp/src/arrow/dataset/filter.cc index 818b9b17be8..4ec72f28ec0 100644 --- a/cpp/src/arrow/dataset/filter.cc +++ b/cpp/src/arrow/dataset/filter.cc @@ -127,55 +127,43 @@ std::shared_ptr ScalarExpression::Make(const char* value) { std::make_shared(Buffer::Wrap(value, std::strlen(value)))); } -struct CompareVisitor { - Status Visit(const NullType&) { - result_ = false; - return Status::OK(); - } - - Status Visit(const BooleanType&) { - result_ = checked_cast(lhs_).value - - checked_cast(rhs_).value; - return Status::OK(); - } +struct Comparison { + enum type { + LESS, + EQUAL, + GREATER, + NULL_, + }; +}; +struct CompareVisitor { template using ScalarType = typename TypeTraits::ScalarType; - template - typename std::enable_if::value, Status>::type - Visit(const T&) { - result_ = static_cast(checked_cast&>(lhs_).value - - checked_cast&>(rhs_).value); + Status Visit(const NullType&) { + result_ = Comparison::NULL_; return Status::OK(); } + Status Visit(const BooleanType&) { return CompareValues(); } + template - enable_if_floating_point Visit(const T&) { - auto delta = checked_cast&>(lhs_).value - - checked_cast&>(rhs_).value; - constexpr decltype(delta) zero = 0; - result_ = delta < zero ? -1 : delta > zero ? +1 : 0; - return Status::OK(); + enable_if_number Visit(const T&) { + return CompareValues(); } template enable_if_binary_like Visit(const T&) { auto lhs = checked_cast&>(lhs_).value; auto rhs = checked_cast&>(rhs_).value; - result_ = std::memcmp(lhs->data(), rhs->data(), std::min(lhs->size(), rhs->size())); - if (result_ == 0) { - result_ = lhs->size() - rhs->size(); + auto cmp = std::memcmp(lhs->data(), rhs->data(), std::min(lhs->size(), rhs->size())); + if (cmp == 0) { + return CompareValues(lhs->size(), rhs->size()); } - return Status::OK(); + return CompareValues(cmp, 0); } - Status Visit(const Decimal128Type&) { - auto lhs = checked_cast(lhs_).value; - auto rhs = checked_cast(rhs_).value; - result_ = (lhs - rhs).Sign(); - return Status::OK(); - } + Status Visit(const Decimal128Type&) { return CompareValues(); } // explicit because both integral and floating point conditions match half float Status Visit(const HalfFloatType&) { @@ -187,27 +175,40 @@ struct CompareVisitor { return Status::NotImplemented("comparison of scalars of type ", *lhs_.type); } - int64_t result_; + // defer comparison to ScalarType::value + template + Status CompareValues() { + auto lhs = checked_cast&>(lhs_).value; + auto rhs = checked_cast&>(rhs_).value; + return CompareValues(lhs, rhs); + } + + // defer comparison to explicit values + template + Status CompareValues(Value lhs, Value rhs) { + result_ = lhs < rhs ? Comparison::LESS + : lhs == rhs ? Comparison::EQUAL : Comparison::GREATER; + return Status::OK(); + } + + Comparison::type result_; const Scalar& lhs_; const Scalar& rhs_; }; // Compare two scalars -// lhs < rhs => return < 0 -// lhs == rhs => return == 0 -// lhs > rhs => return > 0 // if either is null, return is null -Result Compare(const Scalar& lhs, const Scalar& rhs) { +Result Compare(const Scalar& lhs, const Scalar& rhs) { if (!lhs.is_valid || !rhs.is_valid) { - return Int64Scalar(); + return Comparison::NULL_; } if (!lhs.type->Equals(*rhs.type)) { return Status::TypeError("cannot compare scalars with differing type: ", *lhs.type, " vs ", *rhs.type); } - CompareVisitor vis{0, lhs, rhs}; + CompareVisitor vis{Comparison::NULL_, lhs, rhs}; RETURN_NOT_OK(VisitTypeInline(*lhs.type, &vis)); - return Int64Scalar(vis.result_); + return vis.result_; } Result> Invert(const Expression& op) { @@ -286,7 +287,7 @@ Result> AssumeComparisonComparison( ARROW_ASSIGN_OR_RAISE(auto cmp, Compare(*e_rhs, *given_rhs)); - if (!cmp.is_valid) { + if (cmp == Comparison::NULL_) { // the RHS of e or given was null return ScalarExpression::MakeNull(); } @@ -295,7 +296,7 @@ Result> AssumeComparisonComparison( static auto never = ScalarExpression::Make(false); auto unsimplified = e.Copy(); - if (cmp.value == 0) { + if (cmp == Comparison::EQUAL) { // the rhs of the comparisons are equal switch (e.type()) { case ExpressionType::EQUAL: @@ -367,7 +368,7 @@ Result> AssumeComparisonComparison( default: return unsimplified; } - } else if (cmp.value > 0) { + } else if (cmp == Comparison::GREATER) { // the rhs of e is greater than that of given switch (e.type()) { case ExpressionType::EQUAL: From 58990674f0e2e99b43e43e3fc01aa67cb008fb14 Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Wed, 21 Aug 2019 12:10:10 -0400 Subject: [PATCH 12/28] break OperatorExpression into multiple classes --- cpp/src/arrow/dataset/filter.cc | 787 ++++++++++++++------------- cpp/src/arrow/dataset/filter.h | 254 ++++++--- cpp/src/arrow/dataset/filter_test.cc | 27 +- cpp/src/arrow/dataset/type_fwd.h | 5 +- 4 files changed, 611 insertions(+), 462 deletions(-) diff --git a/cpp/src/arrow/dataset/filter.cc b/cpp/src/arrow/dataset/filter.cc index 4ec72f28ec0..5cf10f1f562 100644 --- a/cpp/src/arrow/dataset/filter.cc +++ b/cpp/src/arrow/dataset/filter.cc @@ -35,86 +35,95 @@ using arrow::compute::Datum; using internal::checked_cast; using internal::checked_pointer_cast; -Status TrivialMask(MemoryPool* pool, const BooleanScalar& value, const RecordBatch& batch, - std::shared_ptr* mask) { - if (!value.is_valid) { +Result> ScalarExpression::Evaluate( + compute::FunctionContext* ctx, const RecordBatch& batch) const { + if (!value_->is_valid) { std::shared_ptr mask_array; - RETURN_NOT_OK(MakeArrayOfNull(boolean(), batch.num_rows(), &mask_array)); + RETURN_NOT_OK( + MakeArrayOfNull(ctx->memory_pool(), boolean(), batch.num_rows(), &mask_array)); - *mask = checked_pointer_cast(mask_array); - return Status::OK(); + return checked_pointer_cast(mask_array); + } + + if (!value_->type->Equals(boolean())) { + return Status::Invalid("can't evaluate ", ToString()); } TypedBufferBuilder builder; - RETURN_NOT_OK(builder.Append(batch.num_rows(), value.value)); + RETURN_NOT_OK(builder.Append(batch.num_rows(), + checked_cast(*value_).value)); std::shared_ptr values; RETURN_NOT_OK(builder.Finish(&values)); - *mask = std::make_shared(batch.num_rows(), values); - return Status::OK(); + return std::make_shared(batch.num_rows(), values); } -Status EvaluateExpression(compute::FunctionContext* ctx, const Expression& condition, - const RecordBatch& batch, std::shared_ptr* mask) { - BooleanScalar value; - if (condition.IsNullScalar() || condition.IsBooleanScalar(&value)) { - return TrivialMask(ctx->memory_pool(), value, batch, mask); - } +Result> NotExpression::Evaluate( + compute::FunctionContext* ctx, const RecordBatch& batch) const { + ARROW_ASSIGN_OR_RAISE(auto to_invert, operand_->Evaluate(ctx, batch)); + Datum out; + RETURN_NOT_OK(arrow::compute::Invert(ctx, Datum(to_invert), &out)); + return checked_pointer_cast(out.make_array()); +} - if (!condition.IsOperatorExpression()) { - return Status::Invalid("can't execute condition ", condition.ToString()); - } +template +Result> EvaluateNnary(const Nnary& nnary, + compute::FunctionContext* ctx, + const RecordBatch& batch) { + const auto& operands = nnary.operands(); - const auto& op = checked_cast(condition).operands(); + ARROW_ASSIGN_OR_RAISE(auto next, operands[0]->Evaluate(ctx, batch)); + Datum acc(next); + + for (size_t i_next = 1; i_next < operands.size(); ++i_next) { + ARROW_ASSIGN_OR_RAISE(next, operands[i_next]->Evaluate(ctx, batch)); - if (condition.IsComparisonExpression()) { - const auto& lhs = checked_cast(*op[0]); - const auto& rhs = checked_cast(*op[1]); - Datum out; - auto lhs_array = batch.GetColumnByName(lhs.name()); - if (lhs_array == nullptr) { - // comparing a field absent from batch: return nulls - return TrivialMask(ctx->memory_pool(), BooleanScalar(), batch, mask); + if (std::is_same::value) { + RETURN_NOT_OK(arrow::compute::And(ctx, Datum(acc), Datum(next), &acc)); } - using arrow::compute::CompareOperator; - arrow::compute::CompareOptions opts(static_cast( - static_cast(CompareOperator::EQUAL) + static_cast(condition.type()) - - static_cast(ExpressionType::EQUAL))); - RETURN_NOT_OK( - arrow::compute::Compare(ctx, Datum(lhs_array), Datum(rhs.value()), opts, &out)); - *mask = checked_pointer_cast(out.make_array()); - return Status::OK(); - } - if (condition.type() == ExpressionType::NOT) { - std::shared_ptr to_invert; - RETURN_NOT_OK(EvaluateExpression(ctx, *op[0], batch, &to_invert)); - Datum out; - RETURN_NOT_OK(arrow::compute::Invert(ctx, Datum(to_invert), &out)); - *mask = checked_pointer_cast(out.make_array()); - return Status::OK(); + if (std::is_same::value) { + RETURN_NOT_OK(arrow::compute::Or(ctx, Datum(acc), Datum(next), &acc)); + } } - DCHECK(condition.type() == ExpressionType::OR || - condition.type() == ExpressionType::AND); + return checked_pointer_cast(acc.make_array()); +} - std::shared_ptr next; - RETURN_NOT_OK(EvaluateExpression(ctx, *op[0], batch, &next)); - Datum acc(next); +Result> AndExpression::Evaluate( + compute::FunctionContext* ctx, const RecordBatch& batch) const { + return EvaluateNnary(*this, ctx, batch); +} - for (size_t i_next = 1; i_next < op.size(); ++i_next) { - RETURN_NOT_OK(EvaluateExpression(ctx, *op[i_next], batch, &next)); +Result> OrExpression::Evaluate( + compute::FunctionContext* ctx, const RecordBatch& batch) const { + return EvaluateNnary(*this, ctx, batch); +} - if (condition.type() == ExpressionType::OR) { - RETURN_NOT_OK(arrow::compute::Or(ctx, Datum(acc), Datum(next), &acc)); - } else { - RETURN_NOT_OK(arrow::compute::And(ctx, Datum(acc), Datum(next), &acc)); - } +Result> ComparisonExpression::Evaluate( + compute::FunctionContext* ctx, const RecordBatch& batch) const { + if (left_operand_->type() != ExpressionType::FIELD) { + return Status::Invalid("left hand side of comparison must be a field reference"); + } + + if (right_operand_->type() != ExpressionType::SCALAR) { + return Status::Invalid("right hand side of comparison must be a scalar"); + } + + const auto& lhs = checked_cast(*left_operand_); + const auto& rhs = checked_cast(*right_operand_); + + auto lhs_array = batch.GetColumnByName(lhs.name()); + if (lhs_array == nullptr) { + // comparing a field absent from batch: return nulls + return ScalarExpression::MakeNull()->Evaluate(ctx, batch); } - *mask = checked_pointer_cast(acc.make_array()); - return Status::OK(); + Datum out; + RETURN_NOT_OK(arrow::compute::Compare(ctx, Datum(lhs_array), Datum(rhs.value()), + arrow::compute::CompareOptions(op_), &out)); + return checked_pointer_cast(out.make_array()); } std::shared_ptr ScalarExpression::Make(std::string value) { @@ -211,6 +220,7 @@ Result Compare(const Scalar& lhs, const Scalar& rhs) { return vis.result_; } +/* Result> Invert(const Expression& op) { auto make_opposite = [&op](ExpressionType::type opposite_type) { return std::make_shared( @@ -258,9 +268,11 @@ Result> Invert(const Expression& op) { return op.Copy(); } +*/ // If e can be cast to OperatorExpression try to simplify it against given. // Otherwise pass e through unchanged. +/* Result> AssumeIfOperator(const std::shared_ptr& e, const Expression& given) { if (!e->IsOperatorExpression()) { @@ -268,24 +280,106 @@ Result> AssumeIfOperator(const std::shared_ptr(*e).Assume(given); } +*/ + +Result> ComparisonExpression::Assume( + const Expression& given) const { + switch (given.type()) { + case ExpressionType::COMPARISON: { + return AssumeGivenComparison(checked_cast(given)); + } + + case ExpressionType::NOT: { + // const auto& to_invert = checked_cast(given).operand(); + // ARROW_ASSIGN_OR_RAISE(auto inverted, Invert(*to_invert)); + // return Assume(*inverted); + return Copy(); + } + + case ExpressionType::OR: { + bool simplify_to_always = true; + bool simplify_to_never = true; + for (const auto& operand : checked_cast(given).operands()) { + ARROW_ASSIGN_OR_RAISE(auto simplified, Assume(*operand)); + + BooleanScalar scalar; + if (!simplified->IsTrivialCondition(&scalar)) { + simplify_to_never = false; + simplify_to_always = false; + } + + if (!scalar.is_valid) { + // some subexpression of given is always null, return null + return ScalarExpression::MakeNull(); + } + + if (scalar.value == true) { + simplify_to_never = false; + } else { + simplify_to_always = false; + } + } + + if (simplify_to_always) { + return ScalarExpression::Make(true); + } + + if (simplify_to_never) { + return ScalarExpression::Make(false); + } + + return Copy(); + } + + case ExpressionType::AND: { + auto simplified = Copy(); + for (const auto& operand : checked_cast(given).operands()) { + BooleanScalar value; + if (simplified->IsTrivialCondition(&value)) { + // FIXME(bkietz) but what if something later is null? + break; + } + + ARROW_ASSIGN_OR_RAISE(simplified, simplified->Assume(*operand)); + } + return simplified; + } + + default: + break; + } + + return Copy(); +} // Try to simplify one comparison against another comparison. // For example, // (x > 3) is a subset of (x > 2), so (x > 2).Assume(x > 3) == (true) // (x < 0) is disjoint with (x > 2), so (x > 2).Assume(x < 0) == (false) // If simplification to (true) or (false) is not possible, pass e through unchanged. -Result> AssumeComparisonComparison( - const OperatorExpression& e, const OperatorExpression& given) { - if (e.operands()[1]->type() != ExpressionType::SCALAR || - given.operands()[1]->type() != ExpressionType::SCALAR) { - // TODO(bkietz) allow the RHS of e to be FIELD - return Status::Invalid("right hand side of comparison must be a scalar"); +Result> ComparisonExpression::AssumeGivenComparison( + const ComparisonExpression& given) const { + for (auto comparison : {this, &given}) { + if (comparison->left_operand_->type() != ExpressionType::FIELD) { + return Status::Invalid("left hand side of comparison must be a field reference"); + } + + if (comparison->right_operand_->type() != ExpressionType::SCALAR) { + return Status::Invalid("right hand side of comparison must be a scalar"); + } } - auto e_rhs = checked_cast(*e.operands()[1]).value(); - auto given_rhs = checked_cast(*given.operands()[1]).value(); + const auto& this_lhs = checked_cast(*left_operand_); + const auto& given_lhs = + checked_cast(*given.left_operand_); + if (this_lhs.name() != given_lhs.name()) { + return Copy(); + } - ARROW_ASSIGN_OR_RAISE(auto cmp, Compare(*e_rhs, *given_rhs)); + const auto& this_rhs = checked_cast(*right_operand_).value(); + const auto& given_rhs = + checked_cast(*given.right_operand_).value(); + ARROW_ASSIGN_OR_RAISE(auto cmp, Compare(*this_rhs, *given_rhs)); if (cmp == Comparison::NULL_) { // the RHS of e or given was null @@ -294,255 +388,170 @@ Result> AssumeComparisonComparison( static auto always = ScalarExpression::Make(true); static auto never = ScalarExpression::Make(false); - auto unsimplified = e.Copy(); + + using compute::CompareOperator; if (cmp == Comparison::EQUAL) { // the rhs of the comparisons are equal - switch (e.type()) { - case ExpressionType::EQUAL: - switch (given.type()) { - case ExpressionType::NOT_EQUAL: - case ExpressionType::GREATER: - case ExpressionType::LESS: + switch (op_) { + case CompareOperator::EQUAL: + switch (given.op()) { + case CompareOperator::NOT_EQUAL: + case CompareOperator::GREATER: + case CompareOperator::LESS: return never; - case ExpressionType::EQUAL: + case CompareOperator::EQUAL: return always; default: - return unsimplified; + return Copy(); } - case ExpressionType::NOT_EQUAL: - switch (given.type()) { - case ExpressionType::EQUAL: + case CompareOperator::NOT_EQUAL: + switch (given.op()) { + case CompareOperator::EQUAL: return never; - case ExpressionType::NOT_EQUAL: - case ExpressionType::GREATER: - case ExpressionType::LESS: + case CompareOperator::NOT_EQUAL: + case CompareOperator::GREATER: + case CompareOperator::LESS: return always; default: - return unsimplified; + return Copy(); } - case ExpressionType::GREATER: - switch (given.type()) { - case ExpressionType::EQUAL: - case ExpressionType::LESS_EQUAL: - case ExpressionType::LESS: + case CompareOperator::GREATER: + switch (given.op()) { + case CompareOperator::EQUAL: + case CompareOperator::LESS_EQUAL: + case CompareOperator::LESS: return never; - case ExpressionType::GREATER: + case CompareOperator::GREATER: return always; default: - return unsimplified; + return Copy(); } - case ExpressionType::GREATER_EQUAL: - switch (given.type()) { - case ExpressionType::LESS: + case CompareOperator::GREATER_EQUAL: + switch (given.op()) { + case CompareOperator::LESS: return never; - case ExpressionType::EQUAL: - case ExpressionType::GREATER: - case ExpressionType::GREATER_EQUAL: + case CompareOperator::EQUAL: + case CompareOperator::GREATER: + case CompareOperator::GREATER_EQUAL: return always; default: - return unsimplified; + return Copy(); } - case ExpressionType::LESS: - switch (given.type()) { - case ExpressionType::EQUAL: - case ExpressionType::GREATER: - case ExpressionType::GREATER_EQUAL: + case CompareOperator::LESS: + switch (given.op()) { + case CompareOperator::EQUAL: + case CompareOperator::GREATER: + case CompareOperator::GREATER_EQUAL: return never; - case ExpressionType::LESS: + case CompareOperator::LESS: return always; default: - return unsimplified; + return Copy(); } - case ExpressionType::LESS_EQUAL: - switch (given.type()) { - case ExpressionType::GREATER: + case CompareOperator::LESS_EQUAL: + switch (given.op()) { + case CompareOperator::GREATER: return never; - case ExpressionType::EQUAL: - case ExpressionType::LESS: - case ExpressionType::LESS_EQUAL: + case CompareOperator::EQUAL: + case CompareOperator::LESS: + case CompareOperator::LESS_EQUAL: return always; default: - return unsimplified; + return Copy(); } default: - return unsimplified; + return Copy(); } } else if (cmp == Comparison::GREATER) { // the rhs of e is greater than that of given - switch (e.type()) { - case ExpressionType::EQUAL: - case ExpressionType::GREATER: - case ExpressionType::GREATER_EQUAL: - switch (given.type()) { - case ExpressionType::EQUAL: - case ExpressionType::LESS: - case ExpressionType::LESS_EQUAL: + switch (op()) { + case CompareOperator::EQUAL: + case CompareOperator::GREATER: + case CompareOperator::GREATER_EQUAL: + switch (given.op()) { + case CompareOperator::EQUAL: + case CompareOperator::LESS: + case CompareOperator::LESS_EQUAL: return never; default: - return unsimplified; + return Copy(); } - case ExpressionType::NOT_EQUAL: - case ExpressionType::LESS: - case ExpressionType::LESS_EQUAL: - switch (given.type()) { - case ExpressionType::EQUAL: - case ExpressionType::LESS: - case ExpressionType::LESS_EQUAL: + case CompareOperator::NOT_EQUAL: + case CompareOperator::LESS: + case CompareOperator::LESS_EQUAL: + switch (given.op()) { + case CompareOperator::EQUAL: + case CompareOperator::LESS: + case CompareOperator::LESS_EQUAL: return always; default: - return unsimplified; + return Copy(); } default: - return unsimplified; + return Copy(); } } else { // the rhs of e is less than that of given - switch (e.type()) { - case ExpressionType::EQUAL: - case ExpressionType::LESS: - case ExpressionType::LESS_EQUAL: - switch (given.type()) { - case ExpressionType::EQUAL: - case ExpressionType::GREATER: - case ExpressionType::GREATER_EQUAL: + switch (op()) { + case CompareOperator::EQUAL: + case CompareOperator::LESS: + case CompareOperator::LESS_EQUAL: + switch (given.op()) { + case CompareOperator::EQUAL: + case CompareOperator::GREATER: + case CompareOperator::GREATER_EQUAL: return never; default: - return unsimplified; + return Copy(); } - case ExpressionType::NOT_EQUAL: - case ExpressionType::GREATER: - case ExpressionType::GREATER_EQUAL: - switch (given.type()) { - case ExpressionType::EQUAL: - case ExpressionType::GREATER: - case ExpressionType::GREATER_EQUAL: + case CompareOperator::NOT_EQUAL: + case CompareOperator::GREATER: + case CompareOperator::GREATER_EQUAL: + switch (given.op()) { + case CompareOperator::EQUAL: + case CompareOperator::GREATER: + case CompareOperator::GREATER_EQUAL: return always; default: - return unsimplified; + return Copy(); } default: - return unsimplified; - } - } - - return unsimplified; -} - -// Try to simplify a comparison against a compound expression. -// The operands of the compound expression must be examined individually. -Result> AssumeComparisonCompound( - const OperatorExpression& e, const OperatorExpression& given) { - auto unsimplified = e.Copy(); - - switch (given.type()) { - case ExpressionType::NOT: { - ARROW_ASSIGN_OR_RAISE(auto inverted, Invert(*given.operands()[0])); - return e.Assume(*inverted); - } - - case ExpressionType::OR: { - bool simplify_to_always = true; - bool simplify_to_never = true; - for (auto operand : given.operands()) { - ARROW_ASSIGN_OR_RAISE(auto simplified, e.Assume(*operand)); - BooleanScalar scalar; - if (!simplified->IsBooleanScalar(&scalar)) { - return unsimplified; - } - - // an expression should never simplify to null - DCHECK(scalar.is_valid); - - if (scalar.value == true) { - simplify_to_never = false; - } else { - simplify_to_always = false; - } - } - - if (simplify_to_always) { - return ScalarExpression::Make(true); - } - - if (simplify_to_never) { - return ScalarExpression::Make(false); - } - - return unsimplified; - } - - case ExpressionType::AND: { - std::shared_ptr simplified = unsimplified; - for (auto operand : given.operands()) { - if (simplified->IsBooleanScalar()) { - break; - } - - DCHECK(simplified->IsOperatorExpression()); - const auto& simplified_op = checked_cast(*simplified); - ARROW_ASSIGN_OR_RAISE(simplified, simplified_op.Assume(*operand)); - } - return simplified; + return Copy(); } - - default: - DCHECK(false); } - return unsimplified; + return Copy(); } -Result> AssumeCompound(const OperatorExpression& e, - const Expression& given) { - auto unsimplified = e.Copy(); - - if (e.type() == ExpressionType::NOT) { - DCHECK_EQ(e.operands().size(), 1); - ARROW_ASSIGN_OR_RAISE(auto operand, AssumeIfOperator(e.operands()[0], given)); - - if (operand->IsNullScalar()) { - return operand; - } - - BooleanScalar scalar; - if (!operand->IsBooleanScalar(&scalar)) { - return unsimplified; - } - - return ScalarExpression::Make(!scalar.value); - } - - DCHECK(e.type() == ExpressionType::OR || e.type() == ExpressionType::AND); - +template +Result> AssumeNnary(const Nnary& nnary, + const Expression& given) { // if any of the operands matches trivial_condition, we can return a trivial // expression: // anything OR true => true // anything AND false => false - bool trivial_condition = e.type() == ExpressionType::OR; + constexpr bool trivial_condition = std::is_same::value; bool simplify_to_trivial = false; ExpressionVector operands; - for (auto operand : e.operands()) { - ARROW_ASSIGN_OR_RAISE(operand, AssumeIfOperator(operand, given)); + for (auto operand : nnary.operands()) { + ARROW_ASSIGN_OR_RAISE(operand, operand->Assume(given)); - if (operand->IsNullScalar()) { - return operand; - } + BooleanScalar scalar; + if (operand->IsTrivialCondition(&scalar)) { + if (!scalar.is_valid) { + return ScalarExpression::MakeNull(); + } - if (simplify_to_trivial) { + if (scalar.value == trivial_condition) { + simplify_to_trivial = true; + } continue; } - BooleanScalar scalar; - if (!operand->IsBooleanScalar(&scalar)) { + if (!simplify_to_trivial) { operands.push_back(operand); - continue; - } - - if (scalar.value == trivial_condition) { - simplify_to_trivial = true; - continue; } } @@ -558,65 +567,50 @@ Result> AssumeCompound(const OperatorExpression& e, return ScalarExpression::Make(!trivial_condition); } - return std::make_shared(e.type(), std::move(operands)); + return std::make_shared(std::move(operands)); } -Result> OperatorExpression::Assume( - const Expression& given) const { - auto unsimplified = Copy(); +Result> AndExpression::Assume(const Expression& given) const { + return AssumeNnary(*this, given); +} - if (IsComparisonExpression()) { - if (!given.IsOperatorExpression()) { - return unsimplified; - } +Result> OrExpression::Assume(const Expression& given) const { + return AssumeNnary(*this, given); +} - const auto& given_op = checked_cast(given); - - if (given.IsComparisonExpression()) { - // Both this and given are simple comparisons. If they constrain - // the same field, try to simplify this assuming given - DCHECK_EQ(operands_.size(), 2); - DCHECK_EQ(given_op.operands_.size(), 2); - auto get_name = [](const Expression& e) { - DCHECK_EQ(e.type(), ExpressionType::FIELD); - return checked_cast(e).name(); - }; - if (get_name(*operands_[0]) != get_name(*given_op.operands_[0])) { - return unsimplified; - } - return AssumeComparisonComparison(*this, given_op); - } +Result> NotExpression::Assume(const Expression& given) const { + ARROW_ASSIGN_OR_RAISE(auto operand, operand_->Assume(given)); - // must be NOT, AND, OR- decompose given - return AssumeComparisonCompound(*this, given_op); + BooleanScalar scalar; + if (operand->IsTrivialCondition(&scalar)) { + return Copy(); } - return AssumeCompound(*this, given); + if (!scalar.is_valid) { + return ScalarExpression::MakeNull(); + } + + return ScalarExpression::Make(!scalar.value); } std::string FieldReferenceExpression::ToString() const { return std::string("field(") + name_ + ")"; } -std::string OperatorName(ExpressionType::type type) { - switch (type) { - case ExpressionType::AND: - return "AND"; - case ExpressionType::OR: - return "OR"; - case ExpressionType::NOT: - return "NOT"; - case ExpressionType::EQUAL: +std::string OperatorName(compute::CompareOperator op) { + using compute::CompareOperator; + switch (op) { + case CompareOperator::EQUAL: return "EQUAL"; - case ExpressionType::NOT_EQUAL: + case CompareOperator::NOT_EQUAL: return "NOT_EQUAL"; - case ExpressionType::LESS: + case CompareOperator::LESS: return "LESS"; - case ExpressionType::LESS_EQUAL: + case CompareOperator::LESS_EQUAL: return "LESS_EQUAL"; - case ExpressionType::GREATER: + case CompareOperator::GREATER: return "GREATER"; - case ExpressionType::GREATER_EQUAL: + case CompareOperator::GREATER_EQUAL: return "GREATER_EQUAL"; default: DCHECK(false); @@ -624,21 +618,6 @@ std::string OperatorName(ExpressionType::type type) { return ""; } -std::string OperatorExpression::ToString() const { - auto out = OperatorName(type_) + "("; - bool comma = false; - for (const auto& operand : operands_) { - if (comma) { - out += ", "; - } else { - comma = true; - } - out += operand->ToString(); - } - out += ")"; - return out; -} - std::string ScalarExpression::ToString() const { if (!value_->is_valid) { return "scalar<" + value_->type->ToString() + ", null>()"; @@ -666,90 +645,128 @@ std::string ScalarExpression::ToString() const { return "scalar<" + value_->type->ToString() + ">(" + value + ")"; } -bool Expression::Equals(const Expression& other) const { - if (type_ != other.type()) { - return false; +static std::string EulerNotation(std::string fn, const ExpressionVector& operands) { + fn += "("; + bool comma = false; + for (const auto& operand : operands) { + if (comma) { + fn += ", "; + } else { + comma = true; + } + fn += operand->ToString(); } + fn += ")"; + return fn; +} - switch (type_) { - case ExpressionType::FIELD: - return checked_cast(*this).name() == - checked_cast(other).name(); +std::string AndExpression::ToString() const { return EulerNotation("AND", operands_); } - case ExpressionType::SCALAR: { - auto this_value = checked_cast(*this).value(); - auto other_value = checked_cast(other).value(); - return this_value->Equals(other_value); - } +std::string OrExpression::ToString() const { return EulerNotation("OR", operands_); } - default: { - DCHECK(IsOperatorExpression()); - const auto& this_op = checked_cast(*this).operands(); - const auto& other_op = checked_cast(other).operands(); - if (this_op.size() != other_op.size()) { - return false; - } +std::string ComparisonExpression::ToString() const { + return EulerNotation(OperatorName(op()), {left_operand_, right_operand_}); +} - for (size_t i = 0; i < this_op.size(); ++i) { - if (!this_op[i]->Equals(*other_op[i])) { - return false; - } - } +bool UnaryExpression::OperandsEqual(const UnaryExpression& other) const { + return operand_->Equals(other.operand_); +} - return true; +bool BinaryExpression::OperandsEqual(const BinaryExpression& other) const { + return left_operand_->Equals(other.left_operand_) && + right_operand_->Equals(other.right_operand_); +} + +bool NnaryExpression::OperandsEqual(const NnaryExpression& other) const { + if (operands_.size() != other.operands_.size()) { + return false; + } + for (size_t i = 0; i < operands_.size(); ++i) { + if (!operands_[i]->Equals(other.operands_[i])) { + return false; } } - return true; } -bool Expression::Equals(const std::shared_ptr& other) const { - if (other == NULLPTR) { - return false; +struct ExpressionEqual { + Status Visit(const FieldReferenceExpression& rhs) { + result_ = checked_cast(lhs_).name() == rhs.name(); + return Status::OK(); } - return Equals(*other); -} -bool Expression::IsOperatorExpression() const { - return static_cast(type_) >= static_cast(ExpressionType::NOT) && - static_cast(type_) <= static_cast(ExpressionType::LESS_EQUAL); -} + Status Visit(const ScalarExpression& rhs) { + result_ = checked_cast(lhs_).value()->Equals(rhs.value()); + return Status::OK(); + } -bool Expression::IsComparisonExpression() const { - return static_cast(type_) >= static_cast(ExpressionType::EQUAL) && - static_cast(type_) <= static_cast(ExpressionType::LESS_EQUAL); -} + Status Visit(const ComparisonExpression& rhs) { + const auto& lhs = checked_cast(lhs_); + result_ = lhs.op() == rhs.op() && lhs.OperandsEqual(rhs); + return Status::OK(); + } -bool Expression::IsNullScalar() const { - if (type_ != ExpressionType::SCALAR) { + Status Visit(const NnaryExpression& rhs) { + const auto& lhs = checked_cast(lhs_); + result_ = lhs.OperandsEqual(rhs); + return Status::OK(); + } + + Status Visit(const UnaryExpression& rhs) { + const auto& lhs = checked_cast(lhs_); + result_ = lhs.OperandsEqual(rhs); + return Status::OK(); + } + + Status Visit(const Expression& rhs) { return Status::NotImplemented("halp"); } + + bool Compare(const Expression& rhs) && { + DCHECK_OK(rhs.Accept(*this)); + return result_; + } + + const Expression& lhs_; + bool result_; +}; + +bool Expression::Equals(const Expression& other) const { + if (type_ != other.type()) { return false; } - return !checked_cast(*this).value()->is_valid; + return ExpressionEqual{*this, false}.Compare(other); +} + +bool Expression::Equals(const std::shared_ptr& other) const { + if (other == NULLPTR) { + return false; + } + return Equals(*other); } -bool Expression::IsBooleanScalar(BooleanScalar* out) const { +bool Expression::IsTrivialCondition(BooleanScalar* out) const { if (type_ != ExpressionType::SCALAR) { return false; } - auto scalar = checked_cast(*this).value(); + const auto& scalar = checked_cast(*this).value(); + if (!scalar->is_valid) { + if (out) { + *out = BooleanScalar(); + } + return true; + } + if (scalar->type->id() != Type::BOOL) { return false; } if (out) { - out->type = boolean(); - out->is_valid = scalar->is_valid; - out->value = checked_cast(*scalar).value; + *out = BooleanScalar(checked_cast(*scalar).value); } return true; } -std::shared_ptr OperatorExpression::Copy() const { - return std::make_shared(*this); -} - std::shared_ptr FieldReferenceExpression::Copy() const { return std::make_shared(*this); } @@ -758,55 +775,49 @@ std::shared_ptr ScalarExpression::Copy() const { return std::make_shared(*this); } -std::shared_ptr and_(ExpressionVector operands) { - return std::make_shared(ExpressionType::AND, std::move(operands)); +std::shared_ptr and_(ExpressionVector operands) { + return std::make_shared(std::move(operands)); } -std::shared_ptr or_(ExpressionVector operands) { - return std::make_shared(ExpressionType::OR, std::move(operands)); +std::shared_ptr or_(ExpressionVector operands) { + return std::make_shared(std::move(operands)); } -std::shared_ptr not_(std::shared_ptr operand) { - return std::make_shared(ExpressionType::NOT, - ExpressionVector{std::move(operand)}); +std::shared_ptr not_(std::shared_ptr operand) { + return std::make_shared(std::move(operand)); } // flatten chains of and/or to a single OperatorExpression -OperatorExpression MaybeCombine(ExpressionType::type type, const OperatorExpression& lhs, - const OperatorExpression& rhs) { - if (lhs.type() != type && rhs.type() != type) { - return OperatorExpression(type, {lhs.Copy(), rhs.Copy()}); +template +Out MaybeCombine(const Expression& lhs, const Expression& rhs) { + if (lhs.type() != Out::expression_type && rhs.type() != Out::expression_type) { + return Out(ExpressionVector{lhs.Copy(), rhs.Copy()}); } + ExpressionVector operands; - if (lhs.type() == type) { - operands = lhs.operands(); - if (rhs.type() == type) { - for (auto operand : rhs.operands()) { - operands.emplace_back(std::move(operand)); - } - } else { - operands.emplace_back(rhs.Copy()); + for (auto side : {&lhs, &rhs}) { + if (side->type() != Out::expression_type) { + operands.emplace_back(side->Copy()); + continue; + } + + for (auto operand : checked_cast(*side).operands()) { + operands.emplace_back(std::move(operand)); } - } else { - operands = rhs.operands(); - operands.emplace(operands.begin(), lhs.Copy()); } - return OperatorExpression(type, std::move(operands)); -} -OperatorExpression operator and(const OperatorExpression& lhs, - const OperatorExpression& rhs) { - return MaybeCombine(ExpressionType::AND, lhs, rhs); + return Out(std::move(operands)); } -OperatorExpression operator or(const OperatorExpression& lhs, - const OperatorExpression& rhs) { - return MaybeCombine(ExpressionType::OR, lhs, rhs); +AndExpression operator and(const Expression& lhs, const Expression& rhs) { + return MaybeCombine(lhs, rhs); } -OperatorExpression operator not(const OperatorExpression& rhs) { - return OperatorExpression(ExpressionType::NOT, {rhs.Copy()}); +OrExpression operator or(const Expression& lhs, const Expression& rhs) { + return MaybeCombine(lhs, rhs); } +NotExpression operator not(const Expression& rhs) { return NotExpression(rhs.Copy()); } + } // namespace dataset } // namespace arrow diff --git a/cpp/src/arrow/dataset/filter.h b/cpp/src/arrow/dataset/filter.h index e1e4d27e746..257952b0f9a 100644 --- a/cpp/src/arrow/dataset/filter.h +++ b/cpp/src/arrow/dataset/filter.h @@ -19,6 +19,7 @@ #include +#include "arrow/compute/kernels/compare.h" #include "arrow/dataset/type_fwd.h" #include "arrow/dataset/visibility.h" #include "arrow/result.h" @@ -63,11 +64,6 @@ class ARROW_DS_EXPORT ExpressionFilter : public Filter { std::shared_ptr expression_; }; -/// Evaluate an expression producing a boolean array which encodes whether each row -/// satisfies the condition -Status EvaluateExpression(compute::FunctionContext* ctx, const Expression& condition, - const RecordBatch& batch, std::shared_ptr* mask); - struct ExpressionType { enum type { FIELD, @@ -77,12 +73,7 @@ struct ExpressionType { AND, OR, - EQUAL, - NOT_EQUAL, - GREATER, - GREATER_EQUAL, - LESS, - LESS_EQUAL, + COMPARISON, }; }; @@ -109,57 +100,178 @@ class ARROW_DS_EXPORT Expression { // virtual Result> Validate(const Schema& schema) const = // 0; - /// returns a debug string representing this expression - virtual std::string ToString() const = 0; - - ExpressionType::type type() const { return type_; } + /// Return a simplified form of this expression given some known conditions. + /// For example, (a > 3).Assume(a == 5) == (true). This can be used to do less work + /// in ExpressionFilter when partition conditions guarantee some of this expression. + /// In the example above, *no* filtering need be done on record batches in the + /// partition since (a == 5). + virtual Result> Assume(const Expression& given) const { + return Copy(); + } - /// If true, this Expression may be safely cast to OperatorExpression - bool IsOperatorExpression() const; + virtual Result> Evaluate(compute::FunctionContext* ctx, + const RecordBatch& batch) const { + return Status::Invalid("can't evaluate ", ToString()); + } - /// If true, this Expression may be safely cast to OperatorExpression - /// and there will be exactly two operands representing the left and right hand sides - /// of a comparison - bool IsComparisonExpression() const; + /// returns a debug string representing this expression + virtual std::string ToString() const { return "FIXME"; } - /// If true, this Expression is a ScalarExpression wrapping a null Scalar. - bool IsNullScalar() const; + ExpressionType::type type() const { return type_; } - /// If true, this Expression is a ScalarExpression wrapping a boolean Scalar. Its - /// value may be retrieved at the same time - bool IsBooleanScalar(BooleanScalar* value = NULLPTR) const; + /// If true, this Expression is a ScalarExpression wrapping either a null Scalar or a + /// non-null BooleanScalar. Its value may be retrieved at the same time. + bool IsTrivialCondition(BooleanScalar* value = NULLPTR) const; /// Copy this expression into a shared pointer. virtual std::shared_ptr Copy() const = 0; + template + Status Accept(Visitor&& visitor) const; + protected: ExpressionType::type type_; }; -/// Represents a compound expression; for example comparison between a field and a -/// scalar or a union of other expressions -class ARROW_DS_EXPORT OperatorExpression final : public Expression { +/// Helper class which implements Copy and forwards construction +template +class ExpressionImpl : public Base { public: - OperatorExpression(ExpressionType::type type, ExpressionVector operands) - : Expression(type), operands_(std::move(operands)) {} + static constexpr ExpressionType::type expression_type = E; - /// Return a simplified form of this expression given some known conditions. - /// For example, (a > 3).Assume(a == 5) == (true). This can be used to do less work - /// in ExpressionFilter when partition conditions guarantee some of this expression. - /// In the example above, *no* filtering need be done on record batches in the - /// partition since (a == 5). - Result> Assume(const Expression& given) const; + template + ExpressionImpl(A&&... args) : Base(expression_type, std::forward(args)...) {} + + std::shared_ptr Copy() const override { + return std::make_shared(internal::checked_cast(*this)); + } +}; + +/// Represents an expression with exactly one operand; for example negation +class ARROW_DS_EXPORT UnaryExpression : public Expression { + public: + const std::shared_ptr& operand() const { return operand_; } + + protected: + UnaryExpression(ExpressionType::type type, std::shared_ptr operand) + : Expression(type), operand_(std::move(operand)) {} + + bool OperandsEqual(const UnaryExpression& other) const; + + friend struct ExpressionEqual; + + std::shared_ptr operand_; +}; + +/// Represents an expression with exactly two operands; for example a comparison of two +/// expressions +class ARROW_DS_EXPORT BinaryExpression : public Expression { + public: + const std::shared_ptr& left_operand() const { return left_operand_; } + const std::shared_ptr& right_operand() const { return right_operand_; } + + protected: + BinaryExpression(ExpressionType::type type, std::shared_ptr left_operand, + std::shared_ptr right_operand) + : Expression(type), + left_operand_(std::move(left_operand)), + right_operand_(std::move(right_operand)) {} + + bool OperandsEqual(const BinaryExpression& other) const; + + friend struct ExpressionEqual; + + std::shared_ptr left_operand_, right_operand_; +}; + +/// Represents an expression with multiple operands; for example a conjunction or +/// disjunction of other expressions +class ARROW_DS_EXPORT NnaryExpression : public Expression { + public: const ExpressionVector& operands() const { return operands_; } + protected: + NnaryExpression(ExpressionType::type type, ExpressionVector operands) + : Expression(type), operands_(std::move(operands)) {} + + bool OperandsEqual(const NnaryExpression& other) const; + + friend struct ExpressionEqual; + + ExpressionVector operands_; +}; + +class ARROW_DS_EXPORT ComparisonExpression final + : public ExpressionImpl { + public: + ComparisonExpression(compute::CompareOperator op, + std::shared_ptr left_operand, + std::shared_ptr right_operand) + : ExpressionImpl(std::move(left_operand), std::move(right_operand)), op_(op) {} + std::string ToString() const override; // Result> Validate(const Schema& schema) const override; - std::shared_ptr Copy() const override; + Result> Assume(const Expression& given) const override; + + compute::CompareOperator op() const { return op_; } + + Result> Evaluate(compute::FunctionContext* ctx, + const RecordBatch& batch) const override; private: - ExpressionVector operands_; + Result> AssumeGivenComparison( + const ComparisonExpression& given) const; + + compute::CompareOperator op_; +}; + +class ARROW_DS_EXPORT AndExpression final + : public ExpressionImpl { + public: + using ExpressionImpl::ExpressionImpl; + + std::string ToString() const override; + + // Result> Validate(const Schema& schema) const override; + + Result> Assume(const Expression& given) const override; + + Result> Evaluate(compute::FunctionContext* ctx, + const RecordBatch& batch) const override; +}; + +class ARROW_DS_EXPORT OrExpression final + : public ExpressionImpl { + public: + using ExpressionImpl::ExpressionImpl; + + std::string ToString() const override; + + // Result> Validate(const Schema& schema) const override; + + Result> Assume(const Expression& given) const override; + + Result> Evaluate(compute::FunctionContext* ctx, + const RecordBatch& batch) const override; +}; + +class ARROW_DS_EXPORT NotExpression final + : public ExpressionImpl { + public: + using ExpressionImpl::ExpressionImpl; + + // std::string ToString() const override; + + // Result> Validate(const Schema& schema) const override; + + Result> Assume(const Expression& given) const override; + + Result> Evaluate(compute::FunctionContext* ctx, + const RecordBatch& batch) const override; }; /// Represents a scalar value; thin wrapper around arrow::Scalar @@ -204,6 +316,9 @@ class ARROW_DS_EXPORT ScalarExpression final : public Expression { return std::make_shared(std::make_shared()); } + Result> Evaluate(compute::FunctionContext* ctx, + const RecordBatch& batch) const override; + std::shared_ptr Copy() const override; private: @@ -229,34 +344,30 @@ class ARROW_DS_EXPORT FieldReferenceExpression final : public Expression { std::string name_; }; -ARROW_DS_EXPORT std::shared_ptr and_(ExpressionVector operands); +ARROW_DS_EXPORT std::shared_ptr and_(ExpressionVector operands); -ARROW_DS_EXPORT OperatorExpression operator and(const OperatorExpression& lhs, - const OperatorExpression& rhs); +ARROW_DS_EXPORT AndExpression operator and(const Expression& lhs, const Expression& rhs); -ARROW_DS_EXPORT std::shared_ptr or_(ExpressionVector operands); +ARROW_DS_EXPORT std::shared_ptr or_(ExpressionVector operands); -ARROW_DS_EXPORT OperatorExpression operator or(const OperatorExpression& lhs, - const OperatorExpression& rhs); +ARROW_DS_EXPORT OrExpression operator or(const Expression& lhs, const Expression& rhs); -ARROW_DS_EXPORT std::shared_ptr not_( - std::shared_ptr operand); +ARROW_DS_EXPORT std::shared_ptr not_(std::shared_ptr operand); -ARROW_DS_EXPORT OperatorExpression operator not(const OperatorExpression& rhs); +ARROW_DS_EXPORT NotExpression operator not(const Expression& rhs); -#define COMPARISON_FACTORY(NAME, FACTORY_NAME, OP) \ - inline std::shared_ptr FACTORY_NAME( \ - const std::shared_ptr& lhs, \ - const std::shared_ptr& rhs) { \ - return std::make_shared(ExpressionType::NAME, \ - ExpressionVector{lhs, rhs}); \ - } \ - \ - template \ - OperatorExpression operator OP(const FieldReferenceExpression& lhs, T&& rhs) { \ - return OperatorExpression( \ - ExpressionType::NAME, \ - {lhs.Copy(), ScalarExpression::Make(std::forward(rhs))}); \ +#define COMPARISON_FACTORY(NAME, FACTORY_NAME, OP) \ + inline std::shared_ptr FACTORY_NAME( \ + const std::shared_ptr& lhs, \ + const std::shared_ptr& rhs) { \ + return std::make_shared(compute::CompareOperator::NAME, lhs, \ + rhs); \ + } \ + \ + template \ + ComparisonExpression operator OP(const FieldReferenceExpression& lhs, T&& rhs) { \ + return ComparisonExpression(compute::CompareOperator::NAME, lhs.Copy(), \ + ScalarExpression::Make(std::forward(rhs))); \ } COMPARISON_FACTORY(EQUAL, equal, ==) COMPARISON_FACTORY(NOT_EQUAL, not_equal, !=) @@ -281,5 +392,26 @@ inline FieldReferenceExpression operator""_(const char* name, size_t name_length } } // namespace string_literals +#define EXPRESSION_VISIT_CASE(NAME, CLASS) \ + case ExpressionType::NAME: \ + return visitor.Visit(internal::checked_cast(*this)) + +template +Status Expression::Accept(Visitor&& visitor) const { + switch (type_) { + EXPRESSION_VISIT_CASE(FIELD, FieldReference); + EXPRESSION_VISIT_CASE(SCALAR, Scalar); + EXPRESSION_VISIT_CASE(AND, And); + EXPRESSION_VISIT_CASE(OR, Or); + EXPRESSION_VISIT_CASE(NOT, Not); + EXPRESSION_VISIT_CASE(COMPARISON, Comparison); + default: + break; + } + return Status::TypeError("unknown ExpressionType"); +} + +#undef EXPRESSION_VISIT_CASE + } // namespace dataset } // namespace arrow diff --git a/cpp/src/arrow/dataset/filter_test.cc b/cpp/src/arrow/dataset/filter_test.cc index 0f15313f6c3..8e688de37ad 100644 --- a/cpp/src/arrow/dataset/filter_test.cc +++ b/cpp/src/arrow/dataset/filter_test.cc @@ -25,6 +25,7 @@ #include "arrow/compute/api.h" #include "arrow/dataset/api.h" +#include "arrow/dataset/test_util.h" #include "arrow/record_batch.h" #include "arrow/status.h" #include "arrow/testing/gtest_util.h" @@ -37,7 +38,7 @@ using internal::checked_pointer_cast; class ExpressionsTest : public ::testing::Test { public: - void AssertSimplifiesTo(OperatorExpression expr, const Expression& given, + void AssertSimplifiesTo(const Expression& expr, const Expression& given, const Expression& expected) { auto simplified = expr.Assume(given); ASSERT_OK(simplified.status()); @@ -49,8 +50,8 @@ class ExpressionsTest : public ::testing::Test { } } - template - void AssertOperandsAre(OperatorExpression expr, ExpressionType::type type, + template + void AssertOperandsAre(const NnaryExpression& expr, ExpressionType::type type, T... expected_operands) { ASSERT_EQ(expr.type(), type); ASSERT_EQ(expr.operands().size(), sizeof...(T)); @@ -112,9 +113,9 @@ TEST_F(ExpressionsTest, SimplificationToNull) { class FilterTest : public ::testing::Test { public: - Status DoFilter(const Expression& expr, std::vector> fields, - std::string batch_json, std::shared_ptr* mask, - std::shared_ptr* expected_mask) { + Result> DoFilter( + const Expression& expr, std::vector> fields, + std::string batch_json, std::shared_ptr* expected_mask = nullptr) { // expected filter result is in the "in" field fields.push_back(field("in", boolean())); @@ -122,17 +123,19 @@ class FilterTest : public ::testing::Test { std::shared_ptr batch; RETURN_NOT_OK(RecordBatch::FromStructArray(batch_array, &batch)); - *expected_mask = checked_pointer_cast(batch->GetColumnByName("in")); + if (expected_mask) { + *expected_mask = checked_pointer_cast(batch->GetColumnByName("in")); + } - return EvaluateExpression(&ctx_, expr, *batch, mask); + return expr.Evaluate(&ctx_, *batch); } void AssertFilter(const Expression& expr, std::vector> fields, std::string batch_json) { - std::shared_ptr mask, expected_mask; - ASSERT_OK( - DoFilter(expr, std::move(fields), std::move(batch_json), &mask, &expected_mask)); - ASSERT_ARRAYS_EQUAL(*expected_mask, *mask); + std::shared_ptr expected_mask; + auto mask = DoFilter(expr, std::move(fields), std::move(batch_json), &expected_mask); + ASSERT_OK(mask.status()); + ASSERT_ARRAYS_EQUAL(*expected_mask, *mask.ValueOrDie()); } arrow::compute::FunctionContext ctx_; diff --git a/cpp/src/arrow/dataset/type_fwd.h b/cpp/src/arrow/dataset/type_fwd.h index a7d79cd501b..4f195334e2f 100644 --- a/cpp/src/arrow/dataset/type_fwd.h +++ b/cpp/src/arrow/dataset/type_fwd.h @@ -57,7 +57,10 @@ class Filter; using FilterVector = std::vector>; class Expression; -class OperatorExpression; +class ComparisonExpression; +class AndExpression; +class OrExpression; +class NotExpression; class ScalarExpression; class FieldReferenceExpression; using ExpressionVector = std::vector>; From e9af9354b3ef9eb91a70a8ad5ea458866a7e5d3d Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Thu, 22 Aug 2019 10:07:32 -0400 Subject: [PATCH 13/28] lint fixes --- cpp/build-support/run_cpplint.py | 1 + cpp/src/arrow/dataset/filter.cc | 5 ++++- cpp/src/arrow/dataset/filter.h | 5 ++++- cpp/src/arrow/dataset/filter_test.cc | 5 +++-- 4 files changed, 12 insertions(+), 4 deletions(-) diff --git a/cpp/build-support/run_cpplint.py b/cpp/build-support/run_cpplint.py index 9ee28af727e..bf2b51b8b2f 100755 --- a/cpp/build-support/run_cpplint.py +++ b/cpp/build-support/run_cpplint.py @@ -35,6 +35,7 @@ -whitespace/comments -readability/casting -readability/todo +-readability/alt_tokens -build/header_guard -build/c++11 -runtime/references diff --git a/cpp/src/arrow/dataset/filter.cc b/cpp/src/arrow/dataset/filter.cc index 5cf10f1f562..76d2373118a 100644 --- a/cpp/src/arrow/dataset/filter.cc +++ b/cpp/src/arrow/dataset/filter.cc @@ -17,10 +17,13 @@ #include "arrow/dataset/filter.h" +#include #include +#include +#include -#include "arrow/buffer-builder.h" #include "arrow/buffer.h" +#include "arrow/buffer_builder.h" #include "arrow/compute/context.h" #include "arrow/compute/kernels/boolean.h" #include "arrow/compute/kernels/compare.h" diff --git a/cpp/src/arrow/dataset/filter.h b/cpp/src/arrow/dataset/filter.h index 257952b0f9a..d7d1a250804 100644 --- a/cpp/src/arrow/dataset/filter.h +++ b/cpp/src/arrow/dataset/filter.h @@ -18,6 +18,8 @@ #pragma once #include +#include +#include #include "arrow/compute/kernels/compare.h" #include "arrow/dataset/type_fwd.h" @@ -140,7 +142,8 @@ class ExpressionImpl : public Base { static constexpr ExpressionType::type expression_type = E; template - ExpressionImpl(A&&... args) : Base(expression_type, std::forward(args)...) {} + explicit ExpressionImpl(A&&... args) + : Base(expression_type, std::forward(args)...) {} std::shared_ptr Copy() const override { return std::make_shared(internal::checked_cast(*this)); diff --git a/cpp/src/arrow/dataset/filter_test.cc b/cpp/src/arrow/dataset/filter_test.cc index 8e688de37ad..5dadd92a7df 100644 --- a/cpp/src/arrow/dataset/filter_test.cc +++ b/cpp/src/arrow/dataset/filter_test.cc @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +#include "arrow/dataset/filter.h" + #include #include #include @@ -24,7 +26,6 @@ #include #include "arrow/compute/api.h" -#include "arrow/dataset/api.h" #include "arrow/dataset/test_util.h" #include "arrow/record_batch.h" #include "arrow/status.h" @@ -33,7 +34,7 @@ namespace arrow { namespace dataset { -using namespace string_literals; +using string_literals::operator""_; using internal::checked_pointer_cast; class ExpressionsTest : public ::testing::Test { From f9e0c0887c3b7207f508af3e1694ff69f8abe7cd Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Thu, 22 Aug 2019 10:11:00 -0400 Subject: [PATCH 14/28] remove unused Empty() method --- cpp/src/arrow/scalar.h | 2 -- 1 file changed, 2 deletions(-) diff --git a/cpp/src/arrow/scalar.h b/cpp/src/arrow/scalar.h index 977acefcf32..39a38f4f883 100644 --- a/cpp/src/arrow/scalar.h +++ b/cpp/src/arrow/scalar.h @@ -102,8 +102,6 @@ struct BaseBinaryScalar : public Scalar { BaseBinaryScalar(const std::shared_ptr& value, const std::shared_ptr& type, bool is_valid = true) : Scalar{type, is_valid}, value(value) {} - - static std::shared_ptr Empty(); }; struct ARROW_EXPORT BinaryScalar : public BaseBinaryScalar { From 5bae59004f7026245d8fe6b67ac9184a3707f6db Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Thu, 22 Aug 2019 10:30:55 -0400 Subject: [PATCH 15/28] re-enable Invert --- cpp/src/arrow/dataset/filter.cc | 100 ++++++++++++++------------- cpp/src/arrow/dataset/filter_test.cc | 3 + 2 files changed, 54 insertions(+), 49 deletions(-) diff --git a/cpp/src/arrow/dataset/filter.cc b/cpp/src/arrow/dataset/filter.cc index 76d2373118a..be7e5cd8c56 100644 --- a/cpp/src/arrow/dataset/filter.cc +++ b/cpp/src/arrow/dataset/filter.cc @@ -223,67 +223,67 @@ Result Compare(const Scalar& lhs, const Scalar& rhs) { return vis.result_; } -/* -Result> Invert(const Expression& op) { - auto make_opposite = [&op](ExpressionType::type opposite_type) { - return std::make_shared( - opposite_type, checked_cast(op).operands()); +std::shared_ptr Invert(const ComparisonExpression& comparison) { + using compute::CompareOperator; + auto make_opposite = [&](CompareOperator opposite) { + return std::make_shared(opposite, comparison.left_operand(), + comparison.right_operand()); }; + switch (comparison.op()) { + case CompareOperator::EQUAL: + return make_opposite(CompareOperator::NOT_EQUAL); + + case CompareOperator::NOT_EQUAL: + return make_opposite(CompareOperator::EQUAL); + + case CompareOperator::GREATER: + return make_opposite(CompareOperator::LESS_EQUAL); + + case CompareOperator::GREATER_EQUAL: + return make_opposite(CompareOperator::LESS); + + case CompareOperator::LESS: + return make_opposite(CompareOperator::GREATER_EQUAL); + + case CompareOperator::LESS_EQUAL: + return make_opposite(CompareOperator::GREATER); + + default: + break; + } + + DCHECK(false); + return nullptr; +} + +Result> Invert(const Expression& op) { switch (op.type()) { case ExpressionType::NOT: - return checked_cast(op).operands()[0]; + return checked_cast(op).operand(); case ExpressionType::AND: case ExpressionType::OR: { - ExpressionVector operands; - for (auto operand : checked_cast(op).operands()) { + ExpressionVector inverted_operands; + for (auto operand : checked_cast(op).operands()) { ARROW_ASSIGN_OR_RAISE(auto inverted_operand, Invert(*operand)); - operands.push_back(inverted_operand); + inverted_operands.push_back(inverted_operand); } - auto opposite_type = - op.type() == ExpressionType::AND ? ExpressionType::OR : ExpressionType::AND; - return std::make_shared(opposite_type, std::move(operands)); + if (op.type() == ExpressionType::AND) { + return std::make_shared(std::move(inverted_operands)); + } + return std::make_shared(std::move(inverted_operands)); } - case ExpressionType::EQUAL: - return make_opposite(ExpressionType::NOT_EQUAL); - - case ExpressionType::NOT_EQUAL: - return make_opposite(ExpressionType::EQUAL); - - case ExpressionType::LESS: - return make_opposite(ExpressionType::GREATER_EQUAL); - - case ExpressionType::LESS_EQUAL: - return make_opposite(ExpressionType::GREATER); - - case ExpressionType::GREATER: - return make_opposite(ExpressionType::LESS_EQUAL); - - case ExpressionType::GREATER_EQUAL: - return make_opposite(ExpressionType::LESS); + case ExpressionType::COMPARISON: + return Invert(checked_cast(op)); default: - return Status::NotImplemented("can't invert this expression"); - } - - return op.Copy(); -} -*/ - -// If e can be cast to OperatorExpression try to simplify it against given. -// Otherwise pass e through unchanged. -/* -Result> AssumeIfOperator(const std::shared_ptr& e, - const Expression& given) { - if (!e->IsOperatorExpression()) { - return e; + break; } - return checked_cast(*e).Assume(given); + return Status::NotImplemented("can't invert this expression"); } -*/ Result> ComparisonExpression::Assume( const Expression& given) const { @@ -293,10 +293,12 @@ Result> ComparisonExpression::Assume( } case ExpressionType::NOT: { - // const auto& to_invert = checked_cast(given).operand(); - // ARROW_ASSIGN_OR_RAISE(auto inverted, Invert(*to_invert)); - // return Assume(*inverted); - return Copy(); + const auto& to_invert = checked_cast(given).operand(); + auto inverted = Invert(*to_invert); + if (!inverted.ok()) { + return Copy(); + } + return Assume(*inverted.ValueOrDie()); } case ExpressionType::OR: { diff --git a/cpp/src/arrow/dataset/filter_test.cc b/cpp/src/arrow/dataset/filter_test.cc index 5dadd92a7df..b64dc9cf774 100644 --- a/cpp/src/arrow/dataset/filter_test.cc +++ b/cpp/src/arrow/dataset/filter_test.cc @@ -92,6 +92,9 @@ TEST_F(ExpressionsTest, SimplificationOfCompoundQuery) { AssertSimplifiesTo("b"_ == 3 or "b"_ == 4, "b"_ > 3, "b"_ == 4); AssertSimplifiesTo("b"_ == 3 or "b"_ == 4, "b"_ >= 3, "b"_ == 3 or "b"_ == 4); + AssertSimplifiesTo("b"_ > 0.5 and "b"_ < 1.5, not("b"_ < 0.0 or "b"_ > 1.0), + "b"_ > 0.5); + AssertSimplifiesTo("b"_ == 4, "a"_ == 0, "b"_ == 4); AssertSimplifiesTo("a"_ == 3 or "b"_ == 4, "a"_ == 0, "b"_ == 4); From cb569064f9e539fdbc6f3759e5da6803fe63ec81 Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Thu, 22 Aug 2019 10:56:11 -0400 Subject: [PATCH 16/28] rename FieldRef -> Field, and_ -> all, or_ -> any, add comments to ExpressionType --- cpp/src/arrow/dataset/filter.cc | 73 ++++++++++++++-------------- cpp/src/arrow/dataset/filter.h | 68 ++++++++++++++++---------- cpp/src/arrow/dataset/filter_test.cc | 2 +- 3 files changed, 80 insertions(+), 63 deletions(-) diff --git a/cpp/src/arrow/dataset/filter.cc b/cpp/src/arrow/dataset/filter.cc index be7e5cd8c56..cdbe4e25d02 100644 --- a/cpp/src/arrow/dataset/filter.cc +++ b/cpp/src/arrow/dataset/filter.cc @@ -82,11 +82,11 @@ Result> EvaluateNnary(const Nnary& nnary, for (size_t i_next = 1; i_next < operands.size(); ++i_next) { ARROW_ASSIGN_OR_RAISE(next, operands[i_next]->Evaluate(ctx, batch)); - if (std::is_same::value) { + if (std::is_same::value) { RETURN_NOT_OK(arrow::compute::And(ctx, Datum(acc), Datum(next), &acc)); } - if (std::is_same::value) { + if (std::is_same::value) { RETURN_NOT_OK(arrow::compute::Or(ctx, Datum(acc), Datum(next), &acc)); } } @@ -94,12 +94,12 @@ Result> EvaluateNnary(const Nnary& nnary, return checked_pointer_cast(acc.make_array()); } -Result> AndExpression::Evaluate( +Result> AllExpression::Evaluate( compute::FunctionContext* ctx, const RecordBatch& batch) const { return EvaluateNnary(*this, ctx, batch); } -Result> OrExpression::Evaluate( +Result> AnyExpression::Evaluate( compute::FunctionContext* ctx, const RecordBatch& batch) const { return EvaluateNnary(*this, ctx, batch); } @@ -114,7 +114,7 @@ Result> ComparisonExpression::Evaluate( return Status::Invalid("right hand side of comparison must be a scalar"); } - const auto& lhs = checked_cast(*left_operand_); + const auto& lhs = checked_cast(*left_operand_); const auto& rhs = checked_cast(*right_operand_); auto lhs_array = batch.GetColumnByName(lhs.name()); @@ -262,18 +262,18 @@ Result> Invert(const Expression& op) { case ExpressionType::NOT: return checked_cast(op).operand(); - case ExpressionType::AND: - case ExpressionType::OR: { + case ExpressionType::ALL: + case ExpressionType::ANY: { ExpressionVector inverted_operands; for (auto operand : checked_cast(op).operands()) { ARROW_ASSIGN_OR_RAISE(auto inverted_operand, Invert(*operand)); inverted_operands.push_back(inverted_operand); } - if (op.type() == ExpressionType::AND) { - return std::make_shared(std::move(inverted_operands)); + if (op.type() == ExpressionType::ALL) { + return std::make_shared(std::move(inverted_operands)); } - return std::make_shared(std::move(inverted_operands)); + return std::make_shared(std::move(inverted_operands)); } case ExpressionType::COMPARISON: @@ -301,10 +301,10 @@ Result> ComparisonExpression::Assume( return Assume(*inverted.ValueOrDie()); } - case ExpressionType::OR: { + case ExpressionType::ANY: { bool simplify_to_always = true; bool simplify_to_never = true; - for (const auto& operand : checked_cast(given).operands()) { + for (const auto& operand : checked_cast(given).operands()) { ARROW_ASSIGN_OR_RAISE(auto simplified, Assume(*operand)); BooleanScalar scalar; @@ -336,9 +336,9 @@ Result> ComparisonExpression::Assume( return Copy(); } - case ExpressionType::AND: { + case ExpressionType::ALL: { auto simplified = Copy(); - for (const auto& operand : checked_cast(given).operands()) { + for (const auto& operand : checked_cast(given).operands()) { BooleanScalar value; if (simplified->IsTrivialCondition(&value)) { // FIXME(bkietz) but what if something later is null? @@ -374,9 +374,8 @@ Result> ComparisonExpression::AssumeGivenComparison( } } - const auto& this_lhs = checked_cast(*left_operand_); - const auto& given_lhs = - checked_cast(*given.left_operand_); + const auto& this_lhs = checked_cast(*left_operand_); + const auto& given_lhs = checked_cast(*given.left_operand_); if (this_lhs.name() != given_lhs.name()) { return Copy(); } @@ -534,9 +533,9 @@ Result> AssumeNnary(const Nnary& nnary, const Expression& given) { // if any of the operands matches trivial_condition, we can return a trivial // expression: - // anything OR true => true - // anything AND false => false - constexpr bool trivial_condition = std::is_same::value; + // anything ANY true => true + // anything ALL false => false + constexpr bool trivial_condition = std::is_same::value; bool simplify_to_trivial = false; ExpressionVector operands; @@ -575,11 +574,11 @@ Result> AssumeNnary(const Nnary& nnary, return std::make_shared(std::move(operands)); } -Result> AndExpression::Assume(const Expression& given) const { +Result> AllExpression::Assume(const Expression& given) const { return AssumeNnary(*this, given); } -Result> OrExpression::Assume(const Expression& given) const { +Result> AnyExpression::Assume(const Expression& given) const { return AssumeNnary(*this, given); } @@ -598,7 +597,7 @@ Result> NotExpression::Assume(const Expression& give return ScalarExpression::Make(!scalar.value); } -std::string FieldReferenceExpression::ToString() const { +std::string FieldExpression::ToString() const { return std::string("field(") + name_ + ")"; } @@ -665,9 +664,9 @@ static std::string EulerNotation(std::string fn, const ExpressionVector& operand return fn; } -std::string AndExpression::ToString() const { return EulerNotation("AND", operands_); } +std::string AllExpression::ToString() const { return EulerNotation("ALL", operands_); } -std::string OrExpression::ToString() const { return EulerNotation("OR", operands_); } +std::string AnyExpression::ToString() const { return EulerNotation("ANY", operands_); } std::string ComparisonExpression::ToString() const { return EulerNotation(OperatorName(op()), {left_operand_, right_operand_}); @@ -695,8 +694,8 @@ bool NnaryExpression::OperandsEqual(const NnaryExpression& other) const { } struct ExpressionEqual { - Status Visit(const FieldReferenceExpression& rhs) { - result_ = checked_cast(lhs_).name() == rhs.name(); + Status Visit(const FieldExpression& rhs) { + result_ = checked_cast(lhs_).name() == rhs.name(); return Status::OK(); } @@ -772,20 +771,20 @@ bool Expression::IsTrivialCondition(BooleanScalar* out) const { return true; } -std::shared_ptr FieldReferenceExpression::Copy() const { - return std::make_shared(*this); +std::shared_ptr FieldExpression::Copy() const { + return std::make_shared(*this); } std::shared_ptr ScalarExpression::Copy() const { return std::make_shared(*this); } -std::shared_ptr and_(ExpressionVector operands) { - return std::make_shared(std::move(operands)); +std::shared_ptr all(ExpressionVector operands) { + return std::make_shared(std::move(operands)); } -std::shared_ptr or_(ExpressionVector operands) { - return std::make_shared(std::move(operands)); +std::shared_ptr any(ExpressionVector operands) { + return std::make_shared(std::move(operands)); } std::shared_ptr not_(std::shared_ptr operand) { @@ -814,12 +813,12 @@ Out MaybeCombine(const Expression& lhs, const Expression& rhs) { return Out(std::move(operands)); } -AndExpression operator and(const Expression& lhs, const Expression& rhs) { - return MaybeCombine(lhs, rhs); +AllExpression operator and(const Expression& lhs, const Expression& rhs) { + return MaybeCombine(lhs, rhs); } -OrExpression operator or(const Expression& lhs, const Expression& rhs) { - return MaybeCombine(lhs, rhs); +AnyExpression operator or(const Expression& lhs, const Expression& rhs) { + return MaybeCombine(lhs, rhs); } NotExpression operator not(const Expression& rhs) { return NotExpression(rhs.Copy()); } diff --git a/cpp/src/arrow/dataset/filter.h b/cpp/src/arrow/dataset/filter.h index d7d1a250804..2d439d6e5d4 100644 --- a/cpp/src/arrow/dataset/filter.h +++ b/cpp/src/arrow/dataset/filter.h @@ -33,7 +33,7 @@ namespace dataset { struct FilterType { enum type { /// Simple boolean predicate consisting of comparisons and boolean - /// logic (AND, OR, NOT) involving Schema fields + /// logic (ALL, ANY, NOT) involving Schema fields EXPRESSION, /// Non decomposable filter; must be evaluated against every record batch @@ -54,7 +54,7 @@ class ARROW_DS_EXPORT Filter { }; /// Filter subclass encapsulating a simple boolean predicate consisting of comparisons -/// and boolean logic (AND, OR, NOT) involving Schema fields +/// and boolean logic (ALL, ANY, NOT) involving Schema fields class ARROW_DS_EXPORT ExpressionFilter : public Filter { public: explicit ExpressionFilter(const std::shared_ptr& expression) @@ -68,19 +68,36 @@ class ARROW_DS_EXPORT ExpressionFilter : public Filter { struct ExpressionType { enum type { + /// a reference to a column within a record batch, will evaluate to an array FIELD, + + /// a literal singular value encapuslated in a Scalar SCALAR, + /// a literal Array + // TODO(bkietz) ARRAY, + + /// an inversion of another expression NOT, - AND, - OR, + /// cast an expression to a given DataType + // TODO(bkietz) CAST, + + /// a conjunction of multiple expressions (true if all operands are true) + ALL, + + /// a disjunction of multiple expressions (true if any operand is true) + ANY, + + /// a comparison of two other expressions COMPARISON, + + /// extract validity as an expression (true if operand is valid) + // TODO(bkietz) VALIDITY, }; }; -/// Represents an expression tree. The expression can be evaluated against a -/// RecordBatch via ExpressionFilter +/// Represents an expression tree class ARROW_DS_EXPORT Expression { public: explicit Expression(ExpressionType::type type) : type_(type) {} @@ -111,6 +128,7 @@ class ARROW_DS_EXPORT Expression { return Copy(); } + // Evaluate an expression against a RecordBatch virtual Result> Evaluate(compute::FunctionContext* ctx, const RecordBatch& batch) const { return Status::Invalid("can't evaluate ", ToString()); @@ -232,8 +250,8 @@ class ARROW_DS_EXPORT ComparisonExpression final compute::CompareOperator op_; }; -class ARROW_DS_EXPORT AndExpression final - : public ExpressionImpl { +class ARROW_DS_EXPORT AllExpression final + : public ExpressionImpl { public: using ExpressionImpl::ExpressionImpl; @@ -247,8 +265,8 @@ class ARROW_DS_EXPORT AndExpression final const RecordBatch& batch) const override; }; -class ARROW_DS_EXPORT OrExpression final - : public ExpressionImpl { +class ARROW_DS_EXPORT AnyExpression final + : public ExpressionImpl { public: using ExpressionImpl::ExpressionImpl; @@ -330,9 +348,9 @@ class ARROW_DS_EXPORT ScalarExpression final : public Expression { /// Represents a reference to a field. Stores only the field's name (type and other /// information is known only when a Schema is provided) -class ARROW_DS_EXPORT FieldReferenceExpression final : public Expression { +class ARROW_DS_EXPORT FieldExpression final : public Expression { public: - explicit FieldReferenceExpression(std::string name) + explicit FieldExpression(std::string name) : Expression(ExpressionType::FIELD), name_(std::move(name)) {} std::string name() const { return name_; } @@ -347,13 +365,13 @@ class ARROW_DS_EXPORT FieldReferenceExpression final : public Expression { std::string name_; }; -ARROW_DS_EXPORT std::shared_ptr and_(ExpressionVector operands); +ARROW_DS_EXPORT std::shared_ptr all(ExpressionVector operands); -ARROW_DS_EXPORT AndExpression operator and(const Expression& lhs, const Expression& rhs); +ARROW_DS_EXPORT AllExpression operator and(const Expression& lhs, const Expression& rhs); -ARROW_DS_EXPORT std::shared_ptr or_(ExpressionVector operands); +ARROW_DS_EXPORT std::shared_ptr any(ExpressionVector operands); -ARROW_DS_EXPORT OrExpression operator or(const Expression& lhs, const Expression& rhs); +ARROW_DS_EXPORT AnyExpression operator or(const Expression& lhs, const Expression& rhs); ARROW_DS_EXPORT std::shared_ptr not_(std::shared_ptr operand); @@ -361,14 +379,14 @@ ARROW_DS_EXPORT NotExpression operator not(const Expression& rhs); #define COMPARISON_FACTORY(NAME, FACTORY_NAME, OP) \ inline std::shared_ptr FACTORY_NAME( \ - const std::shared_ptr& lhs, \ + const std::shared_ptr& lhs, \ const std::shared_ptr& rhs) { \ return std::make_shared(compute::CompareOperator::NAME, lhs, \ rhs); \ } \ \ template \ - ComparisonExpression operator OP(const FieldReferenceExpression& lhs, T&& rhs) { \ + ComparisonExpression operator OP(const FieldExpression& lhs, T&& rhs) { \ return ComparisonExpression(compute::CompareOperator::NAME, lhs.Copy(), \ ScalarExpression::Make(std::forward(rhs))); \ } @@ -385,13 +403,13 @@ auto scalar(T&& value) -> decltype(ScalarExpression::Make(std::forward(value) return ScalarExpression::Make(std::forward(value)); } -inline std::shared_ptr fieldRef(std::string name) { - return std::make_shared(std::move(name)); +inline std::shared_ptr fieldRef(std::string name) { + return std::make_shared(std::move(name)); } inline namespace string_literals { -inline FieldReferenceExpression operator""_(const char* name, size_t name_length) { - return FieldReferenceExpression({name, name_length}); +inline FieldExpression operator""_(const char* name, size_t name_length) { + return FieldExpression({name, name_length}); } } // namespace string_literals @@ -402,10 +420,10 @@ inline FieldReferenceExpression operator""_(const char* name, size_t name_length template Status Expression::Accept(Visitor&& visitor) const { switch (type_) { - EXPRESSION_VISIT_CASE(FIELD, FieldReference); + EXPRESSION_VISIT_CASE(FIELD, Field); EXPRESSION_VISIT_CASE(SCALAR, Scalar); - EXPRESSION_VISIT_CASE(AND, And); - EXPRESSION_VISIT_CASE(OR, Or); + EXPRESSION_VISIT_CASE(ALL, All); + EXPRESSION_VISIT_CASE(ANY, Any); EXPRESSION_VISIT_CASE(NOT, Not); EXPRESSION_VISIT_CASE(COMPARISON, Comparison); default: diff --git a/cpp/src/arrow/dataset/filter_test.cc b/cpp/src/arrow/dataset/filter_test.cc index b64dc9cf774..8035494516f 100644 --- a/cpp/src/arrow/dataset/filter_test.cc +++ b/cpp/src/arrow/dataset/filter_test.cc @@ -82,7 +82,7 @@ TEST_F(ExpressionsTest, Equality) { TEST_F(ExpressionsTest, SimplificationOfCompoundQuery) { // chained "and" expressions are flattened auto multi_and = "b"_ > 5 and "b"_ < 10 and "b"_ != 7; - AssertOperandsAre(multi_and, ExpressionType::AND, "b"_ > 5, "b"_ < 10, "b"_ != 7); + AssertOperandsAre(multi_and, ExpressionType::ALL, "b"_ > 5, "b"_ < 10, "b"_ != 7); AssertSimplifiesTo("b"_ > 5 and "b"_ < 10, "b"_ == 3, *never); AssertSimplifiesTo("b"_ > 5 and "b"_ < 10, "b"_ == 6, *always); From d7f3ed2399995f63c47a43a6086e136ca3f63f88 Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Thu, 22 Aug 2019 11:13:52 -0400 Subject: [PATCH 17/28] simplify Expression::Equals --- cpp/src/arrow/dataset/filter.cc | 231 ++++++++++++++------------------ cpp/src/arrow/dataset/filter.h | 26 ++-- 2 files changed, 117 insertions(+), 140 deletions(-) diff --git a/cpp/src/arrow/dataset/filter.cc b/cpp/src/arrow/dataset/filter.cc index cdbe4e25d02..ec5ffd0c4cc 100644 --- a/cpp/src/arrow/dataset/filter.cc +++ b/cpp/src/arrow/dataset/filter.cc @@ -395,79 +395,7 @@ Result> ComparisonExpression::AssumeGivenComparison( using compute::CompareOperator; - if (cmp == Comparison::EQUAL) { - // the rhs of the comparisons are equal - switch (op_) { - case CompareOperator::EQUAL: - switch (given.op()) { - case CompareOperator::NOT_EQUAL: - case CompareOperator::GREATER: - case CompareOperator::LESS: - return never; - case CompareOperator::EQUAL: - return always; - default: - return Copy(); - } - case CompareOperator::NOT_EQUAL: - switch (given.op()) { - case CompareOperator::EQUAL: - return never; - case CompareOperator::NOT_EQUAL: - case CompareOperator::GREATER: - case CompareOperator::LESS: - return always; - default: - return Copy(); - } - case CompareOperator::GREATER: - switch (given.op()) { - case CompareOperator::EQUAL: - case CompareOperator::LESS_EQUAL: - case CompareOperator::LESS: - return never; - case CompareOperator::GREATER: - return always; - default: - return Copy(); - } - case CompareOperator::GREATER_EQUAL: - switch (given.op()) { - case CompareOperator::LESS: - return never; - case CompareOperator::EQUAL: - case CompareOperator::GREATER: - case CompareOperator::GREATER_EQUAL: - return always; - default: - return Copy(); - } - case CompareOperator::LESS: - switch (given.op()) { - case CompareOperator::EQUAL: - case CompareOperator::GREATER: - case CompareOperator::GREATER_EQUAL: - return never; - case CompareOperator::LESS: - return always; - default: - return Copy(); - } - case CompareOperator::LESS_EQUAL: - switch (given.op()) { - case CompareOperator::GREATER: - return never; - case CompareOperator::EQUAL: - case CompareOperator::LESS: - case CompareOperator::LESS_EQUAL: - return always; - default: - return Copy(); - } - default: - return Copy(); - } - } else if (cmp == Comparison::GREATER) { + if (cmp == Comparison::GREATER) { // the rhs of e is greater than that of given switch (op()) { case CompareOperator::EQUAL: @@ -495,7 +423,9 @@ Result> ComparisonExpression::AssumeGivenComparison( default: return Copy(); } - } else { + } + + if (cmp == Comparison::LESS) { // the rhs of e is less than that of given switch (op()) { case CompareOperator::EQUAL: @@ -525,7 +455,79 @@ Result> ComparisonExpression::AssumeGivenComparison( } } - return Copy(); + DCHECK_EQ(cmp, Comparison::EQUAL); + + // the rhs of the comparisons are equal + switch (op_) { + case CompareOperator::EQUAL: + switch (given.op()) { + case CompareOperator::NOT_EQUAL: + case CompareOperator::GREATER: + case CompareOperator::LESS: + return never; + case CompareOperator::EQUAL: + return always; + default: + return Copy(); + } + case CompareOperator::NOT_EQUAL: + switch (given.op()) { + case CompareOperator::EQUAL: + return never; + case CompareOperator::NOT_EQUAL: + case CompareOperator::GREATER: + case CompareOperator::LESS: + return always; + default: + return Copy(); + } + case CompareOperator::GREATER: + switch (given.op()) { + case CompareOperator::EQUAL: + case CompareOperator::LESS_EQUAL: + case CompareOperator::LESS: + return never; + case CompareOperator::GREATER: + return always; + default: + return Copy(); + } + case CompareOperator::GREATER_EQUAL: + switch (given.op()) { + case CompareOperator::LESS: + return never; + case CompareOperator::EQUAL: + case CompareOperator::GREATER: + case CompareOperator::GREATER_EQUAL: + return always; + default: + return Copy(); + } + case CompareOperator::LESS: + switch (given.op()) { + case CompareOperator::EQUAL: + case CompareOperator::GREATER: + case CompareOperator::GREATER_EQUAL: + return never; + case CompareOperator::LESS: + return always; + default: + return Copy(); + } + case CompareOperator::LESS_EQUAL: + switch (given.op()) { + case CompareOperator::GREATER: + return never; + case CompareOperator::EQUAL: + case CompareOperator::LESS: + case CompareOperator::LESS_EQUAL: + return always; + default: + return Copy(); + } + default: + return Copy(); + } } template @@ -672,73 +674,48 @@ std::string ComparisonExpression::ToString() const { return EulerNotation(OperatorName(op()), {left_operand_, right_operand_}); } -bool UnaryExpression::OperandsEqual(const UnaryExpression& other) const { - return operand_->Equals(other.operand_); +bool UnaryExpression::Equals(const Expression& other) const { + return type_ == other.type() && + operand_->Equals(checked_cast(other).operand_); } -bool BinaryExpression::OperandsEqual(const BinaryExpression& other) const { - return left_operand_->Equals(other.left_operand_) && - right_operand_->Equals(other.right_operand_); +bool BinaryExpression::Equals(const Expression& other) const { + return type_ == other.type() && + left_operand_->Equals( + checked_cast(other).left_operand_) && + right_operand_->Equals( + checked_cast(other).right_operand_); } -bool NnaryExpression::OperandsEqual(const NnaryExpression& other) const { - if (operands_.size() != other.operands_.size()) { +bool NnaryExpression::Equals(const Expression& other) const { + if (type_ != other.type()) { + return false; + } + const auto& other_operands = checked_cast(other).operands_; + if (operands_.size() != other_operands.size()) { return false; } for (size_t i = 0; i < operands_.size(); ++i) { - if (!operands_[i]->Equals(other.operands_[i])) { + if (!operands_[i]->Equals(other_operands[i])) { return false; } } return true; } -struct ExpressionEqual { - Status Visit(const FieldExpression& rhs) { - result_ = checked_cast(lhs_).name() == rhs.name(); - return Status::OK(); - } - - Status Visit(const ScalarExpression& rhs) { - result_ = checked_cast(lhs_).value()->Equals(rhs.value()); - return Status::OK(); - } - - Status Visit(const ComparisonExpression& rhs) { - const auto& lhs = checked_cast(lhs_); - result_ = lhs.op() == rhs.op() && lhs.OperandsEqual(rhs); - return Status::OK(); - } - - Status Visit(const NnaryExpression& rhs) { - const auto& lhs = checked_cast(lhs_); - result_ = lhs.OperandsEqual(rhs); - return Status::OK(); - } - - Status Visit(const UnaryExpression& rhs) { - const auto& lhs = checked_cast(lhs_); - result_ = lhs.OperandsEqual(rhs); - return Status::OK(); - } - - Status Visit(const Expression& rhs) { return Status::NotImplemented("halp"); } - - bool Compare(const Expression& rhs) && { - DCHECK_OK(rhs.Accept(*this)); - return result_; - } - - const Expression& lhs_; - bool result_; -}; +bool ComparisonExpression::Equals(const Expression& other) const { + return BinaryExpression::Equals(other) && + op_ == checked_cast(other).op_; +} -bool Expression::Equals(const Expression& other) const { - if (type_ != other.type()) { - return false; - } +bool ScalarExpression::Equals(const Expression& other) const { + return other.type() == ExpressionType::SCALAR && + value_->Equals(checked_cast(other).value_); +} - return ExpressionEqual{*this, false}.Compare(other); +bool FieldExpression::Equals(const Expression& other) const { + return other.type() == ExpressionType::FIELD && + name_ == checked_cast(other).name_; } bool Expression::Equals(const std::shared_ptr& other) const { diff --git a/cpp/src/arrow/dataset/filter.h b/cpp/src/arrow/dataset/filter.h index 2d439d6e5d4..e9e66363306 100644 --- a/cpp/src/arrow/dataset/filter.h +++ b/cpp/src/arrow/dataset/filter.h @@ -107,7 +107,7 @@ class ARROW_DS_EXPORT Expression { /// Returns true iff the expressions are identical; does not check for equivalence. /// For example, (A and B) is not equal to (B and A) nor is (A and not A) equal to /// (false). - bool Equals(const Expression& other) const; + virtual bool Equals(const Expression& other) const = 0; bool Equals(const std::shared_ptr& other) const; @@ -173,14 +173,12 @@ class ARROW_DS_EXPORT UnaryExpression : public Expression { public: const std::shared_ptr& operand() const { return operand_; } + bool Equals(const Expression& other) const override; + protected: UnaryExpression(ExpressionType::type type, std::shared_ptr operand) : Expression(type), operand_(std::move(operand)) {} - bool OperandsEqual(const UnaryExpression& other) const; - - friend struct ExpressionEqual; - std::shared_ptr operand_; }; @@ -192,6 +190,8 @@ class ARROW_DS_EXPORT BinaryExpression : public Expression { const std::shared_ptr& right_operand() const { return right_operand_; } + bool Equals(const Expression& other) const override; + protected: BinaryExpression(ExpressionType::type type, std::shared_ptr left_operand, std::shared_ptr right_operand) @@ -199,10 +199,6 @@ class ARROW_DS_EXPORT BinaryExpression : public Expression { left_operand_(std::move(left_operand)), right_operand_(std::move(right_operand)) {} - bool OperandsEqual(const BinaryExpression& other) const; - - friend struct ExpressionEqual; - std::shared_ptr left_operand_, right_operand_; }; @@ -212,14 +208,12 @@ class ARROW_DS_EXPORT NnaryExpression : public Expression { public: const ExpressionVector& operands() const { return operands_; } + bool Equals(const Expression& other) const override; + protected: NnaryExpression(ExpressionType::type type, ExpressionVector operands) : Expression(type), operands_(std::move(operands)) {} - bool OperandsEqual(const NnaryExpression& other) const; - - friend struct ExpressionEqual; - ExpressionVector operands_; }; @@ -234,6 +228,8 @@ class ARROW_DS_EXPORT ComparisonExpression final std::string ToString() const override; + bool Equals(const Expression& other) const override; + // Result> Validate(const Schema& schema) const override; Result> Assume(const Expression& given) const override; @@ -305,6 +301,8 @@ class ARROW_DS_EXPORT ScalarExpression final : public Expression { std::string ToString() const override; + bool Equals(const Expression& other) const override; + // Result> Validate(const Schema& schema) const override; static std::shared_ptr Make(bool value) { @@ -357,6 +355,8 @@ class ARROW_DS_EXPORT FieldExpression final : public Expression { std::string ToString() const override; + bool Equals(const Expression& other) const override; + // Result> Validate(const Schema& schema) const override; std::shared_ptr Copy() const override; From 24323fc1542f8dde7ececde9c25e758564366f76 Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Thu, 22 Aug 2019 11:15:40 -0400 Subject: [PATCH 18/28] implement NotExpression::ToString --- cpp/src/arrow/dataset/filter.cc | 2 ++ cpp/src/arrow/dataset/filter.h | 4 ++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/cpp/src/arrow/dataset/filter.cc b/cpp/src/arrow/dataset/filter.cc index ec5ffd0c4cc..dd7c8da8376 100644 --- a/cpp/src/arrow/dataset/filter.cc +++ b/cpp/src/arrow/dataset/filter.cc @@ -670,6 +670,8 @@ std::string AllExpression::ToString() const { return EulerNotation("ALL", operan std::string AnyExpression::ToString() const { return EulerNotation("ANY", operands_); } +std::string NotExpression::ToString() const { return EulerNotation("NOT", {operand_}); } + std::string ComparisonExpression::ToString() const { return EulerNotation(OperatorName(op()), {left_operand_, right_operand_}); } diff --git a/cpp/src/arrow/dataset/filter.h b/cpp/src/arrow/dataset/filter.h index e9e66363306..ac27774f5fe 100644 --- a/cpp/src/arrow/dataset/filter.h +++ b/cpp/src/arrow/dataset/filter.h @@ -135,7 +135,7 @@ class ARROW_DS_EXPORT Expression { } /// returns a debug string representing this expression - virtual std::string ToString() const { return "FIXME"; } + virtual std::string ToString() const = 0; ExpressionType::type type() const { return type_; } @@ -281,7 +281,7 @@ class ARROW_DS_EXPORT NotExpression final public: using ExpressionImpl::ExpressionImpl; - // std::string ToString() const override; + std::string ToString() const override; // Result> Validate(const Schema& schema) const override; From d54ca0600fee0839bf2129394f981628e64163a6 Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Thu, 22 Aug 2019 12:14:45 -0400 Subject: [PATCH 19/28] Expressions evaluate to Datums --- cpp/src/arrow/dataset/filter.cc | 154 +++++++++++++++++---------- cpp/src/arrow/dataset/filter.h | 46 ++++---- cpp/src/arrow/dataset/filter_test.cc | 38 +++++-- 3 files changed, 148 insertions(+), 90 deletions(-) diff --git a/cpp/src/arrow/dataset/filter.cc b/cpp/src/arrow/dataset/filter.cc index dd7c8da8376..0b5471fea8e 100644 --- a/cpp/src/arrow/dataset/filter.cc +++ b/cpp/src/arrow/dataset/filter.cc @@ -38,95 +38,137 @@ using arrow::compute::Datum; using internal::checked_cast; using internal::checked_pointer_cast; -Result> ScalarExpression::Evaluate( - compute::FunctionContext* ctx, const RecordBatch& batch) const { - if (!value_->is_valid) { - std::shared_ptr mask_array; - RETURN_NOT_OK( - MakeArrayOfNull(ctx->memory_pool(), boolean(), batch.num_rows(), &mask_array)); +Result ScalarExpression::Evaluate(compute::FunctionContext* ctx, + const RecordBatch& batch) const { + return value_; +} - return checked_pointer_cast(mask_array); +Result FieldExpression::Evaluate(compute::FunctionContext* ctx, + const RecordBatch& batch) const { + auto column = batch.GetColumnByName(name_); + if (column == nullptr) { + return Datum(std::make_shared()); } + return column; +} - if (!value_->type->Equals(boolean())) { - return Status::Invalid("can't evaluate ", ToString()); +bool IsTrivialConditionDatum(const Datum& datum, BooleanScalar* condition) { + if (!datum.is_scalar()) { + return false; } - TypedBufferBuilder builder; - RETURN_NOT_OK(builder.Append(batch.num_rows(), - checked_cast(*value_).value)); + auto scalar = datum.scalar(); + if (!scalar->is_valid) { + *condition = BooleanScalar(); + return true; + } - std::shared_ptr values; - RETURN_NOT_OK(builder.Finish(&values)); + if (scalar->type->id() != Type::BOOL) { + return false; + } - return std::make_shared(batch.num_rows(), values); + *condition = checked_cast(*scalar); + return true; } -Result> NotExpression::Evaluate( - compute::FunctionContext* ctx, const RecordBatch& batch) const { +Result NotExpression::Evaluate(compute::FunctionContext* ctx, + const RecordBatch& batch) const { ARROW_ASSIGN_OR_RAISE(auto to_invert, operand_->Evaluate(ctx, batch)); + DCHECK(to_invert.type()->Equals(boolean())); + + BooleanScalar trivial_condition; + if (IsTrivialConditionDatum(to_invert, &trivial_condition)) { + if (trivial_condition.is_valid) { + trivial_condition.value = !trivial_condition.value; + } + return Datum(std::make_shared(trivial_condition)); + } + + DCHECK(to_invert.is_array()); Datum out; RETURN_NOT_OK(arrow::compute::Invert(ctx, Datum(to_invert), &out)); - return checked_pointer_cast(out.make_array()); + return out; } -template -Result> EvaluateNnary(const Nnary& nnary, - compute::FunctionContext* ctx, - const RecordBatch& batch) { - const auto& operands = nnary.operands(); +// TODO(bkietz) more reusable coallesce helper +template +bool FinishWithTrivial(const Datum& d, Datum* out) { + BooleanScalar trivial; + if (IsTrivialConditionDatum(d, &trivial)) { + if (!trivial.is_valid || trivial.value == trivial_condition) { + *out = d; + return true; + } + } + return false; +} - ARROW_ASSIGN_OR_RAISE(auto next, operands[0]->Evaluate(ctx, batch)); +Result AllExpression::Evaluate(compute::FunctionContext* ctx, + const RecordBatch& batch) const { + ARROW_ASSIGN_OR_RAISE(auto next, operands_[0]->Evaluate(ctx, batch)); Datum acc(next); - for (size_t i_next = 1; i_next < operands.size(); ++i_next) { - ARROW_ASSIGN_OR_RAISE(next, operands[i_next]->Evaluate(ctx, batch)); + if (FinishWithTrivial(next, &acc)) { + return acc; + } - if (std::is_same::value) { - RETURN_NOT_OK(arrow::compute::And(ctx, Datum(acc), Datum(next), &acc)); - } + for (size_t i_next = 1; i_next < operands_.size(); ++i_next) { + ARROW_ASSIGN_OR_RAISE(next, operands_[i_next]->Evaluate(ctx, batch)); - if (std::is_same::value) { - RETURN_NOT_OK(arrow::compute::Or(ctx, Datum(acc), Datum(next), &acc)); + if (FinishWithTrivial(next, &acc)) { + return acc; } - } - return checked_pointer_cast(acc.make_array()); -} + RETURN_NOT_OK(arrow::compute::And(ctx, Datum(acc), Datum(next), &acc)); + } -Result> AllExpression::Evaluate( - compute::FunctionContext* ctx, const RecordBatch& batch) const { - return EvaluateNnary(*this, ctx, batch); + return acc; } -Result> AnyExpression::Evaluate( - compute::FunctionContext* ctx, const RecordBatch& batch) const { - return EvaluateNnary(*this, ctx, batch); -} +Result AnyExpression::Evaluate(compute::FunctionContext* ctx, + const RecordBatch& batch) const { + ARROW_ASSIGN_OR_RAISE(auto next, operands_[0]->Evaluate(ctx, batch)); + Datum acc(next); -Result> ComparisonExpression::Evaluate( - compute::FunctionContext* ctx, const RecordBatch& batch) const { - if (left_operand_->type() != ExpressionType::FIELD) { - return Status::Invalid("left hand side of comparison must be a field reference"); + if (FinishWithTrivial(next, &acc)) { + return acc; } - if (right_operand_->type() != ExpressionType::SCALAR) { - return Status::Invalid("right hand side of comparison must be a scalar"); + for (size_t i_next = 1; i_next < operands_.size(); ++i_next) { + ARROW_ASSIGN_OR_RAISE(next, operands_[i_next]->Evaluate(ctx, batch)); + + if (FinishWithTrivial(next, &acc)) { + return acc; + } + + RETURN_NOT_OK(arrow::compute::And(ctx, Datum(acc), Datum(next), &acc)); } - const auto& lhs = checked_cast(*left_operand_); - const auto& rhs = checked_cast(*right_operand_); + return acc; +} + +Result ComparisonExpression::Evaluate(compute::FunctionContext* ctx, + const RecordBatch& batch) const { + ARROW_ASSIGN_OR_RAISE(auto lhs, left_operand_->Evaluate(ctx, batch)); + ARROW_ASSIGN_OR_RAISE(auto rhs, right_operand_->Evaluate(ctx, batch)); - auto lhs_array = batch.GetColumnByName(lhs.name()); - if (lhs_array == nullptr) { - // comparing a field absent from batch: return nulls - return ScalarExpression::MakeNull()->Evaluate(ctx, batch); + if (lhs.is_scalar()) { + if (!lhs.scalar()->is_valid) { + return lhs; + } + return Status::NotImplemented("comparison with scalar LHS"); + } + + if (rhs.is_scalar()) { + if (!rhs.scalar()->is_valid) { + return rhs; + } } Datum out; - RETURN_NOT_OK(arrow::compute::Compare(ctx, Datum(lhs_array), Datum(rhs.value()), - arrow::compute::CompareOptions(op_), &out)); - return checked_pointer_cast(out.make_array()); + RETURN_NOT_OK( + arrow::compute::Compare(ctx, lhs, rhs, arrow::compute::CompareOptions(op_), &out)); + return out; } std::shared_ptr ScalarExpression::Make(std::string value) { diff --git a/cpp/src/arrow/dataset/filter.h b/cpp/src/arrow/dataset/filter.h index ac27774f5fe..4a0bd681521 100644 --- a/cpp/src/arrow/dataset/filter.h +++ b/cpp/src/arrow/dataset/filter.h @@ -116,8 +116,8 @@ class ARROW_DS_EXPORT Expression { /// of this expression with schema information incorporated: /// - Scalars are cast to other data types if necessary to ensure comparisons are /// between data of identical type - // virtual Result> Validate(const Schema& schema) const = - // 0; + // Result> Validate( + // const Schema& schema, std::shared_ptr* evaluated_type) const; /// Return a simplified form of this expression given some known conditions. /// For example, (a > 3).Assume(a == 5) == (true). This can be used to do less work @@ -128,11 +128,10 @@ class ARROW_DS_EXPORT Expression { return Copy(); } - // Evaluate an expression against a RecordBatch - virtual Result> Evaluate(compute::FunctionContext* ctx, - const RecordBatch& batch) const { - return Status::Invalid("can't evaluate ", ToString()); - } + // Evaluate an expression against a RecordBatch. + // Returned Datum must be of either SCALAR or ARRAY kind. + virtual Result Evaluate(compute::FunctionContext* ctx, + const RecordBatch& batch) const = 0; /// returns a debug string representing this expression virtual std::string ToString() const = 0; @@ -230,14 +229,12 @@ class ARROW_DS_EXPORT ComparisonExpression final bool Equals(const Expression& other) const override; - // Result> Validate(const Schema& schema) const override; - Result> Assume(const Expression& given) const override; compute::CompareOperator op() const { return op_; } - Result> Evaluate(compute::FunctionContext* ctx, - const RecordBatch& batch) const override; + Result Evaluate(compute::FunctionContext* ctx, + const RecordBatch& batch) const override; private: Result> AssumeGivenComparison( @@ -253,12 +250,10 @@ class ARROW_DS_EXPORT AllExpression final std::string ToString() const override; - // Result> Validate(const Schema& schema) const override; - Result> Assume(const Expression& given) const override; - Result> Evaluate(compute::FunctionContext* ctx, - const RecordBatch& batch) const override; + Result Evaluate(compute::FunctionContext* ctx, + const RecordBatch& batch) const override; }; class ARROW_DS_EXPORT AnyExpression final @@ -268,12 +263,10 @@ class ARROW_DS_EXPORT AnyExpression final std::string ToString() const override; - // Result> Validate(const Schema& schema) const override; - Result> Assume(const Expression& given) const override; - Result> Evaluate(compute::FunctionContext* ctx, - const RecordBatch& batch) const override; + Result Evaluate(compute::FunctionContext* ctx, + const RecordBatch& batch) const override; }; class ARROW_DS_EXPORT NotExpression final @@ -283,12 +276,10 @@ class ARROW_DS_EXPORT NotExpression final std::string ToString() const override; - // Result> Validate(const Schema& schema) const override; - Result> Assume(const Expression& given) const override; - Result> Evaluate(compute::FunctionContext* ctx, - const RecordBatch& batch) const override; + Result Evaluate(compute::FunctionContext* ctx, + const RecordBatch& batch) const override; }; /// Represents a scalar value; thin wrapper around arrow::Scalar @@ -303,8 +294,6 @@ class ARROW_DS_EXPORT ScalarExpression final : public Expression { bool Equals(const Expression& other) const override; - // Result> Validate(const Schema& schema) const override; - static std::shared_ptr Make(bool value) { return std::make_shared(std::make_shared(value)); } @@ -335,8 +324,8 @@ class ARROW_DS_EXPORT ScalarExpression final : public Expression { return std::make_shared(std::make_shared()); } - Result> Evaluate(compute::FunctionContext* ctx, - const RecordBatch& batch) const override; + Result Evaluate(compute::FunctionContext* ctx, + const RecordBatch& batch) const override; std::shared_ptr Copy() const override; @@ -357,7 +346,8 @@ class ARROW_DS_EXPORT FieldExpression final : public Expression { bool Equals(const Expression& other) const override; - // Result> Validate(const Schema& schema) const override; + Result Evaluate(compute::FunctionContext* ctx, + const RecordBatch& batch) const override; std::shared_ptr Copy() const override; diff --git a/cpp/src/arrow/dataset/filter_test.cc b/cpp/src/arrow/dataset/filter_test.cc index 8035494516f..e45a0589eee 100644 --- a/cpp/src/arrow/dataset/filter_test.cc +++ b/cpp/src/arrow/dataset/filter_test.cc @@ -35,6 +35,7 @@ namespace arrow { namespace dataset { using string_literals::operator""_; +using internal::checked_cast; using internal::checked_pointer_cast; class ExpressionsTest : public ::testing::Test { @@ -117,9 +118,10 @@ TEST_F(ExpressionsTest, SimplificationToNull) { class FilterTest : public ::testing::Test { public: - Result> DoFilter( - const Expression& expr, std::vector> fields, - std::string batch_json, std::shared_ptr* expected_mask = nullptr) { + Result DoFilter(const Expression& expr, + std::vector> fields, + std::string batch_json, + std::shared_ptr* expected_mask = nullptr) { // expected filter result is in the "in" field fields.push_back(field("in", boolean())); @@ -137,9 +139,33 @@ class FilterTest : public ::testing::Test { void AssertFilter(const Expression& expr, std::vector> fields, std::string batch_json) { std::shared_ptr expected_mask; - auto mask = DoFilter(expr, std::move(fields), std::move(batch_json), &expected_mask); - ASSERT_OK(mask.status()); - ASSERT_ARRAYS_EQUAL(*expected_mask, *mask.ValueOrDie()); + auto mask_res = + DoFilter(expr, std::move(fields), std::move(batch_json), &expected_mask); + ASSERT_OK(mask_res.status()); + + auto mask = std::move(mask_res).ValueOrDie(); + ASSERT_TRUE(mask.type()->Equals(null()) || mask.type()->Equals(boolean())); + + if (mask.is_array()) { + ASSERT_ARRAYS_EQUAL(*expected_mask, *mask.make_array()); + return; + } + + ASSERT_TRUE(mask.is_scalar()); + auto mask_scalar = mask.scalar(); + if (!mask_scalar->is_valid) { + ASSERT_EQ(expected_mask->null_count(), expected_mask->length()); + return; + } + + TypedBufferBuilder builder; + ASSERT_OK(builder.Append(expected_mask->length(), + checked_cast(*mask_scalar).value)); + + std::shared_ptr values; + ASSERT_OK(builder.Finish(&values)); + + ASSERT_ARRAYS_EQUAL(*expected_mask, BooleanArray(expected_mask->length(), values)); } arrow::compute::FunctionContext ctx_; From 46fa3fbfe6c38398008d78d747035ec61a8c570f Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Thu, 22 Aug 2019 17:00:23 -0400 Subject: [PATCH 20/28] add Expression::Validate implementations --- cpp/src/arrow/dataset/filter.cc | 76 +++++++++++++++++++++++++++++++-- cpp/src/arrow/dataset/filter.h | 23 +++++++--- 2 files changed, 89 insertions(+), 10 deletions(-) diff --git a/cpp/src/arrow/dataset/filter.cc b/cpp/src/arrow/dataset/filter.cc index 0b5471fea8e..1814b32633e 100644 --- a/cpp/src/arrow/dataset/filter.cc +++ b/cpp/src/arrow/dataset/filter.cc @@ -20,6 +20,7 @@ #include #include #include +#include #include #include "arrow/buffer.h" @@ -253,13 +254,13 @@ struct CompareVisitor { // Compare two scalars // if either is null, return is null Result Compare(const Scalar& lhs, const Scalar& rhs) { - if (!lhs.is_valid || !rhs.is_valid) { - return Comparison::NULL_; - } if (!lhs.type->Equals(*rhs.type)) { return Status::TypeError("cannot compare scalars with differing type: ", *lhs.type, " vs ", *rhs.type); } + if (!lhs.is_valid || !rhs.is_valid) { + return Comparison::NULL_; + } CompareVisitor vis{Comparison::NULL_, lhs, rhs}; RETURN_NOT_OK(VisitTypeInline(*lhs.type, &vis)); return vis.result_; @@ -686,7 +687,7 @@ std::string ScalarExpression::ToString() const { value = checked_cast(*value_).value->ToString(); break; default: - value = "TODO"; + value = "TODO(bkietz)"; break; } @@ -844,5 +845,72 @@ AnyExpression operator or(const Expression& lhs, const Expression& rhs) { NotExpression operator not(const Expression& rhs) { return NotExpression(rhs.Copy()); } +Result> ComparisonExpression::Validate( + const Schema& schema) const { + if (left_operand_->type() != ExpressionType::FIELD) { + return Status::NotImplemented("comparison with non-FIELD RHS"); + } + + ARROW_ASSIGN_OR_RAISE(auto lhs_type, left_operand_->Validate(schema)); + ARROW_ASSIGN_OR_RAISE(auto rhs_type, right_operand_->Validate(schema)); + if (!lhs_type->Equals(rhs_type)) { + return Status::TypeError("cannot compare expressions of differing type, ", *lhs_type, + " vs ", *rhs_type); + } + + if (lhs_type->id() == Type::NA || rhs_type->id() == Type::NA) { + return null(); + } + + return boolean(); +} + +Status EnsureNullOrBool(const std::string& msg_prefix, + const std::shared_ptr& type) { + if (type->id() == Type::BOOL || type->id() == Type::NA) { + return Status::OK(); + } + return Status::TypeError(msg_prefix, *type); +} + +Result> ValidateNnary(const NnaryExpression& nnary, + const Schema& schema) { + auto out = boolean(); + for (const auto& operand : nnary.operands()) { + ARROW_ASSIGN_OR_RAISE(auto type, operand->Validate(schema)); + RETURN_NOT_OK( + EnsureNullOrBool("cannot combine expressions including one of type ", type)); + if (type->id() == Type::NA) { + out = null(); + } + } + return out; +} + +Result> AllExpression::Validate(const Schema& schema) const { + return ValidateNnary(*this, schema); +} + +Result> AnyExpression::Validate(const Schema& schema) const { + return ValidateNnary(*this, schema); +} + +Result> NotExpression::Validate(const Schema& schema) const { + ARROW_ASSIGN_OR_RAISE(auto operand_type, operand_->Validate(schema)); + RETURN_NOT_OK(EnsureNullOrBool("cannot invert an expression of type ", operand_type)); + return operand_type; +} + +Result> ScalarExpression::Validate(const Schema& schema) const { + return value_->type; +} + +Result> FieldExpression::Validate(const Schema& schema) const { + if (auto field = schema.GetFieldByName(name_)) { + return field->type(); + } + return null(); +} + } // namespace dataset } // namespace arrow diff --git a/cpp/src/arrow/dataset/filter.h b/cpp/src/arrow/dataset/filter.h index 4a0bd681521..e3f07abd3da 100644 --- a/cpp/src/arrow/dataset/filter.h +++ b/cpp/src/arrow/dataset/filter.h @@ -112,12 +112,10 @@ class ARROW_DS_EXPORT Expression { bool Equals(const std::shared_ptr& other) const; /// Validate this expression for execution against a schema. This will check that all - /// reference fields are present and all subexpressions are executable. Returns a copy - /// of this expression with schema information incorporated: - /// - Scalars are cast to other data types if necessary to ensure comparisons are - /// between data of identical type - // Result> Validate( - // const Schema& schema, std::shared_ptr* evaluated_type) const; + /// reference fields are present (fields not in the schema will be replaced with null) + /// and all subexpressions are executable. Returns the type to which this expression + /// will evaluate. + virtual Result> Validate(const Schema& schema) const = 0; /// Return a simplified form of this expression given some known conditions. /// For example, (a > 3).Assume(a == 5) == (true). This can be used to do less work @@ -236,6 +234,8 @@ class ARROW_DS_EXPORT ComparisonExpression final Result Evaluate(compute::FunctionContext* ctx, const RecordBatch& batch) const override; + Result> Validate(const Schema& schema) const override; + private: Result> AssumeGivenComparison( const ComparisonExpression& given) const; @@ -254,6 +254,8 @@ class ARROW_DS_EXPORT AllExpression final Result Evaluate(compute::FunctionContext* ctx, const RecordBatch& batch) const override; + + Result> Validate(const Schema& schema) const override; }; class ARROW_DS_EXPORT AnyExpression final @@ -267,6 +269,8 @@ class ARROW_DS_EXPORT AnyExpression final Result Evaluate(compute::FunctionContext* ctx, const RecordBatch& batch) const override; + + Result> Validate(const Schema& schema) const override; }; class ARROW_DS_EXPORT NotExpression final @@ -280,6 +284,8 @@ class ARROW_DS_EXPORT NotExpression final Result Evaluate(compute::FunctionContext* ctx, const RecordBatch& batch) const override; + + Result> Validate(const Schema& schema) const override; }; /// Represents a scalar value; thin wrapper around arrow::Scalar @@ -298,6 +304,7 @@ class ARROW_DS_EXPORT ScalarExpression final : public Expression { return std::make_shared(std::make_shared(value)); } + // FIXME(bkietz) create correct scalar type template static typename std::enable_if::value, std::shared_ptr>::type @@ -324,6 +331,8 @@ class ARROW_DS_EXPORT ScalarExpression final : public Expression { return std::make_shared(std::make_shared()); } + Result> Validate(const Schema& schema) const override; + Result Evaluate(compute::FunctionContext* ctx, const RecordBatch& batch) const override; @@ -346,6 +355,8 @@ class ARROW_DS_EXPORT FieldExpression final : public Expression { bool Equals(const Expression& other) const override; + Result> Validate(const Schema& schema) const override; + Result Evaluate(compute::FunctionContext* ctx, const RecordBatch& batch) const override; From cab17b1a305f14b1572c2393cf2fa1827c6fc8bc Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Sun, 25 Aug 2019 12:19:46 -0400 Subject: [PATCH 21/28] amend doccomments --- cpp/src/arrow/dataset/filter.h | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/cpp/src/arrow/dataset/filter.h b/cpp/src/arrow/dataset/filter.h index e3f07abd3da..d188a09779e 100644 --- a/cpp/src/arrow/dataset/filter.h +++ b/cpp/src/arrow/dataset/filter.h @@ -165,7 +165,7 @@ class ExpressionImpl : public Base { } }; -/// Represents an expression with exactly one operand; for example negation +/// Base class for an expression with exactly one operand class ARROW_DS_EXPORT UnaryExpression : public Expression { public: const std::shared_ptr& operand() const { return operand_; } @@ -179,8 +179,7 @@ class ARROW_DS_EXPORT UnaryExpression : public Expression { std::shared_ptr operand_; }; -/// Represents an expression with exactly two operands; for example a comparison of two -/// expressions +/// Base class for an expression with exactly two operands class ARROW_DS_EXPORT BinaryExpression : public Expression { public: const std::shared_ptr& left_operand() const { return left_operand_; } @@ -199,8 +198,7 @@ class ARROW_DS_EXPORT BinaryExpression : public Expression { std::shared_ptr left_operand_, right_operand_; }; -/// Represents an expression with multiple operands; for example a conjunction or -/// disjunction of other expressions +/// Base class for an expression with multiple operands class ARROW_DS_EXPORT NnaryExpression : public Expression { public: const ExpressionVector& operands() const { return operands_; } From 3ac299ba6710b405e902362b683c5e3a62b37669 Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Sun, 25 Aug 2019 12:42:38 -0400 Subject: [PATCH 22/28] use strongly typed nulls --- cpp/src/arrow/dataset/filter.cc | 15 +++++++++++---- cpp/src/arrow/dataset/filter.h | 5 ++--- cpp/src/arrow/dataset/filter_test.cc | 13 +++++++------ cpp/src/arrow/scalar.cc | 26 ++++++++++++++++++++++++++ cpp/src/arrow/scalar.h | 6 ++++++ 5 files changed, 52 insertions(+), 13 deletions(-) diff --git a/cpp/src/arrow/dataset/filter.cc b/cpp/src/arrow/dataset/filter.cc index 1814b32633e..719c356b2fb 100644 --- a/cpp/src/arrow/dataset/filter.cc +++ b/cpp/src/arrow/dataset/filter.cc @@ -182,6 +182,13 @@ std::shared_ptr ScalarExpression::Make(const char* value) { std::make_shared(Buffer::Wrap(value, std::strlen(value)))); } +std::shared_ptr ScalarExpression::MakeNull( + const std::shared_ptr& type) { + std::shared_ptr null; + DCHECK_OK(arrow::MakeNull(type, &null)); + return Make(std::move(null)); +} + struct Comparison { enum type { LESS, @@ -358,7 +365,7 @@ Result> ComparisonExpression::Assume( if (!scalar.is_valid) { // some subexpression of given is always null, return null - return ScalarExpression::MakeNull(); + return ScalarExpression::MakeNull(boolean()); } if (scalar.value == true) { @@ -430,7 +437,7 @@ Result> ComparisonExpression::AssumeGivenComparison( if (cmp == Comparison::NULL_) { // the RHS of e or given was null - return ScalarExpression::MakeNull(); + return ScalarExpression::MakeNull(boolean()); } static auto always = ScalarExpression::Make(true); @@ -590,7 +597,7 @@ Result> AssumeNnary(const Nnary& nnary, BooleanScalar scalar; if (operand->IsTrivialCondition(&scalar)) { if (!scalar.is_valid) { - return ScalarExpression::MakeNull(); + return ScalarExpression::MakeNull(boolean()); } if (scalar.value == trivial_condition) { @@ -636,7 +643,7 @@ Result> NotExpression::Assume(const Expression& give } if (!scalar.is_valid) { - return ScalarExpression::MakeNull(); + return ScalarExpression::MakeNull(boolean()); } return ScalarExpression::Make(!scalar.value); diff --git a/cpp/src/arrow/dataset/filter.h b/cpp/src/arrow/dataset/filter.h index d188a09779e..3206ea8d7f8 100644 --- a/cpp/src/arrow/dataset/filter.h +++ b/cpp/src/arrow/dataset/filter.h @@ -325,9 +325,8 @@ class ARROW_DS_EXPORT ScalarExpression final : public Expression { return std::make_shared(std::move(value)); } - static std::shared_ptr MakeNull() { - return std::make_shared(std::make_shared()); - } + static std::shared_ptr MakeNull( + const std::shared_ptr& type); Result> Validate(const Schema& schema) const override; diff --git a/cpp/src/arrow/dataset/filter_test.cc b/cpp/src/arrow/dataset/filter_test.cc index e45a0589eee..46d716973d3 100644 --- a/cpp/src/arrow/dataset/filter_test.cc +++ b/cpp/src/arrow/dataset/filter_test.cc @@ -108,12 +108,13 @@ TEST_F(ExpressionsTest, SimplificationAgainstCompoundCondition) { } TEST_F(ExpressionsTest, SimplificationToNull) { - auto null = ScalarExpression::MakeNull(); + auto null = ScalarExpression::MakeNull(boolean()); + auto null64 = ScalarExpression::MakeNull(int64()); - AssertSimplifiesTo(*equal(fieldRef("b"), null), "b"_ == 3, *null); - AssertSimplifiesTo(*not_equal(fieldRef("b"), null), "b"_ == 3, *null); - AssertSimplifiesTo(*not_equal(fieldRef("b"), null) and "b"_ > 3, "b"_ == 3, *null); - AssertSimplifiesTo("b"_ > 3 and *not_equal(fieldRef("b"), null), "b"_ == 3, *null); + AssertSimplifiesTo(*equal(fieldRef("b"), null64), "b"_ == 3, *null); + AssertSimplifiesTo(*not_equal(fieldRef("b"), null64), "b"_ == 3, *null); + AssertSimplifiesTo(*not_equal(fieldRef("b"), null64) and "b"_ > 3, "b"_ == 3, *null); + AssertSimplifiesTo("b"_ > 3 and *not_equal(fieldRef("b"), null64), "b"_ == 3, *null); } class FilterTest : public ::testing::Test { @@ -192,7 +193,7 @@ TEST_F(FilterTest, Trivial) { {"a": 0, "b": 1.0, "in": 0} ])"); - AssertFilter(*ScalarExpression::MakeNull(), + AssertFilter(*ScalarExpression::MakeNull(boolean()), {field("a", int64()), field("b", float64())}, R"([ {"a": 0, "b": -0.1, "in": null}, {"a": 0, "b": 0.3, "in": null}, diff --git a/cpp/src/arrow/scalar.cc b/cpp/src/arrow/scalar.cc index 7c3a4ee9a4b..df0cca247ce 100644 --- a/cpp/src/arrow/scalar.cc +++ b/cpp/src/arrow/scalar.cc @@ -26,6 +26,7 @@ #include "arrow/util/checked_cast.h" #include "arrow/util/decimal.h" #include "arrow/util/logging.h" +#include "arrow/visitor_inline.h" namespace arrow { @@ -115,4 +116,29 @@ FixedSizeListScalar::FixedSizeListScalar(const std::shared_ptr& value, bool is_valid) : FixedSizeListScalar(value, value->type(), is_valid) {} +struct MakeNullImpl { + template + using ScalarType = typename TypeTraits::ScalarType; + + template + typename std::enable_if>::value, + Status>::type + Visit(const T&) { + *out_ = std::make_shared>(); + return Status::OK(); + } + + Status Visit(const DataType& t) { + return Status::NotImplemented("construcing null scalars of type ", t); + } + + std::shared_ptr type_; + std::shared_ptr* out_; +}; + +Status MakeNull(const std::shared_ptr& type, std::shared_ptr* null) { + MakeNullImpl impl = {type, null}; + return VisitTypeInline(*type, &impl); +} + } // namespace arrow diff --git a/cpp/src/arrow/scalar.h b/cpp/src/arrow/scalar.h index 39a38f4f883..db17c6c164f 100644 --- a/cpp/src/arrow/scalar.h +++ b/cpp/src/arrow/scalar.h @@ -246,4 +246,10 @@ class ARROW_EXPORT UnionScalar : public Scalar {}; class ARROW_EXPORT DictionaryScalar : public Scalar {}; class ARROW_EXPORT ExtensionScalar : public Scalar {}; +/// \param[in] type the type of scalar to produce +/// \param[out] null output scalar with is_valid=false +/// \return Status +ARROW_EXPORT +Status MakeNull(const std::shared_ptr& type, std::shared_ptr* null); + } // namespace arrow From 7c01f761940f328dfd7adfbb39efb919298c2995 Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Sun, 25 Aug 2019 12:52:37 -0400 Subject: [PATCH 23/28] construct correct scalartype --- cpp/src/arrow/compute/kernels/compare.cc | 5 ++++- cpp/src/arrow/dataset/filter.cc | 2 +- cpp/src/arrow/dataset/filter.h | 14 ++++---------- cpp/src/arrow/dataset/filter_test.cc | 22 +++++++++++----------- 4 files changed, 20 insertions(+), 23 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/compare.cc b/cpp/src/arrow/compute/kernels/compare.cc index 0ed785f26f9..386e3b98e59 100644 --- a/cpp/src/arrow/compute/kernels/compare.cc +++ b/cpp/src/arrow/compute/kernels/compare.cc @@ -244,7 +244,10 @@ Status Compare(FunctionContext* context, const Datum& left, const Datum& right, DCHECK(out); auto type = left.type(); - DCHECK(type->Equals(right.type())); + if (!type->Equals(right.type())) { + return Status::TypeError("Cannot compare data of differing type ", *type, " vs ", + *right.type()); + } // Requires that both types are equal. auto fn = MakeCompareFunction(context, *type, options); if (fn == nullptr) { diff --git a/cpp/src/arrow/dataset/filter.cc b/cpp/src/arrow/dataset/filter.cc index 719c356b2fb..572eb0028db 100644 --- a/cpp/src/arrow/dataset/filter.cc +++ b/cpp/src/arrow/dataset/filter.cc @@ -262,7 +262,7 @@ struct CompareVisitor { // if either is null, return is null Result Compare(const Scalar& lhs, const Scalar& rhs) { if (!lhs.type->Equals(*rhs.type)) { - return Status::TypeError("cannot compare scalars with differing type: ", *lhs.type, + return Status::TypeError("Cannot compare scalars of differing type: ", *lhs.type, " vs ", *rhs.type); } if (!lhs.is_valid || !rhs.is_valid) { diff --git a/cpp/src/arrow/dataset/filter.h b/cpp/src/arrow/dataset/filter.h index 3206ea8d7f8..83830752598 100644 --- a/cpp/src/arrow/dataset/filter.h +++ b/cpp/src/arrow/dataset/filter.h @@ -302,19 +302,13 @@ class ARROW_DS_EXPORT ScalarExpression final : public Expression { return std::make_shared(std::make_shared(value)); } - // FIXME(bkietz) create correct scalar type template - static typename std::enable_if::value, + static typename std::enable_if::value || + std::is_floating_point::value, std::shared_ptr>::type Make(T value) { - return std::make_shared(std::make_shared(value)); - } - - template - static typename std::enable_if::value, - std::shared_ptr>::type - Make(T value) { - return std::make_shared(std::make_shared(value)); + using ScalarType = typename CTypeTraits::ScalarType; + return std::make_shared(std::make_shared(value)); } static std::shared_ptr Make(std::string value); diff --git a/cpp/src/arrow/dataset/filter_test.cc b/cpp/src/arrow/dataset/filter_test.cc index 46d716973d3..12dd4da47fd 100644 --- a/cpp/src/arrow/dataset/filter_test.cc +++ b/cpp/src/arrow/dataset/filter_test.cc @@ -109,12 +109,12 @@ TEST_F(ExpressionsTest, SimplificationAgainstCompoundCondition) { TEST_F(ExpressionsTest, SimplificationToNull) { auto null = ScalarExpression::MakeNull(boolean()); - auto null64 = ScalarExpression::MakeNull(int64()); + auto null32 = ScalarExpression::MakeNull(int32()); - AssertSimplifiesTo(*equal(fieldRef("b"), null64), "b"_ == 3, *null); - AssertSimplifiesTo(*not_equal(fieldRef("b"), null64), "b"_ == 3, *null); - AssertSimplifiesTo(*not_equal(fieldRef("b"), null64) and "b"_ > 3, "b"_ == 3, *null); - AssertSimplifiesTo("b"_ > 3 and *not_equal(fieldRef("b"), null64), "b"_ == 3, *null); + AssertSimplifiesTo(*equal(fieldRef("b"), null32), "b"_ == 3, *null); + AssertSimplifiesTo(*not_equal(fieldRef("b"), null32), "b"_ == 3, *null); + AssertSimplifiesTo(*not_equal(fieldRef("b"), null32) and "b"_ > 3, "b"_ == 3, *null); + AssertSimplifiesTo("b"_ > 3 and *not_equal(fieldRef("b"), null32), "b"_ == 3, *null); } class FilterTest : public ::testing::Test { @@ -173,7 +173,7 @@ class FilterTest : public ::testing::Test { }; TEST_F(FilterTest, Trivial) { - AssertFilter(*scalar(true), {field("a", int64()), field("b", float64())}, R"([ + AssertFilter(*scalar(true), {field("a", int32()), field("b", float64())}, R"([ {"a": 0, "b": -0.1, "in": 1}, {"a": 0, "b": 0.3, "in": 1}, {"a": 1, "b": 0.2, "in": 1}, @@ -183,7 +183,7 @@ TEST_F(FilterTest, Trivial) { {"a": 0, "b": 1.0, "in": 1} ])"); - AssertFilter(*scalar(false), {field("a", int64()), field("b", float64())}, R"([ + AssertFilter(*scalar(false), {field("a", int32()), field("b", float64())}, R"([ {"a": 0, "b": -0.1, "in": 0}, {"a": 0, "b": 0.3, "in": 0}, {"a": 1, "b": 0.2, "in": 0}, @@ -194,7 +194,7 @@ TEST_F(FilterTest, Trivial) { ])"); AssertFilter(*ScalarExpression::MakeNull(boolean()), - {field("a", int64()), field("b", float64())}, R"([ + {field("a", int32()), field("b", float64())}, R"([ {"a": 0, "b": -0.1, "in": null}, {"a": 0, "b": 0.3, "in": null}, {"a": 1, "b": 0.2, "in": null}, @@ -207,7 +207,7 @@ TEST_F(FilterTest, Trivial) { TEST_F(FilterTest, Basics) { AssertFilter("a"_ == 0 and "b"_ > 0.0 and "b"_ < 1.0, - {field("a", int64()), field("b", float64())}, R"([ + {field("a", int32()), field("b", float64())}, R"([ {"a": 0, "b": -0.1, "in": 0}, {"a": 0, "b": 0.3, "in": 1}, {"a": 1, "b": 0.2, "in": 0}, @@ -217,7 +217,7 @@ TEST_F(FilterTest, Basics) { {"a": 0, "b": 1.0, "in": 0} ])"); - AssertFilter("a"_ != 0 and "b"_ > 0.1, {field("a", int64()), field("b", float64())}, + AssertFilter("a"_ != 0 and "b"_ > 0.1, {field("a", int32()), field("b", float64())}, R"([ {"a": 0, "b": -0.1, "in": 0}, {"a": 0, "b": 0.3, "in": 0}, @@ -231,7 +231,7 @@ TEST_F(FilterTest, Basics) { TEST_F(FilterTest, ConditionOnAbsentColumn) { AssertFilter("a"_ == 0 and "b"_ > 0.0 and "b"_ < 1.0 and "absent"_ == 0, - {field("a", int64()), field("b", float64())}, R"([ + {field("a", int32()), field("b", float64())}, R"([ {"a": 0, "b": -0.1, "in": null}, {"a": 0, "b": 0.3, "in": null}, {"a": 1, "b": 0.2, "in": null}, From ca155c6d9514f88c10957ac56015bc853eedb0ef Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Mon, 26 Aug 2019 11:34:21 -0400 Subject: [PATCH 24/28] add explicit std::move, msvc doesn't like defining operator and --- cpp/src/arrow/dataset/filter.cc | 32 ++++++++++++++++---------------- cpp/src/arrow/dataset/filter.h | 6 +++--- 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/cpp/src/arrow/dataset/filter.cc b/cpp/src/arrow/dataset/filter.cc index 572eb0028db..bfcf5955f5e 100644 --- a/cpp/src/arrow/dataset/filter.cc +++ b/cpp/src/arrow/dataset/filter.cc @@ -50,7 +50,7 @@ Result FieldExpression::Evaluate(compute::FunctionContext* ctx, if (column == nullptr) { return Datum(std::make_shared()); } - return column; + return std::move(column); } bool IsTrivialConditionDatum(const Datum& datum, BooleanScalar* condition) { @@ -88,7 +88,7 @@ Result NotExpression::Evaluate(compute::FunctionContext* ctx, DCHECK(to_invert.is_array()); Datum out; RETURN_NOT_OK(arrow::compute::Invert(ctx, Datum(to_invert), &out)); - return out; + return std::move(out); } // TODO(bkietz) more reusable coallesce helper @@ -110,20 +110,20 @@ Result AllExpression::Evaluate(compute::FunctionContext* ctx, Datum acc(next); if (FinishWithTrivial(next, &acc)) { - return acc; + return std::move(acc); } for (size_t i_next = 1; i_next < operands_.size(); ++i_next) { ARROW_ASSIGN_OR_RAISE(next, operands_[i_next]->Evaluate(ctx, batch)); if (FinishWithTrivial(next, &acc)) { - return acc; + return std::move(acc); } RETURN_NOT_OK(arrow::compute::And(ctx, Datum(acc), Datum(next), &acc)); } - return acc; + return std::move(acc); } Result AnyExpression::Evaluate(compute::FunctionContext* ctx, @@ -132,20 +132,20 @@ Result AnyExpression::Evaluate(compute::FunctionContext* ctx, Datum acc(next); if (FinishWithTrivial(next, &acc)) { - return acc; + return std::move(acc); } for (size_t i_next = 1; i_next < operands_.size(); ++i_next) { ARROW_ASSIGN_OR_RAISE(next, operands_[i_next]->Evaluate(ctx, batch)); if (FinishWithTrivial(next, &acc)) { - return acc; + return std::move(acc); } RETURN_NOT_OK(arrow::compute::And(ctx, Datum(acc), Datum(next), &acc)); } - return acc; + return std::move(acc); } Result ComparisonExpression::Evaluate(compute::FunctionContext* ctx, @@ -155,21 +155,21 @@ Result ComparisonExpression::Evaluate(compute::FunctionContext* ctx, if (lhs.is_scalar()) { if (!lhs.scalar()->is_valid) { - return lhs; + return std::move(lhs); } return Status::NotImplemented("comparison with scalar LHS"); } if (rhs.is_scalar()) { if (!rhs.scalar()->is_valid) { - return rhs; + return std::move(rhs); } } Datum out; RETURN_NOT_OK( arrow::compute::Compare(ctx, lhs, rhs, arrow::compute::CompareOptions(op_), &out)); - return out; + return std::move(out); } std::shared_ptr ScalarExpression::Make(std::string value) { @@ -397,7 +397,7 @@ Result> ComparisonExpression::Assume( ARROW_ASSIGN_OR_RAISE(simplified, simplified->Assume(*operand)); } - return simplified; + return std::move(simplified); } default: @@ -842,15 +842,15 @@ Out MaybeCombine(const Expression& lhs, const Expression& rhs) { return Out(std::move(operands)); } -AllExpression operator and(const Expression& lhs, const Expression& rhs) { +AllExpression operator&&(const Expression& lhs, const Expression& rhs) { return MaybeCombine(lhs, rhs); } -AnyExpression operator or(const Expression& lhs, const Expression& rhs) { +AnyExpression operator||(const Expression& lhs, const Expression& rhs) { return MaybeCombine(lhs, rhs); } -NotExpression operator not(const Expression& rhs) { return NotExpression(rhs.Copy()); } +NotExpression operator!(const Expression& rhs) { return NotExpression(rhs.Copy()); } Result> ComparisonExpression::Validate( const Schema& schema) const { @@ -905,7 +905,7 @@ Result> AnyExpression::Validate(const Schema& schema) Result> NotExpression::Validate(const Schema& schema) const { ARROW_ASSIGN_OR_RAISE(auto operand_type, operand_->Validate(schema)); RETURN_NOT_OK(EnsureNullOrBool("cannot invert an expression of type ", operand_type)); - return operand_type; + return std::move(operand_type); } Result> ScalarExpression::Validate(const Schema& schema) const { diff --git a/cpp/src/arrow/dataset/filter.h b/cpp/src/arrow/dataset/filter.h index 83830752598..41e53abffaf 100644 --- a/cpp/src/arrow/dataset/filter.h +++ b/cpp/src/arrow/dataset/filter.h @@ -359,15 +359,15 @@ class ARROW_DS_EXPORT FieldExpression final : public Expression { ARROW_DS_EXPORT std::shared_ptr all(ExpressionVector operands); -ARROW_DS_EXPORT AllExpression operator and(const Expression& lhs, const Expression& rhs); +ARROW_DS_EXPORT AllExpression operator&&(const Expression& lhs, const Expression& rhs); ARROW_DS_EXPORT std::shared_ptr any(ExpressionVector operands); -ARROW_DS_EXPORT AnyExpression operator or(const Expression& lhs, const Expression& rhs); +ARROW_DS_EXPORT AnyExpression operator||(const Expression& lhs, const Expression& rhs); ARROW_DS_EXPORT std::shared_ptr not_(std::shared_ptr operand); -ARROW_DS_EXPORT NotExpression operator not(const Expression& rhs); +ARROW_DS_EXPORT NotExpression operator!(const Expression& rhs); #define COMPARISON_FACTORY(NAME, FACTORY_NAME, OP) \ inline std::shared_ptr FACTORY_NAME( \ From 0e366bb16308a6317c0231b3bf065fb2424fc568 Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Wed, 28 Aug 2019 14:53:00 -0400 Subject: [PATCH 25/28] rename all/any to and/or --- cpp/src/arrow/dataset/filter.cc | 60 ++++++++++++++-------------- cpp/src/arrow/dataset/filter.h | 28 ++++++------- cpp/src/arrow/dataset/filter_test.cc | 2 +- cpp/src/arrow/scalar.cc | 3 +- cpp/src/arrow/scalar.h | 3 +- 5 files changed, 49 insertions(+), 47 deletions(-) diff --git a/cpp/src/arrow/dataset/filter.cc b/cpp/src/arrow/dataset/filter.cc index bfcf5955f5e..26464e93922 100644 --- a/cpp/src/arrow/dataset/filter.cc +++ b/cpp/src/arrow/dataset/filter.cc @@ -104,7 +104,7 @@ bool FinishWithTrivial(const Datum& d, Datum* out) { return false; } -Result AllExpression::Evaluate(compute::FunctionContext* ctx, +Result AndExpression::Evaluate(compute::FunctionContext* ctx, const RecordBatch& batch) const { ARROW_ASSIGN_OR_RAISE(auto next, operands_[0]->Evaluate(ctx, batch)); Datum acc(next); @@ -126,8 +126,8 @@ Result AllExpression::Evaluate(compute::FunctionContext* ctx, return std::move(acc); } -Result AnyExpression::Evaluate(compute::FunctionContext* ctx, - const RecordBatch& batch) const { +Result OrExpression::Evaluate(compute::FunctionContext* ctx, + const RecordBatch& batch) const { ARROW_ASSIGN_OR_RAISE(auto next, operands_[0]->Evaluate(ctx, batch)); Datum acc(next); @@ -185,7 +185,7 @@ std::shared_ptr ScalarExpression::Make(const char* value) { std::shared_ptr ScalarExpression::MakeNull( const std::shared_ptr& type) { std::shared_ptr null; - DCHECK_OK(arrow::MakeNull(type, &null)); + DCHECK_OK(arrow::MakeNullScalar(type, &null)); return Make(std::move(null)); } @@ -312,18 +312,18 @@ Result> Invert(const Expression& op) { case ExpressionType::NOT: return checked_cast(op).operand(); - case ExpressionType::ALL: - case ExpressionType::ANY: { + case ExpressionType::AND: + case ExpressionType::OR: { ExpressionVector inverted_operands; for (auto operand : checked_cast(op).operands()) { ARROW_ASSIGN_OR_RAISE(auto inverted_operand, Invert(*operand)); inverted_operands.push_back(inverted_operand); } - if (op.type() == ExpressionType::ALL) { - return std::make_shared(std::move(inverted_operands)); + if (op.type() == ExpressionType::AND) { + return std::make_shared(std::move(inverted_operands)); } - return std::make_shared(std::move(inverted_operands)); + return std::make_shared(std::move(inverted_operands)); } case ExpressionType::COMPARISON: @@ -351,10 +351,10 @@ Result> ComparisonExpression::Assume( return Assume(*inverted.ValueOrDie()); } - case ExpressionType::ANY: { + case ExpressionType::OR: { bool simplify_to_always = true; bool simplify_to_never = true; - for (const auto& operand : checked_cast(given).operands()) { + for (const auto& operand : checked_cast(given).operands()) { ARROW_ASSIGN_OR_RAISE(auto simplified, Assume(*operand)); BooleanScalar scalar; @@ -386,9 +386,9 @@ Result> ComparisonExpression::Assume( return Copy(); } - case ExpressionType::ALL: { + case ExpressionType::AND: { auto simplified = Copy(); - for (const auto& operand : checked_cast(given).operands()) { + for (const auto& operand : checked_cast(given).operands()) { BooleanScalar value; if (simplified->IsTrivialCondition(&value)) { // FIXME(bkietz) but what if something later is null? @@ -585,9 +585,9 @@ Result> AssumeNnary(const Nnary& nnary, const Expression& given) { // if any of the operands matches trivial_condition, we can return a trivial // expression: - // anything ANY true => true - // anything ALL false => false - constexpr bool trivial_condition = std::is_same::value; + // anything OR true => true + // anything AND false => false + constexpr bool trivial_condition = std::is_same::value; bool simplify_to_trivial = false; ExpressionVector operands; @@ -626,11 +626,11 @@ Result> AssumeNnary(const Nnary& nnary, return std::make_shared(std::move(operands)); } -Result> AllExpression::Assume(const Expression& given) const { +Result> AndExpression::Assume(const Expression& given) const { return AssumeNnary(*this, given); } -Result> AnyExpression::Assume(const Expression& given) const { +Result> OrExpression::Assume(const Expression& given) const { return AssumeNnary(*this, given); } @@ -716,9 +716,9 @@ static std::string EulerNotation(std::string fn, const ExpressionVector& operand return fn; } -std::string AllExpression::ToString() const { return EulerNotation("ALL", operands_); } +std::string AndExpression::ToString() const { return EulerNotation("ALL", operands_); } -std::string AnyExpression::ToString() const { return EulerNotation("ANY", operands_); } +std::string OrExpression::ToString() const { return EulerNotation("ANY", operands_); } std::string NotExpression::ToString() const { return EulerNotation("NOT", {operand_}); } @@ -808,12 +808,12 @@ std::shared_ptr ScalarExpression::Copy() const { return std::make_shared(*this); } -std::shared_ptr all(ExpressionVector operands) { - return std::make_shared(std::move(operands)); +std::shared_ptr and_(ExpressionVector operands) { + return std::make_shared(std::move(operands)); } -std::shared_ptr any(ExpressionVector operands) { - return std::make_shared(std::move(operands)); +std::shared_ptr or_(ExpressionVector operands) { + return std::make_shared(std::move(operands)); } std::shared_ptr not_(std::shared_ptr operand) { @@ -842,12 +842,12 @@ Out MaybeCombine(const Expression& lhs, const Expression& rhs) { return Out(std::move(operands)); } -AllExpression operator&&(const Expression& lhs, const Expression& rhs) { - return MaybeCombine(lhs, rhs); +AndExpression operator&&(const Expression& lhs, const Expression& rhs) { + return MaybeCombine(lhs, rhs); } -AnyExpression operator||(const Expression& lhs, const Expression& rhs) { - return MaybeCombine(lhs, rhs); +OrExpression operator||(const Expression& lhs, const Expression& rhs) { + return MaybeCombine(lhs, rhs); } NotExpression operator!(const Expression& rhs) { return NotExpression(rhs.Copy()); } @@ -894,11 +894,11 @@ Result> ValidateNnary(const NnaryExpression& nnary, return out; } -Result> AllExpression::Validate(const Schema& schema) const { +Result> AndExpression::Validate(const Schema& schema) const { return ValidateNnary(*this, schema); } -Result> AnyExpression::Validate(const Schema& schema) const { +Result> OrExpression::Validate(const Schema& schema) const { return ValidateNnary(*this, schema); } diff --git a/cpp/src/arrow/dataset/filter.h b/cpp/src/arrow/dataset/filter.h index 41e53abffaf..7840833c8da 100644 --- a/cpp/src/arrow/dataset/filter.h +++ b/cpp/src/arrow/dataset/filter.h @@ -33,7 +33,7 @@ namespace dataset { struct FilterType { enum type { /// Simple boolean predicate consisting of comparisons and boolean - /// logic (ALL, ANY, NOT) involving Schema fields + /// logic (ALL, OR, NOT) involving Schema fields EXPRESSION, /// Non decomposable filter; must be evaluated against every record batch @@ -54,7 +54,7 @@ class ARROW_DS_EXPORT Filter { }; /// Filter subclass encapsulating a simple boolean predicate consisting of comparisons -/// and boolean logic (ALL, ANY, NOT) involving Schema fields +/// and boolean logic (ALL, OR, NOT) involving Schema fields class ARROW_DS_EXPORT ExpressionFilter : public Filter { public: explicit ExpressionFilter(const std::shared_ptr& expression) @@ -84,10 +84,10 @@ struct ExpressionType { // TODO(bkietz) CAST, /// a conjunction of multiple expressions (true if all operands are true) - ALL, + AND, /// a disjunction of multiple expressions (true if any operand is true) - ANY, + OR, /// a comparison of two other expressions COMPARISON, @@ -241,8 +241,8 @@ class ARROW_DS_EXPORT ComparisonExpression final compute::CompareOperator op_; }; -class ARROW_DS_EXPORT AllExpression final - : public ExpressionImpl { +class ARROW_DS_EXPORT AndExpression final + : public ExpressionImpl { public: using ExpressionImpl::ExpressionImpl; @@ -256,8 +256,8 @@ class ARROW_DS_EXPORT AllExpression final Result> Validate(const Schema& schema) const override; }; -class ARROW_DS_EXPORT AnyExpression final - : public ExpressionImpl { +class ARROW_DS_EXPORT OrExpression final + : public ExpressionImpl { public: using ExpressionImpl::ExpressionImpl; @@ -357,13 +357,13 @@ class ARROW_DS_EXPORT FieldExpression final : public Expression { std::string name_; }; -ARROW_DS_EXPORT std::shared_ptr all(ExpressionVector operands); +ARROW_DS_EXPORT std::shared_ptr and_(ExpressionVector operands); -ARROW_DS_EXPORT AllExpression operator&&(const Expression& lhs, const Expression& rhs); +ARROW_DS_EXPORT AndExpression operator&&(const Expression& lhs, const Expression& rhs); -ARROW_DS_EXPORT std::shared_ptr any(ExpressionVector operands); +ARROW_DS_EXPORT std::shared_ptr or_(ExpressionVector operands); -ARROW_DS_EXPORT AnyExpression operator||(const Expression& lhs, const Expression& rhs); +ARROW_DS_EXPORT OrExpression operator||(const Expression& lhs, const Expression& rhs); ARROW_DS_EXPORT std::shared_ptr not_(std::shared_ptr operand); @@ -414,8 +414,8 @@ Status Expression::Accept(Visitor&& visitor) const { switch (type_) { EXPRESSION_VISIT_CASE(FIELD, Field); EXPRESSION_VISIT_CASE(SCALAR, Scalar); - EXPRESSION_VISIT_CASE(ALL, All); - EXPRESSION_VISIT_CASE(ANY, Any); + EXPRESSION_VISIT_CASE(AND, And); + EXPRESSION_VISIT_CASE(OR, Or); EXPRESSION_VISIT_CASE(NOT, Not); EXPRESSION_VISIT_CASE(COMPARISON, Comparison); default: diff --git a/cpp/src/arrow/dataset/filter_test.cc b/cpp/src/arrow/dataset/filter_test.cc index 12dd4da47fd..2b76ae2c6e4 100644 --- a/cpp/src/arrow/dataset/filter_test.cc +++ b/cpp/src/arrow/dataset/filter_test.cc @@ -83,7 +83,7 @@ TEST_F(ExpressionsTest, Equality) { TEST_F(ExpressionsTest, SimplificationOfCompoundQuery) { // chained "and" expressions are flattened auto multi_and = "b"_ > 5 and "b"_ < 10 and "b"_ != 7; - AssertOperandsAre(multi_and, ExpressionType::ALL, "b"_ > 5, "b"_ < 10, "b"_ != 7); + AssertOperandsAre(multi_and, ExpressionType::AND, "b"_ > 5, "b"_ < 10, "b"_ != 7); AssertSimplifiesTo("b"_ > 5 and "b"_ < 10, "b"_ == 3, *never); AssertSimplifiesTo("b"_ > 5 and "b"_ < 10, "b"_ == 6, *always); diff --git a/cpp/src/arrow/scalar.cc b/cpp/src/arrow/scalar.cc index df0cca247ce..10019906df1 100644 --- a/cpp/src/arrow/scalar.cc +++ b/cpp/src/arrow/scalar.cc @@ -136,7 +136,8 @@ struct MakeNullImpl { std::shared_ptr* out_; }; -Status MakeNull(const std::shared_ptr& type, std::shared_ptr* null) { +Status MakeNullScalar(const std::shared_ptr& type, + std::shared_ptr* null) { MakeNullImpl impl = {type, null}; return VisitTypeInline(*type, &impl); } diff --git a/cpp/src/arrow/scalar.h b/cpp/src/arrow/scalar.h index db17c6c164f..4c590ef6a42 100644 --- a/cpp/src/arrow/scalar.h +++ b/cpp/src/arrow/scalar.h @@ -250,6 +250,7 @@ class ARROW_EXPORT ExtensionScalar : public Scalar {}; /// \param[out] null output scalar with is_valid=false /// \return Status ARROW_EXPORT -Status MakeNull(const std::shared_ptr& type, std::shared_ptr* null); +Status MakeNullScalar(const std::shared_ptr& type, + std::shared_ptr* null); } // namespace arrow From 9494bab85afa6ca0770f9638d550aa22b075692d Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Sun, 1 Sep 2019 20:05:36 -0400 Subject: [PATCH 26/28] refactor And, Or to binary --- cpp/src/arrow/dataset/filter.cc | 425 ++++++++++++++------------- cpp/src/arrow/dataset/filter.h | 65 ++-- cpp/src/arrow/dataset/filter_test.cc | 15 +- 3 files changed, 244 insertions(+), 261 deletions(-) diff --git a/cpp/src/arrow/dataset/filter.cc b/cpp/src/arrow/dataset/filter.cc index 26464e93922..7471299fcc7 100644 --- a/cpp/src/arrow/dataset/filter.cc +++ b/cpp/src/arrow/dataset/filter.cc @@ -44,45 +44,57 @@ Result ScalarExpression::Evaluate(compute::FunctionContext* ctx, return value_; } +Datum NullDatum() { return Datum(std::make_shared()); } + Result FieldExpression::Evaluate(compute::FunctionContext* ctx, const RecordBatch& batch) const { auto column = batch.GetColumnByName(name_); if (column == nullptr) { - return Datum(std::make_shared()); + return NullDatum(); } return std::move(column); } -bool IsTrivialConditionDatum(const Datum& datum, BooleanScalar* condition) { +bool IsNullDatum(const Datum& datum) { + if (datum.is_scalar()) { + auto scalar = datum.scalar(); + return !scalar->is_valid; + } + + auto array_data = datum.array(); + return array_data->GetNullCount() == array_data->length; +} + +bool IsTrivialConditionDatum(const Datum& datum, bool* condition = nullptr) { if (!datum.is_scalar()) { return false; } auto scalar = datum.scalar(); if (!scalar->is_valid) { - *condition = BooleanScalar(); - return true; + return false; } if (scalar->type->id() != Type::BOOL) { return false; } - *condition = checked_cast(*scalar); + if (condition) { + *condition = checked_cast(*scalar).value; + } return true; } Result NotExpression::Evaluate(compute::FunctionContext* ctx, const RecordBatch& batch) const { ARROW_ASSIGN_OR_RAISE(auto to_invert, operand_->Evaluate(ctx, batch)); - DCHECK(to_invert.type()->Equals(boolean())); + if (IsNullDatum(to_invert)) { + return NullDatum(); + } - BooleanScalar trivial_condition; + bool trivial_condition; if (IsTrivialConditionDatum(to_invert, &trivial_condition)) { - if (trivial_condition.is_valid) { - trivial_condition.value = !trivial_condition.value; - } - return Datum(std::make_shared(trivial_condition)); + return Datum(std::make_shared(!trivial_condition)); } DCHECK(to_invert.is_array()); @@ -91,61 +103,68 @@ Result NotExpression::Evaluate(compute::FunctionContext* ctx, return std::move(out); } -// TODO(bkietz) more reusable coallesce helper -template -bool FinishWithTrivial(const Datum& d, Datum* out) { - BooleanScalar trivial; - if (IsTrivialConditionDatum(d, &trivial)) { - if (!trivial.is_valid || trivial.value == trivial_condition) { - *out = d; - return true; - } - } - return false; -} - Result AndExpression::Evaluate(compute::FunctionContext* ctx, const RecordBatch& batch) const { - ARROW_ASSIGN_OR_RAISE(auto next, operands_[0]->Evaluate(ctx, batch)); - Datum acc(next); + ARROW_ASSIGN_OR_RAISE(auto lhs, left_operand_->Evaluate(ctx, batch)); + ARROW_ASSIGN_OR_RAISE(auto rhs, right_operand_->Evaluate(ctx, batch)); - if (FinishWithTrivial(next, &acc)) { - return std::move(acc); + if (IsNullDatum(lhs) || IsNullDatum(rhs)) { + return NullDatum(); } - for (size_t i_next = 1; i_next < operands_.size(); ++i_next) { - ARROW_ASSIGN_OR_RAISE(next, operands_[i_next]->Evaluate(ctx, batch)); + if (lhs.is_array() && rhs.is_array()) { + Datum out; + RETURN_NOT_OK(arrow::compute::And(ctx, lhs, rhs, &out)); + return std::move(out); + } - if (FinishWithTrivial(next, &acc)) { - return std::move(acc); - } + if (lhs.is_scalar() && rhs.is_scalar()) { + return Datum(checked_cast(*lhs.scalar()).value && + checked_cast(*rhs.scalar()).value); + } + + auto array_operand = (lhs.is_array() ? lhs : rhs).make_array(); + bool scalar_operand = + checked_cast(*(lhs.is_scalar() ? lhs : rhs).scalar()).value; - RETURN_NOT_OK(arrow::compute::And(ctx, Datum(acc), Datum(next), &acc)); + if (!scalar_operand) { + // FIXME(bkietz) this is an error if array_operand contains nulls + return Datum(false); } - return std::move(acc); + return Datum(array_operand); } Result OrExpression::Evaluate(compute::FunctionContext* ctx, const RecordBatch& batch) const { - ARROW_ASSIGN_OR_RAISE(auto next, operands_[0]->Evaluate(ctx, batch)); - Datum acc(next); + ARROW_ASSIGN_OR_RAISE(auto lhs, left_operand_->Evaluate(ctx, batch)); + ARROW_ASSIGN_OR_RAISE(auto rhs, right_operand_->Evaluate(ctx, batch)); - if (FinishWithTrivial(next, &acc)) { - return std::move(acc); + if (IsNullDatum(lhs) || IsNullDatum(rhs)) { + return NullDatum(); } - for (size_t i_next = 1; i_next < operands_.size(); ++i_next) { - ARROW_ASSIGN_OR_RAISE(next, operands_[i_next]->Evaluate(ctx, batch)); + if (lhs.is_array() && rhs.is_array()) { + Datum out; + RETURN_NOT_OK(arrow::compute::Or(ctx, lhs, rhs, &out)); + return std::move(out); + } - if (FinishWithTrivial(next, &acc)) { - return std::move(acc); - } + if (lhs.is_scalar() && rhs.is_scalar()) { + return Datum(checked_cast(*lhs.scalar()).value && + checked_cast(*rhs.scalar()).value); + } + + auto array_operand = (lhs.is_array() ? lhs : rhs).make_array(); + bool scalar_operand = + checked_cast(*(lhs.is_scalar() ? lhs : rhs).scalar()).value; - RETURN_NOT_OK(arrow::compute::And(ctx, Datum(acc), Datum(next), &acc)); + if (!scalar_operand) { + // FIXME(bkietz) this is an error if array_operand contains nulls + return Datum(true); } - return std::move(acc); + return Datum(array_operand); } Result ComparisonExpression::Evaluate(compute::FunctionContext* ctx, @@ -153,17 +172,12 @@ Result ComparisonExpression::Evaluate(compute::FunctionContext* ctx, ARROW_ASSIGN_OR_RAISE(auto lhs, left_operand_->Evaluate(ctx, batch)); ARROW_ASSIGN_OR_RAISE(auto rhs, right_operand_->Evaluate(ctx, batch)); - if (lhs.is_scalar()) { - if (!lhs.scalar()->is_valid) { - return std::move(lhs); - } - return Status::NotImplemented("comparison with scalar LHS"); + if (IsNullDatum(lhs) || IsNullDatum(rhs)) { + return NullDatum(); } - if (rhs.is_scalar()) { - if (!rhs.scalar()->is_valid) { - return std::move(rhs); - } + if (lhs.is_scalar()) { + return Status::NotImplemented("comparison with scalar LHS"); } Datum out; @@ -273,61 +287,67 @@ Result Compare(const Scalar& lhs, const Scalar& rhs) { return vis.result_; } -std::shared_ptr Invert(const ComparisonExpression& comparison) { +compute::CompareOperator InvertCompareOperator(compute::CompareOperator op) { using compute::CompareOperator; - auto make_opposite = [&](CompareOperator opposite) { - return std::make_shared(opposite, comparison.left_operand(), - comparison.right_operand()); - }; - switch (comparison.op()) { + switch (op) { case CompareOperator::EQUAL: - return make_opposite(CompareOperator::NOT_EQUAL); + return CompareOperator::NOT_EQUAL; case CompareOperator::NOT_EQUAL: - return make_opposite(CompareOperator::EQUAL); + return CompareOperator::EQUAL; case CompareOperator::GREATER: - return make_opposite(CompareOperator::LESS_EQUAL); + return CompareOperator::LESS_EQUAL; case CompareOperator::GREATER_EQUAL: - return make_opposite(CompareOperator::LESS); + return CompareOperator::LESS; case CompareOperator::LESS: - return make_opposite(CompareOperator::GREATER_EQUAL); + return CompareOperator::GREATER_EQUAL; case CompareOperator::LESS_EQUAL: - return make_opposite(CompareOperator::GREATER); + return CompareOperator::GREATER; default: break; } DCHECK(false); - return nullptr; + return CompareOperator::EQUAL; } -Result> Invert(const Expression& op) { - switch (op.type()) { +template +Result> InvertBoolean(const Boolean& expr) { + ARROW_ASSIGN_OR_RAISE(auto lhs, Invert(*expr.left_operand())); + ARROW_ASSIGN_OR_RAISE(auto rhs, Invert(*expr.right_operand())); + + if (std::is_same::value) { + return std::make_shared(std::move(lhs), std::move(rhs)); + } + + if (std::is_same::value) { + return std::make_shared(std::move(lhs), std::move(rhs)); + } +} + +Result> Invert(const Expression& expr) { + switch (expr.type()) { case ExpressionType::NOT: - return checked_cast(op).operand(); + return checked_cast(expr).operand(); case ExpressionType::AND: - case ExpressionType::OR: { - ExpressionVector inverted_operands; - for (auto operand : checked_cast(op).operands()) { - ARROW_ASSIGN_OR_RAISE(auto inverted_operand, Invert(*operand)); - inverted_operands.push_back(inverted_operand); - } + return InvertBoolean(checked_cast(expr)); - if (op.type() == ExpressionType::AND) { - return std::make_shared(std::move(inverted_operands)); - } - return std::make_shared(std::move(inverted_operands)); - } + case ExpressionType::OR: + return InvertBoolean(checked_cast(expr)); - case ExpressionType::COMPARISON: - return Invert(checked_cast(op)); + case ExpressionType::COMPARISON: { + const auto& comparison = checked_cast(expr); + auto inverted_op = InvertCompareOperator(comparison.op()); + return std::make_shared( + inverted_op, comparison.left_operand(), comparison.right_operand()); + } default: break; @@ -352,23 +372,28 @@ Result> ComparisonExpression::Assume( } case ExpressionType::OR: { + const auto& given_or = checked_cast(given); + bool simplify_to_always = true; bool simplify_to_never = true; - for (const auto& operand : checked_cast(given).operands()) { + for (const auto& operand : {given_or.left_operand(), given_or.right_operand()}) { ARROW_ASSIGN_OR_RAISE(auto simplified, Assume(*operand)); - BooleanScalar scalar; - if (!simplified->IsTrivialCondition(&scalar)) { - simplify_to_never = false; - simplify_to_always = false; - } - - if (!scalar.is_valid) { + if (simplified->IsNull()) { // some subexpression of given is always null, return null return ScalarExpression::MakeNull(boolean()); } - if (scalar.value == true) { + bool trivial; + if (!simplified->IsTrivialCondition(&trivial)) { + // Or cannot be simplified unless all of its operands simplify to the same + // trivial condition + simplify_to_never = false; + simplify_to_always = false; + continue; + } + + if (trivial == true) { simplify_to_never = false; } else { simplify_to_always = false; @@ -387,12 +412,16 @@ Result> ComparisonExpression::Assume( } case ExpressionType::AND: { + const auto& given_and = checked_cast(given); + auto simplified = Copy(); - for (const auto& operand : checked_cast(given).operands()) { - BooleanScalar value; - if (simplified->IsTrivialCondition(&value)) { - // FIXME(bkietz) but what if something later is null? - break; + for (const auto& operand : {given_and.left_operand(), given_and.right_operand()}) { + if (simplified->IsNull()) { + return ScalarExpression::MakeNull(boolean()); + } + + if (simplified->IsTrivialCondition()) { + return std::move(simplified); } ARROW_ASSIGN_OR_RAISE(simplified, simplified->Assume(*operand)); @@ -578,75 +607,84 @@ Result> ComparisonExpression::AssumeGivenComparison( default: return Copy(); } + return Copy(); } -template -Result> AssumeNnary(const Nnary& nnary, - const Expression& given) { - // if any of the operands matches trivial_condition, we can return a trivial - // expression: - // anything OR true => true - // anything AND false => false - constexpr bool trivial_condition = std::is_same::value; - bool simplify_to_trivial = false; - - ExpressionVector operands; - for (auto operand : nnary.operands()) { - ARROW_ASSIGN_OR_RAISE(operand, operand->Assume(given)); +Result> AndExpression::Assume(const Expression& given) const { + ARROW_ASSIGN_OR_RAISE(auto left_operand, left_operand_->Assume(given)); + ARROW_ASSIGN_OR_RAISE(auto right_operand, right_operand_->Assume(given)); - BooleanScalar scalar; - if (operand->IsTrivialCondition(&scalar)) { - if (!scalar.is_valid) { - return ScalarExpression::MakeNull(boolean()); - } + // if either operand is trivially null then so is this AND + if (left_operand->IsNull() || right_operand->IsNull()) { + return ScalarExpression::MakeNull(boolean()); + } - if (scalar.value == trivial_condition) { - simplify_to_trivial = true; - } - continue; - } + bool left_trivial, right_trivial; + bool left_is_trivial = left_operand->IsTrivialCondition(&left_trivial); + bool right_is_trivial = right_operand->IsTrivialCondition(&right_trivial); - if (!simplify_to_trivial) { - operands.push_back(operand); - } + // if neither of the operands is trivial, simply construct a new AND + if (!left_is_trivial && !right_is_trivial) { + return std::make_shared(std::move(left_operand), + std::move(right_operand)); } - if (simplify_to_trivial) { - return ScalarExpression::Make(trivial_condition); + // if either of the operands is trivially false then so is this AND + if ((left_is_trivial && left_trivial == false) || + (right_is_trivial && right_trivial == false)) { + // FIXME(bkietz) if left is false and right is a column conaining nulls, this is an + // error because we should be yielding null there rather than false + return ScalarExpression::Make(false); } - if (operands.size() == 1) { - return operands[0]; - } + // at least one of the operands is trivially true; return the other operand + return right_is_trivial ? std::move(left_operand) : std::move(right_operand); +} - if (operands.size() == 0) { - return ScalarExpression::Make(!trivial_condition); +Result> OrExpression::Assume(const Expression& given) const { + ARROW_ASSIGN_OR_RAISE(auto left_operand, left_operand_->Assume(given)); + ARROW_ASSIGN_OR_RAISE(auto right_operand, right_operand_->Assume(given)); + + // if either operand is trivially null then so is this OR + if (left_operand->IsNull() || right_operand->IsNull()) { + return ScalarExpression::MakeNull(boolean()); } - return std::make_shared(std::move(operands)); -} + bool left_trivial, right_trivial; + bool left_is_trivial = left_operand->IsTrivialCondition(&left_trivial); + bool right_is_trivial = right_operand->IsTrivialCondition(&right_trivial); -Result> AndExpression::Assume(const Expression& given) const { - return AssumeNnary(*this, given); -} + // if neither of the operands is trivial, simply construct a new OR + if (!left_is_trivial && !right_is_trivial) { + return std::make_shared(std::move(left_operand), + std::move(right_operand)); + } -Result> OrExpression::Assume(const Expression& given) const { - return AssumeNnary(*this, given); + // if either of the operands is trivially true then so is this OR + if ((left_is_trivial && left_trivial == true) || + (right_is_trivial && right_trivial == true)) { + // FIXME(bkietz) if left is true but right is a column conaining nulls, this is an + // error because we should be yielding null there rather than true + return ScalarExpression::Make(true); + } + + // at least one of the operands is trivially false; return the other operand + return right_is_trivial ? std::move(left_operand) : std::move(right_operand); } Result> NotExpression::Assume(const Expression& given) const { ARROW_ASSIGN_OR_RAISE(auto operand, operand_->Assume(given)); - BooleanScalar scalar; - if (operand->IsTrivialCondition(&scalar)) { - return Copy(); + if (operand->IsNull()) { + return ScalarExpression::MakeNull(boolean()); } - if (!scalar.is_valid) { - return ScalarExpression::MakeNull(boolean()); + bool trivial; + if (operand->IsTrivialCondition(&trivial)) { + return ScalarExpression::Make(!trivial); } - return ScalarExpression::Make(!scalar.value); + return Copy(); } std::string FieldExpression::ToString() const { @@ -684,6 +722,9 @@ std::string ScalarExpression::ToString() const { case Type::BOOL: value = checked_cast(*value_).value ? "true" : "false"; break; + case Type::INT32: + value = std::to_string(checked_cast(*value_).value); + break; case Type::INT64: value = std::to_string(checked_cast(*value_).value); break; @@ -716,9 +757,13 @@ static std::string EulerNotation(std::string fn, const ExpressionVector& operand return fn; } -std::string AndExpression::ToString() const { return EulerNotation("ALL", operands_); } +std::string AndExpression::ToString() const { + return EulerNotation("AND", {left_operand_, right_operand_}); +} -std::string OrExpression::ToString() const { return EulerNotation("ANY", operands_); } +std::string OrExpression::ToString() const { + return EulerNotation("OR", {left_operand_, right_operand_}); +} std::string NotExpression::ToString() const { return EulerNotation("NOT", {operand_}); } @@ -739,22 +784,6 @@ bool BinaryExpression::Equals(const Expression& other) const { checked_cast(other).right_operand_); } -bool NnaryExpression::Equals(const Expression& other) const { - if (type_ != other.type()) { - return false; - } - const auto& other_operands = checked_cast(other).operands_; - if (operands_.size() != other_operands.size()) { - return false; - } - for (size_t i = 0; i < operands_.size(); ++i) { - if (!operands_[i]->Equals(other_operands[i])) { - return false; - } - } - return true; -} - bool ComparisonExpression::Equals(const Expression& other) const { return BinaryExpression::Equals(other) && op_ == checked_cast(other).op_; @@ -777,25 +806,35 @@ bool Expression::Equals(const std::shared_ptr& other) const { return Equals(*other); } -bool Expression::IsTrivialCondition(BooleanScalar* out) const { +bool Expression::IsNull() const { if (type_ != ExpressionType::SCALAR) { return false; } const auto& scalar = checked_cast(*this).value(); if (!scalar->is_valid) { - if (out) { - *out = BooleanScalar(); - } return true; } + return false; +} + +bool Expression::IsTrivialCondition(bool* out) const { + if (type_ != ExpressionType::SCALAR) { + return false; + } + + const auto& scalar = checked_cast(*this).value(); + if (!scalar->is_valid) { + return false; + } + if (scalar->type->id() != Type::BOOL) { return false; } if (out) { - *out = BooleanScalar(checked_cast(*scalar).value); + *out = checked_cast(*scalar).value; } return true; } @@ -808,46 +847,26 @@ std::shared_ptr ScalarExpression::Copy() const { return std::make_shared(*this); } -std::shared_ptr and_(ExpressionVector operands) { - return std::make_shared(std::move(operands)); +std::shared_ptr and_(std::shared_ptr lhs, + std::shared_ptr rhs) { + return std::make_shared(std::move(lhs), std::move(rhs)); } -std::shared_ptr or_(ExpressionVector operands) { - return std::make_shared(std::move(operands)); +std::shared_ptr or_(std::shared_ptr lhs, + std::shared_ptr rhs) { + return std::make_shared(std::move(lhs), std::move(rhs)); } std::shared_ptr not_(std::shared_ptr operand) { return std::make_shared(std::move(operand)); } -// flatten chains of and/or to a single OperatorExpression -template -Out MaybeCombine(const Expression& lhs, const Expression& rhs) { - if (lhs.type() != Out::expression_type && rhs.type() != Out::expression_type) { - return Out(ExpressionVector{lhs.Copy(), rhs.Copy()}); - } - - ExpressionVector operands; - for (auto side : {&lhs, &rhs}) { - if (side->type() != Out::expression_type) { - operands.emplace_back(side->Copy()); - continue; - } - - for (auto operand : checked_cast(*side).operands()) { - operands.emplace_back(std::move(operand)); - } - } - - return Out(std::move(operands)); -} - AndExpression operator&&(const Expression& lhs, const Expression& rhs) { - return MaybeCombine(lhs, rhs); + return AndExpression(lhs.Copy(), rhs.Copy()); } OrExpression operator||(const Expression& lhs, const Expression& rhs) { - return MaybeCombine(lhs, rhs); + return OrExpression(lhs.Copy(), rhs.Copy()); } NotExpression operator!(const Expression& rhs) { return NotExpression(rhs.Copy()); } @@ -880,10 +899,10 @@ Status EnsureNullOrBool(const std::string& msg_prefix, return Status::TypeError(msg_prefix, *type); } -Result> ValidateNnary(const NnaryExpression& nnary, - const Schema& schema) { +Result> ValidateBoolean(const ExpressionVector& operands, + const Schema& schema) { auto out = boolean(); - for (const auto& operand : nnary.operands()) { + for (auto operand : operands) { ARROW_ASSIGN_OR_RAISE(auto type, operand->Validate(schema)); RETURN_NOT_OK( EnsureNullOrBool("cannot combine expressions including one of type ", type)); @@ -891,21 +910,19 @@ Result> ValidateNnary(const NnaryExpression& nnary, out = null(); } } - return out; + return std::move(out); } Result> AndExpression::Validate(const Schema& schema) const { - return ValidateNnary(*this, schema); + return ValidateBoolean({left_operand_, right_operand_}, schema); } Result> OrExpression::Validate(const Schema& schema) const { - return ValidateNnary(*this, schema); + return ValidateBoolean({left_operand_, right_operand_}, schema); } Result> NotExpression::Validate(const Schema& schema) const { - ARROW_ASSIGN_OR_RAISE(auto operand_type, operand_->Validate(schema)); - RETURN_NOT_OK(EnsureNullOrBool("cannot invert an expression of type ", operand_type)); - return std::move(operand_type); + return ValidateBoolean({operand_}, schema); } Result> ScalarExpression::Validate(const Schema& schema) const { diff --git a/cpp/src/arrow/dataset/filter.h b/cpp/src/arrow/dataset/filter.h index 7840833c8da..faefa800e94 100644 --- a/cpp/src/arrow/dataset/filter.h +++ b/cpp/src/arrow/dataset/filter.h @@ -92,8 +92,12 @@ struct ExpressionType { /// a comparison of two other expressions COMPARISON, - /// extract validity as an expression (true if operand is valid) - // TODO(bkietz) VALIDITY, + /// replace nulls with other expressions + /// currently only boolean expressions may be coalesced + // TODO(bkietz) COALESCE, + + /// extract validity as a boolean expression + // TODO(bkietz) IS_VALID, }; }; @@ -136,16 +140,16 @@ class ARROW_DS_EXPORT Expression { ExpressionType::type type() const { return type_; } - /// If true, this Expression is a ScalarExpression wrapping either a null Scalar or a - /// non-null BooleanScalar. Its value may be retrieved at the same time. - bool IsTrivialCondition(BooleanScalar* value = NULLPTR) const; + /// If true, this Expression is a ScalarExpression wrapping a null scalar. + bool IsNull() const; + + /// If true, this Expression is a ScalarExpression wrapping a + /// BooleanScalar. Its value may be retrieved at the same time. + bool IsTrivialCondition(bool* value = NULLPTR) const; /// Copy this expression into a shared pointer. virtual std::shared_ptr Copy() const = 0; - template - Status Accept(Visitor&& visitor) const; - protected: ExpressionType::type type_; }; @@ -198,20 +202,6 @@ class ARROW_DS_EXPORT BinaryExpression : public Expression { std::shared_ptr left_operand_, right_operand_; }; -/// Base class for an expression with multiple operands -class ARROW_DS_EXPORT NnaryExpression : public Expression { - public: - const ExpressionVector& operands() const { return operands_; } - - bool Equals(const Expression& other) const override; - - protected: - NnaryExpression(ExpressionType::type type, ExpressionVector operands) - : Expression(type), operands_(std::move(operands)) {} - - ExpressionVector operands_; -}; - class ARROW_DS_EXPORT ComparisonExpression final : public ExpressionImpl { @@ -242,7 +232,7 @@ class ARROW_DS_EXPORT ComparisonExpression final }; class ARROW_DS_EXPORT AndExpression final - : public ExpressionImpl { + : public ExpressionImpl { public: using ExpressionImpl::ExpressionImpl; @@ -257,7 +247,7 @@ class ARROW_DS_EXPORT AndExpression final }; class ARROW_DS_EXPORT OrExpression final - : public ExpressionImpl { + : public ExpressionImpl { public: using ExpressionImpl::ExpressionImpl; @@ -357,11 +347,13 @@ class ARROW_DS_EXPORT FieldExpression final : public Expression { std::string name_; }; -ARROW_DS_EXPORT std::shared_ptr and_(ExpressionVector operands); +ARROW_DS_EXPORT std::shared_ptr and_(std::shared_ptr lhs, + std::shared_ptr rhs); ARROW_DS_EXPORT AndExpression operator&&(const Expression& lhs, const Expression& rhs); -ARROW_DS_EXPORT std::shared_ptr or_(ExpressionVector operands); +ARROW_DS_EXPORT std::shared_ptr or_(std::shared_ptr lhs, + std::shared_ptr rhs); ARROW_DS_EXPORT OrExpression operator||(const Expression& lhs, const Expression& rhs); @@ -405,26 +397,5 @@ inline FieldExpression operator""_(const char* name, size_t name_length) { } } // namespace string_literals -#define EXPRESSION_VISIT_CASE(NAME, CLASS) \ - case ExpressionType::NAME: \ - return visitor.Visit(internal::checked_cast(*this)) - -template -Status Expression::Accept(Visitor&& visitor) const { - switch (type_) { - EXPRESSION_VISIT_CASE(FIELD, Field); - EXPRESSION_VISIT_CASE(SCALAR, Scalar); - EXPRESSION_VISIT_CASE(AND, And); - EXPRESSION_VISIT_CASE(OR, Or); - EXPRESSION_VISIT_CASE(NOT, Not); - EXPRESSION_VISIT_CASE(COMPARISON, Comparison); - default: - break; - } - return Status::TypeError("unknown ExpressionType"); -} - -#undef EXPRESSION_VISIT_CASE - } // namespace dataset } // namespace arrow diff --git a/cpp/src/arrow/dataset/filter_test.cc b/cpp/src/arrow/dataset/filter_test.cc index 2b76ae2c6e4..401f9a73284 100644 --- a/cpp/src/arrow/dataset/filter_test.cc +++ b/cpp/src/arrow/dataset/filter_test.cc @@ -52,16 +52,11 @@ class ExpressionsTest : public ::testing::Test { } } - template - void AssertOperandsAre(const NnaryExpression& expr, ExpressionType::type type, - T... expected_operands) { + void AssertOperandsAre(const BinaryExpression& expr, ExpressionType::type type, + const Expression& lhs, const Expression& rhs) { ASSERT_EQ(expr.type(), type); - ASSERT_EQ(expr.operands().size(), sizeof...(T)); - std::shared_ptr expected_operand_ptrs[] = {expected_operands.Copy()...}; - - for (size_t i = 0; i < sizeof...(T); ++i) { - ASSERT_TRUE(expr.operands()[i]->Equals(expected_operand_ptrs[i])); - } + ASSERT_TRUE(expr.left_operand()->Equals(lhs)); + ASSERT_TRUE(expr.right_operand()->Equals(rhs)); } std::shared_ptr always = ScalarExpression::Make(true); @@ -83,7 +78,7 @@ TEST_F(ExpressionsTest, Equality) { TEST_F(ExpressionsTest, SimplificationOfCompoundQuery) { // chained "and" expressions are flattened auto multi_and = "b"_ > 5 and "b"_ < 10 and "b"_ != 7; - AssertOperandsAre(multi_and, ExpressionType::AND, "b"_ > 5, "b"_ < 10, "b"_ != 7); + AssertOperandsAre(multi_and, ExpressionType::AND, ("b"_ > 5 and "b"_ < 10), "b"_ != 7); AssertSimplifiesTo("b"_ > 5 and "b"_ < 10, "b"_ == 3, *never); AssertSimplifiesTo("b"_ > 5 and "b"_ < 10, "b"_ == 6, *always); From 539c2871944d5c02be06e267505e0e99c2d03387 Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Mon, 2 Sep 2019 09:25:53 -0400 Subject: [PATCH 27/28] rename fieldRef to field_ref, comments --- cpp/src/arrow/dataset/filter.cc | 2 ++ cpp/src/arrow/dataset/filter.h | 9 ++++++--- cpp/src/arrow/dataset/filter_test.cc | 8 ++++---- 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/cpp/src/arrow/dataset/filter.cc b/cpp/src/arrow/dataset/filter.cc index 7471299fcc7..e8985dbeb7e 100644 --- a/cpp/src/arrow/dataset/filter.cc +++ b/cpp/src/arrow/dataset/filter.cc @@ -329,6 +329,8 @@ Result> InvertBoolean(const Boolean& expr) { if (std::is_same::value) { return std::make_shared(std::move(lhs), std::move(rhs)); } + + return Status::Invalid("unknown boolean expression ", expr.ToString()); } Result> Invert(const Expression& expr) { diff --git a/cpp/src/arrow/dataset/filter.h b/cpp/src/arrow/dataset/filter.h index faefa800e94..9acece6f2d9 100644 --- a/cpp/src/arrow/dataset/filter.h +++ b/cpp/src/arrow/dataset/filter.h @@ -130,8 +130,11 @@ class ARROW_DS_EXPORT Expression { return Copy(); } - // Evaluate an expression against a RecordBatch. - // Returned Datum must be of either SCALAR or ARRAY kind. + /// Evaluate this expression against each row of a RecordBatch. + /// Returned Datum will be of either SCALAR or ARRAY kind. + /// A return value of ARRAY kind will have length == batch.num_rows() + /// An return value of SCALAR kind is equivalent to an array of the same type whose + /// slots contain a single repeated value. virtual Result Evaluate(compute::FunctionContext* ctx, const RecordBatch& batch) const = 0; @@ -387,7 +390,7 @@ auto scalar(T&& value) -> decltype(ScalarExpression::Make(std::forward(value) return ScalarExpression::Make(std::forward(value)); } -inline std::shared_ptr fieldRef(std::string name) { +inline std::shared_ptr field_ref(std::string name) { return std::make_shared(std::move(name)); } diff --git a/cpp/src/arrow/dataset/filter_test.cc b/cpp/src/arrow/dataset/filter_test.cc index 401f9a73284..2e884ebb08f 100644 --- a/cpp/src/arrow/dataset/filter_test.cc +++ b/cpp/src/arrow/dataset/filter_test.cc @@ -106,10 +106,10 @@ TEST_F(ExpressionsTest, SimplificationToNull) { auto null = ScalarExpression::MakeNull(boolean()); auto null32 = ScalarExpression::MakeNull(int32()); - AssertSimplifiesTo(*equal(fieldRef("b"), null32), "b"_ == 3, *null); - AssertSimplifiesTo(*not_equal(fieldRef("b"), null32), "b"_ == 3, *null); - AssertSimplifiesTo(*not_equal(fieldRef("b"), null32) and "b"_ > 3, "b"_ == 3, *null); - AssertSimplifiesTo("b"_ > 3 and *not_equal(fieldRef("b"), null32), "b"_ == 3, *null); + AssertSimplifiesTo(*equal(field_ref("b"), null32), "b"_ == 3, *null); + AssertSimplifiesTo(*not_equal(field_ref("b"), null32), "b"_ == 3, *null); + AssertSimplifiesTo(*not_equal(field_ref("b"), null32) and "b"_ > 3, "b"_ == 3, *null); + AssertSimplifiesTo("b"_ > 3 and *not_equal(field_ref("b"), null32), "b"_ == 3, *null); } class FilterTest : public ::testing::Test { From fda47422fc467f433556ec12e23165bb1d2792a1 Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Mon, 9 Sep 2019 10:44:16 -0400 Subject: [PATCH 28/28] give MSVC a little help to avoid instantiating impossible constructors --- cpp/src/arrow/dataset/filter.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cpp/src/arrow/dataset/filter.h b/cpp/src/arrow/dataset/filter.h index 9acece6f2d9..9cf31f1ed59 100644 --- a/cpp/src/arrow/dataset/filter.h +++ b/cpp/src/arrow/dataset/filter.h @@ -163,9 +163,9 @@ class ExpressionImpl : public Base { public: static constexpr ExpressionType::type expression_type = E; - template - explicit ExpressionImpl(A&&... args) - : Base(expression_type, std::forward(args)...) {} + template + explicit ExpressionImpl(A0&& arg0, A&&... args) + : Base(expression_type, std::forward(arg0), std::forward(args)...) {} std::shared_ptr Copy() const override { return std::make_shared(internal::checked_cast(*this));