diff --git a/cpp/build-support/run_cpplint.py b/cpp/build-support/run_cpplint.py index 9ee28af727e..bf2b51b8b2f 100755 --- a/cpp/build-support/run_cpplint.py +++ b/cpp/build-support/run_cpplint.py @@ -35,6 +35,7 @@ -whitespace/comments -readability/casting -readability/todo +-readability/alt_tokens -build/header_guard -build/c++11 -runtime/references diff --git a/cpp/src/arrow/compute/kernels/compare.cc b/cpp/src/arrow/compute/kernels/compare.cc index 0ed785f26f9..386e3b98e59 100644 --- a/cpp/src/arrow/compute/kernels/compare.cc +++ b/cpp/src/arrow/compute/kernels/compare.cc @@ -244,7 +244,10 @@ Status Compare(FunctionContext* context, const Datum& left, const Datum& right, DCHECK(out); auto type = left.type(); - DCHECK(type->Equals(right.type())); + if (!type->Equals(right.type())) { + return Status::TypeError("Cannot compare data of differing type ", *type, " vs ", + *right.type()); + } // Requires that both types are equal. auto fn = MakeCompareFunction(context, *type, options); if (fn == nullptr) { diff --git a/cpp/src/arrow/dataset/CMakeLists.txt b/cpp/src/arrow/dataset/CMakeLists.txt index 9f102ac1601..6f94297e272 100644 --- a/cpp/src/arrow/dataset/CMakeLists.txt +++ b/cpp/src/arrow/dataset/CMakeLists.txt @@ -23,7 +23,7 @@ arrow_install_all_headers("arrow/dataset") # pkg-config support arrow_add_pkg_config("arrow-dataset") -set(ARROW_DATASET_SRCS dataset.cc file_base.cc scanner.cc) +set(ARROW_DATASET_SRCS dataset.cc file_base.cc filter.cc scanner.cc) set(ARROW_DATASET_LINK_STATIC arrow_static) set(ARROW_DATASET_LINK_SHARED arrow_shared) @@ -79,22 +79,23 @@ function(ADD_ARROW_DATASET_TEST REL_TEST_NAME) set(LABELS "arrow_dataset") endif() - if(NOT WIN32) - add_arrow_test(${REL_TEST_NAME} - EXTRA_LINK_LIBS - ${ARROW_DATASET_TEST_LINK_LIBS} - PREFIX - ${PREFIX} - LABELS - ${LABELS} - ${ARG_UNPARSED_ARGUMENTS}) - endif() + add_arrow_test(${REL_TEST_NAME} + EXTRA_LINK_LIBS + ${ARROW_DATASET_TEST_LINK_LIBS} + PREFIX + ${PREFIX} + LABELS + ${LABELS} + ${ARG_UNPARSED_ARGUMENTS}) endfunction() -add_arrow_dataset_test(dataset_test) -add_arrow_dataset_test(file_test) -add_arrow_dataset_test(scanner_test) +if(NOT WIN32) + add_arrow_dataset_test(dataset_test) + add_arrow_dataset_test(file_test) + add_arrow_dataset_test(filter_test) + add_arrow_dataset_test(scanner_test) -if(ARROW_PARQUET) - add_arrow_dataset_test(file_parquet_test) + if(ARROW_PARQUET) + add_arrow_dataset_test(file_parquet_test) + endif() endif() diff --git a/cpp/src/arrow/dataset/api.h b/cpp/src/arrow/dataset/api.h index 18622f3a448..f9e49f20994 100644 --- a/cpp/src/arrow/dataset/api.h +++ b/cpp/src/arrow/dataset/api.h @@ -22,4 +22,5 @@ #include "arrow/dataset/file_base.h" #include "arrow/dataset/file_csv.h" #include "arrow/dataset/file_feather.h" +#include "arrow/dataset/filter.h" #include "arrow/dataset/scanner.h" diff --git a/cpp/src/arrow/dataset/filter.cc b/cpp/src/arrow/dataset/filter.cc new file mode 100644 index 00000000000..e8985dbeb7e --- /dev/null +++ b/cpp/src/arrow/dataset/filter.cc @@ -0,0 +1,942 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/dataset/filter.h" + +#include +#include +#include +#include +#include + +#include "arrow/buffer.h" +#include "arrow/buffer_builder.h" +#include "arrow/compute/context.h" +#include "arrow/compute/kernels/boolean.h" +#include "arrow/compute/kernels/compare.h" +#include "arrow/record_batch.h" +#include "arrow/util/logging.h" +#include "arrow/visitor_inline.h" + +namespace arrow { +namespace dataset { + +using arrow::compute::Datum; +using internal::checked_cast; +using internal::checked_pointer_cast; + +Result ScalarExpression::Evaluate(compute::FunctionContext* ctx, + const RecordBatch& batch) const { + return value_; +} + +Datum NullDatum() { return Datum(std::make_shared()); } + +Result FieldExpression::Evaluate(compute::FunctionContext* ctx, + const RecordBatch& batch) const { + auto column = batch.GetColumnByName(name_); + if (column == nullptr) { + return NullDatum(); + } + return std::move(column); +} + +bool IsNullDatum(const Datum& datum) { + if (datum.is_scalar()) { + auto scalar = datum.scalar(); + return !scalar->is_valid; + } + + auto array_data = datum.array(); + return array_data->GetNullCount() == array_data->length; +} + +bool IsTrivialConditionDatum(const Datum& datum, bool* condition = nullptr) { + if (!datum.is_scalar()) { + return false; + } + + auto scalar = datum.scalar(); + if (!scalar->is_valid) { + return false; + } + + if (scalar->type->id() != Type::BOOL) { + return false; + } + + if (condition) { + *condition = checked_cast(*scalar).value; + } + return true; +} + +Result NotExpression::Evaluate(compute::FunctionContext* ctx, + const RecordBatch& batch) const { + ARROW_ASSIGN_OR_RAISE(auto to_invert, operand_->Evaluate(ctx, batch)); + if (IsNullDatum(to_invert)) { + return NullDatum(); + } + + bool trivial_condition; + if (IsTrivialConditionDatum(to_invert, &trivial_condition)) { + return Datum(std::make_shared(!trivial_condition)); + } + + DCHECK(to_invert.is_array()); + Datum out; + RETURN_NOT_OK(arrow::compute::Invert(ctx, Datum(to_invert), &out)); + return std::move(out); +} + +Result AndExpression::Evaluate(compute::FunctionContext* ctx, + const RecordBatch& batch) const { + ARROW_ASSIGN_OR_RAISE(auto lhs, left_operand_->Evaluate(ctx, batch)); + ARROW_ASSIGN_OR_RAISE(auto rhs, right_operand_->Evaluate(ctx, batch)); + + if (IsNullDatum(lhs) || IsNullDatum(rhs)) { + return NullDatum(); + } + + if (lhs.is_array() && rhs.is_array()) { + Datum out; + RETURN_NOT_OK(arrow::compute::And(ctx, lhs, rhs, &out)); + return std::move(out); + } + + if (lhs.is_scalar() && rhs.is_scalar()) { + return Datum(checked_cast(*lhs.scalar()).value && + checked_cast(*rhs.scalar()).value); + } + + auto array_operand = (lhs.is_array() ? lhs : rhs).make_array(); + bool scalar_operand = + checked_cast(*(lhs.is_scalar() ? lhs : rhs).scalar()).value; + + if (!scalar_operand) { + // FIXME(bkietz) this is an error if array_operand contains nulls + return Datum(false); + } + + return Datum(array_operand); +} + +Result OrExpression::Evaluate(compute::FunctionContext* ctx, + const RecordBatch& batch) const { + ARROW_ASSIGN_OR_RAISE(auto lhs, left_operand_->Evaluate(ctx, batch)); + ARROW_ASSIGN_OR_RAISE(auto rhs, right_operand_->Evaluate(ctx, batch)); + + if (IsNullDatum(lhs) || IsNullDatum(rhs)) { + return NullDatum(); + } + + if (lhs.is_array() && rhs.is_array()) { + Datum out; + RETURN_NOT_OK(arrow::compute::Or(ctx, lhs, rhs, &out)); + return std::move(out); + } + + if (lhs.is_scalar() && rhs.is_scalar()) { + return Datum(checked_cast(*lhs.scalar()).value && + checked_cast(*rhs.scalar()).value); + } + + auto array_operand = (lhs.is_array() ? lhs : rhs).make_array(); + bool scalar_operand = + checked_cast(*(lhs.is_scalar() ? lhs : rhs).scalar()).value; + + if (!scalar_operand) { + // FIXME(bkietz) this is an error if array_operand contains nulls + return Datum(true); + } + + return Datum(array_operand); +} + +Result ComparisonExpression::Evaluate(compute::FunctionContext* ctx, + const RecordBatch& batch) const { + ARROW_ASSIGN_OR_RAISE(auto lhs, left_operand_->Evaluate(ctx, batch)); + ARROW_ASSIGN_OR_RAISE(auto rhs, right_operand_->Evaluate(ctx, batch)); + + if (IsNullDatum(lhs) || IsNullDatum(rhs)) { + return NullDatum(); + } + + if (lhs.is_scalar()) { + return Status::NotImplemented("comparison with scalar LHS"); + } + + Datum out; + RETURN_NOT_OK( + arrow::compute::Compare(ctx, lhs, rhs, arrow::compute::CompareOptions(op_), &out)); + return std::move(out); +} + +std::shared_ptr ScalarExpression::Make(std::string value) { + return std::make_shared( + std::make_shared(Buffer::FromString(std::move(value)))); +} + +std::shared_ptr ScalarExpression::Make(const char* value) { + return std::make_shared( + std::make_shared(Buffer::Wrap(value, std::strlen(value)))); +} + +std::shared_ptr ScalarExpression::MakeNull( + const std::shared_ptr& type) { + std::shared_ptr null; + DCHECK_OK(arrow::MakeNullScalar(type, &null)); + return Make(std::move(null)); +} + +struct Comparison { + enum type { + LESS, + EQUAL, + GREATER, + NULL_, + }; +}; + +struct CompareVisitor { + template + using ScalarType = typename TypeTraits::ScalarType; + + Status Visit(const NullType&) { + result_ = Comparison::NULL_; + return Status::OK(); + } + + Status Visit(const BooleanType&) { return CompareValues(); } + + template + enable_if_number Visit(const T&) { + return CompareValues(); + } + + template + enable_if_binary_like Visit(const T&) { + auto lhs = checked_cast&>(lhs_).value; + auto rhs = checked_cast&>(rhs_).value; + auto cmp = std::memcmp(lhs->data(), rhs->data(), std::min(lhs->size(), rhs->size())); + if (cmp == 0) { + return CompareValues(lhs->size(), rhs->size()); + } + return CompareValues(cmp, 0); + } + + Status Visit(const Decimal128Type&) { return CompareValues(); } + + // explicit because both integral and floating point conditions match half float + Status Visit(const HalfFloatType&) { + // TODO(bkietz) whenever we vendor a float16, this can be implemented + return Status::NotImplemented("comparison of scalars of type ", *lhs_.type); + } + + Status Visit(const DataType&) { + return Status::NotImplemented("comparison of scalars of type ", *lhs_.type); + } + + // defer comparison to ScalarType::value + template + Status CompareValues() { + auto lhs = checked_cast&>(lhs_).value; + auto rhs = checked_cast&>(rhs_).value; + return CompareValues(lhs, rhs); + } + + // defer comparison to explicit values + template + Status CompareValues(Value lhs, Value rhs) { + result_ = lhs < rhs ? Comparison::LESS + : lhs == rhs ? Comparison::EQUAL : Comparison::GREATER; + return Status::OK(); + } + + Comparison::type result_; + const Scalar& lhs_; + const Scalar& rhs_; +}; + +// Compare two scalars +// if either is null, return is null +Result Compare(const Scalar& lhs, const Scalar& rhs) { + if (!lhs.type->Equals(*rhs.type)) { + return Status::TypeError("Cannot compare scalars of differing type: ", *lhs.type, + " vs ", *rhs.type); + } + if (!lhs.is_valid || !rhs.is_valid) { + return Comparison::NULL_; + } + CompareVisitor vis{Comparison::NULL_, lhs, rhs}; + RETURN_NOT_OK(VisitTypeInline(*lhs.type, &vis)); + return vis.result_; +} + +compute::CompareOperator InvertCompareOperator(compute::CompareOperator op) { + using compute::CompareOperator; + + switch (op) { + case CompareOperator::EQUAL: + return CompareOperator::NOT_EQUAL; + + case CompareOperator::NOT_EQUAL: + return CompareOperator::EQUAL; + + case CompareOperator::GREATER: + return CompareOperator::LESS_EQUAL; + + case CompareOperator::GREATER_EQUAL: + return CompareOperator::LESS; + + case CompareOperator::LESS: + return CompareOperator::GREATER_EQUAL; + + case CompareOperator::LESS_EQUAL: + return CompareOperator::GREATER; + + default: + break; + } + + DCHECK(false); + return CompareOperator::EQUAL; +} + +template +Result> InvertBoolean(const Boolean& expr) { + ARROW_ASSIGN_OR_RAISE(auto lhs, Invert(*expr.left_operand())); + ARROW_ASSIGN_OR_RAISE(auto rhs, Invert(*expr.right_operand())); + + if (std::is_same::value) { + return std::make_shared(std::move(lhs), std::move(rhs)); + } + + if (std::is_same::value) { + return std::make_shared(std::move(lhs), std::move(rhs)); + } + + return Status::Invalid("unknown boolean expression ", expr.ToString()); +} + +Result> Invert(const Expression& expr) { + switch (expr.type()) { + case ExpressionType::NOT: + return checked_cast(expr).operand(); + + case ExpressionType::AND: + return InvertBoolean(checked_cast(expr)); + + case ExpressionType::OR: + return InvertBoolean(checked_cast(expr)); + + case ExpressionType::COMPARISON: { + const auto& comparison = checked_cast(expr); + auto inverted_op = InvertCompareOperator(comparison.op()); + return std::make_shared( + inverted_op, comparison.left_operand(), comparison.right_operand()); + } + + default: + break; + } + return Status::NotImplemented("can't invert this expression"); +} + +Result> ComparisonExpression::Assume( + const Expression& given) const { + switch (given.type()) { + case ExpressionType::COMPARISON: { + return AssumeGivenComparison(checked_cast(given)); + } + + case ExpressionType::NOT: { + const auto& to_invert = checked_cast(given).operand(); + auto inverted = Invert(*to_invert); + if (!inverted.ok()) { + return Copy(); + } + return Assume(*inverted.ValueOrDie()); + } + + case ExpressionType::OR: { + const auto& given_or = checked_cast(given); + + bool simplify_to_always = true; + bool simplify_to_never = true; + for (const auto& operand : {given_or.left_operand(), given_or.right_operand()}) { + ARROW_ASSIGN_OR_RAISE(auto simplified, Assume(*operand)); + + if (simplified->IsNull()) { + // some subexpression of given is always null, return null + return ScalarExpression::MakeNull(boolean()); + } + + bool trivial; + if (!simplified->IsTrivialCondition(&trivial)) { + // Or cannot be simplified unless all of its operands simplify to the same + // trivial condition + simplify_to_never = false; + simplify_to_always = false; + continue; + } + + if (trivial == true) { + simplify_to_never = false; + } else { + simplify_to_always = false; + } + } + + if (simplify_to_always) { + return ScalarExpression::Make(true); + } + + if (simplify_to_never) { + return ScalarExpression::Make(false); + } + + return Copy(); + } + + case ExpressionType::AND: { + const auto& given_and = checked_cast(given); + + auto simplified = Copy(); + for (const auto& operand : {given_and.left_operand(), given_and.right_operand()}) { + if (simplified->IsNull()) { + return ScalarExpression::MakeNull(boolean()); + } + + if (simplified->IsTrivialCondition()) { + return std::move(simplified); + } + + ARROW_ASSIGN_OR_RAISE(simplified, simplified->Assume(*operand)); + } + return std::move(simplified); + } + + default: + break; + } + + return Copy(); +} + +// Try to simplify one comparison against another comparison. +// For example, +// (x > 3) is a subset of (x > 2), so (x > 2).Assume(x > 3) == (true) +// (x < 0) is disjoint with (x > 2), so (x > 2).Assume(x < 0) == (false) +// If simplification to (true) or (false) is not possible, pass e through unchanged. +Result> ComparisonExpression::AssumeGivenComparison( + const ComparisonExpression& given) const { + for (auto comparison : {this, &given}) { + if (comparison->left_operand_->type() != ExpressionType::FIELD) { + return Status::Invalid("left hand side of comparison must be a field reference"); + } + + if (comparison->right_operand_->type() != ExpressionType::SCALAR) { + return Status::Invalid("right hand side of comparison must be a scalar"); + } + } + + const auto& this_lhs = checked_cast(*left_operand_); + const auto& given_lhs = checked_cast(*given.left_operand_); + if (this_lhs.name() != given_lhs.name()) { + return Copy(); + } + + const auto& this_rhs = checked_cast(*right_operand_).value(); + const auto& given_rhs = + checked_cast(*given.right_operand_).value(); + ARROW_ASSIGN_OR_RAISE(auto cmp, Compare(*this_rhs, *given_rhs)); + + if (cmp == Comparison::NULL_) { + // the RHS of e or given was null + return ScalarExpression::MakeNull(boolean()); + } + + static auto always = ScalarExpression::Make(true); + static auto never = ScalarExpression::Make(false); + + using compute::CompareOperator; + + if (cmp == Comparison::GREATER) { + // the rhs of e is greater than that of given + switch (op()) { + case CompareOperator::EQUAL: + case CompareOperator::GREATER: + case CompareOperator::GREATER_EQUAL: + switch (given.op()) { + case CompareOperator::EQUAL: + case CompareOperator::LESS: + case CompareOperator::LESS_EQUAL: + return never; + default: + return Copy(); + } + case CompareOperator::NOT_EQUAL: + case CompareOperator::LESS: + case CompareOperator::LESS_EQUAL: + switch (given.op()) { + case CompareOperator::EQUAL: + case CompareOperator::LESS: + case CompareOperator::LESS_EQUAL: + return always; + default: + return Copy(); + } + default: + return Copy(); + } + } + + if (cmp == Comparison::LESS) { + // the rhs of e is less than that of given + switch (op()) { + case CompareOperator::EQUAL: + case CompareOperator::LESS: + case CompareOperator::LESS_EQUAL: + switch (given.op()) { + case CompareOperator::EQUAL: + case CompareOperator::GREATER: + case CompareOperator::GREATER_EQUAL: + return never; + default: + return Copy(); + } + case CompareOperator::NOT_EQUAL: + case CompareOperator::GREATER: + case CompareOperator::GREATER_EQUAL: + switch (given.op()) { + case CompareOperator::EQUAL: + case CompareOperator::GREATER: + case CompareOperator::GREATER_EQUAL: + return always; + default: + return Copy(); + } + default: + return Copy(); + } + } + + DCHECK_EQ(cmp, Comparison::EQUAL); + + // the rhs of the comparisons are equal + switch (op_) { + case CompareOperator::EQUAL: + switch (given.op()) { + case CompareOperator::NOT_EQUAL: + case CompareOperator::GREATER: + case CompareOperator::LESS: + return never; + case CompareOperator::EQUAL: + return always; + default: + return Copy(); + } + case CompareOperator::NOT_EQUAL: + switch (given.op()) { + case CompareOperator::EQUAL: + return never; + case CompareOperator::NOT_EQUAL: + case CompareOperator::GREATER: + case CompareOperator::LESS: + return always; + default: + return Copy(); + } + case CompareOperator::GREATER: + switch (given.op()) { + case CompareOperator::EQUAL: + case CompareOperator::LESS_EQUAL: + case CompareOperator::LESS: + return never; + case CompareOperator::GREATER: + return always; + default: + return Copy(); + } + case CompareOperator::GREATER_EQUAL: + switch (given.op()) { + case CompareOperator::LESS: + return never; + case CompareOperator::EQUAL: + case CompareOperator::GREATER: + case CompareOperator::GREATER_EQUAL: + return always; + default: + return Copy(); + } + case CompareOperator::LESS: + switch (given.op()) { + case CompareOperator::EQUAL: + case CompareOperator::GREATER: + case CompareOperator::GREATER_EQUAL: + return never; + case CompareOperator::LESS: + return always; + default: + return Copy(); + } + case CompareOperator::LESS_EQUAL: + switch (given.op()) { + case CompareOperator::GREATER: + return never; + case CompareOperator::EQUAL: + case CompareOperator::LESS: + case CompareOperator::LESS_EQUAL: + return always; + default: + return Copy(); + } + default: + return Copy(); + } + return Copy(); +} + +Result> AndExpression::Assume(const Expression& given) const { + ARROW_ASSIGN_OR_RAISE(auto left_operand, left_operand_->Assume(given)); + ARROW_ASSIGN_OR_RAISE(auto right_operand, right_operand_->Assume(given)); + + // if either operand is trivially null then so is this AND + if (left_operand->IsNull() || right_operand->IsNull()) { + return ScalarExpression::MakeNull(boolean()); + } + + bool left_trivial, right_trivial; + bool left_is_trivial = left_operand->IsTrivialCondition(&left_trivial); + bool right_is_trivial = right_operand->IsTrivialCondition(&right_trivial); + + // if neither of the operands is trivial, simply construct a new AND + if (!left_is_trivial && !right_is_trivial) { + return std::make_shared(std::move(left_operand), + std::move(right_operand)); + } + + // if either of the operands is trivially false then so is this AND + if ((left_is_trivial && left_trivial == false) || + (right_is_trivial && right_trivial == false)) { + // FIXME(bkietz) if left is false and right is a column conaining nulls, this is an + // error because we should be yielding null there rather than false + return ScalarExpression::Make(false); + } + + // at least one of the operands is trivially true; return the other operand + return right_is_trivial ? std::move(left_operand) : std::move(right_operand); +} + +Result> OrExpression::Assume(const Expression& given) const { + ARROW_ASSIGN_OR_RAISE(auto left_operand, left_operand_->Assume(given)); + ARROW_ASSIGN_OR_RAISE(auto right_operand, right_operand_->Assume(given)); + + // if either operand is trivially null then so is this OR + if (left_operand->IsNull() || right_operand->IsNull()) { + return ScalarExpression::MakeNull(boolean()); + } + + bool left_trivial, right_trivial; + bool left_is_trivial = left_operand->IsTrivialCondition(&left_trivial); + bool right_is_trivial = right_operand->IsTrivialCondition(&right_trivial); + + // if neither of the operands is trivial, simply construct a new OR + if (!left_is_trivial && !right_is_trivial) { + return std::make_shared(std::move(left_operand), + std::move(right_operand)); + } + + // if either of the operands is trivially true then so is this OR + if ((left_is_trivial && left_trivial == true) || + (right_is_trivial && right_trivial == true)) { + // FIXME(bkietz) if left is true but right is a column conaining nulls, this is an + // error because we should be yielding null there rather than true + return ScalarExpression::Make(true); + } + + // at least one of the operands is trivially false; return the other operand + return right_is_trivial ? std::move(left_operand) : std::move(right_operand); +} + +Result> NotExpression::Assume(const Expression& given) const { + ARROW_ASSIGN_OR_RAISE(auto operand, operand_->Assume(given)); + + if (operand->IsNull()) { + return ScalarExpression::MakeNull(boolean()); + } + + bool trivial; + if (operand->IsTrivialCondition(&trivial)) { + return ScalarExpression::Make(!trivial); + } + + return Copy(); +} + +std::string FieldExpression::ToString() const { + return std::string("field(") + name_ + ")"; +} + +std::string OperatorName(compute::CompareOperator op) { + using compute::CompareOperator; + switch (op) { + case CompareOperator::EQUAL: + return "EQUAL"; + case CompareOperator::NOT_EQUAL: + return "NOT_EQUAL"; + case CompareOperator::LESS: + return "LESS"; + case CompareOperator::LESS_EQUAL: + return "LESS_EQUAL"; + case CompareOperator::GREATER: + return "GREATER"; + case CompareOperator::GREATER_EQUAL: + return "GREATER_EQUAL"; + default: + DCHECK(false); + } + return ""; +} + +std::string ScalarExpression::ToString() const { + if (!value_->is_valid) { + return "scalar<" + value_->type->ToString() + ", null>()"; + } + + std::string value; + switch (value_->type->id()) { + case Type::BOOL: + value = checked_cast(*value_).value ? "true" : "false"; + break; + case Type::INT32: + value = std::to_string(checked_cast(*value_).value); + break; + case Type::INT64: + value = std::to_string(checked_cast(*value_).value); + break; + case Type::DOUBLE: + value = std::to_string(checked_cast(*value_).value); + break; + case Type::STRING: + value = checked_cast(*value_).value->ToString(); + break; + default: + value = "TODO(bkietz)"; + break; + } + + return "scalar<" + value_->type->ToString() + ">(" + value + ")"; +} + +static std::string EulerNotation(std::string fn, const ExpressionVector& operands) { + fn += "("; + bool comma = false; + for (const auto& operand : operands) { + if (comma) { + fn += ", "; + } else { + comma = true; + } + fn += operand->ToString(); + } + fn += ")"; + return fn; +} + +std::string AndExpression::ToString() const { + return EulerNotation("AND", {left_operand_, right_operand_}); +} + +std::string OrExpression::ToString() const { + return EulerNotation("OR", {left_operand_, right_operand_}); +} + +std::string NotExpression::ToString() const { return EulerNotation("NOT", {operand_}); } + +std::string ComparisonExpression::ToString() const { + return EulerNotation(OperatorName(op()), {left_operand_, right_operand_}); +} + +bool UnaryExpression::Equals(const Expression& other) const { + return type_ == other.type() && + operand_->Equals(checked_cast(other).operand_); +} + +bool BinaryExpression::Equals(const Expression& other) const { + return type_ == other.type() && + left_operand_->Equals( + checked_cast(other).left_operand_) && + right_operand_->Equals( + checked_cast(other).right_operand_); +} + +bool ComparisonExpression::Equals(const Expression& other) const { + return BinaryExpression::Equals(other) && + op_ == checked_cast(other).op_; +} + +bool ScalarExpression::Equals(const Expression& other) const { + return other.type() == ExpressionType::SCALAR && + value_->Equals(checked_cast(other).value_); +} + +bool FieldExpression::Equals(const Expression& other) const { + return other.type() == ExpressionType::FIELD && + name_ == checked_cast(other).name_; +} + +bool Expression::Equals(const std::shared_ptr& other) const { + if (other == NULLPTR) { + return false; + } + return Equals(*other); +} + +bool Expression::IsNull() const { + if (type_ != ExpressionType::SCALAR) { + return false; + } + + const auto& scalar = checked_cast(*this).value(); + if (!scalar->is_valid) { + return true; + } + + return false; +} + +bool Expression::IsTrivialCondition(bool* out) const { + if (type_ != ExpressionType::SCALAR) { + return false; + } + + const auto& scalar = checked_cast(*this).value(); + if (!scalar->is_valid) { + return false; + } + + if (scalar->type->id() != Type::BOOL) { + return false; + } + + if (out) { + *out = checked_cast(*scalar).value; + } + return true; +} + +std::shared_ptr FieldExpression::Copy() const { + return std::make_shared(*this); +} + +std::shared_ptr ScalarExpression::Copy() const { + return std::make_shared(*this); +} + +std::shared_ptr and_(std::shared_ptr lhs, + std::shared_ptr rhs) { + return std::make_shared(std::move(lhs), std::move(rhs)); +} + +std::shared_ptr or_(std::shared_ptr lhs, + std::shared_ptr rhs) { + return std::make_shared(std::move(lhs), std::move(rhs)); +} + +std::shared_ptr not_(std::shared_ptr operand) { + return std::make_shared(std::move(operand)); +} + +AndExpression operator&&(const Expression& lhs, const Expression& rhs) { + return AndExpression(lhs.Copy(), rhs.Copy()); +} + +OrExpression operator||(const Expression& lhs, const Expression& rhs) { + return OrExpression(lhs.Copy(), rhs.Copy()); +} + +NotExpression operator!(const Expression& rhs) { return NotExpression(rhs.Copy()); } + +Result> ComparisonExpression::Validate( + const Schema& schema) const { + if (left_operand_->type() != ExpressionType::FIELD) { + return Status::NotImplemented("comparison with non-FIELD RHS"); + } + + ARROW_ASSIGN_OR_RAISE(auto lhs_type, left_operand_->Validate(schema)); + ARROW_ASSIGN_OR_RAISE(auto rhs_type, right_operand_->Validate(schema)); + if (!lhs_type->Equals(rhs_type)) { + return Status::TypeError("cannot compare expressions of differing type, ", *lhs_type, + " vs ", *rhs_type); + } + + if (lhs_type->id() == Type::NA || rhs_type->id() == Type::NA) { + return null(); + } + + return boolean(); +} + +Status EnsureNullOrBool(const std::string& msg_prefix, + const std::shared_ptr& type) { + if (type->id() == Type::BOOL || type->id() == Type::NA) { + return Status::OK(); + } + return Status::TypeError(msg_prefix, *type); +} + +Result> ValidateBoolean(const ExpressionVector& operands, + const Schema& schema) { + auto out = boolean(); + for (auto operand : operands) { + ARROW_ASSIGN_OR_RAISE(auto type, operand->Validate(schema)); + RETURN_NOT_OK( + EnsureNullOrBool("cannot combine expressions including one of type ", type)); + if (type->id() == Type::NA) { + out = null(); + } + } + return std::move(out); +} + +Result> AndExpression::Validate(const Schema& schema) const { + return ValidateBoolean({left_operand_, right_operand_}, schema); +} + +Result> OrExpression::Validate(const Schema& schema) const { + return ValidateBoolean({left_operand_, right_operand_}, schema); +} + +Result> NotExpression::Validate(const Schema& schema) const { + return ValidateBoolean({operand_}, schema); +} + +Result> ScalarExpression::Validate(const Schema& schema) const { + return value_->type; +} + +Result> FieldExpression::Validate(const Schema& schema) const { + if (auto field = schema.GetFieldByName(name_)) { + return field->type(); + } + return null(); +} + +} // namespace dataset +} // namespace arrow diff --git a/cpp/src/arrow/dataset/filter.h b/cpp/src/arrow/dataset/filter.h index a727b1ce4b8..9cf31f1ed59 100644 --- a/cpp/src/arrow/dataset/filter.h +++ b/cpp/src/arrow/dataset/filter.h @@ -18,23 +18,387 @@ #pragma once #include +#include +#include +#include "arrow/compute/kernels/compare.h" +#include "arrow/dataset/type_fwd.h" #include "arrow/dataset/visibility.h" +#include "arrow/result.h" +#include "arrow/scalar.h" namespace arrow { namespace dataset { -class ARROW_DS_EXPORT Filter { - public: +struct FilterType { enum type { /// Simple boolean predicate consisting of comparisons and boolean - /// logic (AND, OR, NOT) involving Schema fields + /// logic (ALL, OR, NOT) involving Schema fields EXPRESSION, - /// + /// Non decomposable filter; must be evaluated against every record batch GENERIC }; }; +class ARROW_DS_EXPORT Filter { + public: + explicit Filter(FilterType::type type) : type_(type) {} + + virtual ~Filter() = default; + + FilterType::type type() const { return type_; } + + private: + FilterType::type type_; +}; + +/// Filter subclass encapsulating a simple boolean predicate consisting of comparisons +/// and boolean logic (ALL, OR, NOT) involving Schema fields +class ARROW_DS_EXPORT ExpressionFilter : public Filter { + public: + explicit ExpressionFilter(const std::shared_ptr& expression) + : Filter(FilterType::EXPRESSION), expression_(std::move(expression)) {} + + const std::shared_ptr& expression() const { return expression_; } + + private: + std::shared_ptr expression_; +}; + +struct ExpressionType { + enum type { + /// a reference to a column within a record batch, will evaluate to an array + FIELD, + + /// a literal singular value encapuslated in a Scalar + SCALAR, + + /// a literal Array + // TODO(bkietz) ARRAY, + + /// an inversion of another expression + NOT, + + /// cast an expression to a given DataType + // TODO(bkietz) CAST, + + /// a conjunction of multiple expressions (true if all operands are true) + AND, + + /// a disjunction of multiple expressions (true if any operand is true) + OR, + + /// a comparison of two other expressions + COMPARISON, + + /// replace nulls with other expressions + /// currently only boolean expressions may be coalesced + // TODO(bkietz) COALESCE, + + /// extract validity as a boolean expression + // TODO(bkietz) IS_VALID, + }; +}; + +/// Represents an expression tree +class ARROW_DS_EXPORT Expression { + public: + explicit Expression(ExpressionType::type type) : type_(type) {} + + virtual ~Expression() = default; + + /// Returns true iff the expressions are identical; does not check for equivalence. + /// For example, (A and B) is not equal to (B and A) nor is (A and not A) equal to + /// (false). + virtual bool Equals(const Expression& other) const = 0; + + bool Equals(const std::shared_ptr& other) const; + + /// Validate this expression for execution against a schema. This will check that all + /// reference fields are present (fields not in the schema will be replaced with null) + /// and all subexpressions are executable. Returns the type to which this expression + /// will evaluate. + virtual Result> Validate(const Schema& schema) const = 0; + + /// Return a simplified form of this expression given some known conditions. + /// For example, (a > 3).Assume(a == 5) == (true). This can be used to do less work + /// in ExpressionFilter when partition conditions guarantee some of this expression. + /// In the example above, *no* filtering need be done on record batches in the + /// partition since (a == 5). + virtual Result> Assume(const Expression& given) const { + return Copy(); + } + + /// Evaluate this expression against each row of a RecordBatch. + /// Returned Datum will be of either SCALAR or ARRAY kind. + /// A return value of ARRAY kind will have length == batch.num_rows() + /// An return value of SCALAR kind is equivalent to an array of the same type whose + /// slots contain a single repeated value. + virtual Result Evaluate(compute::FunctionContext* ctx, + const RecordBatch& batch) const = 0; + + /// returns a debug string representing this expression + virtual std::string ToString() const = 0; + + ExpressionType::type type() const { return type_; } + + /// If true, this Expression is a ScalarExpression wrapping a null scalar. + bool IsNull() const; + + /// If true, this Expression is a ScalarExpression wrapping a + /// BooleanScalar. Its value may be retrieved at the same time. + bool IsTrivialCondition(bool* value = NULLPTR) const; + + /// Copy this expression into a shared pointer. + virtual std::shared_ptr Copy() const = 0; + + protected: + ExpressionType::type type_; +}; + +/// Helper class which implements Copy and forwards construction +template +class ExpressionImpl : public Base { + public: + static constexpr ExpressionType::type expression_type = E; + + template + explicit ExpressionImpl(A0&& arg0, A&&... args) + : Base(expression_type, std::forward(arg0), std::forward(args)...) {} + + std::shared_ptr Copy() const override { + return std::make_shared(internal::checked_cast(*this)); + } +}; + +/// Base class for an expression with exactly one operand +class ARROW_DS_EXPORT UnaryExpression : public Expression { + public: + const std::shared_ptr& operand() const { return operand_; } + + bool Equals(const Expression& other) const override; + + protected: + UnaryExpression(ExpressionType::type type, std::shared_ptr operand) + : Expression(type), operand_(std::move(operand)) {} + + std::shared_ptr operand_; +}; + +/// Base class for an expression with exactly two operands +class ARROW_DS_EXPORT BinaryExpression : public Expression { + public: + const std::shared_ptr& left_operand() const { return left_operand_; } + + const std::shared_ptr& right_operand() const { return right_operand_; } + + bool Equals(const Expression& other) const override; + + protected: + BinaryExpression(ExpressionType::type type, std::shared_ptr left_operand, + std::shared_ptr right_operand) + : Expression(type), + left_operand_(std::move(left_operand)), + right_operand_(std::move(right_operand)) {} + + std::shared_ptr left_operand_, right_operand_; +}; + +class ARROW_DS_EXPORT ComparisonExpression final + : public ExpressionImpl { + public: + ComparisonExpression(compute::CompareOperator op, + std::shared_ptr left_operand, + std::shared_ptr right_operand) + : ExpressionImpl(std::move(left_operand), std::move(right_operand)), op_(op) {} + + std::string ToString() const override; + + bool Equals(const Expression& other) const override; + + Result> Assume(const Expression& given) const override; + + compute::CompareOperator op() const { return op_; } + + Result Evaluate(compute::FunctionContext* ctx, + const RecordBatch& batch) const override; + + Result> Validate(const Schema& schema) const override; + + private: + Result> AssumeGivenComparison( + const ComparisonExpression& given) const; + + compute::CompareOperator op_; +}; + +class ARROW_DS_EXPORT AndExpression final + : public ExpressionImpl { + public: + using ExpressionImpl::ExpressionImpl; + + std::string ToString() const override; + + Result> Assume(const Expression& given) const override; + + Result Evaluate(compute::FunctionContext* ctx, + const RecordBatch& batch) const override; + + Result> Validate(const Schema& schema) const override; +}; + +class ARROW_DS_EXPORT OrExpression final + : public ExpressionImpl { + public: + using ExpressionImpl::ExpressionImpl; + + std::string ToString() const override; + + Result> Assume(const Expression& given) const override; + + Result Evaluate(compute::FunctionContext* ctx, + const RecordBatch& batch) const override; + + Result> Validate(const Schema& schema) const override; +}; + +class ARROW_DS_EXPORT NotExpression final + : public ExpressionImpl { + public: + using ExpressionImpl::ExpressionImpl; + + std::string ToString() const override; + + Result> Assume(const Expression& given) const override; + + Result Evaluate(compute::FunctionContext* ctx, + const RecordBatch& batch) const override; + + Result> Validate(const Schema& schema) const override; +}; + +/// Represents a scalar value; thin wrapper around arrow::Scalar +class ARROW_DS_EXPORT ScalarExpression final : public Expression { + public: + explicit ScalarExpression(const std::shared_ptr& value) + : Expression(ExpressionType::SCALAR), value_(std::move(value)) {} + + const std::shared_ptr& value() const { return value_; } + + std::string ToString() const override; + + bool Equals(const Expression& other) const override; + + static std::shared_ptr Make(bool value) { + return std::make_shared(std::make_shared(value)); + } + + template + static typename std::enable_if::value || + std::is_floating_point::value, + std::shared_ptr>::type + Make(T value) { + using ScalarType = typename CTypeTraits::ScalarType; + return std::make_shared(std::make_shared(value)); + } + + static std::shared_ptr Make(std::string value); + + static std::shared_ptr Make(const char* value); + + static std::shared_ptr Make(std::shared_ptr value) { + return std::make_shared(std::move(value)); + } + + static std::shared_ptr MakeNull( + const std::shared_ptr& type); + + Result> Validate(const Schema& schema) const override; + + Result Evaluate(compute::FunctionContext* ctx, + const RecordBatch& batch) const override; + + std::shared_ptr Copy() const override; + + private: + std::shared_ptr value_; +}; + +/// Represents a reference to a field. Stores only the field's name (type and other +/// information is known only when a Schema is provided) +class ARROW_DS_EXPORT FieldExpression final : public Expression { + public: + explicit FieldExpression(std::string name) + : Expression(ExpressionType::FIELD), name_(std::move(name)) {} + + std::string name() const { return name_; } + + std::string ToString() const override; + + bool Equals(const Expression& other) const override; + + Result> Validate(const Schema& schema) const override; + + Result Evaluate(compute::FunctionContext* ctx, + const RecordBatch& batch) const override; + + std::shared_ptr Copy() const override; + + private: + std::string name_; +}; + +ARROW_DS_EXPORT std::shared_ptr and_(std::shared_ptr lhs, + std::shared_ptr rhs); + +ARROW_DS_EXPORT AndExpression operator&&(const Expression& lhs, const Expression& rhs); + +ARROW_DS_EXPORT std::shared_ptr or_(std::shared_ptr lhs, + std::shared_ptr rhs); + +ARROW_DS_EXPORT OrExpression operator||(const Expression& lhs, const Expression& rhs); + +ARROW_DS_EXPORT std::shared_ptr not_(std::shared_ptr operand); + +ARROW_DS_EXPORT NotExpression operator!(const Expression& rhs); + +#define COMPARISON_FACTORY(NAME, FACTORY_NAME, OP) \ + inline std::shared_ptr FACTORY_NAME( \ + const std::shared_ptr& lhs, \ + const std::shared_ptr& rhs) { \ + return std::make_shared(compute::CompareOperator::NAME, lhs, \ + rhs); \ + } \ + \ + template \ + ComparisonExpression operator OP(const FieldExpression& lhs, T&& rhs) { \ + return ComparisonExpression(compute::CompareOperator::NAME, lhs.Copy(), \ + ScalarExpression::Make(std::forward(rhs))); \ + } +COMPARISON_FACTORY(EQUAL, equal, ==) +COMPARISON_FACTORY(NOT_EQUAL, not_equal, !=) +COMPARISON_FACTORY(GREATER, greater, >) +COMPARISON_FACTORY(GREATER_EQUAL, greater_equal, >=) +COMPARISON_FACTORY(LESS, less, <) +COMPARISON_FACTORY(LESS_EQUAL, less_equal, <=) +#undef COMPARISON_FACTORY + +template +auto scalar(T&& value) -> decltype(ScalarExpression::Make(std::forward(value))) { + return ScalarExpression::Make(std::forward(value)); +} + +inline std::shared_ptr field_ref(std::string name) { + return std::make_shared(std::move(name)); +} + +inline namespace string_literals { +inline FieldExpression operator""_(const char* name, size_t name_length) { + return FieldExpression({name, name_length}); +} +} // namespace string_literals + } // namespace dataset } // namespace arrow diff --git a/cpp/src/arrow/dataset/filter_test.cc b/cpp/src/arrow/dataset/filter_test.cc new file mode 100644 index 00000000000..2e884ebb08f --- /dev/null +++ b/cpp/src/arrow/dataset/filter_test.cc @@ -0,0 +1,241 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/dataset/filter.h" + +#include +#include +#include +#include + +#include +#include + +#include "arrow/compute/api.h" +#include "arrow/dataset/test_util.h" +#include "arrow/record_batch.h" +#include "arrow/status.h" +#include "arrow/testing/gtest_util.h" + +namespace arrow { +namespace dataset { + +using string_literals::operator""_; +using internal::checked_cast; +using internal::checked_pointer_cast; + +class ExpressionsTest : public ::testing::Test { + public: + void AssertSimplifiesTo(const Expression& expr, const Expression& given, + const Expression& expected) { + auto simplified = expr.Assume(given); + ASSERT_OK(simplified.status()); + if (!simplified.ValueOrDie()->Equals(expected)) { + FAIL() << " simplification of: " << expr.ToString() << std::endl + << " given: " << given.ToString() << std::endl + << " expected: " << expected.ToString() << std::endl + << " was: " << simplified.ValueOrDie()->ToString(); + } + } + + void AssertOperandsAre(const BinaryExpression& expr, ExpressionType::type type, + const Expression& lhs, const Expression& rhs) { + ASSERT_EQ(expr.type(), type); + ASSERT_TRUE(expr.left_operand()->Equals(lhs)); + ASSERT_TRUE(expr.right_operand()->Equals(rhs)); + } + + std::shared_ptr always = ScalarExpression::Make(true); + std::shared_ptr never = ScalarExpression::Make(false); +}; + +TEST_F(ExpressionsTest, Equality) { + ASSERT_TRUE("a"_.Equals("a"_)); + ASSERT_FALSE("a"_.Equals("b"_)); + + ASSERT_TRUE(("b"_ == 3).Equals("b"_ == 3)); + ASSERT_FALSE(("b"_ == 3).Equals("b"_ < 3)); + ASSERT_FALSE(("b"_ == 3).Equals("b"_)); + + // ordering matters + ASSERT_FALSE(("b"_ > 2 and "b"_ < 3).Equals("b"_ < 3 and "b"_ > 2)); +} + +TEST_F(ExpressionsTest, SimplificationOfCompoundQuery) { + // chained "and" expressions are flattened + auto multi_and = "b"_ > 5 and "b"_ < 10 and "b"_ != 7; + AssertOperandsAre(multi_and, ExpressionType::AND, ("b"_ > 5 and "b"_ < 10), "b"_ != 7); + + AssertSimplifiesTo("b"_ > 5 and "b"_ < 10, "b"_ == 3, *never); + AssertSimplifiesTo("b"_ > 5 and "b"_ < 10, "b"_ == 6, *always); + + AssertSimplifiesTo("b"_ == 3 or "b"_ == 4, "b"_ > 6, *never); + AssertSimplifiesTo("b"_ == 3 or "b"_ == 4, "b"_ == 3, *always); + AssertSimplifiesTo("b"_ == 3 or "b"_ == 4, "b"_ > 3, "b"_ == 4); + AssertSimplifiesTo("b"_ == 3 or "b"_ == 4, "b"_ >= 3, "b"_ == 3 or "b"_ == 4); + + AssertSimplifiesTo("b"_ > 0.5 and "b"_ < 1.5, not("b"_ < 0.0 or "b"_ > 1.0), + "b"_ > 0.5); + + AssertSimplifiesTo("b"_ == 4, "a"_ == 0, "b"_ == 4); + + AssertSimplifiesTo("a"_ == 3 or "b"_ == 4, "a"_ == 0, "b"_ == 4); +} + +TEST_F(ExpressionsTest, SimplificationAgainstCompoundCondition) { + AssertSimplifiesTo("b"_ > 5, "b"_ == 3 or "b"_ == 6, "b"_ > 5); + AssertSimplifiesTo("b"_ > 7, "b"_ == 3 or "b"_ == 6, *never); + AssertSimplifiesTo("b"_ > 5 and "b"_ < 10, "b"_ > 6 and "b"_ < 13, "b"_ < 10); +} + +TEST_F(ExpressionsTest, SimplificationToNull) { + auto null = ScalarExpression::MakeNull(boolean()); + auto null32 = ScalarExpression::MakeNull(int32()); + + AssertSimplifiesTo(*equal(field_ref("b"), null32), "b"_ == 3, *null); + AssertSimplifiesTo(*not_equal(field_ref("b"), null32), "b"_ == 3, *null); + AssertSimplifiesTo(*not_equal(field_ref("b"), null32) and "b"_ > 3, "b"_ == 3, *null); + AssertSimplifiesTo("b"_ > 3 and *not_equal(field_ref("b"), null32), "b"_ == 3, *null); +} + +class FilterTest : public ::testing::Test { + public: + Result DoFilter(const Expression& expr, + std::vector> fields, + std::string batch_json, + std::shared_ptr* expected_mask = nullptr) { + // expected filter result is in the "in" field + fields.push_back(field("in", boolean())); + + auto batch_array = ArrayFromJSON(struct_(std::move(fields)), std::move(batch_json)); + std::shared_ptr batch; + RETURN_NOT_OK(RecordBatch::FromStructArray(batch_array, &batch)); + + if (expected_mask) { + *expected_mask = checked_pointer_cast(batch->GetColumnByName("in")); + } + + return expr.Evaluate(&ctx_, *batch); + } + + void AssertFilter(const Expression& expr, std::vector> fields, + std::string batch_json) { + std::shared_ptr expected_mask; + auto mask_res = + DoFilter(expr, std::move(fields), std::move(batch_json), &expected_mask); + ASSERT_OK(mask_res.status()); + + auto mask = std::move(mask_res).ValueOrDie(); + ASSERT_TRUE(mask.type()->Equals(null()) || mask.type()->Equals(boolean())); + + if (mask.is_array()) { + ASSERT_ARRAYS_EQUAL(*expected_mask, *mask.make_array()); + return; + } + + ASSERT_TRUE(mask.is_scalar()); + auto mask_scalar = mask.scalar(); + if (!mask_scalar->is_valid) { + ASSERT_EQ(expected_mask->null_count(), expected_mask->length()); + return; + } + + TypedBufferBuilder builder; + ASSERT_OK(builder.Append(expected_mask->length(), + checked_cast(*mask_scalar).value)); + + std::shared_ptr values; + ASSERT_OK(builder.Finish(&values)); + + ASSERT_ARRAYS_EQUAL(*expected_mask, BooleanArray(expected_mask->length(), values)); + } + + arrow::compute::FunctionContext ctx_; +}; + +TEST_F(FilterTest, Trivial) { + AssertFilter(*scalar(true), {field("a", int32()), field("b", float64())}, R"([ + {"a": 0, "b": -0.1, "in": 1}, + {"a": 0, "b": 0.3, "in": 1}, + {"a": 1, "b": 0.2, "in": 1}, + {"a": 2, "b": -0.1, "in": 1}, + {"a": 0, "b": 0.1, "in": 1}, + {"a": 0, "b": null, "in": 1}, + {"a": 0, "b": 1.0, "in": 1} + ])"); + + AssertFilter(*scalar(false), {field("a", int32()), field("b", float64())}, R"([ + {"a": 0, "b": -0.1, "in": 0}, + {"a": 0, "b": 0.3, "in": 0}, + {"a": 1, "b": 0.2, "in": 0}, + {"a": 2, "b": -0.1, "in": 0}, + {"a": 0, "b": 0.1, "in": 0}, + {"a": 0, "b": null, "in": 0}, + {"a": 0, "b": 1.0, "in": 0} + ])"); + + AssertFilter(*ScalarExpression::MakeNull(boolean()), + {field("a", int32()), field("b", float64())}, R"([ + {"a": 0, "b": -0.1, "in": null}, + {"a": 0, "b": 0.3, "in": null}, + {"a": 1, "b": 0.2, "in": null}, + {"a": 2, "b": -0.1, "in": null}, + {"a": 0, "b": 0.1, "in": null}, + {"a": 0, "b": null, "in": null}, + {"a": 0, "b": 1.0, "in": null} + ])"); +} + +TEST_F(FilterTest, Basics) { + AssertFilter("a"_ == 0 and "b"_ > 0.0 and "b"_ < 1.0, + {field("a", int32()), field("b", float64())}, R"([ + {"a": 0, "b": -0.1, "in": 0}, + {"a": 0, "b": 0.3, "in": 1}, + {"a": 1, "b": 0.2, "in": 0}, + {"a": 2, "b": -0.1, "in": 0}, + {"a": 0, "b": 0.1, "in": 1}, + {"a": 0, "b": null, "in": null}, + {"a": 0, "b": 1.0, "in": 0} + ])"); + + AssertFilter("a"_ != 0 and "b"_ > 0.1, {field("a", int32()), field("b", float64())}, + R"([ + {"a": 0, "b": -0.1, "in": 0}, + {"a": 0, "b": 0.3, "in": 0}, + {"a": 1, "b": 0.2, "in": 1}, + {"a": 2, "b": -0.1, "in": 0}, + {"a": 0, "b": 0.1, "in": 0}, + {"a": 0, "b": null, "in": null}, + {"a": 0, "b": 1.0, "in": 0} + ])"); +} + +TEST_F(FilterTest, ConditionOnAbsentColumn) { + AssertFilter("a"_ == 0 and "b"_ > 0.0 and "b"_ < 1.0 and "absent"_ == 0, + {field("a", int32()), field("b", float64())}, R"([ + {"a": 0, "b": -0.1, "in": null}, + {"a": 0, "b": 0.3, "in": null}, + {"a": 1, "b": 0.2, "in": null}, + {"a": 2, "b": -0.1, "in": null}, + {"a": 0, "b": 0.1, "in": null}, + {"a": 0, "b": null, "in": null}, + {"a": 0, "b": 1.0, "in": null} + ])"); +} + +} // namespace dataset +} // namespace arrow diff --git a/cpp/src/arrow/dataset/type_fwd.h b/cpp/src/arrow/dataset/type_fwd.h index 8e3824625ed..4f195334e2f 100644 --- a/cpp/src/arrow/dataset/type_fwd.h +++ b/cpp/src/arrow/dataset/type_fwd.h @@ -25,6 +25,12 @@ namespace arrow { +namespace compute { + +class FunctionContext; + +} // namespace compute + namespace fs { class FileSystem; @@ -50,6 +56,15 @@ class FileWriteOptions; class Filter; using FilterVector = std::vector>; +class Expression; +class ComparisonExpression; +class AndExpression; +class OrExpression; +class NotExpression; +class ScalarExpression; +class FieldReferenceExpression; +using ExpressionVector = std::vector>; + class Partition; class PartitionKey; class PartitionScheme; diff --git a/cpp/src/arrow/record_batch.cc b/cpp/src/arrow/record_batch.cc index 7ae865af5f0..1e4ad45a3d2 100644 --- a/cpp/src/arrow/record_batch.cc +++ b/cpp/src/arrow/record_batch.cc @@ -198,6 +198,17 @@ std::shared_ptr RecordBatch::Make( return std::make_shared(schema, num_rows, columns); } +Status RecordBatch::FromStructArray(const std::shared_ptr& array, + std::shared_ptr* out) { + if (array->type_id() != Type::STRUCT) { + return Status::Invalid("Cannot construct record batch from array of type ", + *array->type()); + } + *out = Make(arrow::schema(array->type()->children()), array->length(), + array->data()->child_data); + return Status::OK(); +} + const std::string& RecordBatch::column_name(int i) const { return schema_->field(i)->name(); } diff --git a/cpp/src/arrow/record_batch.h b/cpp/src/arrow/record_batch.h index 6c67f363c64..c3b003edf68 100644 --- a/cpp/src/arrow/record_batch.h +++ b/cpp/src/arrow/record_batch.h @@ -71,6 +71,9 @@ class ARROW_EXPORT RecordBatch { const std::shared_ptr& schema, int64_t num_rows, const std::vector>& columns); + static Status FromStructArray(const std::shared_ptr& array, + std::shared_ptr* out); + /// \brief Determine if two record batches are exactly equal /// \return true if batches are equal bool Equals(const RecordBatch& other) const; diff --git a/cpp/src/arrow/scalar.cc b/cpp/src/arrow/scalar.cc index 7c3a4ee9a4b..10019906df1 100644 --- a/cpp/src/arrow/scalar.cc +++ b/cpp/src/arrow/scalar.cc @@ -26,6 +26,7 @@ #include "arrow/util/checked_cast.h" #include "arrow/util/decimal.h" #include "arrow/util/logging.h" +#include "arrow/visitor_inline.h" namespace arrow { @@ -115,4 +116,30 @@ FixedSizeListScalar::FixedSizeListScalar(const std::shared_ptr& value, bool is_valid) : FixedSizeListScalar(value, value->type(), is_valid) {} +struct MakeNullImpl { + template + using ScalarType = typename TypeTraits::ScalarType; + + template + typename std::enable_if>::value, + Status>::type + Visit(const T&) { + *out_ = std::make_shared>(); + return Status::OK(); + } + + Status Visit(const DataType& t) { + return Status::NotImplemented("construcing null scalars of type ", t); + } + + std::shared_ptr type_; + std::shared_ptr* out_; +}; + +Status MakeNullScalar(const std::shared_ptr& type, + std::shared_ptr* null) { + MakeNullImpl impl = {type, null}; + return VisitTypeInline(*type, &impl); +} + } // namespace arrow diff --git a/cpp/src/arrow/scalar.h b/cpp/src/arrow/scalar.h index 09192896ba3..4c590ef6a42 100644 --- a/cpp/src/arrow/scalar.h +++ b/cpp/src/arrow/scalar.h @@ -76,6 +76,7 @@ struct ARROW_EXPORT BooleanScalar : public internal::PrimitiveScalar { bool value; explicit BooleanScalar(bool value, bool is_valid = true) : internal::PrimitiveScalar{boolean(), is_valid}, value(value) {} + BooleanScalar() : BooleanScalar(false, false) {} }; template @@ -86,6 +87,8 @@ struct NumericScalar : public internal::PrimitiveScalar { explicit NumericScalar(T value, bool is_valid = true) : NumericScalar(value, TypeTraits::type_singleton(), is_valid) {} + NumericScalar() : NumericScalar(0, false) {} + protected: explicit NumericScalar(T value, const std::shared_ptr& type, bool is_valid) : internal::PrimitiveScalar{type, is_valid}, value(value) {} @@ -105,6 +108,8 @@ struct ARROW_EXPORT BinaryScalar : public BaseBinaryScalar { explicit BinaryScalar(const std::shared_ptr& value, bool is_valid = true) : BaseBinaryScalar(value, binary(), is_valid) {} + BinaryScalar() : BinaryScalar(NULLPTR, false) {} + protected: using BaseBinaryScalar::BaseBinaryScalar; }; @@ -112,12 +117,16 @@ struct ARROW_EXPORT BinaryScalar : public BaseBinaryScalar { struct ARROW_EXPORT StringScalar : public BinaryScalar { explicit StringScalar(const std::shared_ptr& value, bool is_valid = true) : BinaryScalar(value, utf8(), is_valid) {} + + StringScalar() : StringScalar(NULLPTR, false) {} }; struct ARROW_EXPORT LargeBinaryScalar : public BaseBinaryScalar { explicit LargeBinaryScalar(const std::shared_ptr& value, bool is_valid = true) : BaseBinaryScalar(value, large_binary(), is_valid) {} + LargeBinaryScalar() : LargeBinaryScalar(NULLPTR, false) {} + protected: using BaseBinaryScalar::BaseBinaryScalar; }; @@ -125,6 +134,8 @@ struct ARROW_EXPORT LargeBinaryScalar : public BaseBinaryScalar struct ARROW_EXPORT LargeStringScalar : public LargeBinaryScalar { explicit LargeStringScalar(const std::shared_ptr& value, bool is_valid = true) : LargeBinaryScalar(value, utf8(), is_valid) {} + + LargeStringScalar() : LargeStringScalar(NULLPTR, false) {} }; struct ARROW_EXPORT FixedSizeBinaryScalar : public BinaryScalar { @@ -235,4 +246,11 @@ class ARROW_EXPORT UnionScalar : public Scalar {}; class ARROW_EXPORT DictionaryScalar : public Scalar {}; class ARROW_EXPORT ExtensionScalar : public Scalar {}; +/// \param[in] type the type of scalar to produce +/// \param[out] null output scalar with is_valid=false +/// \return Status +ARROW_EXPORT +Status MakeNullScalar(const std::shared_ptr& type, + std::shared_ptr* null); + } // namespace arrow