From 619fa1347d659a3a430f4780b39ff2090a495d4e Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Mon, 26 Oct 2020 12:46:07 -0400 Subject: [PATCH 01/31] ARROW-10322: [C++][Dataset] Minimize dataset::Expression --- cpp/src/arrow/array/array_struct_test.cc | 36 + cpp/src/arrow/array/util.cc | 17 +- cpp/src/arrow/chunked_array.h | 3 + cpp/src/arrow/compute/cast.cc | 80 +- cpp/src/arrow/compute/cast.h | 14 +- cpp/src/arrow/compute/exec.h | 2 +- cpp/src/arrow/compute/exec_internal.h | 4 +- cpp/src/arrow/compute/function.h | 6 +- .../arrow/compute/kernels/scalar_cast_test.cc | 43 + cpp/src/arrow/dataset/CMakeLists.txt | 2 + cpp/src/arrow/dataset/expression.cc | 931 ++++++++++++++++++ cpp/src/arrow/dataset/expression.h | 217 ++++ cpp/src/arrow/dataset/expression_internal.h | 110 +++ cpp/src/arrow/dataset/expression_test.cc | 729 ++++++++++++++ cpp/src/arrow/dataset/partition.cc | 11 - cpp/src/arrow/dataset/partition_test.cc | 4 - cpp/src/arrow/dataset/test_util.h | 1 + cpp/src/arrow/dataset/type_fwd.h | 3 + cpp/src/arrow/datum.cc | 15 + cpp/src/arrow/datum.h | 32 +- cpp/src/arrow/result.h | 22 + cpp/src/arrow/type.cc | 33 +- cpp/src/arrow/type.h | 25 +- cpp/src/arrow/type_fwd.h | 2 + 24 files changed, 2296 insertions(+), 46 deletions(-) create mode 100644 cpp/src/arrow/dataset/expression.cc create mode 100644 cpp/src/arrow/dataset/expression.h create mode 100644 cpp/src/arrow/dataset/expression_internal.h create mode 100644 cpp/src/arrow/dataset/expression_test.cc diff --git a/cpp/src/arrow/array/array_struct_test.cc b/cpp/src/arrow/array/array_struct_test.cc index f54b43465e9..aef0076d0d3 100644 --- a/cpp/src/arrow/array/array_struct_test.cc +++ b/cpp/src/arrow/array/array_struct_test.cc @@ -24,6 +24,7 @@ #include "arrow/array.h" #include "arrow/array/builder_nested.h" +#include "arrow/chunked_array.h" #include "arrow/status.h" #include "arrow/testing/gtest_common.h" #include "arrow/testing/gtest_util.h" @@ -582,4 +583,39 @@ TEST_F(TestStructBuilder, TestSlice) { ASSERT_EQ(list_field->null_count(), 1); } +TEST(TestFieldRef, GetChildren) { + auto struct_array = ArrayFromJSON(struct_({field("a", float64())}), R"([ + {"a": 6.125}, + {"a": 0.0}, + {"a": -1} + ])"); + + ASSERT_OK_AND_ASSIGN(auto a, FieldRef("a").GetOne(*struct_array)); + auto expected_a = ArrayFromJSON(float64(), "[6.125, 0.0, -1]"); + AssertArraysEqual(*a, *expected_a); + + auto ToChunked = [struct_array](int64_t midpoint) { + return ChunkedArray( + ArrayVector{ + struct_array->Slice(0, midpoint), + struct_array->Slice(midpoint), + }, + struct_array->type()); + }; + AssertChunkedEquivalent(ToChunked(1), ToChunked(2)); + + // more nested: + struct_array = + ArrayFromJSON(struct_({field("a", struct_({field("a", float64())}))}), R"([ + {"a": {"a": 6.125}}, + {"a": {"a": 0.0}}, + {"a": {"a": -1}} + ])"); + + ASSERT_OK_AND_ASSIGN(a, FieldRef("a", "a").GetOne(*struct_array)); + expected_a = ArrayFromJSON(float64(), "[6.125, 0.0, -1]"); + AssertArraysEqual(*a, *expected_a); + AssertChunkedEquivalent(ToChunked(1), ToChunked(2)); +} + } // namespace arrow diff --git a/cpp/src/arrow/array/util.cc b/cpp/src/arrow/array/util.cc index 70167954e5a..0d498931d42 100644 --- a/cpp/src/arrow/array/util.cc +++ b/cpp/src/arrow/array/util.cc @@ -297,8 +297,9 @@ class RepeatedArrayFactory { return out_; } - Status Visit(const DataType& type) { - return Status::NotImplemented("construction from scalar of type ", *scalar_.type); + Status Visit(const NullType& type) { + DCHECK(false); // already forwarded to MakeArrayOfNull + return Status::OK(); } Status Visit(const BooleanType&) { @@ -403,6 +404,18 @@ class RepeatedArrayFactory { return Status::OK(); } + Status Visit(const ExtensionType& type) { + return Status::NotImplemented("construction from scalar of type ", *scalar_.type); + } + + Status Visit(const DenseUnionType& type) { + return Status::NotImplemented("construction from scalar of type ", *scalar_.type); + } + + Status Visit(const SparseUnionType& type) { + return Status::NotImplemented("construction from scalar of type ", *scalar_.type); + } + template Status CreateOffsetsBuffer(OffsetType value_length, std::shared_ptr* out) { TypedBufferBuilder builder(pool_); diff --git a/cpp/src/arrow/chunked_array.h b/cpp/src/arrow/chunked_array.h index b1f8d6cd6f5..5c0dda91850 100644 --- a/cpp/src/arrow/chunked_array.h +++ b/cpp/src/arrow/chunked_array.h @@ -72,6 +72,9 @@ class ARROW_EXPORT ChunkedArray { /// data type. explicit ChunkedArray(ArrayVector chunks); + ChunkedArray(ChunkedArray&&) = default; + ChunkedArray& operator=(ChunkedArray&&) = default; + /// \brief Construct a chunked array from a single Array explicit ChunkedArray(std::shared_ptr chunk) : ChunkedArray(ArrayVector{std::move(chunk)}) {} diff --git a/cpp/src/arrow/compute/cast.cc b/cpp/src/arrow/compute/cast.cc index 29a80f73241..b55dfb38c15 100644 --- a/cpp/src/arrow/compute/cast.cc +++ b/cpp/src/arrow/compute/cast.cc @@ -118,8 +118,86 @@ class CastMetaFunction : public MetaFunction { } // namespace +const FunctionDoc struct_doc{"Wrap Arrays into a StructArray", + ("Names of the StructArray's fields are\n" + "specified through StructOptions."), + {}, + "StructOptions"}; + +Result StructResolve(KernelContext* ctx, + const std::vector& descrs) { + const auto& names = OptionsWrapper::Get(ctx).field_names; + if (names.size() != descrs.size()) { + return Status::Invalid("Struct() was passed ", names.size(), " field ", "names but ", + descrs.size(), " arguments"); + } + + size_t i = 0; + FieldVector fields(descrs.size()); + + ValueDescr::Shape shape = ValueDescr::SCALAR; + for (const ValueDescr& descr : descrs) { + if (descr.shape != ValueDescr::SCALAR) { + shape = ValueDescr::ARRAY; + } else { + switch (descr.type->id()) { + case Type::EXTENSION: + case Type::DENSE_UNION: + case Type::SPARSE_UNION: + return Status::NotImplemented("Broadcasting scalars of type ", *descr.type); + default: + break; + } + } + + fields[i] = field(names[i], descr.type); + ++i; + } + + return ValueDescr{struct_(std::move(fields)), shape}; +} + +void StructExec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + KERNEL_ASSIGN_OR_RAISE(auto descr, ctx, StructResolve(ctx, batch.GetDescriptors())); + + if (descr.shape == ValueDescr::SCALAR) { + ScalarVector scalars(batch.num_values()); + for (int i = 0; i < batch.num_values(); ++i) { + scalars[i] = batch[i].scalar(); + } + + *out = + Datum(std::make_shared(std::move(scalars), std::move(descr.type))); + return; + } + + ArrayVector arrays(batch.num_values()); + for (int i = 0; i < batch.num_values(); ++i) { + if (batch[i].is_array()) { + arrays[i] = batch[i].make_array(); + continue; + } + + KERNEL_ASSIGN_OR_RAISE( + arrays[i], ctx, + MakeArrayFromScalar(*batch[i].scalar(), batch.length, ctx->memory_pool())); + } + + *out = std::make_shared(descr.type, batch.length, std::move(arrays)); +} + void RegisterScalarCast(FunctionRegistry* registry) { DCHECK_OK(registry->AddFunction(std::make_shared())); + + auto struct_function = + std::make_shared("struct", Arity::VarArgs(), &struct_doc); + ScalarKernel kernel{KernelSignature::Make({InputType{}}, OutputType{StructResolve}, + /*is_varargs=*/true), + StructExec, OptionsWrapper::Init}; + kernel.null_handling = NullHandling::OUTPUT_NOT_NULL; + kernel.mem_allocation = MemAllocation::NO_PREALLOCATE; + DCHECK_OK(struct_function->AddKernel(std::move(kernel))); + DCHECK_OK(registry->AddFunction(std::move(struct_function))); } } // namespace internal @@ -135,7 +213,7 @@ CastFunction::CastFunction(std::string name, Type::type out_type) impl_->out_type = out_type; } -CastFunction::~CastFunction() {} +CastFunction::~CastFunction() = default; Type::type CastFunction::out_type_id() const { return impl_->out_type; } diff --git a/cpp/src/arrow/compute/cast.h b/cpp/src/arrow/compute/cast.h index 43392ce99bf..6808d1d86f3 100644 --- a/cpp/src/arrow/compute/cast.h +++ b/cpp/src/arrow/compute/cast.h @@ -83,7 +83,7 @@ struct ARROW_EXPORT CastOptions : public FunctionOptions { class CastFunction : public ScalarFunction { public: CastFunction(std::string name, Type::type out_type); - ~CastFunction(); + ~CastFunction() override; Type::type out_type_id() const; @@ -157,5 +157,17 @@ Result Cast(const Datum& value, std::shared_ptr to_type, const CastOptions& options = CastOptions::Safe(), ExecContext* ctx = NULLPTR); +/// \addtogroup compute-concrete-options +/// @{ + +struct ARROW_EXPORT StructOptions : public FunctionOptions { + explicit StructOptions(std::vector n) : field_names(std::move(n)) {} + + /// Names for wrapped columns + std::vector field_names; +}; + +/// @} + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/exec.h b/cpp/src/arrow/compute/exec.h index 142e149bccc..f491489ed8a 100644 --- a/cpp/src/arrow/compute/exec.h +++ b/cpp/src/arrow/compute/exec.h @@ -165,7 +165,7 @@ class ARROW_EXPORT SelectionVector { /// than is desirable for this class. Microbenchmarks would help determine for /// sure. See ARROW-8928. struct ExecBatch { - ExecBatch() {} + ExecBatch() = default; ExecBatch(std::vector values, int64_t length) : values(std::move(values)), length(length) {} diff --git a/cpp/src/arrow/compute/exec_internal.h b/cpp/src/arrow/compute/exec_internal.h index 8bad135e40d..a74e5c8d8fa 100644 --- a/cpp/src/arrow/compute/exec_internal.h +++ b/cpp/src/arrow/compute/exec_internal.h @@ -84,14 +84,14 @@ class ARROW_EXPORT ExecListener { class DatumAccumulator : public ExecListener { public: - DatumAccumulator() {} + DatumAccumulator() = default; Status OnResult(Datum value) override { values_.emplace_back(value); return Status::OK(); } - std::vector values() const { return values_; } + std::vector values() { return std::move(values_); } private: std::vector values_; diff --git a/cpp/src/arrow/compute/function.h b/cpp/src/arrow/compute/function.h index a71dbe40292..e8e732027c9 100644 --- a/cpp/src/arrow/compute/function.h +++ b/cpp/src/arrow/compute/function.h @@ -41,7 +41,9 @@ namespace compute { /// \brief Base class for specifying options configuring a function's behavior, /// such as error handling. -struct ARROW_EXPORT FunctionOptions {}; +struct ARROW_EXPORT FunctionOptions { + virtual ~FunctionOptions() = default; +}; /// \brief Contains the number of required arguments for the function. /// @@ -96,7 +98,7 @@ struct ARROW_EXPORT FunctionDoc { /// \brief Name of the options class, if any. std::string options_class; - FunctionDoc() {} + FunctionDoc() = default; FunctionDoc(std::string summary, std::string description, std::vector arg_names, std::string options_class = "") diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_test.cc b/cpp/src/arrow/compute/kernels/scalar_cast_test.cc index ec612e497f4..566a89f4b47 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_test.cc @@ -1892,5 +1892,48 @@ TEST_F(TestCast, ExtensionTypeToIntDowncast) { ASSERT_OK(UnregisterExtensionType("smallint")); } +class TestStruct : public TestBase { + public: + Result Struct(std::vector args) { + StructOptions opts{field_names}; + return CallFunction("struct", args, &opts); + } + + std::vector field_names; +}; + +TEST_F(TestStruct, Scalar) { + std::shared_ptr expected(new StructScalar{{}, struct_({})}); + ASSERT_OK_AND_EQ(Datum(expected), Struct({})); + + auto i32 = MakeScalar(1); + auto f64 = MakeScalar(2.5); + auto str = MakeScalar("yo"); + + expected.reset(new StructScalar{ + {i32, f64, str}, + struct_({field("i", i32->type), field("f", f64->type), field("s", str->type)})}); + field_names = {"i", "f", "s"}; + ASSERT_OK_AND_EQ(Datum(expected), Struct({i32, f64, str})); + + // Three field names but one input value + ASSERT_RAISES(Invalid, Struct({str})); +} + +TEST_F(TestStruct, Array) { + field_names = {"i", "s"}; + auto i32 = ArrayFromJSON(int32(), "[42, 13, 7]"); + auto str = ArrayFromJSON(utf8(), R"(["aa", "aa", "aa"])"); + ASSERT_OK_AND_ASSIGN(Datum expected, StructArray::Make({i32, str}, field_names)); + + ASSERT_OK_AND_EQ(expected, Struct({i32, str})); + + // Scalars are broadcast to the length of the arrays + ASSERT_OK_AND_EQ(expected, Struct({i32, MakeScalar("aa")})); + + // Array length mismatch + ASSERT_RAISES(Invalid, Struct({i32->Slice(1), str})); +} + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/dataset/CMakeLists.txt b/cpp/src/arrow/dataset/CMakeLists.txt index b693c48049c..ae6bb3656dc 100644 --- a/cpp/src/arrow/dataset/CMakeLists.txt +++ b/cpp/src/arrow/dataset/CMakeLists.txt @@ -22,6 +22,7 @@ arrow_install_all_headers("arrow/dataset") set(ARROW_DATASET_SRCS dataset.cc discovery.cc + expression.cc file_base.cc file_ipc.cc filter.cc @@ -106,6 +107,7 @@ endfunction() add_arrow_dataset_test(dataset_test) add_arrow_dataset_test(discovery_test) +add_arrow_dataset_test(expression_test) add_arrow_dataset_test(file_ipc_test) add_arrow_dataset_test(file_test) add_arrow_dataset_test(filter_test) diff --git a/cpp/src/arrow/dataset/expression.cc b/cpp/src/arrow/dataset/expression.cc new file mode 100644 index 00000000000..e0a44a34ab8 --- /dev/null +++ b/cpp/src/arrow/dataset/expression.cc @@ -0,0 +1,931 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/dataset/expression.h" + +#include +#include + +#include "arrow/chunked_array.h" +#include "arrow/compute/exec_internal.h" +#include "arrow/compute/registry.h" +#include "arrow/dataset/expression_internal.h" +#include "arrow/util/logging.h" +#include "arrow/util/optional.h" +#include "arrow/util/string.h" + +namespace arrow { + +using internal::checked_cast; +using internal::checked_pointer_cast; + +namespace dataset { + +const Expression2::Call* Expression2::call() const { + return util::get_if(impl_.get()); +} + +const Datum* Expression2::literal() const { return util::get_if(impl_.get()); } + +const FieldRef* Expression2::field_ref() const { + return util::get_if(impl_.get()); +} + +std::string Expression2::ToString() const { + if (auto lit = literal()) { + if (lit->is_scalar()) { + return lit->scalar()->ToString(); + } + return lit->ToString(); + } + + if (auto ref = field_ref()) { + if (auto name = ref->name()) { + return "FieldRef(" + *name + ")"; + } + if (auto path = ref->field_path()) { + return path->ToString(); + } + return ref->ToString(); + } + + auto call = CallNotNull(*this); + + // FIXME represent FunctionOptions + std::string out = call->function + "("; + for (const auto& arg : call->arguments) { + out += arg.ToString() + ","; + } + out.back() = ')'; + return out; +} + +void PrintTo(const Expression2& expr, std::ostream* os) { *os << expr.ToString(); } + +bool Expression2::Equals(const Expression2& other) const { + if (Identical(*this, other)) return true; + + if (impl_->index() != other.impl_->index()) { + return false; + } + + if (auto lit = literal()) { + return lit->Equals(*other.literal()); + } + + if (auto ref = field_ref()) { + return ref->Equals(*other.field_ref()); + } + + auto call = CallNotNull(*this); + auto other_call = CallNotNull(other); + + if (call->function != other_call->function || call->kernel != other_call->kernel) { + return false; + } + + // FIXME compare FunctionOptions for equality + for (size_t i = 0; i < call->arguments.size(); ++i) { + if (!call->arguments[i].Equals(other_call->arguments[i])) { + return false; + } + } + + return true; +} + +size_t Expression2::hash() const { + if (auto lit = literal()) { + if (lit->is_scalar()) { + return Scalar::Hash::hash(*lit->scalar()); + } + return 0; + } + + if (auto ref = field_ref()) { + return ref->hash(); + } + + auto call = CallNotNull(*this); + + size_t out = std::hash{}(call->function); + for (const auto& arg : call->arguments) { + out ^= arg.hash(); + } + return out; +} + +bool Expression2::IsBound() const { + if (descr_.type == nullptr) return false; + + if (auto lit = literal()) return true; + + if (auto ref = field_ref()) return true; + + auto call = CallNotNull(*this); + + for (const Expression2& arg : call->arguments) { + if (!arg.IsBound()) return false; + } + + return call->kernel != nullptr; +} + +bool Expression2::IsScalarExpression() const { + if (auto lit = literal()) { + return lit->is_scalar(); + } + + // FIXME handle case where a list's item field is referenced + if (auto ref = field_ref()) return true; + + auto call = CallNotNull(*this); + + for (const Expression2& arg : call->arguments) { + if (!arg.IsScalarExpression()) return false; + } + + if (call->kernel == nullptr) { + // this expression is not bound; make a best guess based on + // the default function registry + if (auto function = compute::GetFunctionRegistry() + ->GetFunction(call->function) + .ValueOr(nullptr)) { + return function->kind() == compute::Function::SCALAR; + } + + // unknown function or other error; conservatively return false + return false; + } + + return call->function_kind == compute::Function::SCALAR; +} + +bool Expression2::IsNullLiteral() const { + if (auto lit = literal()) { + if (lit->null_count() == lit->length()) { + return true; + } + } + return false; +} + +bool Expression2::IsSatisfiable() const { + if (auto lit = literal()) { + if (lit->null_count() == lit->length()) { + return false; + } + if (lit->is_scalar() && lit->type()->id() == Type::BOOL) { + return lit->scalar_as().value; + } + } + return true; +} + +bool KernelStateIsImmutable(const std::string& function) { + // XXX maybe just add Kernel::state_is_immutable or so? + + // known functions with non-null but nevertheless immutable KernelState + static std::unordered_set immutable_state = { + "is_in", "index_in", "cast", "struct", "strptime", + }; + + return immutable_state.find(function) != immutable_state.end(); +} + +Result> InitKernelState( + const Expression2::Call& call, compute::ExecContext* exec_context) { + if (!call.kernel->init) return nullptr; + + compute::KernelContext kernel_context(exec_context); + auto kernel_state = call.kernel->init( + &kernel_context, {call.kernel, GetDescriptors(call.arguments), call.options.get()}); + + RETURN_NOT_OK(kernel_context.status()); + return std::move(kernel_state); +} + +Result> ExpressionState::Clone( + const std::shared_ptr& state, const Expression2& expr, + compute::ExecContext* exec_context) { + if (!expr.IsBound()) { + return Status::Invalid("Cannot clone State against an unbound expression."); + } + + if (state == nullptr) return nullptr; + + auto call = CallNotNull(expr); + auto call_state = checked_cast(state.get()); + + CallState clone; + clone.argument_states.resize(call_state->argument_states.size()); + + bool recursively_share = true; + for (size_t i = 0; i < call->arguments.size(); ++i) { + ARROW_ASSIGN_OR_RAISE( + clone.argument_states[i], + Clone(call_state->argument_states[i], call->arguments[i], exec_context)); + + if (clone.argument_states[i] != call_state->argument_states[i]) { + recursively_share = false; + } + } + + if (call_state->kernel_state == nullptr || KernelStateIsImmutable(call->function)) { + // The kernel's state is immutable so it's safe to just + // share a pointer between threads + if (recursively_share) { + return state; + } + clone.kernel_state = call_state->kernel_state; + } else { + // The kernel's state must be re-initialized. + ARROW_ASSIGN_OR_RAISE(clone.kernel_state, InitKernelState(*call, exec_context)); + } + + return std::make_shared(std::move(clone)); +} + +Result Expression2::Bind( + ValueDescr in, compute::ExecContext* exec_context) const { + if (exec_context == nullptr) { + compute::ExecContext exec_context; + return Bind(std::move(in), &exec_context); + } + + if (literal()) return StateAndBound{*this, nullptr}; + + if (auto ref = field_ref()) { + ARROW_ASSIGN_OR_RAISE(auto field, ref->GetOneOrNone(*in.type)); + auto bound = *this; + bound.descr_ = + field ? ValueDescr{field->type(), in.shape} : ValueDescr::Scalar(null()); + return StateAndBound{std::move(bound), nullptr}; + } + + auto bound_call = *CallNotNull(*this); + + std::shared_ptr function; + if (bound_call.function == "cast") { + // XXX this special case is strange; why not make "cast" a ScalarFunction? + const auto& to_type = + checked_cast(*bound_call.options).to_type; + ARROW_ASSIGN_OR_RAISE(function, compute::GetCastFunction(to_type)); + } else { + ARROW_ASSIGN_OR_RAISE( + function, exec_context->func_registry()->GetFunction(bound_call.function)); + } + bound_call.function_kind = function->kind(); + + std::vector> argument_states( + bound_call.arguments.size()); + for (size_t i = 0; i < argument_states.size(); ++i) { + ARROW_ASSIGN_OR_RAISE(std::tie(bound_call.arguments[i], argument_states[i]), + bound_call.arguments[i].Bind(in, exec_context)); + } + + auto descrs = GetDescriptors(bound_call.arguments); + ARROW_ASSIGN_OR_RAISE(bound_call.kernel, function->DispatchExact(descrs)); + + auto call_state = std::make_shared(); + ARROW_ASSIGN_OR_RAISE(call_state->kernel_state, + InitKernelState(bound_call, exec_context)); + call_state->argument_states = std::move(argument_states); + + compute::KernelContext kernel_context(exec_context); + kernel_context.SetState(call_state->kernel_state.get()); + + auto bound = Expression2(std::make_shared(std::move(bound_call))); + ARROW_ASSIGN_OR_RAISE(bound.descr_, bound_call.kernel->signature->out_type().Resolve( + &kernel_context, descrs)); + return StateAndBound{std::move(bound), std::move(call_state)}; +} + +Result Expression2::Bind( + const Schema& in_schema, compute::ExecContext* exec_context) const { + return Bind(ValueDescr::Array(struct_(in_schema.fields())), exec_context); +} + +Result ExecuteScalarExpression(const Expression2& expr, ExpressionState* state, + const Datum& input, + compute::ExecContext* exec_context) { + if (exec_context == nullptr) { + compute::ExecContext exec_context; + return ExecuteScalarExpression(expr, state, input, &exec_context); + } + + if (!expr.IsBound()) { + return Status::Invalid("Cannot Execute unbound expression."); + } + + if (!expr.IsScalarExpression()) { + return Status::Invalid( + "ExecuteScalarExpression cannot Execute non-scalar expression ", expr.ToString()); + } + + if (auto lit = expr.literal()) return *lit; + + if (auto ref = expr.field_ref()) { + ARROW_ASSIGN_OR_RAISE(Datum field, GetDatumField(*ref, input)); + + if (field.descr() != expr.descr()) { + // Refernced field was present but didn't have the expected type. + // Should we just error here? For now, pay dispatch cost and just cast. + ARROW_ASSIGN_OR_RAISE( + field, compute::Cast(field, expr.descr().type, compute::CastOptions::Safe(), + exec_context)); + } + + return field; + } + + auto call = CallNotNull(expr); + auto call_state = checked_cast(state); + + std::vector arguments(call->arguments.size()); + for (size_t i = 0; i < arguments.size(); ++i) { + auto argument_state = call_state->argument_states[i].get(); + ARROW_ASSIGN_OR_RAISE( + arguments[i], + ExecuteScalarExpression(call->arguments[i], argument_state, input, exec_context)); + } + + auto executor = compute::detail::KernelExecutor::MakeScalar(); + + compute::KernelContext kernel_context(exec_context); + kernel_context.SetState(call_state->kernel_state.get()); + + auto kernel = call->kernel; + auto inputs = GetDescriptors(call->arguments); + auto options = call->options.get(); + RETURN_NOT_OK(executor->Init(&kernel_context, {kernel, inputs, options})); + + auto listener = std::make_shared(); + RETURN_NOT_OK(executor->Execute(arguments, listener.get())); + return executor->WrapResults(arguments, listener->values()); +} + +std::array, 2> +ArgumentsAndFlippedArguments(const Expression2::Call& call) { + DCHECK_EQ(call.arguments.size(), 2); + return {std::pair{call.arguments[0], + call.arguments[1]}, + std::pair{call.arguments[1], + call.arguments[0]}}; +} + +util::optional GetNullHandling( + const Expression2::Call& call) { + if (call.function_kind == compute::Function::SCALAR) { + return static_cast(call.kernel)->null_handling; + } + return util::nullopt; +} + +bool DefinitelyNotNull(const Expression2& expr) { + DCHECK(expr.IsBound()); + + if (expr.literal()) { + return !expr.IsNullLiteral(); + } + + if (expr.field_ref()) return false; + + auto call = CallNotNull(expr); + if (auto null_handling = GetNullHandling(*call)) { + if (null_handling == compute::NullHandling::OUTPUT_NOT_NULL) { + return true; + } + if (null_handling == compute::NullHandling::INTERSECTION) { + return std::all_of(call->arguments.begin(), call->arguments.end(), + DefinitelyNotNull); + } + } + + return false; +} + +template +Result Modify(Expression2 expr, const PreVisit& pre, + const PostVisitCall& post_call) { + ARROW_ASSIGN_OR_RAISE(expr, Result(pre(std::move(expr)))); + + auto call = expr.call(); + if (!call) return expr; + + bool at_least_one_modified = false; + auto modified_call = *call; + auto modified_argument = modified_call.arguments.begin(); + + for (const auto& argument : call->arguments) { + ARROW_ASSIGN_OR_RAISE(*modified_argument, Modify(argument, pre, post_call)); + + if (!Identical(*modified_argument, argument)) { + at_least_one_modified = true; + } + ++modified_argument; + } + + if (at_least_one_modified) { + // reconstruct the call expression with the modified arguments + auto modified_expr = Expression2( + std::make_shared(std::move(modified_call)), expr.descr()); + return post_call(std::move(modified_expr), &expr); + } + + return post_call(std::move(expr), nullptr); +} + +Result FoldConstants(Expression2 expr, ExpressionState* state) { + DCHECK(expr.IsBound()); + + struct StateAndIndex { + CallState* state; + int index; + }; + std::vector stack; + + auto root_state = checked_cast(state); + return Modify( + std::move(expr), + [&stack, root_state](Expression2 expr) { + auto call = expr.call(); + + if (stack.empty()) { + stack = {{root_state, 0}}; + return expr; + } + + int i = stack.back().index++; + + if (!call) return expr; + auto next_state = stack.back().state->argument_states[i].get(); + stack.push_back({checked_cast(next_state), 0}); + + return expr; + }, + [&stack](Expression2 expr, ...) -> Result { + auto state = stack.back().state; + stack.pop_back(); + + auto call = CallNotNull(expr); + if (std::all_of(call->arguments.begin(), call->arguments.end(), + [](const Expression2& argument) { return argument.literal(); })) { + // all arguments are literal; we can evaluate this subexpression *now* + static const Datum ignored_input; + ARROW_ASSIGN_OR_RAISE(Datum constant, + ExecuteScalarExpression(expr, state, ignored_input)); + return literal(std::move(constant)); + } + + // XXX the following should probably be in a registry of passes instead of inline + + if (GetNullHandling(*call) == compute::NullHandling::INTERSECTION) { + // kernels which always produce intersected validity can be resolved to null + // *now* if any of their inputs is a null literal + for (const auto& argument : call->arguments) { + if (argument.IsNullLiteral()) return argument; + } + } + + if (call->function == "and_kleene") { + for (auto args : ArgumentsAndFlippedArguments(*call)) { + // true and x == x + if (args.first == literal(true)) return args.second; + + // false and x == false + if (args.first == literal(false)) return args.first; + } + return expr; + } + + if (call->function == "or_kleene") { + for (auto args : ArgumentsAndFlippedArguments(*call)) { + // false or x == x + if (args.first == literal(false)) return args.second; + + // true or x == true + if (args.first == literal(true)) return args.first; + } + return expr; + } + + return expr; + }); + + return expr; +} + +struct FlattenedAssociativeChain { + bool was_left_folded = true; + std::vector exprs, fringe; + + explicit FlattenedAssociativeChain(Expression2 expr) : exprs{std::move(expr)} { + auto call = CallNotNull(exprs.back()); + fringe = call->arguments; + + auto it = fringe.begin(); + + while (it != fringe.end()) { + auto sub_call = it->call(); + if (!sub_call || sub_call->function != call->function) { + ++it; + continue; + } + + if (it != fringe.begin()) { + was_left_folded = false; + } + + exprs.push_back(std::move(*it)); + it = fringe.erase(it); + it = fringe.insert(it, sub_call->arguments.begin(), sub_call->arguments.end()); + // NB: no increment so we hit sub_call's first argument next iteration + } + + DCHECK(std::all_of(exprs.begin(), exprs.end(), [](const Expression2& expr) { + return CallNotNull(expr)->options == nullptr; + })); + } +}; + +inline std::vector GuaranteeConjunctionMembers( + const Expression2& guaranteed_true_predicate) { + auto guarantee = guaranteed_true_predicate.call(); + if (!guarantee || guarantee->function != "and_kleene") { + return {guaranteed_true_predicate}; + } + return FlattenedAssociativeChain(guaranteed_true_predicate).fringe; +} + +// Conjunction members which are represented in known_values are erased from +// conjunction_members +Status ExtractKnownFieldValuesImpl( + std::vector* conjunction_members, + std::unordered_map* known_values) { + for (auto&& member : *conjunction_members) { + ARROW_ASSIGN_OR_RAISE(member, Canonicalize(std::move(member))); + } + + auto unconsumed_end = + std::partition(conjunction_members->begin(), conjunction_members->end(), + [](const Expression2& expr) { + // search for an equality conditions between a field and a literal + auto call = expr.call(); + if (!call) return true; + + if (call->function == "equal") { + auto ref = call->arguments[0].field_ref(); + auto lit = call->arguments[1].literal(); + return !(ref && lit); + } + + return true; + }); + + for (auto it = unconsumed_end; it != conjunction_members->end(); ++it) { + auto call = CallNotNull(*it); + + auto ref = call->arguments[0].field_ref(); + auto lit = call->arguments[1].literal(); + + auto it_success = known_values->emplace(*ref, *lit); + if (it_success.second) continue; + + // A value was already known for ref; check it + auto ref_lit = it_success.first; + if (*lit != ref_lit->second) { + return Status::Invalid("Conflicting guarantees: (", ref->ToString(), + " == ", lit->ToString(), ") vs (", ref->ToString(), + " == ", ref_lit->second.ToString()); + } + } + + conjunction_members->erase(unconsumed_end, conjunction_members->end()); + + return Status::OK(); +} + +Result> ExtractKnownFieldValues( + const Expression2& guaranteed_true_predicate) { + auto conjunction_members = GuaranteeConjunctionMembers(guaranteed_true_predicate); + std::unordered_map known_values; + RETURN_NOT_OK(ExtractKnownFieldValuesImpl(&conjunction_members, &known_values)); + return known_values; +} + +Result ReplaceFieldsWithKnownValues( + const std::unordered_map& known_values, + Expression2 expr) { + return Modify( + std::move(expr), + [&known_values](Expression2 expr) { + if (auto ref = expr.field_ref()) { + auto it = known_values.find(*ref); + if (it != known_values.end()) { + return literal(it->second); + } + } + return expr; + }, + [](Expression2 expr, ...) { return expr; }); +} + +template ::value_type> +util::optional FoldLeft(It begin, It end, const BinOp& bin_op) { + if (begin == end) return util::nullopt; + + Out folded = std::move(*begin++); + while (begin != end) { + folded = bin_op(std::move(folded), std::move(*begin++)); + } + return folded; +} + +inline bool IsBinaryAssociativeCommutative(const Expression2::Call& call) { + static std::unordered_set binary_associative_commutative{ + "and", "or", "and_kleene", "or_kleene", "xor", + "multiply", "add", "multiply_checked", "add_checked"}; + + auto it = binary_associative_commutative.find(call.function); + return it != binary_associative_commutative.end(); +} + +struct Comparison { + enum type { + NA = 0, + EQUAL = 1, + LESS = 2, + GREATER = 4, + NOT_EQUAL = LESS | GREATER, + LESS_EQUAL = LESS | EQUAL, + GREATER_EQUAL = GREATER | EQUAL, + }; + + static const type* Get(const std::string& function) { + static std::unordered_map flipped_comparisons{ + {"equal", EQUAL}, {"not_equal", NOT_EQUAL}, + {"less", LESS}, {"less_equal", LESS_EQUAL}, + {"greater", GREATER}, {"greater_equal", GREATER_EQUAL}, + }; + + auto it = flipped_comparisons.find(function); + return it != flipped_comparisons.end() ? &it->second : nullptr; + } + + static const type* Get(const Expression2& expr) { + if (auto call = expr.call()) { + return Comparison::Get(call->function); + } + return nullptr; + } + + static Result Execute(Datum l, Datum r) { + if (!l.is_scalar() || !r.is_scalar()) { + return Status::Invalid("Cannot Execute Comparison on non-scalars"); + } + std::vector arguments{std::move(l), std::move(r)}; + + ARROW_ASSIGN_OR_RAISE(auto equal, compute::CallFunction("equal", arguments)); + + if (!equal.scalar()->is_valid) return NA; + if (equal.scalar_as().value) return EQUAL; + + ARROW_ASSIGN_OR_RAISE(auto less, compute::CallFunction("less", arguments)); + + if (!less.scalar()->is_valid) return NA; + return less.scalar_as().value ? LESS : GREATER; + } + + static type GetFlipped(type op) { + switch (op) { + case NA: + return NA; + case EQUAL: + return EQUAL; + case LESS: + return GREATER; + case GREATER: + return LESS; + case NOT_EQUAL: + return NOT_EQUAL; + case LESS_EQUAL: + return GREATER_EQUAL; + case GREATER_EQUAL: + return LESS_EQUAL; + }; + DCHECK(false); + return NA; + } + + static std::string GetName(type op) { + switch (op) { + case NA: + DCHECK(false) << "unreachable"; + break; + case EQUAL: + return "equal"; + case LESS: + return "less"; + case GREATER: + return "greater"; + case NOT_EQUAL: + return "not_equal"; + case LESS_EQUAL: + return "less_equal"; + case GREATER_EQUAL: + return "greater_equal"; + }; + DCHECK(false); + return "na"; + } +}; + +Result Canonicalize(Expression2 expr) { + // If potentially reconstructing more deeply than a call's immediate arguments + // (for example, when reorganizing an associative chain), add expressions to this set to + // avoid unnecessary work + struct { + std::unordered_set set_; + + bool operator()(const Expression2& expr) const { + return set_.find(expr) != set_.end(); + } + + void Add(std::vector exprs) { + std::move(exprs.begin(), exprs.end(), std::inserter(set_, set_.end())); + } + } AlreadyCanonicalized; + + return Modify( + std::move(expr), + [&AlreadyCanonicalized](Expression2 expr) -> Result { + auto call = expr.call(); + if (!call) return expr; + + if (AlreadyCanonicalized(expr)) return expr; + + if (IsBinaryAssociativeCommutative(*call)) { + struct { + int Priority(const Expression2& operand) const { + // order literals first, starting with nulls + if (operand.IsNullLiteral()) return 0; + if (operand.literal()) return 1; + return 2; + } + bool operator()(const Expression2& l, const Expression2& r) const { + return Priority(l) < Priority(r); + } + } CanonicalOrdering; + + FlattenedAssociativeChain chain(expr); + if (chain.was_left_folded && + std::is_sorted(chain.fringe.begin(), chain.fringe.end(), + CanonicalOrdering)) { + AlreadyCanonicalized.Add(std::move(chain.exprs)); + return expr; + } + + std::stable_sort(chain.fringe.begin(), chain.fringe.end(), CanonicalOrdering); + + // fold the chain back up + const auto& descr = expr.descr(); + auto folded = FoldLeft( + chain.fringe.begin(), chain.fringe.end(), + [call, &descr, &AlreadyCanonicalized](Expression2 l, Expression2 r) { + auto ret = *call; + ret.arguments = {std::move(l), std::move(r)}; + Expression2 expr(std::make_shared(std::move(ret)), + descr); + AlreadyCanonicalized.Add({expr}); + return expr; + }); + return std::move(*folded); + } + + if (auto cmp = Comparison::Get(call->function)) { + if (call->arguments[0].literal() && !call->arguments[1].literal()) { + // ensure that literals are on comparisons' RHS + auto flipped_call = *call; + for (auto&& argument : flipped_call.arguments) { + ARROW_ASSIGN_OR_RAISE(argument, Canonicalize(std::move(argument))); + } + flipped_call.function = Comparison::GetName(Comparison::GetFlipped(*cmp)); + std::swap(flipped_call.arguments[0], flipped_call.arguments[1]); + return Expression2( + std::make_shared(std::move(flipped_call))); + } + } + + return expr; + }, + [](Expression2 expr, ...) { return expr; }); +} + +Result DirectComparisonSimplification(Expression2 expr, + const Expression2::Call& guarantee) { + return Modify( + std::move(expr), [](Expression2 expr) { return expr; }, + [&guarantee](Expression2 expr, ...) -> Result { + auto call = expr.call(); + if (!call) return expr; + + // Ensure both calls are comparisons with equal LHS and scalar RHS + auto cmp = Comparison::Get(expr); + auto cmp_guarantee = Comparison::Get(guarantee.function); + if (!cmp || !cmp_guarantee) return expr; + + if (call->arguments[0] != guarantee.arguments[0]) return expr; + + auto rhs = call->arguments[1].literal(); + auto guarantee_rhs = guarantee.arguments[1].literal(); + if (!rhs || !guarantee_rhs) return expr; + + if (!rhs->is_scalar() || !guarantee_rhs->is_scalar()) { + return expr; + } + + ARROW_ASSIGN_OR_RAISE(auto cmp_rhs_guarantee_rhs, + Comparison::Execute(*rhs, *guarantee_rhs)); + DCHECK_NE(cmp_rhs_guarantee_rhs, Comparison::NA); + + if (cmp_rhs_guarantee_rhs == Comparison::EQUAL) { + // RHS of filter is equal to RHS of guarantee + + if ((*cmp_guarantee & *cmp) == *cmp_guarantee) { + // guarantee is a subset of filter, so all data will be included + return literal(true); + } + + if ((*cmp_guarantee & *cmp) == 0) { + // guarantee disjoint with filter, so all data will be excluded + return literal(false); + } + + return expr; + } + + if (*cmp_guarantee & cmp_rhs_guarantee_rhs) { + // unusable guarantee + return expr; + } + + if (*cmp & Comparison::GetFlipped(cmp_rhs_guarantee_rhs)) { + // x > 1, x >= 1, x != 1 guaranteed by x >= 3 + return literal(true); + } else { + // x < 1, x <= 1, x == 1 unsatisfiable if x >= 3 + return literal(false); + } + }); +} + +Result SimplifyWithGuarantee(Expression2 expr, ExpressionState* state, + const Expression2& guaranteed_true_predicate) { + auto conjunction_members = GuaranteeConjunctionMembers(guaranteed_true_predicate); + + std::unordered_map known_values; + RETURN_NOT_OK(ExtractKnownFieldValuesImpl(&conjunction_members, &known_values)); + + ARROW_ASSIGN_OR_RAISE(expr, + ReplaceFieldsWithKnownValues(known_values, std::move(expr))); + + auto CanonicalizeAndFoldConstants = [&expr, state] { + ARROW_ASSIGN_OR_RAISE(expr, Canonicalize(std::move(expr))); + ARROW_ASSIGN_OR_RAISE(expr, FoldConstants(std::move(expr), state)); + return Status::OK(); + }; + RETURN_NOT_OK(CanonicalizeAndFoldConstants()); + + for (const auto& guarantee : conjunction_members) { + if (Comparison::Get(guarantee) && guarantee.call()->arguments[1].literal()) { + ARROW_ASSIGN_OR_RAISE( + auto simplified, DirectComparisonSimplification(expr, *CallNotNull(guarantee))); + + if (Identical(simplified, expr)) continue; + + expr = std::move(simplified); + RETURN_NOT_OK(CanonicalizeAndFoldConstants()); + } + } + + return expr; +} + +} // namespace dataset +} // namespace arrow diff --git a/cpp/src/arrow/dataset/expression.h b/cpp/src/arrow/dataset/expression.h new file mode 100644 index 00000000000..4da41b8b9c4 --- /dev/null +++ b/cpp/src/arrow/dataset/expression.h @@ -0,0 +1,217 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// This API is EXPERIMENTAL. + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "arrow/chunked_array.h" +#include "arrow/compute/api_scalar.h" +#include "arrow/compute/cast.h" +#include "arrow/dataset/type_fwd.h" +#include "arrow/dataset/visibility.h" +#include "arrow/datum.h" +#include "arrow/result.h" +#include "arrow/scalar.h" +#include "arrow/type_fwd.h" +#include "arrow/util/variant.h" + +namespace arrow { +namespace dataset { + +/// An unbound expression which maps a single Datum to another Datum. +/// An expression is one of +/// - A literal Datum. +/// - A reference to a single (potentially nested) field of the input Datum. +/// - A call to a compute function, with arguments specified by other Expressions. +/// - Optionally, a explicitly selected Kernel of the function. If provided, +/// execution will skip function lookup and kernel dispatch and use that Kernel +/// directly. +class ARROW_DS_EXPORT Expression2 { + public: + struct Call { + std::string function; + std::vector arguments; + std::shared_ptr options; + + // post-Bind properties: + const compute::Kernel* kernel = NULLPTR; + compute::Function::Kind function_kind; + }; + + std::string ToString() const; + bool Equals(const Expression2& other) const; + size_t hash() const; + struct Hash { + size_t operator()(const Expression2& expr) const { return expr.hash(); } + }; + + /// Bind this expression to the given input type, looking up Kernels and field types. + /// Some expression simplification may be performed and implicit casts will be inserted. + /// Any state necessary for execution will be initialized and returned. + using StateAndBound = std::pair>; + Result Bind(ValueDescr in, compute::ExecContext* = NULLPTR) const; + Result Bind(const Schema& in_schema, + compute::ExecContext* = NULLPTR) const; + + /// Return true if all an expression's field references have explicit ValueDescr and all + /// of its functions' kernels are looked up. + bool IsBound() const; + + /// Return true if this expression is composed only of Scalar literals, field + /// references, and calls to ScalarFunctions. + bool IsScalarExpression() const; + + /// Return true if this expression is literal and entirely null. + bool IsNullLiteral() const; + + /// Return true if this expression could evaluate to true. + bool IsSatisfiable() const; + + // XXX someday + // Result GetPipelines(); + + const Call* call() const; + const Datum* literal() const; + const FieldRef* field_ref() const; + + const ValueDescr& descr() const { return descr_; } + + using Impl = util::variant; + + explicit Expression2(std::shared_ptr impl, ValueDescr descr = {}) + : impl_(std::move(impl)), descr_(std::move(descr)) {} + + private: + std::shared_ptr impl_; + ValueDescr descr_; + // XXX someday + // NullGeneralization::type evaluates_to_null_; + + friend bool Identical(const Expression2& l, const Expression2& r); + + ARROW_EXPORT friend void PrintTo(const Expression2&, std::ostream*); +}; + +struct ExpressionState { + virtual ~ExpressionState() = default; + + // Produce another instance of ExpressionState which may be safely used from a + // different thread + static Result> Clone( + const std::shared_ptr&, const Expression2&, compute::ExecContext*); +}; + +inline bool operator==(const Expression2& l, const Expression2& r) { return l.Equals(r); } +inline bool operator!=(const Expression2& l, const Expression2& r) { + return !l.Equals(r); +} + +// Factories + +inline Expression2 call(std::string function, std::vector arguments, + std::shared_ptr options = NULLPTR) { + Expression2::Call call; + call.function = std::move(function); + call.arguments = std::move(arguments); + call.options = std::move(options); + return Expression2(std::make_shared(std::move(call))); +} + +template ::value>::type> +Expression2 call(std::string function, std::vector arguments, + Options options) { + return call(std::move(function), std::move(arguments), + std::make_shared(std::move(options))); +} + +inline Expression2 project(std::vector names, + std::vector values) { + return call("struct", std::move(values), compute::StructOptions{std::move(names)}); +} + +template +Expression2 field_ref(Args&&... args) { + return Expression2( + std::make_shared(FieldRef(std::forward(args)...))); +} + +template +Expression2 literal(Arg&& arg) { + Datum lit(std::forward(arg)); + ValueDescr descr = lit.descr(); + return Expression2(std::make_shared(std::move(lit)), + std::move(descr)); +} + +// Simplification passes + +/// Weak canonicalization which establishes guarantees for subsequent passes. Even +/// equivalent Expressions may result in different canonicalized expressions. +/// TODO this could be a strong canonicalization +ARROW_DS_EXPORT +Result Canonicalize(Expression2 expr); + +/// Simplify Expressions based on literal arguments (for example, add(null, x) will always +/// be null so replace the call with a null literal). Includes early evaluation of all +/// calls whose arguments are entirely literal. +ARROW_DS_EXPORT +Result FoldConstants(Expression2 expr, ExpressionState*); + +ARROW_DS_EXPORT +Result> ExtractKnownFieldValues( + const Expression2& guaranteed_true_predicate); + +ARROW_DS_EXPORT +Result ReplaceFieldsWithKnownValues( + const std::unordered_map& known_values, + Expression2 expr); + +/// Simplify an expression by replacing subexpressions based on a guarantee: +/// a boolean expression which is guaranteed to evaluate to `true`. For example, this is +/// used to remove redundant function calls from a filter expression or to replace a +/// reference to a constant-value field with a literal. +ARROW_DS_EXPORT +Result SimplifyWithGuarantee(Expression2 expr, ExpressionState*, + const Expression2& guaranteed_true_predicate); + +// Execution + +/// Execute a scalar expression against the provided state and input Datum. This +/// expression must be bound. +Result ExecuteScalarExpression(const Expression2&, ExpressionState*, + const Datum& input, + compute::ExecContext* exec_context = NULLPTR); + +// Serialization + +ARROW_DS_EXPORT +Result> Serialize(const Expression2&); + +ARROW_DS_EXPORT +Result Deserialize(const Buffer&); + +} // namespace dataset +} // namespace arrow diff --git a/cpp/src/arrow/dataset/expression_internal.h b/cpp/src/arrow/dataset/expression_internal.h new file mode 100644 index 00000000000..1998a565ec3 --- /dev/null +++ b/cpp/src/arrow/dataset/expression_internal.h @@ -0,0 +1,110 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/dataset/expression.h" + +#include +#include +#include + +#include "arrow/record_batch.h" +#include "arrow/table.h" +#include "arrow/util/logging.h" + +namespace arrow { + +using internal::checked_cast; + +namespace dataset { + +bool Identical(const Expression2& l, const Expression2& r) { return l.impl_ == r.impl_; } + +const Expression2::Call* CallNotNull(const Expression2& expr) { + auto call = expr.call(); + DCHECK_NE(call, nullptr); + return call; +} + +inline void GetAllFieldRefs(const Expression2& expr, + std::unordered_set* refs) { + if (auto lit = expr.literal()) return; + + if (auto ref = expr.field_ref()) { + refs->emplace(*ref); + return; + } + + for (const Expression2& arg : CallNotNull(expr)->arguments) { + GetAllFieldRefs(arg, refs); + } +} + +inline std::vector GetDescriptors(const std::vector& exprs) { + std::vector descrs(exprs.size()); + for (size_t i = 0; i < exprs.size(); ++i) { + DCHECK(exprs[i].IsBound()); + descrs[i] = exprs[i].descr(); + } + return descrs; +} + +struct CallState : ExpressionState { + std::vector> argument_states; + std::shared_ptr kernel_state; +}; + +struct FieldPathGetDatumImpl { + template ()))> + Result operator()(const std::shared_ptr& ptr) { + return path_.Get(*ptr).template As(); + } + + template + Result operator()(const T&) { + return Status::NotImplemented("FieldPath::Get() into Datum ", datum_.ToString()); + } + + const Datum& datum_; + const FieldPath& path_; +}; + +inline Result GetDatumField(const FieldRef& ref, const Datum& input) { + Datum field; + + FieldPath path; + if (auto type = input.type()) { + ARROW_ASSIGN_OR_RAISE(path, ref.FindOneOrNone(*input.type())); + } else if (input.kind() == Datum::RECORD_BATCH) { + ARROW_ASSIGN_OR_RAISE(path, ref.FindOneOrNone(*input.record_batch()->schema())); + } else if (input.kind() == Datum::TABLE) { + ARROW_ASSIGN_OR_RAISE(path, ref.FindOneOrNone(*input.table()->schema())); + } + + if (path) { + ARROW_ASSIGN_OR_RAISE(field, + util::visit(FieldPathGetDatumImpl{input, path}, input.value)); + } + + if (field == Datum{}) { + field = Datum(std::make_shared()); + } + + return field; +} + +} // namespace dataset +} // namespace arrow diff --git a/cpp/src/arrow/dataset/expression_test.cc b/cpp/src/arrow/dataset/expression_test.cc new file mode 100644 index 00000000000..7cc98a1e151 --- /dev/null +++ b/cpp/src/arrow/dataset/expression_test.cc @@ -0,0 +1,729 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/dataset/expression.h" + +#include +#include +#include +#include +#include + +#include +#include + +#include "arrow/compute/registry.h" +#include "arrow/dataset/expression_internal.h" +#include "arrow/testing/gtest_util.h" + +using testing::UnorderedElementsAreArray; + +namespace arrow { + +using internal::checked_cast; +using internal::checked_pointer_cast; + +namespace dataset { + +const Schema kBoringSchema{{ + field("i32", int32()), + field("f32", float32()), +}}; + +#define EXPECT_OK ARROW_EXPECT_OK + +TEST(Expression2, ToString) { + EXPECT_EQ(field_ref("alpha").ToString(), "FieldRef(alpha)"); + + EXPECT_EQ(literal(3).ToString(), "3"); + + EXPECT_EQ(call("add", {literal(3), field_ref("beta")}).ToString(), + "add(3,FieldRef(beta))"); + + EXPECT_EQ(call("add", {literal(3), call("index_in", {field_ref("beta")})}).ToString(), + "add(3,index_in(FieldRef(beta)))"); +} + +TEST(Expression2, Equality) { + EXPECT_EQ(literal(1), literal(1)); + + EXPECT_EQ(field_ref("a"), field_ref("a")); + + EXPECT_EQ(call("add", {literal(3), field_ref("a")}), + call("add", {literal(3), field_ref("a")})); + + EXPECT_EQ(call("add", {literal(3), call("index_in", {field_ref("beta")})}), + call("add", {literal(3), call("index_in", {field_ref("beta")})})); +} + +TEST(Expression2, Hash) { + std::unordered_set set; + + EXPECT_TRUE(set.emplace(field_ref("alpha")).second); + EXPECT_TRUE(set.emplace(field_ref("beta")).second); + EXPECT_FALSE(set.emplace(field_ref("beta")).second) << "already inserted"; + EXPECT_TRUE(set.emplace(literal(1)).second); + EXPECT_FALSE(set.emplace(literal(1)).second) << "already inserted"; + EXPECT_TRUE(set.emplace(literal(3)).second); + + // NB: no validation on construction; we couldn't execute + // add with zero arguments + EXPECT_TRUE(set.emplace(call("add", {})).second); + EXPECT_FALSE(set.emplace(call("add", {})).second) << "already inserted"; + + // NB: unbound expressions don't check for availability in any registry + EXPECT_TRUE(set.emplace(call("widgetify", {})).second); + + EXPECT_EQ(set.size(), 6); +} + +TEST(Expression2, BindLiteral) { + for (Datum dat : { + Datum(3), + Datum(3.5), + Datum(ArrayFromJSON(int32(), "[1,2,3]")), + }) { + // literals are always considered bound + auto expr = literal(dat); + EXPECT_EQ(expr.descr(), dat.descr()); + EXPECT_TRUE(expr.IsBound()); + } +} + +TEST(Expression2, BindFieldRef) { + // an unbound field_ref does not have the output ValueDescr set + { + auto expr = field_ref("alpha"); + EXPECT_EQ(expr.descr(), ValueDescr{}); + EXPECT_FALSE(expr.IsBound()); + } + + { + auto expr = field_ref("alpha"); + // binding a field_ref looks up that field's type in the input Schema + ASSERT_OK_AND_ASSIGN(std::tie(expr, std::ignore), + expr.Bind(Schema({field("alpha", int32())}))); + EXPECT_EQ(expr.descr(), ValueDescr::Array(int32())); + EXPECT_TRUE(expr.IsBound()); + } + + { + // if the field is not found, a null scalar will be emitted + auto expr = field_ref("alpha"); + ASSERT_OK_AND_ASSIGN(std::tie(expr, std::ignore), expr.Bind(Schema({}))); + EXPECT_EQ(expr.descr(), ValueDescr::Scalar(null())); + EXPECT_TRUE(expr.IsBound()); + } + + { + // referencing a field by name is not supported if that name is not unique + // in the input schema + auto expr = field_ref("alpha"); + ASSERT_RAISES( + Invalid, expr.Bind(Schema({field("alpha", int32()), field("alpha", float32())}))); + } + + { + // referencing nested fields is supported + auto expr = field_ref("a", "b"); + ASSERT_OK_AND_ASSIGN(std::tie(expr, std::ignore), + expr.Bind(Schema({field("a", struct_({field("b", int32())}))}))); + EXPECT_EQ(expr.descr(), ValueDescr::Array(int32())); + EXPECT_TRUE(expr.IsBound()); + } +} + +TEST(Expression2, BindCall) { + auto expr = call("add", {field_ref("a"), field_ref("b")}); + EXPECT_FALSE(expr.IsBound()); + + ASSERT_OK_AND_ASSIGN(std::tie(expr, std::ignore), + expr.Bind(Schema({field("a", int32()), field("b", int32())}))); + EXPECT_EQ(expr.descr(), ValueDescr::Array(int32())); + EXPECT_TRUE(expr.IsBound()); + + expr = call("add", {field_ref("a"), literal(3.5)}); + ASSERT_RAISES(NotImplemented, + expr.Bind(Schema({field("a", int32()), field("b", int32())}))); +} + +TEST(Expression2, BindNestedCall) { + auto expr = + call("add", {field_ref("a"), + call("subtract", {call("multiply", {field_ref("b"), field_ref("c")}), + field_ref("d")})}); + EXPECT_FALSE(expr.IsBound()); + + ASSERT_OK_AND_ASSIGN(std::tie(expr, std::ignore), + expr.Bind(Schema({field("a", int32()), field("b", int32()), + field("c", int32()), field("d", int32())}))); + EXPECT_EQ(expr.descr(), ValueDescr::Array(int32())); + EXPECT_TRUE(expr.IsBound()); +} + +TEST(Expression2, ExecuteFieldRef) { + auto AssertRefIs = [](FieldRef ref, Datum in, Datum expected) { + auto expr = field_ref(ref); + + ASSERT_OK_AND_ASSIGN(std::tie(expr, std::ignore), expr.Bind(in.descr())); + ASSERT_OK_AND_ASSIGN(Datum actual, + ExecuteScalarExpression(expr, /*state=*/nullptr, in)); + + AssertDatumsEqual(actual, expected, /*verbose=*/true); + }; + + AssertRefIs("a", ArrayFromJSON(struct_({field("a", float64())}), R"([ + {"a": 6.125}, + {"a": 0.0}, + {"a": -1} + ])"), + ArrayFromJSON(float64(), R"([6.125, 0.0, -1])")); + + // more nested: + AssertRefIs(FieldRef{"a", "a"}, + ArrayFromJSON(struct_({field("a", struct_({field("a", float64())}))}), R"([ + {"a": {"a": 6.125}}, + {"a": {"a": 0.0}}, + {"a": {"a": -1}} + ])"), + ArrayFromJSON(float64(), R"([6.125, 0.0, -1])")); + + // absent fields are resolved as a null scalar: + AssertRefIs(FieldRef{"b"}, ArrayFromJSON(struct_({field("a", float64())}), R"([ + {"a": 6.125}, + {"a": 0.0}, + {"a": -1} + ])"), + MakeNullScalar(float64())); +} + +Result NaiveExecuteScalarExpression(const Expression2& expr, + ExpressionState* state, const Datum& input) { + auto call = expr.call(); + if (call == nullptr) { + // already tested execution of field_ref, execution of literal is trivial + return ExecuteScalarExpression(expr, state, input); + } + + auto call_state = checked_cast(state); + + std::vector arguments(call->arguments.size()); + for (size_t i = 0; i < arguments.size(); ++i) { + auto argument_state = call_state->argument_states[i].get(); + ARROW_ASSIGN_OR_RAISE(arguments[i], NaiveExecuteScalarExpression( + call->arguments[i], argument_state, input)); + + EXPECT_EQ(call->arguments[i].descr(), arguments[i].descr()); + } + + ARROW_ASSIGN_OR_RAISE(auto function, + compute::GetFunctionRegistry()->GetFunction(call->function)); + ARROW_ASSIGN_OR_RAISE(auto expected_kernel, + function->DispatchExact(GetDescriptors(call->arguments))); + + EXPECT_EQ(call->kernel, expected_kernel); + + compute::ExecContext exec_context; + return function->Execute(arguments, call->options.get(), &exec_context); +} + +void AssertExecute(Expression2 expr, Datum in) { + std::shared_ptr state; + ASSERT_OK_AND_ASSIGN(std::tie(expr, state), expr.Bind(in.descr())); + + ASSERT_OK_AND_ASSIGN(Datum actual, ExecuteScalarExpression(expr, state.get(), in)); + + ASSERT_OK_AND_ASSIGN(Datum expected, + NaiveExecuteScalarExpression(expr, state.get(), in)); + + AssertDatumsEqual(actual, expected, /*verbose=*/true); +} + +TEST(Expression2, ExecuteCall) { + AssertExecute(call("add", {field_ref("a"), literal(3.5)}), + ArrayFromJSON(struct_({field("a", float64())}), R"([ + {"a": 6.125}, + {"a": 0.0}, + {"a": -1} + ])")); + + AssertExecute( + call("add", {field_ref("a"), call("subtract", {literal(3.5), field_ref("b")})}), + ArrayFromJSON(struct_({field("a", float64()), field("b", float64())}), R"([ + {"a": 6.125, "b": 3.375}, + {"a": 0.0, "b": 1}, + {"a": -1, "b": 4.75} + ])")); + + AssertExecute(call("strptime", {field_ref("a")}, + compute::StrptimeOptions("%m/%d/%Y", TimeUnit::MICRO)), + ArrayFromJSON(struct_({field("a", utf8())}), R"([ + {"a": "5/1/2020"}, + {"a": null}, + {"a": "12/11/1900"} + ])")); + + AssertExecute(project({"a + 3.5"}, {call("add", {field_ref("a"), literal(3.5)})}), + ArrayFromJSON(struct_({field("a", float64())}), R"([ + {"a": 6.125}, + {"a": 0.0}, + {"a": -1} + ])")); +} + +struct { + void operator()(Expression2 expr, Expression2 expected) { + std::shared_ptr state; + ASSERT_OK_AND_ASSIGN(std::tie(expr, state), expr.Bind(kBoringSchema)); + ASSERT_OK_AND_ASSIGN(std::tie(expected, std::ignore), expected.Bind(kBoringSchema)); + ASSERT_OK_AND_ASSIGN(auto actual, FoldConstants(expr, state.get())); + EXPECT_EQ(actual, expected); + if (actual == expr) { + // no change -> must be identical + EXPECT_TRUE(Identical(actual, expr)); + } + } +} ExpectFoldsTo; + +TEST(Expression2, FoldConstants) { + // literals are unchanged + ExpectFoldsTo(literal(3), literal(3)); + + // field_refs are unchanged + ExpectFoldsTo(field_ref("i32"), field_ref("i32")); + + // call against literals (3 + 2 == 5) + ExpectFoldsTo(call("add", {literal(3), literal(2)}), literal(5)); + + // call against literal and field_ref + ExpectFoldsTo(call("add", {literal(3), field_ref("i32")}), + call("add", {literal(3), field_ref("i32")})); + + // nested call against literals ((8 - (2 * 3)) + 2 == 4) + ExpectFoldsTo(call("add", + { + call("subtract", + { + literal(8), + call("multiply", {literal(2), literal(3)}), + }), + literal(2), + }), + literal(4)); + + // nested call against literals with one field_ref + // (i32 - (2 * 3)) + 2 == (i32 - 6) + 2 + // NB this could be improved further by using associativity of addition; another pass + ExpectFoldsTo(call("add", + { + call("subtract", + { + field_ref("i32"), + call("multiply", {literal(2), literal(3)}), + }), + literal(2), + }), + call("add", { + call("subtract", + { + field_ref("i32"), + literal(6), + }), + literal(2), + })); +} + +TEST(Expression2, FoldConstantsBoolean) { + // test and_kleene/or_kleene-specific optimizations + auto one = literal(1); + auto two = literal(2); + auto whatever = call("equal", {call("add", {one, field_ref("i32")}), two}); + + auto true_ = literal(true); + auto false_ = literal(false); + + ExpectFoldsTo(call("and_kleene", {false_, whatever}), false_); + ExpectFoldsTo(call("and_kleene", {true_, whatever}), whatever); + + ExpectFoldsTo(call("or_kleene", {true_, whatever}), true_); + ExpectFoldsTo(call("or_kleene", {false_, whatever}), whatever); +} + +TEST(Expression2, ExtractKnownFieldValues) { + struct { + void operator()(Expression2 guarantee, + std::unordered_map expected) { + ASSERT_OK_AND_ASSIGN(auto actual, ExtractKnownFieldValues(guarantee)); + EXPECT_THAT(actual, UnorderedElementsAreArray(expected)); + } + } ExpectKnown; + + ExpectKnown(call("equal", {field_ref("a"), literal(3)}), {{"a", Datum(3)}}); + + ExpectKnown(call("greater", {field_ref("a"), literal(3)}), {}); + + // FIXME known null should be expressed with is_null rather than equality + auto null_int32 = std::make_shared(); + ExpectKnown(call("equal", {field_ref("a"), literal(null_int32)}), + {{"a", Datum(null_int32)}}); + + ExpectKnown(call("and_kleene", {call("equal", {field_ref("a"), literal(3)}), + call("equal", {literal(1), field_ref("b")})}), + {{"a", Datum(3)}, {"b", Datum(1)}}); + + ExpectKnown(call("and_kleene", + {call("equal", {field_ref("a"), literal(3)}), + call("and_kleene", {call("equal", {field_ref("b"), literal(2)}), + call("equal", {literal(1), field_ref("c")})})}), + {{"a", Datum(3)}, {"b", Datum(2)}, {"c", Datum(1)}}); + + ExpectKnown(call("and_kleene", + {call("or_kleene", {call("equal", {field_ref("a"), literal(3)}), + call("equal", {field_ref("a"), literal(4)})}), + call("equal", {literal(1), field_ref("b")})}), + {{"b", Datum(1)}}); + + ExpectKnown( + call("and_kleene", + {call("equal", {field_ref("a"), literal(3)}), + call("and_kleene", {call("and_kleene", {field_ref("b"), field_ref("d")}), + call("equal", {literal(1), field_ref("c")})})}), + {{"a", Datum(3)}, {"c", Datum(1)}}); +} + +TEST(Expression2, ReplaceFieldsWithKnownValues) { + auto ExpectSimplifiesTo = + [](Expression2 expr, + std::unordered_map known_values, + Expression2 expected) { + ASSERT_OK_AND_ASSIGN(auto actual, + ReplaceFieldsWithKnownValues(known_values, expr)); + EXPECT_EQ(actual, expected); + + if (actual == expr) { + // no change -> must be identical + EXPECT_TRUE(Identical(actual, expr)); + } + }; + + std::unordered_map a_is_3{{"a", Datum(3)}}; + + ExpectSimplifiesTo(literal(1), a_is_3, literal(1)); + + ExpectSimplifiesTo(field_ref("a"), a_is_3, literal(3)); + + ExpectSimplifiesTo(field_ref("b"), a_is_3, field_ref("b")); + + ExpectSimplifiesTo(call("equal", {field_ref("a"), literal(1)}), a_is_3, + call("equal", {literal(3), literal(1)})); + + ExpectSimplifiesTo(call("add", + { + call("subtract", + { + field_ref("a"), + call("multiply", {literal(2), literal(3)}), + }), + literal(2), + }), + a_is_3, + call("add", { + call("subtract", + { + literal(3), + call("multiply", {literal(2), literal(3)}), + }), + literal(2), + })); +} + +struct { + void operator()(Expression2 expr, Expression2 expected) const { + ASSERT_OK_AND_ASSIGN(auto actual, Canonicalize(expr)); + EXPECT_EQ(actual, expected); + + if (expr.IsBound()) { + EXPECT_TRUE(actual.IsBound()); + } + + if (actual == expr) { + // no change -> must be identical + EXPECT_TRUE(Identical(actual, expr)); + } + } +} ExpectCanonicalizesTo; + +TEST(Expression2, CanonicalizeTrivial) { + ExpectCanonicalizesTo(literal(1), literal(1)); + + ExpectCanonicalizesTo(field_ref("b"), field_ref("b")); + + ExpectCanonicalizesTo(call("equal", {field_ref("a"), field_ref("b")}), + call("equal", {field_ref("a"), field_ref("b")})); +} + +TEST(Expression2, CanonicalizeAnd) { + // some aliases for brevity: + auto true_ = literal(true); + auto null_ = literal(std::make_shared()); + + auto a = field_ref("a"); + auto c = call("equal", {literal(1), literal(2)}); + + auto and_ = [](Expression2 l, Expression2 r) { return call("and_kleene", {l, r}); }; + + // no change possible: + ExpectCanonicalizesTo(and_(a, c), and_(a, c)); + + // literals are placed innermost + ExpectCanonicalizesTo(and_(a, true_), and_(true_, a)); + ExpectCanonicalizesTo(and_(true_, a), and_(true_, a)); + + ExpectCanonicalizesTo(and_(a, and_(true_, c)), and_(and_(true_, a), c)); + ExpectCanonicalizesTo(and_(a, and_(and_(true_, true_), c)), + and_(and_(and_(true_, true_), a), c)); + ExpectCanonicalizesTo(and_(a, and_(and_(true_, null_), c)), + and_(and_(and_(null_, true_), a), c)); + ExpectCanonicalizesTo(and_(a, and_(and_(true_, null_), and_(c, null_))), + and_(and_(and_(and_(null_, null_), true_), a), c)); + + // catches and_kleene even when it's a subexpression + ExpectCanonicalizesTo(call("is_valid", {and_(a, true_)}), + call("is_valid", {and_(true_, a)})); +} + +TEST(Expression2, CanonicalizeComparison) { + // some aliases for brevity: + auto equal = [](Expression2 l, Expression2 r) { return call("equal", {l, r}); }; + auto less = [](Expression2 l, Expression2 r) { return call("less", {l, r}); }; + auto greater = [](Expression2 l, Expression2 r) { return call("greater", {l, r}); }; + + ExpectCanonicalizesTo(equal(literal(1), field_ref("a")), + equal(field_ref("a"), literal(1))); + + ExpectCanonicalizesTo(equal(field_ref("a"), literal(1)), + equal(field_ref("a"), literal(1))); + + ExpectCanonicalizesTo(less(literal(1), field_ref("a")), + greater(field_ref("a"), literal(1))); + + ExpectCanonicalizesTo(less(field_ref("a"), literal(1)), + less(field_ref("a"), literal(1))); +} + +struct Simplify { + Expression2 filter; + + struct Expectable { + Expression2 filter, guarantee; + + void Expect(Expression2 expected) { + std::shared_ptr state; + ASSERT_OK_AND_ASSIGN(std::tie(filter, state), filter.Bind(kBoringSchema)); + + ASSERT_OK_AND_ASSIGN(std::tie(expected, std::ignore), expected.Bind(kBoringSchema)); + ASSERT_OK_AND_ASSIGN(auto simplified, + SimplifyWithGuarantee(filter, state.get(), guarantee)); + EXPECT_EQ(simplified, expected) << " original: " << filter.ToString() << "\n" + << " guarantee: " << guarantee.ToString() << "\n" + << (simplified == filter ? " (no change)\n" : ""); + + if (simplified == filter) { + EXPECT_TRUE(Identical(simplified, filter)); + } + } + void ExpectUnchanged() { Expect(filter); } + void Expect(bool constant) { Expect(literal(constant)); } + }; + + Expectable WithGuarantee(Expression2 guarantee) { return {filter, guarantee}; } +}; + +TEST(Expression2, SingleComparisonGuarantees) { + // some aliases for brevity: + auto equal = [](Expression2 l, Expression2 r) { return call("equal", {l, r}); }; + auto less = [](Expression2 l, Expression2 r) { return call("less", {l, r}); }; + auto greater = [](Expression2 l, Expression2 r) { return call("greater", {l, r}); }; + auto not_equal = [](Expression2 l, Expression2 r) { return call("not_equal", {l, r}); }; + auto less_equal = [](Expression2 l, Expression2 r) { + return call("less_equal", {l, r}); + }; + auto greater_equal = [](Expression2 l, Expression2 r) { + return call("greater_equal", {l, r}); + }; + auto i32 = field_ref("i32"); + + // i32 is guaranteed equal to 3, so the projection can just materialize that constant + // and need not incur IO + Simplify{project({"i32 + 1"}, {call("add", {i32, literal(1)})})} + .WithGuarantee(equal(i32, literal(3))) + .Expect(literal( + std::make_shared(ScalarVector{std::make_shared(4)}, + struct_({field("i32 + 1", int32())})))); + + // i32 is guaranteed equal to 5 everywhere, so filtering i32==5 is redundant and the + // filter can be simplified to true (== select everything) + Simplify{ + equal(i32, literal(5)), + } + .WithGuarantee(equal(i32, literal(5))) + .Expect(true); + + Simplify{ + equal(i32, literal(5)), + } + .WithGuarantee(equal(i32, literal(5))) + .Expect(true); + + Simplify{ + less_equal(i32, literal(5)), + } + .WithGuarantee(equal(i32, literal(5))) + .Expect(true); + + Simplify{ + less(i32, literal(5)), + } + .WithGuarantee(equal(i32, literal(3))) + .Expect(true); + + Simplify{ + greater_equal(i32, literal(5)), + } + .WithGuarantee(greater(i32, literal(5))) + .Expect(true); + + // i32 is guaranteed less than 3 everywhere, so filtering i32==5 is redundant and the + // filter can be simplified to false (== select nothing) + Simplify{ + equal(i32, literal(5)), + } + .WithGuarantee(less(i32, literal(3))) + .Expect(false); + + Simplify{ + less(i32, literal(5)), + } + .WithGuarantee(equal(i32, literal(5))) + .Expect(false); + + Simplify{ + less_equal(i32, literal(3)), + } + .WithGuarantee(equal(i32, literal(5))) + .Expect(false); + + // no simplification possible: + Simplify{ + not_equal(i32, literal(3)), + } + .WithGuarantee(less(i32, literal(5))) + .ExpectUnchanged(); + + // exhaustive coverage of all single comparison simplifications + for (std::string filter_op : + {"equal", "not_equal", "less", "less_equal", "greater", "greater_equal"}) { + for (auto filter_rhs : {literal(5), literal(3), literal(7)}) { + auto filter = call(filter_op, {i32, filter_rhs}); + for (std::string guarantee_op : + {"equal", "less", "less_equal", "greater", "greater_equal"}) { + auto guarantee = call(guarantee_op, {i32, literal(5)}); + + // generate data which satisfies the guarantee + static std::unordered_map satisfying_i32{ + {"equal", "[5]"}, + {"less", "[4, 3, 2, 1]"}, + {"less_equal", "[5, 4, 3, 2, 1]"}, + {"greater", "[6, 7, 8, 9]"}, + {"greater_equal", "[5, 6, 7, 8, 9]"}, + }; + + ASSERT_OK_AND_ASSIGN( + Datum input, + StructArray::Make({ArrayFromJSON(int32(), satisfying_i32[guarantee_op])}, + {"i32"})); + + std::shared_ptr state; + ASSERT_OK_AND_ASSIGN(std::tie(filter, state), filter.Bind(kBoringSchema)); + ASSERT_OK_AND_ASSIGN(Datum evaluated, + ExecuteScalarExpression(filter, state.get(), input)); + + // ensure that the simplified filter is as simplified as it could be + // (this is always possible for single comparisons) + bool all = true, none = true; + for (int64_t i = 0; i < input.length(); ++i) { + if (evaluated.array_as()->Value(i)) { + none = false; + } else { + all = false; + } + } + Simplify{filter}.WithGuarantee(guarantee).Expect( + all ? literal(true) : none ? literal(false) : filter); + } + } + } +} + +TEST(Expression2, SimplifyWithGuarantee) { + // drop both members of a conjunctive filter + Simplify{call("and_kleene", + { + call("equal", {field_ref("i32"), literal(2)}), + call("equal", {field_ref("f32"), literal(3.5F)}), + })} + .WithGuarantee(call("and_kleene", + { + call("greater_equal", {field_ref("i32"), literal(0)}), + call("less_equal", {field_ref("i32"), literal(1)}), + })) + .Expect(false); + + // drop one member of a conjunctive filter + Simplify{call("and_kleene", + { + call("equal", {field_ref("i32"), literal(0)}), + call("equal", {field_ref("f32"), literal(3.5F)}), + })} + .WithGuarantee(call("equal", {field_ref("i32"), literal(0)})) + .Expect(call("equal", {field_ref("f32"), literal(3.5F)})); + + // drop both members of a disjunctive filter + Simplify{call("or_kleene", + { + call("equal", {field_ref("i32"), literal(0)}), + call("equal", {field_ref("f32"), literal(3.5F)}), + })} + .WithGuarantee(call("equal", {field_ref("i32"), literal(0)})) + .Expect(true); + + // drop one member of a disjunctive filter + Simplify{call("or_kleene", + { + call("equal", {field_ref("i32"), literal(0)}), + call("equal", {field_ref("i32"), literal(3)}), + })} + .WithGuarantee(call("and_kleene", + { + call("greater_equal", {field_ref("i32"), literal(0)}), + call("less_equal", {field_ref("i32"), literal(1)}), + })) + .Expect(call("equal", {field_ref("i32"), literal(0)})); +} + +} // namespace dataset +} // namespace arrow diff --git a/cpp/src/arrow/dataset/partition.cc b/cpp/src/arrow/dataset/partition.cc index 159e0ac0331..a39a8adaae0 100644 --- a/cpp/src/arrow/dataset/partition.cc +++ b/cpp/src/arrow/dataset/partition.cc @@ -18,31 +18,20 @@ #include "arrow/dataset/partition.h" #include -#include #include -#include -#include #include #include #include "arrow/array/array_base.h" #include "arrow/array/array_nested.h" -#include "arrow/array/builder_binary.h" #include "arrow/array/builder_dict.h" #include "arrow/compute/api_scalar.h" #include "arrow/dataset/dataset_internal.h" -#include "arrow/dataset/file_base.h" #include "arrow/dataset/filter.h" -#include "arrow/dataset/scanner.h" -#include "arrow/dataset/scanner_internal.h" -#include "arrow/filesystem/filesystem.h" #include "arrow/filesystem/path_util.h" #include "arrow/scalar.h" -#include "arrow/util/iterator.h" #include "arrow/util/logging.h" #include "arrow/util/make_unique.h" -#include "arrow/util/range.h" -#include "arrow/util/sort.h" #include "arrow/util/string_view.h" namespace arrow { diff --git a/cpp/src/arrow/dataset/partition_test.cc b/cpp/src/arrow/dataset/partition_test.cc index f49103a585a..b47f5ba9926 100644 --- a/cpp/src/arrow/dataset/partition_test.cc +++ b/cpp/src/arrow/dataset/partition_test.cc @@ -21,20 +21,16 @@ #include #include -#include #include #include #include #include -#include "arrow/dataset/file_base.h" #include "arrow/dataset/scanner_internal.h" #include "arrow/dataset/test_util.h" -#include "arrow/filesystem/localfs.h" #include "arrow/filesystem/path_util.h" #include "arrow/status.h" #include "arrow/testing/gtest_util.h" -#include "arrow/util/io_util.h" namespace arrow { using internal::checked_pointer_cast; diff --git a/cpp/src/arrow/dataset/test_util.h b/cpp/src/arrow/dataset/test_util.h index f504305a996..9cfba41a624 100644 --- a/cpp/src/arrow/dataset/test_util.h +++ b/cpp/src/arrow/dataset/test_util.h @@ -22,6 +22,7 @@ #include #include #include +#include #include #include #include diff --git a/cpp/src/arrow/dataset/type_fwd.h b/cpp/src/arrow/dataset/type_fwd.h index 0ff77de0102..522def49818 100644 --- a/cpp/src/arrow/dataset/type_fwd.h +++ b/cpp/src/arrow/dataset/type_fwd.h @@ -76,6 +76,9 @@ class ExpressionEvaluator; ARROW_DS_EXPORT const std::shared_ptr& scalar(bool); +struct ExpressionState; +class Expression2; + class Partitioning; class PartitioningFactory; class PartitioningOrFactory; diff --git a/cpp/src/arrow/datum.cc b/cpp/src/arrow/datum.cc index 5feed556207..c1666adc4c5 100644 --- a/cpp/src/arrow/datum.cc +++ b/cpp/src/arrow/datum.cc @@ -68,6 +68,8 @@ Datum::Datum(int64_t value) : value(std::make_shared(value)) {} Datum::Datum(uint64_t value) : value(std::make_shared(value)) {} Datum::Datum(float value) : value(std::make_shared(value)) {} Datum::Datum(double value) : value(std::make_shared(value)) {} +Datum::Datum(std::string value) + : value(std::make_shared(std::move(value))) {} Datum::Datum(const ChunkedArray& value) : value(std::make_shared(value.chunks(), value.type())) {} @@ -238,4 +240,17 @@ ValueDescr::Shape GetBroadcastShape(const std::vector& args) { return ValueDescr::SCALAR; } +void PrintTo(const Datum& datum, std::ostream* os) { + switch (datum.kind()) { + case Datum::SCALAR: + *os << datum.scalar()->ToString(); + break; + case Datum::ARRAY: + *os << datum.make_array()->ToString(); + break; + default: + *os << datum.ToString(); + } +} + } // namespace arrow diff --git a/cpp/src/arrow/datum.h b/cpp/src/arrow/datum.h index 09dc870f687..5f80551f337 100644 --- a/cpp/src/arrow/datum.h +++ b/cpp/src/arrow/datum.h @@ -81,7 +81,9 @@ struct ARROW_EXPORT ValueDescr { } bool operator==(const ValueDescr& other) const { - return this->shape == other.shape && this->type->Equals(*other.type); + if (shape != other.shape) return false; + if (type == other.type) return true; + return type && type->Equals(other.type); } bool operator!=(const ValueDescr& other) const { return !(*this == other); } @@ -114,6 +116,11 @@ struct ARROW_EXPORT Datum { /// \brief Empty datum, to be populated elsewhere Datum() = default; + Datum(const Datum& other) noexcept = default; + Datum& operator=(const Datum& other) noexcept = default; + Datum(Datum&& other) noexcept = default; + Datum& operator=(Datum&& other) noexcept = default; + Datum(std::shared_ptr value) // NOLINT implicit conversion : value(std::move(value)) {} @@ -153,21 +160,7 @@ struct ARROW_EXPORT Datum { explicit Datum(uint64_t value); explicit Datum(float value); explicit Datum(double value); - - Datum(const Datum& other) noexcept { this->value = other.value; } - - Datum& operator=(const Datum& other) noexcept { - value = other.value; - return *this; - } - - // Define move constructor and move assignment, for better performance - Datum(Datum&& other) noexcept : value(std::move(other.value)) {} - - Datum& operator=(Datum&& other) noexcept { - value = std::move(other.value); - return *this; - } + explicit Datum(std::string value); Datum::Kind kind() const { switch (this->value.index()) { @@ -218,6 +211,11 @@ struct ARROW_EXPORT Datum { return util::get>(this->value); } + template + std::shared_ptr array_as() const { + return internal::checked_pointer_cast(this->make_array()); + } + template const ExactType& scalar_as() const { return internal::checked_cast(*this->scalar()); @@ -267,6 +265,8 @@ struct ARROW_EXPORT Datum { bool operator!=(const Datum& other) const { return !Equals(other); } std::string ToString() const; + + ARROW_EXPORT friend void PrintTo(const Datum&, std::ostream*); }; } // namespace arrow diff --git a/cpp/src/arrow/result.h b/cpp/src/arrow/result.h index eac08919286..09dfd59c8d2 100644 --- a/cpp/src/arrow/result.h +++ b/cpp/src/arrow/result.h @@ -401,6 +401,28 @@ class ARROW_MUST_USE_TYPE Result : public util::EqualityComparable> { return std::forward(m)(ValueUnsafe()); } + /// Cast the internally stored value to produce a new result or propagate the stored + /// error. + template ::value>::type> + Result As() && { + if (!ok()) { + return status(); + } + return U(MoveValueUnsafe()); + } + + /// Cast the internally stored value to produce a new result or propagate the stored + /// error. + template ::value>::type> + Result As() const& { + if (!ok()) { + return status(); + } + return U(ValueUnsafe()); + } + const T& ValueUnsafe() const& { return *internal::launder(reinterpret_cast(&data_)); } diff --git a/cpp/src/arrow/type.cc b/cpp/src/arrow/type.cc index 5e6fec82cb3..5604c5e2717 100644 --- a/cpp/src/arrow/type.cc +++ b/cpp/src/arrow/type.cc @@ -885,12 +885,15 @@ size_t FieldPath::hash() const { } std::string FieldPath::ToString() const { + if (this->indices().empty()) { + return "FieldPath(empty)"; + } + std::string repr = "FieldPath("; for (auto index : this->indices()) { repr += std::to_string(index) + " "; } - repr.resize(repr.size() - 1); - repr += ")"; + repr.back() = ')'; return repr; } @@ -1039,13 +1042,31 @@ Result> FieldPath::Get(const FieldVector& fields) const { Result> FieldPath::Get(const RecordBatch& batch) const { ARROW_ASSIGN_OR_RAISE(auto data, FieldPathGetImpl::Get(this, batch.column_data())); - return MakeArray(data); + return MakeArray(std::move(data)); } Result> FieldPath::Get(const Table& table) const { return FieldPathGetImpl::Get(this, table.columns()); } +Result> FieldPath::Get(const Array& array) const { + ARROW_ASSIGN_OR_RAISE(auto data, Get(*array.data())); + return MakeArray(std::move(data)); +} + +Result> FieldPath::Get(const ArrayData& data) const { + return FieldPathGetImpl::Get(this, data.child_data); +} + +Result> FieldPath::Get(const ChunkedArray& array) const { + FieldPath prefixed_with_0 = *this; + prefixed_with_0.indices_.insert(prefixed_with_0.indices_.begin(), 0); + + ChunkedArrayVector vec; + vec.emplace_back(const_cast(&array), [](...) {}); + return FieldPathGetImpl::Get(&prefixed_with_0, vec); +} + FieldRef::FieldRef(FieldPath indices) : impl_(std::move(indices)) { DCHECK_GT(util::get(impl_).indices().size(), 0); } @@ -1291,6 +1312,10 @@ std::vector FieldRef::FindAll(const FieldVector& fields) const { return util::visit(Visitor{fields}, impl_); } +std::vector FieldRef::FindAll(const ArrayData& array) const { + return FindAll(*array.type); +} + std::vector FieldRef::FindAll(const Array& array) const { return FindAll(*array.type()); } @@ -1333,7 +1358,7 @@ Schema::Schema(std::vector> fields, Schema::Schema(const Schema& schema) : detail::Fingerprintable(), impl_(new Impl(*schema.impl_)) {} -Schema::~Schema() {} +Schema::~Schema() = default; int Schema::num_fields() const { return static_cast(impl_->fields_.size()); } diff --git a/cpp/src/arrow/type.h b/cpp/src/arrow/type.h index ed504c3afe5..3594c248e10 100644 --- a/cpp/src/arrow/type.h +++ b/cpp/src/arrow/type.h @@ -1402,6 +1402,9 @@ class ARROW_EXPORT FieldPath { std::string ToString() const; size_t hash() const; + struct Hash { + size_t operator()(const FieldPath& path) const { return path.hash(); } + }; explicit operator bool() const { return !indices_.empty(); } bool operator!() const { return indices_.empty(); } @@ -1423,10 +1426,14 @@ class ARROW_EXPORT FieldPath { Result> Get(const RecordBatch& batch) const; Result> Get(const Table& table) const; - /// \brief Retrieve the referenced child Array from an Array or ChunkedArray + /// \brief Retrieve the referenced child from an Array, ArrayData, or ChunkedArray Result> Get(const Array& array) const; + Result> Get(const ArrayData& data) const; Result> Get(const ChunkedArray& array) const; + /// \brief Retrieve the reference child from a Datum + Result Get(const Datum& datum) const; + private: std::vector indices_; }; @@ -1517,6 +1524,12 @@ class ARROW_EXPORT FieldRef { std::string ToString() const; size_t hash() const; + struct Hash { + size_t operator()(const FieldRef& ref) const { return ref.hash(); } + }; + + explicit operator bool() const { return Equals(FieldPath{}); } + bool operator!() const { return !Equals(FieldPath{}); } bool IsFieldPath() const { return util::holds_alternative(impl_); } bool IsName() const { return util::holds_alternative(impl_); } @@ -1526,6 +1539,13 @@ class ARROW_EXPORT FieldRef { return true; } + const FieldPath* field_path() const { + return IsFieldPath() ? &util::get(impl_) : NULLPTR; + } + const std::string* name() const { + return IsName() ? &util::get(impl_) : NULLPTR; + } + /// \brief Retrieve FieldPath of every child field which matches this FieldRef. std::vector FindAll(const Schema& schema) const; std::vector FindAll(const Field& field) const; @@ -1533,6 +1553,7 @@ class ARROW_EXPORT FieldRef { std::vector FindAll(const FieldVector& fields) const; /// \brief Convenience function which applies FindAll to arg's type or schema. + std::vector FindAll(const ArrayData& array) const; std::vector FindAll(const Array& array) const; std::vector FindAll(const ChunkedArray& array) const; std::vector FindAll(const RecordBatch& batch) const; @@ -1609,7 +1630,7 @@ class ARROW_EXPORT FieldRef { if (match) { return match.Get(root).ValueOrDie(); } - return NULLPTR; + return GetType(NULLPTR); } private: diff --git a/cpp/src/arrow/type_fwd.h b/cpp/src/arrow/type_fwd.h index cf90953527e..f1000d1fe7f 100644 --- a/cpp/src/arrow/type_fwd.h +++ b/cpp/src/arrow/type_fwd.h @@ -71,6 +71,8 @@ class RecordBatch; class RecordBatchReader; class Table; +struct Datum; + using ChunkedArrayVector = std::vector>; using RecordBatchVector = std::vector>; using RecordBatchIterator = Iterator>; From 1152d2c2da6e001ca8bf3dbc704b2f6101135add Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Tue, 1 Dec 2020 13:28:17 -0500 Subject: [PATCH 02/31] replace filtering with Expression2 --- .../arrow/dataset-parquet-scan-example.cc | 11 +- .../compute/kernels/scalar_cast_internal.cc | 4 +- cpp/src/arrow/dataset/dataset.cc | 37 +- cpp/src/arrow/dataset/dataset.h | 34 +- cpp/src/arrow/dataset/dataset_internal.h | 8 +- cpp/src/arrow/dataset/dataset_test.cc | 3 +- cpp/src/arrow/dataset/discovery.cc | 2 +- cpp/src/arrow/dataset/discovery.h | 8 +- cpp/src/arrow/dataset/discovery_test.cc | 10 +- cpp/src/arrow/dataset/expression.cc | 804 ++++++++++++------ cpp/src/arrow/dataset/expression.h | 100 ++- cpp/src/arrow/dataset/expression_internal.h | 318 ++++++- cpp/src/arrow/dataset/expression_test.cc | 593 ++++++++----- cpp/src/arrow/dataset/file_base.cc | 36 +- cpp/src/arrow/dataset/file_base.h | 17 +- cpp/src/arrow/dataset/file_csv.cc | 14 +- cpp/src/arrow/dataset/file_ipc_test.cc | 4 +- cpp/src/arrow/dataset/file_parquet.cc | 78 +- cpp/src/arrow/dataset/file_parquet.h | 14 +- cpp/src/arrow/dataset/file_parquet_test.cc | 57 +- cpp/src/arrow/dataset/file_test.cc | 122 ++- cpp/src/arrow/dataset/filter.cc | 41 +- cpp/src/arrow/dataset/filter.h | 55 -- cpp/src/arrow/dataset/filter_test.cc | 236 +---- cpp/src/arrow/dataset/partition.cc | 127 ++- cpp/src/arrow/dataset/partition.h | 33 +- cpp/src/arrow/dataset/partition_test.cc | 107 ++- cpp/src/arrow/dataset/scanner.cc | 48 +- cpp/src/arrow/dataset/scanner.h | 16 +- cpp/src/arrow/dataset/scanner_internal.h | 96 +-- cpp/src/arrow/dataset/scanner_test.cc | 45 +- cpp/src/arrow/dataset/test_util.h | 122 +-- cpp/src/arrow/dataset/type_fwd.h | 6 - cpp/src/arrow/datum.cc | 3 + cpp/src/arrow/datum.h | 3 + cpp/src/arrow/type.cc | 2 +- cpp/src/arrow/type.h | 2 +- cpp/src/arrow/util/variant.h | 16 + 38 files changed, 1962 insertions(+), 1270 deletions(-) diff --git a/cpp/examples/arrow/dataset-parquet-scan-example.cc b/cpp/examples/arrow/dataset-parquet-scan-example.cc index 3cdd298b8fd..5b933d3ca62 100644 --- a/cpp/examples/arrow/dataset-parquet-scan-example.cc +++ b/cpp/examples/arrow/dataset-parquet-scan-example.cc @@ -37,8 +37,6 @@ namespace fs = arrow::fs; namespace ds = arrow::dataset; -using ds::string_literals::operator"" _; - #define ABORT_ON_FAILURE(expr) \ do { \ arrow::Status status_ = (expr); \ @@ -62,7 +60,8 @@ struct Configuration { // Indicates the filter by which rows will be filtered. This optimization can // make use of partition information and/or file metadata if possible. - std::shared_ptr filter = ("total_amount"_ > 1000.0f).Copy(); + ds::Expression2 filter = + ds::greater(ds::field_ref("total_amount"), ds::literal(1000.0f)); ds::InspectOptions inspect_options{}; ds::FinishOptions finish_options{}; @@ -147,7 +146,7 @@ std::shared_ptr GetDatasetFromPath( std::shared_ptr GetScannerFromDataset(std::shared_ptr dataset, std::vector columns, - std::shared_ptr filter, + ds::Expression2 filter, bool use_threads) { auto scanner_builder = dataset->NewScan().ValueOrDie(); @@ -155,9 +154,7 @@ std::shared_ptr GetScannerFromDataset(std::shared_ptr ABORT_ON_FAILURE(scanner_builder->Project(columns)); } - if (filter != nullptr) { - ABORT_ON_FAILURE(scanner_builder->Filter(filter)); - } + ABORT_ON_FAILURE(scanner_builder->Filter(filter)); ABORT_ON_FAILURE(scanner_builder->UseThreads(use_threads)); diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_internal.cc b/cpp/src/arrow/compute/kernels/scalar_cast_internal.cc index cd33de672a6..67f0820402a 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_internal.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_internal.cc @@ -191,6 +191,8 @@ void CastFromExtension(KernelContext* ctx, const ExecBatch& batch, Datum* out) { } void CastFromNull(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + if (batch[0].is_scalar()) return; + ArrayData* output = out->mutable_array(); std::shared_ptr nulls; Status s = MakeArrayOfNull(output->type, batch.length).Value(&nulls); @@ -251,7 +253,7 @@ static bool CanCastFromDictionary(Type::type type_id) { void AddCommonCasts(Type::type out_type_id, OutputType out_ty, CastFunction* func) { // From null to this type - DCHECK_OK(func->AddKernel(Type::NA, {InputType::Array(null())}, out_ty, CastFromNull)); + DCHECK_OK(func->AddKernel(Type::NA, {null()}, out_ty, CastFromNull)); // From dictionary to this type if (CanCastFromDictionary(out_type_id)) { diff --git a/cpp/src/arrow/dataset/dataset.cc b/cpp/src/arrow/dataset/dataset.cc index 71755aaf566..56c901f959d 100644 --- a/cpp/src/arrow/dataset/dataset.cc +++ b/cpp/src/arrow/dataset/dataset.cc @@ -32,12 +32,10 @@ namespace arrow { namespace dataset { -Fragment::Fragment(std::shared_ptr partition_expression, +Fragment::Fragment(Expression2 partition_expression, std::shared_ptr physical_schema) : partition_expression_(std::move(partition_expression)), - physical_schema_(std::move(physical_schema)) { - DCHECK_NE(partition_expression_, nullptr); -} + physical_schema_(std::move(physical_schema)) {} Result> Fragment::ReadPhysicalSchema() { { @@ -61,14 +59,14 @@ Result> InMemoryFragment::ReadPhysicalSchemaImpl() { InMemoryFragment::InMemoryFragment(std::shared_ptr schema, RecordBatchVector record_batches, - std::shared_ptr partition_expression) + Expression2 partition_expression) : Fragment(std::move(partition_expression), std::move(schema)), record_batches_(std::move(record_batches)) { DCHECK_NE(physical_schema_, nullptr); } InMemoryFragment::InMemoryFragment(RecordBatchVector record_batches, - std::shared_ptr partition_expression) + Expression2 partition_expression) : InMemoryFragment(record_batches.empty() ? schema({}) : record_batches[0]->schema(), std::move(record_batches), std::move(partition_expression)) {} @@ -95,11 +93,9 @@ Result InMemoryFragment::Scan(std::shared_ptr opt return MakeMapIterator(fn, std::move(batches_it)); } -Dataset::Dataset(std::shared_ptr schema, - std::shared_ptr partition_expression) - : schema_(std::move(schema)), partition_expression_(std::move(partition_expression)) { - DCHECK_NE(partition_expression_, nullptr); -} +Dataset::Dataset(std::shared_ptr schema, Expression2 partition_expression) + : schema_(std::move(schema)), + partition_expression_(std::move(partition_expression)) {} Result> Dataset::NewScan( std::shared_ptr context) { @@ -110,10 +106,16 @@ Result> Dataset::NewScan() { return NewScan(std::make_shared()); } -FragmentIterator Dataset::GetFragments(std::shared_ptr predicate) { - predicate = predicate->Assume(*partition_expression_); - return predicate->IsSatisfiable() ? GetFragmentsImpl(std::move(predicate)) - : MakeEmptyIterator>(); +Result Dataset::GetFragments() { + ARROW_ASSIGN_OR_RAISE(auto predicate, literal(true).Bind(*schema_)); + return GetFragments(std::move(predicate)); +} + +Result Dataset::GetFragments(Expression2::BoundWithState predicate) { + ARROW_ASSIGN_OR_RAISE( + predicate, SimplifyWithGuarantee(std::move(predicate), partition_expression_)); + return predicate.first.IsSatisfiable() ? GetFragmentsImpl(std::move(predicate)) + : MakeEmptyIterator>(); } struct VectorRecordBatchGenerator : InMemoryDataset::RecordBatchGenerator { @@ -153,7 +155,7 @@ Result> InMemoryDataset::ReplaceSchema( return std::make_shared(std::move(schema), get_batches_); } -FragmentIterator InMemoryDataset::GetFragmentsImpl(std::shared_ptr) { +Result InMemoryDataset::GetFragmentsImpl(Expression2::BoundWithState) { auto schema = this->schema(); auto create_fragment = @@ -194,7 +196,8 @@ Result> UnionDataset::ReplaceSchema( new UnionDataset(std::move(schema), std::move(children))); } -FragmentIterator UnionDataset::GetFragmentsImpl(std::shared_ptr predicate) { +Result UnionDataset::GetFragmentsImpl( + Expression2::BoundWithState predicate) { return GetFragmentsFromDatasets(children_, predicate); } diff --git a/cpp/src/arrow/dataset/dataset.h b/cpp/src/arrow/dataset/dataset.h index be138dd277d..35c91b79fed 100644 --- a/cpp/src/arrow/dataset/dataset.h +++ b/cpp/src/arrow/dataset/dataset.h @@ -25,6 +25,7 @@ #include #include +#include "arrow/dataset/expression.h" #include "arrow/dataset/type_fwd.h" #include "arrow/dataset/visibility.h" #include "arrow/util/macros.h" @@ -71,21 +72,19 @@ class ARROW_DS_EXPORT Fragment { /// \brief An expression which evaluates to true for all data viewed by this /// Fragment. - const std::shared_ptr& partition_expression() const { - return partition_expression_; - } + const Expression2& partition_expression() const { return partition_expression_; } virtual ~Fragment() = default; protected: Fragment() = default; - explicit Fragment(std::shared_ptr partition_expression, + explicit Fragment(Expression2 partition_expression, std::shared_ptr physical_schema); virtual Result> ReadPhysicalSchemaImpl() = 0; util::Mutex physical_schema_mutex_; - std::shared_ptr partition_expression_ = scalar(true); + Expression2 partition_expression_ = literal(true); std::shared_ptr physical_schema_; }; @@ -94,9 +93,9 @@ class ARROW_DS_EXPORT Fragment { class ARROW_DS_EXPORT InMemoryFragment : public Fragment { public: InMemoryFragment(std::shared_ptr schema, RecordBatchVector record_batches, - std::shared_ptr = scalar(true)); + Expression2 = literal(true)); explicit InMemoryFragment(RecordBatchVector record_batches, - std::shared_ptr = scalar(true)); + Expression2 = literal(true)); Result Scan(std::shared_ptr options, std::shared_ptr context) override; @@ -123,15 +122,14 @@ class ARROW_DS_EXPORT Dataset : public std::enable_shared_from_this { Result> NewScan(); /// \brief GetFragments returns an iterator of Fragments given a predicate. - FragmentIterator GetFragments(std::shared_ptr predicate = scalar(true)); + Result GetFragments(Expression2::BoundWithState predicate); + Result GetFragments(); const std::shared_ptr& schema() const { return schema_; } /// \brief An expression which evaluates to true for all data viewed by this Dataset. /// May be null, which indicates no information is available. - const std::shared_ptr& partition_expression() const { - return partition_expression_; - } + const Expression2& partition_expression() const { return partition_expression_; } /// \brief The name identifying the kind of Dataset virtual std::string type_name() const = 0; @@ -148,13 +146,13 @@ class ARROW_DS_EXPORT Dataset : public std::enable_shared_from_this { protected: explicit Dataset(std::shared_ptr schema) : schema_(std::move(schema)) {} - Dataset(std::shared_ptr schema, - std::shared_ptr partition_expression); + Dataset(std::shared_ptr schema, Expression2 partition_expression); - virtual FragmentIterator GetFragmentsImpl(std::shared_ptr predicate) = 0; + virtual Result GetFragmentsImpl( + Expression2::BoundWithState predicate) = 0; std::shared_ptr schema_; - std::shared_ptr partition_expression_ = scalar(true); + Expression2 partition_expression_ = literal(true); }; /// \brief A Source which yields fragments wrapping a stream of record batches. @@ -183,7 +181,8 @@ class ARROW_DS_EXPORT InMemoryDataset : public Dataset { std::shared_ptr schema) const override; protected: - FragmentIterator GetFragmentsImpl(std::shared_ptr predicate) override; + Result GetFragmentsImpl( + Expression2::BoundWithState predicate) override; std::shared_ptr get_batches_; }; @@ -207,7 +206,8 @@ class ARROW_DS_EXPORT UnionDataset : public Dataset { std::shared_ptr schema) const override; protected: - FragmentIterator GetFragmentsImpl(std::shared_ptr predicate) override; + Result GetFragmentsImpl( + Expression2::BoundWithState predicate) override; explicit UnionDataset(std::shared_ptr schema, DatasetVector children) : Dataset(std::move(schema)), children_(std::move(children)) {} diff --git a/cpp/src/arrow/dataset/dataset_internal.h b/cpp/src/arrow/dataset/dataset_internal.h index 489339e7907..a6ee8a117bc 100644 --- a/cpp/src/arrow/dataset/dataset_internal.h +++ b/cpp/src/arrow/dataset/dataset_internal.h @@ -35,18 +35,18 @@ namespace dataset { /// \brief GetFragmentsFromDatasets transforms a vector into a /// flattened FragmentIterator. -inline FragmentIterator GetFragmentsFromDatasets(const DatasetVector& datasets, - std::shared_ptr predicate) { +inline Result GetFragmentsFromDatasets( + const DatasetVector& datasets, Expression2::BoundWithState predicate) { // Iterator auto datasets_it = MakeVectorIterator(datasets); // Dataset -> Iterator - auto fn = [predicate](std::shared_ptr dataset) -> FragmentIterator { + auto fn = [predicate](std::shared_ptr dataset) -> Result { return dataset->GetFragments(predicate); }; // Iterator> - auto fragments_it = MakeMapIterator(fn, std::move(datasets_it)); + auto fragments_it = MakeMaybeMapIterator(fn, std::move(datasets_it)); // Iterator return MakeFlattenIterator(std::move(fragments_it)); diff --git a/cpp/src/arrow/dataset/dataset_test.cc b/cpp/src/arrow/dataset/dataset_test.cc index 0b138caafe7..58a4a4a0b3f 100644 --- a/cpp/src/arrow/dataset/dataset_test.cc +++ b/cpp/src/arrow/dataset/dataset_test.cc @@ -505,7 +505,8 @@ TEST_F(TestEndToEnd, EndToEndSingleDataset) { // The following filter tests both predicate pushdown and post filtering // without partition information because `year` is a partition and `sales` is // not. - auto filter = ("year"_ == 2019 && "sales"_ > 100.0); + auto filter = and_(equal(field_ref("year"), literal(2019)), + greater(field_ref("sales"), literal(100.0))); ASSERT_OK(scanner_builder->Filter(filter)); ASSERT_OK_AND_ASSIGN(auto scanner, scanner_builder->Finish()); diff --git a/cpp/src/arrow/dataset/discovery.cc b/cpp/src/arrow/dataset/discovery.cc index b7af6dcbf0f..079cc6d4065 100644 --- a/cpp/src/arrow/dataset/discovery.cc +++ b/cpp/src/arrow/dataset/discovery.cc @@ -35,7 +35,7 @@ namespace arrow { namespace dataset { -DatasetFactory::DatasetFactory() : root_partition_(scalar(true)) {} +DatasetFactory::DatasetFactory() : root_partition_(literal(true)) {} Result> DatasetFactory::Inspect(InspectOptions options) { ARROW_ASSIGN_OR_RAISE(auto schemas, InspectSchemas(std::move(options))); diff --git a/cpp/src/arrow/dataset/discovery.h b/cpp/src/arrow/dataset/discovery.h index 5176f4f53a7..0b1c94f5adc 100644 --- a/cpp/src/arrow/dataset/discovery.h +++ b/cpp/src/arrow/dataset/discovery.h @@ -90,9 +90,9 @@ class ARROW_DS_EXPORT DatasetFactory { virtual Result> Finish(FinishOptions options) = 0; /// \brief Optional root partition for the resulting Dataset. - const std::shared_ptr& root_partition() const { return root_partition_; } - Status SetRootPartition(std::shared_ptr partition) { - root_partition_ = partition; + const Expression2& root_partition() const { return root_partition_; } + Status SetRootPartition(Expression2 partition) { + root_partition_ = std::move(partition); return Status::OK(); } @@ -101,7 +101,7 @@ class ARROW_DS_EXPORT DatasetFactory { protected: DatasetFactory(); - std::shared_ptr root_partition_; + Expression2 root_partition_; }; /// \brief DatasetFactory provides a way to inspect/discover a Dataset's diff --git a/cpp/src/arrow/dataset/discovery_test.cc b/cpp/src/arrow/dataset/discovery_test.cc index bc330edb53b..cd0ba2bee4d 100644 --- a/cpp/src/arrow/dataset/discovery_test.cc +++ b/cpp/src/arrow/dataset/discovery_test.cc @@ -139,7 +139,8 @@ class FileSystemDatasetFactoryTest : public DatasetFactoryTest { } options_ = ScanOptions::Make(schema); ASSERT_OK_AND_ASSIGN(dataset_, factory_->Finish(schema)); - AssertFragmentsAreFromPath(dataset_->GetFragments(), paths); + ASSERT_OK_AND_ASSIGN(auto fragment_it, dataset_->GetFragments()); + AssertFragmentsAreFromPath(std::move(fragment_it), paths); } protected: @@ -368,10 +369,13 @@ TEST_F(FileSystemDatasetFactoryTest, FilenameNotPartOfPartitions) { // column. In such case, the filename should not be used. MakeFactory({fs::File("one/file.parquet")}); + ASSERT_OK_AND_ASSIGN(auto expected, equal(field_ref("first"), literal("one")).Bind(*s)); + ASSERT_OK_AND_ASSIGN(auto dataset, factory_->Finish()); - for (const auto& maybe_fragment : dataset->GetFragments()) { + ASSERT_OK_AND_ASSIGN(auto fragment_it, dataset->GetFragments()); + for (const auto& maybe_fragment : fragment_it) { ASSERT_OK_AND_ASSIGN(auto fragment, maybe_fragment); - ASSERT_TRUE(fragment->partition_expression()->Equals(("first"_ == "one"))); + EXPECT_EQ(fragment->partition_expression(), expected.first); } } diff --git a/cpp/src/arrow/dataset/expression.cc b/cpp/src/arrow/dataset/expression.cc index e0a44a34ab8..c762cb89e0f 100644 --- a/cpp/src/arrow/dataset/expression.cc +++ b/cpp/src/arrow/dataset/expression.cc @@ -24,9 +24,15 @@ #include "arrow/compute/exec_internal.h" #include "arrow/compute/registry.h" #include "arrow/dataset/expression_internal.h" +#include "arrow/dataset/filter.h" +#include "arrow/io/memory.h" +#include "arrow/ipc/reader.h" +#include "arrow/ipc/writer.h" +#include "arrow/util/key_value_metadata.h" #include "arrow/util/logging.h" #include "arrow/util/optional.h" #include "arrow/util/string.h" +#include "arrow/util/value_parsing.h" namespace arrow { @@ -45,6 +51,144 @@ const FieldRef* Expression2::field_ref() const { return util::get_if(impl_.get()); } +Expression2::operator std::shared_ptr() const { + if (auto lit = literal()) { + DCHECK(lit->is_scalar()); + return std::make_shared(lit->scalar()); + } + + if (auto ref = field_ref()) { + DCHECK(ref->name()); + return std::make_shared(*ref->name()); + } + + auto call = CallNotNull(*this); + if (call->function == "invert") { + return std::make_shared(call->arguments[0]); + } + + if (call->function == "cast") { + const auto& options = checked_cast(*call->options); + return std::make_shared(call->arguments[0], options.to_type, options); + } + + if (call->function == "and_kleene") { + return std::make_shared(call->arguments[0], call->arguments[1]); + } + + if (call->function == "or_kleene") { + return std::make_shared(call->arguments[0], call->arguments[1]); + } + + if (auto cmp = Comparison::Get(call->function)) { + compute::CompareOperator op = [&] { + switch (*cmp) { + case Comparison::EQUAL: + return compute::EQUAL; + case Comparison::LESS: + return compute::LESS; + case Comparison::GREATER: + return compute::GREATER; + case Comparison::NOT_EQUAL: + return compute::NOT_EQUAL; + case Comparison::LESS_EQUAL: + return compute::LESS_EQUAL; + case Comparison::GREATER_EQUAL: + return compute::GREATER_EQUAL; + default: + break; + } + return static_cast(-1); + }(); + + return std::make_shared(op, call->arguments[0], + call->arguments[1]); + } + + if (call->function == "is_valid") { + return std::make_shared(call->arguments[0]); + } + + if (call->function == "is_in") { + auto set = checked_cast(*call->options) + .value_set.make_array(); + return std::make_shared(call->arguments[0], std::move(set)); + } + + DCHECK(false) << "untranslatable Expression2: " << ToString(); + return nullptr; +} + +Expression2::Expression2(const Expression& expr) { + switch (expr.type()) { + case ExpressionType::FIELD: + *this = + ::arrow::dataset::field_ref(checked_cast(expr).name()); + return; + + case ExpressionType::SCALAR: + *this = + ::arrow::dataset::literal(checked_cast(expr).value()); + return; + + case ExpressionType::NOT: + *this = ::arrow::dataset::call( + "invert", {checked_cast(expr).operand()}); + return; + + case ExpressionType::CAST: { + const auto& cast_expr = checked_cast(expr); + auto options = cast_expr.options(); + options.to_type = cast_expr.to_type(); + *this = ::arrow::dataset::call("cast", {cast_expr.operand()}, std::move(options)); + return; + } + + case ExpressionType::AND: { + const auto& and_expr = checked_cast(expr); + *this = ::arrow::dataset::call("and_kleene", + {and_expr.left_operand(), and_expr.right_operand()}); + return; + } + + case ExpressionType::OR: { + const auto& or_expr = checked_cast(expr); + *this = ::arrow::dataset::call("or_kleene", + {or_expr.left_operand(), or_expr.right_operand()}); + return; + } + + case ExpressionType::COMPARISON: { + const auto& cmp_expr = checked_cast(expr); + static std::array ops = { + "equal", "not_equal", "greater", "greater_equal", "less", "less_equal", + }; + *this = ::arrow::dataset::call(ops[cmp_expr.op()], + {cmp_expr.left_operand(), cmp_expr.right_operand()}); + return; + } + + case ExpressionType::IS_VALID: { + const auto& is_valid_expr = checked_cast(expr); + *this = ::arrow::dataset::call("is_valid", {is_valid_expr.operand()}); + return; + } + + case ExpressionType::IN: { + const auto& in_expr = checked_cast(expr); + *this = ::arrow::dataset::call( + "is_in", {in_expr.operand()}, + compute::SetLookupOptions{in_expr.set(), /*skip_nulls=*/true}); + return; + } + + default: + break; + } + + DCHECK(false) << "untranslatable Expression: " << expr.ToString(); +} + std::string Expression2::ToString() const { if (auto lit = literal()) { if (lit->is_scalar()) { @@ -74,7 +218,12 @@ std::string Expression2::ToString() const { return out; } -void PrintTo(const Expression2& expr, std::ostream* os) { *os << expr.ToString(); } +void PrintTo(const Expression2& expr, std::ostream* os) { + *os << expr.ToString(); + if (expr.IsBound()) { + *os << "[bound]"; + } +} bool Expression2::Equals(const Expression2& other) const { if (Identical(*this, other)) return true; @@ -181,30 +330,41 @@ bool Expression2::IsNullLiteral() const { return true; } } + return false; } bool Expression2::IsSatisfiable() const { + if (descr_.type && descr_.type->id() == Type::NA) { + return false; + } + if (auto lit = literal()) { if (lit->null_count() == lit->length()) { return false; } + if (lit->is_scalar() && lit->type()->id() == Type::BOOL) { return lit->scalar_as().value; } } + + if (auto ref = field_ref()) { + return true; + } + return true; } -bool KernelStateIsImmutable(const std::string& function) { +inline bool KernelStateIsImmutable(const std::string& function) { // XXX maybe just add Kernel::state_is_immutable or so? // known functions with non-null but nevertheless immutable KernelState - static std::unordered_set immutable_state = { + static std::unordered_set names = { "is_in", "index_in", "cast", "struct", "strptime", }; - return immutable_state.find(function) != immutable_state.end(); + return names.find(function) != names.end(); } Result> InitKernelState( @@ -219,62 +379,45 @@ Result> InitKernelState( return std::move(kernel_state); } -Result> ExpressionState::Clone( - const std::shared_ptr& state, const Expression2& expr, - compute::ExecContext* exec_context) { - if (!expr.IsBound()) { - return Status::Invalid("Cannot clone State against an unbound expression."); - } - - if (state == nullptr) return nullptr; - - auto call = CallNotNull(expr); - auto call_state = checked_cast(state.get()); +Result> CloneExpressionState( + const ExpressionState& state, compute::ExecContext* exec_context) { + auto clone = std::make_shared(); + clone->kernel_states.reserve(state.kernel_states.size()); - CallState clone; - clone.argument_states.resize(call_state->argument_states.size()); + for (const auto& sub_state : state.kernel_states) { + auto call = CallNotNull(sub_state.first); - bool recursively_share = true; - for (size_t i = 0; i < call->arguments.size(); ++i) { - ARROW_ASSIGN_OR_RAISE( - clone.argument_states[i], - Clone(call_state->argument_states[i], call->arguments[i], exec_context)); - - if (clone.argument_states[i] != call_state->argument_states[i]) { - recursively_share = false; + if (KernelStateIsImmutable(call->function)) { + // The kernel's state is immutable so it's safe to just + // share a pointer between threads + clone->kernel_states.insert(sub_state); + continue; } - } - if (call_state->kernel_state == nullptr || KernelStateIsImmutable(call->function)) { - // The kernel's state is immutable so it's safe to just - // share a pointer between threads - if (recursively_share) { - return state; - } - clone.kernel_state = call_state->kernel_state; - } else { // The kernel's state must be re-initialized. - ARROW_ASSIGN_OR_RAISE(clone.kernel_state, InitKernelState(*call, exec_context)); + ARROW_ASSIGN_OR_RAISE(auto kernel_state, InitKernelState(*call, exec_context)); + clone->kernel_states.emplace(sub_state.first, std::move(kernel_state)); } - return std::make_shared(std::move(clone)); + return clone; } -Result Expression2::Bind( +Result Expression2::Bind( ValueDescr in, compute::ExecContext* exec_context) const { if (exec_context == nullptr) { compute::ExecContext exec_context; return Bind(std::move(in), &exec_context); } - if (literal()) return StateAndBound{*this, nullptr}; + BoundWithState ret{*this, std::make_shared()}; + + if (literal()) return ret; if (auto ref = field_ref()) { ARROW_ASSIGN_OR_RAISE(auto field, ref->GetOneOrNone(*in.type)); - auto bound = *this; - bound.descr_ = + ret.first.descr_ = field ? ValueDescr{field->type(), in.shape} : ValueDescr::Scalar(null()); - return StateAndBound{std::move(bound), nullptr}; + return ret; } auto bound_call = *CallNotNull(*this); @@ -291,31 +434,41 @@ Result Expression2::Bind( } bound_call.function_kind = function->kind(); - std::vector> argument_states( - bound_call.arguments.size()); - for (size_t i = 0; i < argument_states.size(); ++i) { - ARROW_ASSIGN_OR_RAISE(std::tie(bound_call.arguments[i], argument_states[i]), + auto state = std::make_shared(); + for (size_t i = 0; i < bound_call.arguments.size(); ++i) { + std::shared_ptr argument_state; + ARROW_ASSIGN_OR_RAISE(std::tie(bound_call.arguments[i], argument_state), bound_call.arguments[i].Bind(in, exec_context)); + state->MoveFrom(argument_state.get()); + } + + if (RequriesDictionaryTransparency(bound_call)) { + RETURN_NOT_OK(EnsureNotDictionary(&bound_call)); } auto descrs = GetDescriptors(bound_call.arguments); - ARROW_ASSIGN_OR_RAISE(bound_call.kernel, function->DispatchExact(descrs)); + for (auto& descr : descrs) { + if (RequriesDictionaryTransparency(bound_call)) { + RETURN_NOT_OK(EnsureNotDictionary(&descr)); + } + } - auto call_state = std::make_shared(); - ARROW_ASSIGN_OR_RAISE(call_state->kernel_state, - InitKernelState(bound_call, exec_context)); - call_state->argument_states = std::move(argument_states); + ARROW_ASSIGN_OR_RAISE(bound_call.kernel, function->DispatchExact(descrs)); compute::KernelContext kernel_context(exec_context); - kernel_context.SetState(call_state->kernel_state.get()); + ARROW_ASSIGN_OR_RAISE(auto kernel_state, InitKernelState(bound_call, exec_context)); + kernel_context.SetState(kernel_state.get()); + + ARROW_ASSIGN_OR_RAISE(auto descr, bound_call.kernel->signature->out_type().Resolve( + &kernel_context, descrs)); - auto bound = Expression2(std::make_shared(std::move(bound_call))); - ARROW_ASSIGN_OR_RAISE(bound.descr_, bound_call.kernel->signature->out_type().Resolve( - &kernel_context, descrs)); - return StateAndBound{std::move(bound), std::move(call_state)}; + Expression2 bound(std::make_shared(std::move(bound_call)), std::move(descr)); + + state->kernel_states.emplace(bound, std::move(kernel_state)); + return BoundWithState{std::move(bound), std::move(state)}; } -Result Expression2::Bind( +Result Expression2::Bind( const Schema& in_schema, compute::ExecContext* exec_context) const { return Bind(ValueDescr::Array(struct_(in_schema.fields())), exec_context); } @@ -354,25 +507,26 @@ Result ExecuteScalarExpression(const Expression2& expr, ExpressionState* } auto call = CallNotNull(expr); - auto call_state = checked_cast(state); std::vector arguments(call->arguments.size()); for (size_t i = 0; i < arguments.size(); ++i) { - auto argument_state = call_state->argument_states[i].get(); - ARROW_ASSIGN_OR_RAISE( - arguments[i], - ExecuteScalarExpression(call->arguments[i], argument_state, input, exec_context)); + ARROW_ASSIGN_OR_RAISE(arguments[i], ExecuteScalarExpression(call->arguments[i], state, + input, exec_context)); + + if (RequriesDictionaryTransparency(*call)) { + RETURN_NOT_OK(EnsureNotDictionary(&arguments[i])); + } } auto executor = compute::detail::KernelExecutor::MakeScalar(); compute::KernelContext kernel_context(exec_context); - kernel_context.SetState(call_state->kernel_state.get()); + kernel_context.SetState(state->Get(expr)); auto kernel = call->kernel; - auto inputs = GetDescriptors(call->arguments); + auto descrs = GetDescriptors(arguments); auto options = call->options.get(); - RETURN_NOT_OK(executor->Init(&kernel_context, {kernel, inputs, options})); + RETURN_NOT_OK(executor->Init(&kernel_context, {kernel, descrs, options})); auto listener = std::make_shared(); RETURN_NOT_OK(executor->Execute(arguments, listener.get())); @@ -388,6 +542,18 @@ ArgumentsAndFlippedArguments(const Expression2::Call& call) { call.arguments[0]}}; } +template ::value_type> +util::optional FoldLeft(It begin, It end, const BinOp& bin_op) { + if (begin == end) return util::nullopt; + + Out folded = std::move(*begin++); + while (begin != end) { + folded = bin_op(std::move(folded), std::move(*begin++)); + } + return folded; +} + util::optional GetNullHandling( const Expression2::Call& call) { if (call.function_kind == compute::Function::SCALAR) { @@ -419,8 +585,23 @@ bool DefinitelyNotNull(const Expression2& expr) { return false; } +std::vector FieldsInExpression(const Expression2& expr) { + if (auto lit = expr.literal()) return {}; + + if (auto ref = expr.field_ref()) { + return {*ref}; + } + + std::vector fields; + for (const Expression2& arg : CallNotNull(expr)->arguments) { + auto argument_fields = FieldsInExpression(arg); + std::move(argument_fields.begin(), argument_fields.end(), std::back_inserter(fields)); + } + return fields; +} + template -Result Modify(Expression2 expr, const PreVisit& pre, +Result Modify(Expression2 expr, ExpressionState* state, const PreVisit& pre, const PostVisitCall& post_call) { ARROW_ASSIGN_OR_RAISE(expr, Result(pre(std::move(expr)))); @@ -432,7 +613,7 @@ Result Modify(Expression2 expr, const PreVisit& pre, auto modified_argument = modified_call.arguments.begin(); for (const auto& argument : call->arguments) { - ARROW_ASSIGN_OR_RAISE(*modified_argument, Modify(argument, pre, post_call)); + ARROW_ASSIGN_OR_RAISE(*modified_argument, Modify(argument, state, pre, post_call)); if (!Identical(*modified_argument, argument)) { at_least_one_modified = true; @@ -444,44 +625,37 @@ Result Modify(Expression2 expr, const PreVisit& pre, // reconstruct the call expression with the modified arguments auto modified_expr = Expression2( std::make_shared(std::move(modified_call)), expr.descr()); + + // if expr had associated kernel state, associate it with modified_expr + state->Replace(expr, modified_expr); return post_call(std::move(modified_expr), &expr); } return post_call(std::move(expr), nullptr); } -Result FoldConstants(Expression2 expr, ExpressionState* state) { - DCHECK(expr.IsBound()); - - struct StateAndIndex { - CallState* state; - int index; - }; - std::vector stack; - - auto root_state = checked_cast(state); - return Modify( - std::move(expr), - [&stack, root_state](Expression2 expr) { - auto call = expr.call(); +template +Result Modify(Expression2::BoundWithState bound, + const PreVisit& pre, + const PostVisitCall& post_call) { + DCHECK(bound.first.IsBound()); - if (stack.empty()) { - stack = {{root_state, 0}}; - return expr; - } + auto expr = std::move(bound.first); + auto state = bound.second.get(); - int i = stack.back().index++; + ARROW_ASSIGN_OR_RAISE(expr, Modify(std::move(expr), state, pre, post_call)); - if (!call) return expr; - auto next_state = stack.back().state->argument_states[i].get(); - stack.push_back({checked_cast(next_state), 0}); + bound.first = std::move(expr); - return expr; - }, - [&stack](Expression2 expr, ...) -> Result { - auto state = stack.back().state; - stack.pop_back(); + return bound; +} +Result FoldConstants(Expression2::BoundWithState bound) { + bound.second = std::make_shared(*bound.second); + auto state = bound.second.get(); + return Modify( + std::move(bound), [](Expression2 expr) { return expr; }, + [state](Expression2 expr, ...) -> Result { auto call = CallNotNull(expr); if (std::all_of(call->arguments.begin(), call->arguments.end(), [](const Expression2& argument) { return argument.literal(); })) { @@ -489,16 +663,22 @@ Result FoldConstants(Expression2 expr, ExpressionState* state) { static const Datum ignored_input; ARROW_ASSIGN_OR_RAISE(Datum constant, ExecuteScalarExpression(expr, state, ignored_input)); + + state->Drop(expr); return literal(std::move(constant)); } - // XXX the following should probably be in a registry of passes instead of inline + // XXX the following should probably be in a registry of passes instead + // of inline if (GetNullHandling(*call) == compute::NullHandling::INTERSECTION) { - // kernels which always produce intersected validity can be resolved to null - // *now* if any of their inputs is a null literal + // kernels which always produce intersected validity can be resolved + // to null *now* if any of their inputs is a null literal for (const auto& argument : call->arguments) { - if (argument.IsNullLiteral()) return argument; + if (argument.IsNullLiteral()) { + state->Drop(expr); + return argument; + } } } @@ -526,43 +706,8 @@ Result FoldConstants(Expression2 expr, ExpressionState* state) { return expr; }); - - return expr; } -struct FlattenedAssociativeChain { - bool was_left_folded = true; - std::vector exprs, fringe; - - explicit FlattenedAssociativeChain(Expression2 expr) : exprs{std::move(expr)} { - auto call = CallNotNull(exprs.back()); - fringe = call->arguments; - - auto it = fringe.begin(); - - while (it != fringe.end()) { - auto sub_call = it->call(); - if (!sub_call || sub_call->function != call->function) { - ++it; - continue; - } - - if (it != fringe.begin()) { - was_left_folded = false; - } - - exprs.push_back(std::move(*it)); - it = fringe.erase(it); - it = fringe.insert(it, sub_call->arguments.begin(), sub_call->arguments.end()); - // NB: no increment so we hit sub_call's first argument next iteration - } - - DCHECK(std::all_of(exprs.begin(), exprs.end(), [](const Expression2& expr) { - return CallNotNull(expr)->options == nullptr; - })); - } -}; - inline std::vector GuaranteeConjunctionMembers( const Expression2& guaranteed_true_predicate) { auto guarantee = guaranteed_true_predicate.call(); @@ -577,8 +722,12 @@ inline std::vector GuaranteeConjunctionMembers( Status ExtractKnownFieldValuesImpl( std::vector* conjunction_members, std::unordered_map* known_values) { - for (auto&& member : *conjunction_members) { - ARROW_ASSIGN_OR_RAISE(member, Canonicalize(std::move(member))); + { + auto empty_state = std::make_shared(); + for (auto&& member : *conjunction_members) { + ARROW_ASSIGN_OR_RAISE(std::tie(member, std::ignore), + Canonicalize({std::move(member), empty_state})); + } } auto unconsumed_end = @@ -622,17 +771,21 @@ Status ExtractKnownFieldValuesImpl( Result> ExtractKnownFieldValues( const Expression2& guaranteed_true_predicate) { + if (!guaranteed_true_predicate.IsBound()) { + return Status::Invalid("guaranteed_true_predicate was not bound"); + } auto conjunction_members = GuaranteeConjunctionMembers(guaranteed_true_predicate); std::unordered_map known_values; RETURN_NOT_OK(ExtractKnownFieldValuesImpl(&conjunction_members, &known_values)); return known_values; } -Result ReplaceFieldsWithKnownValues( +Result ReplaceFieldsWithKnownValues( const std::unordered_map& known_values, - Expression2 expr) { + Expression2::BoundWithState bound) { + bound.second = std::make_shared(*bound.second); return Modify( - std::move(expr), + std::move(bound), [&known_values](Expression2 expr) { if (auto ref = expr.field_ref()) { auto it = known_values.find(*ref); @@ -645,18 +798,6 @@ Result ReplaceFieldsWithKnownValues( [](Expression2 expr, ...) { return expr; }); } -template ::value_type> -util::optional FoldLeft(It begin, It end, const BinOp& bin_op) { - if (begin == end) return util::nullopt; - - Out folded = std::move(*begin++); - while (begin != end) { - folded = bin_op(std::move(folded), std::move(*begin++)); - } - return folded; -} - inline bool IsBinaryAssociativeCommutative(const Expression2::Call& call) { static std::unordered_set binary_associative_commutative{ "and", "or", "and_kleene", "or_kleene", "xor", @@ -666,97 +807,13 @@ inline bool IsBinaryAssociativeCommutative(const Expression2::Call& call) { return it != binary_associative_commutative.end(); } -struct Comparison { - enum type { - NA = 0, - EQUAL = 1, - LESS = 2, - GREATER = 4, - NOT_EQUAL = LESS | GREATER, - LESS_EQUAL = LESS | EQUAL, - GREATER_EQUAL = GREATER | EQUAL, - }; +Result Canonicalize(Expression2::BoundWithState bound, + compute::ExecContext* exec_context) { + if (exec_context == nullptr) { + compute::ExecContext exec_context; + return Canonicalize(std::move(bound), &exec_context); + } - static const type* Get(const std::string& function) { - static std::unordered_map flipped_comparisons{ - {"equal", EQUAL}, {"not_equal", NOT_EQUAL}, - {"less", LESS}, {"less_equal", LESS_EQUAL}, - {"greater", GREATER}, {"greater_equal", GREATER_EQUAL}, - }; - - auto it = flipped_comparisons.find(function); - return it != flipped_comparisons.end() ? &it->second : nullptr; - } - - static const type* Get(const Expression2& expr) { - if (auto call = expr.call()) { - return Comparison::Get(call->function); - } - return nullptr; - } - - static Result Execute(Datum l, Datum r) { - if (!l.is_scalar() || !r.is_scalar()) { - return Status::Invalid("Cannot Execute Comparison on non-scalars"); - } - std::vector arguments{std::move(l), std::move(r)}; - - ARROW_ASSIGN_OR_RAISE(auto equal, compute::CallFunction("equal", arguments)); - - if (!equal.scalar()->is_valid) return NA; - if (equal.scalar_as().value) return EQUAL; - - ARROW_ASSIGN_OR_RAISE(auto less, compute::CallFunction("less", arguments)); - - if (!less.scalar()->is_valid) return NA; - return less.scalar_as().value ? LESS : GREATER; - } - - static type GetFlipped(type op) { - switch (op) { - case NA: - return NA; - case EQUAL: - return EQUAL; - case LESS: - return GREATER; - case GREATER: - return LESS; - case NOT_EQUAL: - return NOT_EQUAL; - case LESS_EQUAL: - return GREATER_EQUAL; - case GREATER_EQUAL: - return LESS_EQUAL; - }; - DCHECK(false); - return NA; - } - - static std::string GetName(type op) { - switch (op) { - case NA: - DCHECK(false) << "unreachable"; - break; - case EQUAL: - return "equal"; - case LESS: - return "less"; - case GREATER: - return "greater"; - case NOT_EQUAL: - return "not_equal"; - case LESS_EQUAL: - return "less_equal"; - case GREATER_EQUAL: - return "greater_equal"; - }; - DCHECK(false); - return "na"; - } -}; - -Result Canonicalize(Expression2 expr) { // If potentially reconstructing more deeply than a call's immediate arguments // (for example, when reorganizing an associative chain), add expressions to this set to // avoid unnecessary work @@ -773,8 +830,8 @@ Result Canonicalize(Expression2 expr) { } AlreadyCanonicalized; return Modify( - std::move(expr), - [&AlreadyCanonicalized](Expression2 expr) -> Result { + std::move(bound), + [&AlreadyCanonicalized, exec_context](Expression2 expr) -> Result { auto call = expr.call(); if (!call) return expr; @@ -822,13 +879,25 @@ Result Canonicalize(Expression2 expr) { if (call->arguments[0].literal() && !call->arguments[1].literal()) { // ensure that literals are on comparisons' RHS auto flipped_call = *call; - for (auto&& argument : flipped_call.arguments) { - ARROW_ASSIGN_OR_RAISE(argument, Canonicalize(std::move(argument))); - } flipped_call.function = Comparison::GetName(Comparison::GetFlipped(*cmp)); + // look up the flipped kernel + // TODO extract a helper for use here and in Bind + ARROW_ASSIGN_OR_RAISE( + auto function, + exec_context->func_registry()->GetFunction(flipped_call.function)); + + auto descrs = GetDescriptors(flipped_call.arguments); + for (auto& descr : descrs) { + if (RequriesDictionaryTransparency(flipped_call)) { + RETURN_NOT_OK(EnsureNotDictionary(&descr)); + } + } + ARROW_ASSIGN_OR_RAISE(flipped_call.kernel, function->DispatchExact(descrs)); + std::swap(flipped_call.arguments[0], flipped_call.arguments[1]); return Expression2( - std::make_shared(std::move(flipped_call))); + std::make_shared(std::move(flipped_call)), + expr.descr()); } } @@ -837,10 +906,10 @@ Result Canonicalize(Expression2 expr) { [](Expression2 expr, ...) { return expr; }); } -Result DirectComparisonSimplification(Expression2 expr, - const Expression2::Call& guarantee) { +Result DirectComparisonSimplification( + Expression2::BoundWithState bound, const Expression2::Call& guarantee) { return Modify( - std::move(expr), [](Expression2 expr) { return expr; }, + std::move(bound), [](Expression2 expr) { return expr; }, [&guarantee](Expression2 expr, ...) -> Result { auto call = expr.call(); if (!call) return expr; @@ -895,36 +964,251 @@ Result DirectComparisonSimplification(Expression2 expr, }); } -Result SimplifyWithGuarantee(Expression2 expr, ExpressionState* state, - const Expression2& guaranteed_true_predicate) { +Result SimplifyWithGuarantee( + Expression2::BoundWithState bound, const Expression2& guaranteed_true_predicate) { + if (!guaranteed_true_predicate.IsBound()) { + return Status::Invalid("guaranteed_true_predicate was not bound"); + } + auto conjunction_members = GuaranteeConjunctionMembers(guaranteed_true_predicate); std::unordered_map known_values; RETURN_NOT_OK(ExtractKnownFieldValuesImpl(&conjunction_members, &known_values)); - ARROW_ASSIGN_OR_RAISE(expr, - ReplaceFieldsWithKnownValues(known_values, std::move(expr))); + ARROW_ASSIGN_OR_RAISE(bound, + ReplaceFieldsWithKnownValues(known_values, std::move(bound))); - auto CanonicalizeAndFoldConstants = [&expr, state] { - ARROW_ASSIGN_OR_RAISE(expr, Canonicalize(std::move(expr))); - ARROW_ASSIGN_OR_RAISE(expr, FoldConstants(std::move(expr), state)); + auto CanonicalizeAndFoldConstants = [&bound] { + ARROW_ASSIGN_OR_RAISE(bound, Canonicalize(std::move(bound))); + ARROW_ASSIGN_OR_RAISE(bound, FoldConstants(std::move(bound))); return Status::OK(); }; RETURN_NOT_OK(CanonicalizeAndFoldConstants()); for (const auto& guarantee : conjunction_members) { if (Comparison::Get(guarantee) && guarantee.call()->arguments[1].literal()) { - ARROW_ASSIGN_OR_RAISE( - auto simplified, DirectComparisonSimplification(expr, *CallNotNull(guarantee))); + ARROW_ASSIGN_OR_RAISE(auto simplified, DirectComparisonSimplification( + bound, *CallNotNull(guarantee))); - if (Identical(simplified, expr)) continue; + if (Identical(simplified.first, bound.first)) continue; - expr = std::move(simplified); + bound = std::move(simplified); RETURN_NOT_OK(CanonicalizeAndFoldConstants()); } } - return expr; + return bound; +} + +Result> Serialize(const Expression2& expr) { + struct { + std::shared_ptr metadata_ = std::make_shared(); + ArrayVector columns_; + + Result AddScalar(const Scalar& scalar) { + auto ret = columns_.size(); + ARROW_ASSIGN_OR_RAISE(auto array, MakeArrayFromScalar(scalar, 1)); + columns_.push_back(std::move(array)); + return std::to_string(ret); + } + + Status Visit(const Expression2& expr) { + if (auto lit = expr.literal()) { + if (!lit->is_scalar()) { + return Status::NotImplemented("Serialization of non-scalar literals"); + } + ARROW_ASSIGN_OR_RAISE(auto value, AddScalar(*lit->scalar())); + metadata_->Append("literal", std::move(value)); + return Status::OK(); + } + + if (auto ref = expr.field_ref()) { + if (!ref->name()) { + return Status::NotImplemented("Serialization of non-name field_refs"); + } + metadata_->Append("field_ref", *ref->name()); + return Status::OK(); + } + + auto call = CallNotNull(expr); + metadata_->Append("call", call->function); + + for (const auto& argument : call->arguments) { + RETURN_NOT_OK(Visit(argument)); + } + + if (call->options) { + ARROW_ASSIGN_OR_RAISE(auto options_scalar, FunctionOptionsToStructScalar(*call)); + ARROW_ASSIGN_OR_RAISE(auto value, AddScalar(*options_scalar)); + metadata_->Append("options", std::move(value)); + } + + metadata_->Append("end", call->function); + return Status::OK(); + } + + Result> operator()(const Expression2& expr) { + RETURN_NOT_OK(Visit(expr)); + FieldVector fields(columns_.size()); + for (size_t i = 0; i < fields.size(); ++i) { + fields[i] = field("", columns_[i]->type()); + } + return RecordBatch::Make(schema(std::move(fields), std::move(metadata_)), 1, + std::move(columns_)); + } + } ToRecordBatch; + + ARROW_ASSIGN_OR_RAISE(auto batch, ToRecordBatch(expr)); + ARROW_ASSIGN_OR_RAISE(auto stream, io::BufferOutputStream::Create()); + ARROW_ASSIGN_OR_RAISE(auto writer, ipc::MakeFileWriter(stream, batch->schema())); + RETURN_NOT_OK(writer->WriteRecordBatch(*batch)); + RETURN_NOT_OK(writer->Close()); + return stream->Finish(); +} + +Result Deserialize(const Buffer& buffer) { + io::BufferReader stream(buffer); + ARROW_ASSIGN_OR_RAISE(auto reader, ipc::RecordBatchFileReader::Open(&stream)); + ARROW_ASSIGN_OR_RAISE(auto batch, reader->ReadRecordBatch(0)); + if (batch->schema()->metadata() == nullptr) { + return Status::Invalid("serialized Expression2's batch repr had null metadata"); + } + if (batch->num_rows() != 1) { + return Status::Invalid( + "serialized Expression2's batch repr was not a single row - had ", + batch->num_rows()); + } + + struct FromRecordBatch { + const RecordBatch& batch_; + int index_; + + const KeyValueMetadata& metadata() { return *batch_.schema()->metadata(); } + + Result> GetScalar(const std::string& i) { + int32_t column_index; + if (!internal::ParseValue(i.data(), i.length(), &column_index)) { + return Status::Invalid("Couldn't parse column_index"); + } + if (column_index >= batch_.num_columns()) { + return Status::Invalid("column_index out of bounds"); + } + return batch_.column(column_index)->GetScalar(0); + } + + Result GetOne() { + if (index_ >= metadata().size()) { + return Status::Invalid("unterminated serialized Expression2"); + } + + const std::string& key = metadata().key(index_); + const std::string& value = metadata().value(index_); + ++index_; + + if (key == "literal") { + ARROW_ASSIGN_OR_RAISE(auto scalar, GetScalar(value)); + return literal(std::move(scalar)); + } + + if (key == "field_ref") { + return field_ref(value); + } + + if (key != "call") { + return Status::Invalid("Unrecognized serialized Expression2 key ", key); + } + + std::vector arguments; + while (metadata().key(index_) != "end") { + if (metadata().key(index_) == "options") { + ARROW_ASSIGN_OR_RAISE(auto options_scalar, GetScalar(metadata().value(index_))); + auto expr = call(value, std::move(arguments)); + RETURN_NOT_OK(FunctionOptionsFromStructScalar( + checked_cast(options_scalar.get()), + const_cast(expr.call()))); + index_ += 2; + return expr; + } + + ARROW_ASSIGN_OR_RAISE(auto argument, GetOne()); + arguments.push_back(std::move(argument)); + } + + ++index_; + return call(value, std::move(arguments)); + } + }; + + return FromRecordBatch{*batch, 0}.GetOne(); +} + +Expression2 project(std::vector values, std::vector names) { + return call("struct", std::move(values), compute::StructOptions{std::move(names)}); +} + +Expression2 equal(Expression2 lhs, Expression2 rhs) { + return call("equal", {std::move(lhs), std::move(rhs)}); +} + +Expression2 not_equal(Expression2 lhs, Expression2 rhs) { + return call("not_equal", {std::move(lhs), std::move(rhs)}); +} + +Expression2 less(Expression2 lhs, Expression2 rhs) { + return call("less", {std::move(lhs), std::move(rhs)}); +} + +Expression2 less_equal(Expression2 lhs, Expression2 rhs) { + return call("less_equal", {std::move(lhs), std::move(rhs)}); +} + +Expression2 greater(Expression2 lhs, Expression2 rhs) { + return call("greater", {std::move(lhs), std::move(rhs)}); +} + +Expression2 greater_equal(Expression2 lhs, Expression2 rhs) { + return call("greater_equal", {std::move(lhs), std::move(rhs)}); +} + +Expression2 and_(Expression2 lhs, Expression2 rhs) { + return call("and_kleene", {std::move(lhs), std::move(rhs)}); +} + +Expression2 and_(const std::vector& operands) { + auto folded = FoldLeft(operands.begin(), + operands.end(), and_); + if (folded) { + return std::move(*folded); + } + return literal(true); +} + +Expression2 or_(Expression2 lhs, Expression2 rhs) { + return call("or_kleene", {std::move(lhs), std::move(rhs)}); +} + +Expression2 or_(const std::vector& operands) { + auto folded = FoldLeft(operands.begin(), + operands.end(), or_); + if (folded) { + return std::move(*folded); + } + return literal(false); +} + +Expression2 not_(Expression2 operand) { return call("invert", {std::move(operand)}); } + +Expression2 operator&&(Expression2 lhs, Expression2 rhs) { + return and_(std::move(lhs), std::move(rhs)); +} + +Expression2 operator||(Expression2 lhs, Expression2 rhs) { + return or_(std::move(lhs), std::move(rhs)); +} + +Result InsertImplicitCasts(Expression2 expr, const Schema& s) { + std::shared_ptr e(expr); + return InsertImplicitCasts(*e, s); } } // namespace dataset diff --git a/cpp/src/arrow/dataset/expression.h b/cpp/src/arrow/dataset/expression.h index 4da41b8b9c4..f452e54b7c2 100644 --- a/cpp/src/arrow/dataset/expression.h +++ b/cpp/src/arrow/dataset/expression.h @@ -45,9 +45,6 @@ namespace dataset { /// - A literal Datum. /// - A reference to a single (potentially nested) field of the input Datum. /// - A call to a compute function, with arguments specified by other Expressions. -/// - Optionally, a explicitly selected Kernel of the function. If provided, -/// execution will skip function lookup and kernel dispatch and use that Kernel -/// directly. class ARROW_DS_EXPORT Expression2 { public: struct Call { @@ -70,10 +67,10 @@ class ARROW_DS_EXPORT Expression2 { /// Bind this expression to the given input type, looking up Kernels and field types. /// Some expression simplification may be performed and implicit casts will be inserted. /// Any state necessary for execution will be initialized and returned. - using StateAndBound = std::pair>; - Result Bind(ValueDescr in, compute::ExecContext* = NULLPTR) const; - Result Bind(const Schema& in_schema, - compute::ExecContext* = NULLPTR) const; + using BoundWithState = std::pair>; + Result Bind(ValueDescr in, compute::ExecContext* = NULLPTR) const; + Result Bind(const Schema& in_schema, + compute::ExecContext* = NULLPTR) const; /// Return true if all an expression's field references have explicit ValueDescr and all /// of its functions' kernels are looked up. @@ -96,33 +93,32 @@ class ARROW_DS_EXPORT Expression2 { const Datum* literal() const; const FieldRef* field_ref() const; + // FIXME remove these + operator std::shared_ptr() const; // NOLINT runtime/explicit + Expression2(const Expression& expr); // NOLINT runtime/explicit + Expression2(std::shared_ptr expr) // NOLINT runtime/explicit + : Expression2(*expr) {} + const ValueDescr& descr() const { return descr_; } - using Impl = util::variant; + using Impl = util::Variant; explicit Expression2(std::shared_ptr impl, ValueDescr descr = {}) : impl_(std::move(impl)), descr_(std::move(descr)) {} + Expression2() = default; + private: std::shared_ptr impl_; ValueDescr descr_; // XXX someday // NullGeneralization::type evaluates_to_null_; - friend bool Identical(const Expression2& l, const Expression2& r); + ARROW_EXPORT friend bool Identical(const Expression2& l, const Expression2& r); ARROW_EXPORT friend void PrintTo(const Expression2&, std::ostream*); }; -struct ExpressionState { - virtual ~ExpressionState() = default; - - // Produce another instance of ExpressionState which may be safely used from a - // different thread - static Result> Clone( - const std::shared_ptr&, const Expression2&, compute::ExecContext*); -}; - inline bool operator==(const Expression2& l, const Expression2& r) { return l.Equals(r); } inline bool operator!=(const Expression2& l, const Expression2& r) { return !l.Equals(r); @@ -147,11 +143,6 @@ Expression2 call(std::string function, std::vector arguments, std::make_shared(std::move(options))); } -inline Expression2 project(std::vector names, - std::vector values) { - return call("struct", std::move(values), compute::StructOptions{std::move(names)}); -} - template Expression2 field_ref(Args&&... args) { return Expression2( @@ -166,36 +157,47 @@ Expression2 literal(Arg&& arg) { std::move(descr)); } -// Simplification passes +ARROW_DS_EXPORT +std::vector FieldsInExpression(const Expression2&); + +ARROW_DS_EXPORT +Result> ExtractKnownFieldValues( + const Expression2& guaranteed_true_predicate); + +/// \defgroup expression-passes Functions for modification of Expression2s +/// +/// @{ +/// +/// These operate on a bound expression and its bound state simultaneously, +/// ensuring that Call Expression2s' KernelState can be utilized or reassociated. /// Weak canonicalization which establishes guarantees for subsequent passes. Even /// equivalent Expressions may result in different canonicalized expressions. /// TODO this could be a strong canonicalization ARROW_DS_EXPORT -Result Canonicalize(Expression2 expr); +Result Canonicalize(Expression2::BoundWithState, + compute::ExecContext* = NULLPTR); /// Simplify Expressions based on literal arguments (for example, add(null, x) will always /// be null so replace the call with a null literal). Includes early evaluation of all /// calls whose arguments are entirely literal. ARROW_DS_EXPORT -Result FoldConstants(Expression2 expr, ExpressionState*); - -ARROW_DS_EXPORT -Result> ExtractKnownFieldValues( - const Expression2& guaranteed_true_predicate); +Result FoldConstants(Expression2::BoundWithState); ARROW_DS_EXPORT -Result ReplaceFieldsWithKnownValues( +Result ReplaceFieldsWithKnownValues( const std::unordered_map& known_values, - Expression2 expr); + Expression2::BoundWithState); /// Simplify an expression by replacing subexpressions based on a guarantee: /// a boolean expression which is guaranteed to evaluate to `true`. For example, this is /// used to remove redundant function calls from a filter expression or to replace a /// reference to a constant-value field with a literal. ARROW_DS_EXPORT -Result SimplifyWithGuarantee(Expression2 expr, ExpressionState*, - const Expression2& guaranteed_true_predicate); +Result SimplifyWithGuarantee( + Expression2::BoundWithState, const Expression2& guaranteed_true_predicate); + +/// @} // Execution @@ -203,7 +205,7 @@ Result SimplifyWithGuarantee(Expression2 expr, ExpressionState*, /// expression must be bound. Result ExecuteScalarExpression(const Expression2&, ExpressionState*, const Datum& input, - compute::ExecContext* exec_context = NULLPTR); + compute::ExecContext* = NULLPTR); // Serialization @@ -213,5 +215,33 @@ Result> Serialize(const Expression2&); ARROW_DS_EXPORT Result Deserialize(const Buffer&); +// Convenience aliases for factories + +ARROW_DS_EXPORT Expression2 project(std::vector values, + std::vector names); + +ARROW_DS_EXPORT Expression2 equal(Expression2 lhs, Expression2 rhs); + +ARROW_DS_EXPORT Expression2 not_equal(Expression2 lhs, Expression2 rhs); + +ARROW_DS_EXPORT Expression2 less(Expression2 lhs, Expression2 rhs); + +ARROW_DS_EXPORT Expression2 less_equal(Expression2 lhs, Expression2 rhs); + +ARROW_DS_EXPORT Expression2 greater(Expression2 lhs, Expression2 rhs); + +ARROW_DS_EXPORT Expression2 greater_equal(Expression2 lhs, Expression2 rhs); + +ARROW_DS_EXPORT Expression2 and_(Expression2 lhs, Expression2 rhs); +ARROW_DS_EXPORT Expression2 and_(const std::vector&); +ARROW_DS_EXPORT Expression2 or_(Expression2 lhs, Expression2 rhs); +ARROW_DS_EXPORT Expression2 or_(const std::vector&); +ARROW_DS_EXPORT Expression2 not_(Expression2 operand); + +// FIXME remove these +ARROW_DS_EXPORT Expression2 operator&&(Expression2 lhs, Expression2 rhs); +ARROW_DS_EXPORT Expression2 operator||(Expression2 lhs, Expression2 rhs); +ARROW_DS_EXPORT Result InsertImplicitCasts(Expression2, const Schema&); + } // namespace dataset } // namespace arrow diff --git a/cpp/src/arrow/dataset/expression_internal.h b/cpp/src/arrow/dataset/expression_internal.h index 1998a565ec3..548d36711fb 100644 --- a/cpp/src/arrow/dataset/expression_internal.h +++ b/cpp/src/arrow/dataset/expression_internal.h @@ -21,6 +21,7 @@ #include #include +#include "arrow/compute/api_vector.h" #include "arrow/record_batch.h" #include "arrow/table.h" #include "arrow/util/logging.h" @@ -62,9 +63,44 @@ inline std::vector GetDescriptors(const std::vector& ex return descrs; } -struct CallState : ExpressionState { - std::vector> argument_states; - std::shared_ptr kernel_state; +inline std::vector GetDescriptors(const std::vector& values) { + std::vector descrs(values.size()); + for (size_t i = 0; i < values.size(); ++i) { + descrs[i] = values[i].descr(); + } + return descrs; +} + +struct ARROW_DS_EXPORT ExpressionState { + std::unordered_map, + Expression2::Hash> + kernel_states; + + compute::KernelState* Get(const Expression2& expr) const { + auto it = kernel_states.find(expr); + if (it == kernel_states.end()) return nullptr; + return it->second.get(); + } + + void Replace(const Expression2& expr, const Expression2& replacement) { + auto it = kernel_states.find(expr); + if (it == kernel_states.end()) return; + + auto kernel_state = std::move(it->second); + kernel_states.erase(it); + kernel_states.emplace(replacement, std::move(kernel_state)); + } + + void Drop(const Expression2& expr) { + auto it = kernel_states.find(expr); + if (it == kernel_states.end()) return; + kernel_states.erase(it); + } + + void MoveFrom(ExpressionState* other) { + std::move(other->kernel_states.begin(), other->kernel_states.end(), + std::inserter(kernel_states, kernel_states.end())); + } }; struct FieldPathGetDatumImpl { @@ -106,5 +142,281 @@ inline Result GetDatumField(const FieldRef& ref, const Datum& input) { return field; } +struct Comparison { + enum type { + NA = 0, + EQUAL = 1, + LESS = 2, + GREATER = 4, + NOT_EQUAL = LESS | GREATER, + LESS_EQUAL = LESS | EQUAL, + GREATER_EQUAL = GREATER | EQUAL, + }; + + static const type* Get(const std::string& function) { + static std::unordered_map flipped_comparisons{ + {"equal", EQUAL}, {"not_equal", NOT_EQUAL}, + {"less", LESS}, {"less_equal", LESS_EQUAL}, + {"greater", GREATER}, {"greater_equal", GREATER_EQUAL}, + }; + + auto it = flipped_comparisons.find(function); + return it != flipped_comparisons.end() ? &it->second : nullptr; + } + + static const type* Get(const Expression2& expr) { + if (auto call = expr.call()) { + return Comparison::Get(call->function); + } + return nullptr; + } + + static Result Execute(Datum l, Datum r) { + if (!l.is_scalar() || !r.is_scalar()) { + return Status::Invalid("Cannot Execute Comparison on non-scalars"); + } + std::vector arguments{std::move(l), std::move(r)}; + + ARROW_ASSIGN_OR_RAISE(auto equal, compute::CallFunction("equal", arguments)); + + if (!equal.scalar()->is_valid) return NA; + if (equal.scalar_as().value) return EQUAL; + + ARROW_ASSIGN_OR_RAISE(auto less, compute::CallFunction("less", arguments)); + + if (!less.scalar()->is_valid) return NA; + return less.scalar_as().value ? LESS : GREATER; + } + + static type GetFlipped(type op) { + switch (op) { + case NA: + return NA; + case EQUAL: + return EQUAL; + case LESS: + return GREATER; + case GREATER: + return LESS; + case NOT_EQUAL: + return NOT_EQUAL; + case LESS_EQUAL: + return GREATER_EQUAL; + case GREATER_EQUAL: + return LESS_EQUAL; + }; + DCHECK(false); + return NA; + } + + static std::string GetName(type op) { + switch (op) { + case NA: + DCHECK(false) << "unreachable"; + break; + case EQUAL: + return "equal"; + case LESS: + return "less"; + case GREATER: + return "greater"; + case NOT_EQUAL: + return "not_equal"; + case LESS_EQUAL: + return "less_equal"; + case GREATER_EQUAL: + return "greater_equal"; + }; + DCHECK(false); + return "na"; + } +}; + +inline bool IsSetLookup(const std::string& function) { + return function == "is_in" || function == "index_in"; +} + +inline const compute::SetLookupOptions* GetSetLookupOptions( + const Expression2::Call& call) { + if (!IsSetLookup(call.function)) return nullptr; + return checked_cast(call.options.get()); +} + +inline bool RequriesDictionaryTransparency(const Expression2::Call& call) { + // TODO move this functionality into compute:: + + // Functions which don't provide kernels for dictionary types. Dictionaries will be + // decoded for these functions. + if (Comparison::Get(call.function)) return true; + + if (IsSetLookup(call.function)) return true; + + return false; +} + +inline Status EnsureNotDictionary(ValueDescr* descr) { + const auto& type = descr->type; + if (type && type->id() == Type::DICTIONARY) { + descr->type = checked_cast(*type).value_type(); + } + return Status::OK(); +} + +inline Status EnsureNotDictionary(Datum* datum) { + if (datum->type()->id() != Type::DICTIONARY) { + return Status::OK(); + } + + if (datum->is_scalar()) { + ARROW_ASSIGN_OR_RAISE( + *datum, + checked_cast(*datum->scalar()).GetEncodedValue()); + return Status::OK(); + } + + DCHECK_EQ(datum->kind(), Datum::ARRAY); + ArrayData indices = *datum->array(); + indices.type = checked_cast(*datum->type()).index_type(); + auto values = std::move(indices.dictionary); + + ARROW_ASSIGN_OR_RAISE( + *datum, compute::Take(values, indices, compute::TakeOptions::NoBoundsCheck())); + return Status::OK(); +} + +inline Status EnsureNotDictionary(Expression2::Call* call) { + if (auto options = GetSetLookupOptions(*call)) { + auto new_options = *options; + RETURN_NOT_OK(EnsureNotDictionary(&new_options.value_set)); + call->options.reset(new compute::SetLookupOptions(std::move(new_options))); + } + return Status::OK(); +} + +inline Result> FunctionOptionsToStructScalar( + const Expression2::Call& call) { + if (call.options == nullptr) { + return nullptr; + } + + auto Finish = [](ScalarVector values, std::vector names) { + FieldVector fields(names.size()); + for (size_t i = 0; i < fields.size(); ++i) { + fields[i] = field(std::move(names[i]), values[i]->type); + } + return std::make_shared(std::move(values), struct_(std::move(fields))); + }; + + if (auto options = GetSetLookupOptions(call)) { + if (!options->value_set.is_array()) { + return Status::NotImplemented("chunked value_set"); + } + return Finish( + { + std::make_shared(options->value_set.make_array()), + MakeScalar(options->skip_nulls), + }, + {"value_set", "skip_nulls"}); + } + + if (call.function == "cast") { + auto options = checked_cast(call.options.get()); + return Finish( + { + MakeNullScalar(options->to_type), + MakeScalar(options->allow_int_overflow), + MakeScalar(options->allow_time_truncate), + MakeScalar(options->allow_time_overflow), + MakeScalar(options->allow_decimal_truncate), + MakeScalar(options->allow_float_truncate), + MakeScalar(options->allow_invalid_utf8), + }, + { + "to_type_holder", + "allow_int_overflow", + "allow_time_truncate", + "allow_time_overflow", + "allow_decimal_truncate", + "allow_float_truncate", + "allow_invalid_utf8", + }); + } + + return Status::NotImplemented("conversion of options for ", call.function); +} + +inline Status FunctionOptionsFromStructScalar(const StructScalar* repr, + Expression2::Call* call) { + if (repr == nullptr) { + call->options = nullptr; + return Status::OK(); + } + + if (IsSetLookup(call->function)) { + ARROW_ASSIGN_OR_RAISE(auto value_set, repr->field("value_set")); + ARROW_ASSIGN_OR_RAISE(auto skip_nulls, repr->field("skip_nulls")); + call->options = std::make_shared( + checked_cast(*value_set).value, + checked_cast(*skip_nulls).value); + return Status::OK(); + } + + if (call->function == "cast") { + auto options = std::make_shared(); + ARROW_ASSIGN_OR_RAISE(auto to_type_holder, repr->field("to_type_holder")); + options->to_type = to_type_holder->type; + + int i = 1; + for (bool* opt : { + &options->allow_int_overflow, + &options->allow_time_truncate, + &options->allow_time_overflow, + &options->allow_decimal_truncate, + &options->allow_float_truncate, + &options->allow_invalid_utf8, + }) { + *opt = checked_cast(*repr->value[i++]).value; + } + + call->options = std::move(options); + return Status::OK(); + } + + return Status::NotImplemented("conversion of options for ", call->function); +} + +struct FlattenedAssociativeChain { + bool was_left_folded = true; + std::vector exprs, fringe; + + explicit FlattenedAssociativeChain(Expression2 expr) : exprs{std::move(expr)} { + auto call = CallNotNull(exprs.back()); + fringe = call->arguments; + + auto it = fringe.begin(); + + while (it != fringe.end()) { + auto sub_call = it->call(); + if (!sub_call || sub_call->function != call->function) { + ++it; + continue; + } + + if (it != fringe.begin()) { + was_left_folded = false; + } + + exprs.push_back(std::move(*it)); + it = fringe.erase(it); + it = fringe.insert(it, sub_call->arguments.begin(), sub_call->arguments.end()); + // NB: no increment so we hit sub_call's first argument next iteration + } + + DCHECK(std::all_of(exprs.begin(), exprs.end(), [](const Expression2& expr) { + return CallNotNull(expr)->options == nullptr; + })); + } +}; + } // namespace dataset } // namespace arrow diff --git a/cpp/src/arrow/dataset/expression_test.cc b/cpp/src/arrow/dataset/expression_test.cc index 7cc98a1e151..6274147d208 100644 --- a/cpp/src/arrow/dataset/expression_test.cc +++ b/cpp/src/arrow/dataset/expression_test.cc @@ -28,6 +28,7 @@ #include "arrow/compute/registry.h" #include "arrow/dataset/expression_internal.h" +#include "arrow/dataset/test_util.h" #include "arrow/testing/gtest_util.h" using testing::UnorderedElementsAreArray; @@ -39,11 +40,6 @@ using internal::checked_pointer_cast; namespace dataset { -const Schema kBoringSchema{{ - field("i32", int32()), - field("f32", float32()), -}}; - #define EXPECT_OK ARROW_EXPECT_OK TEST(Expression2, ToString) { @@ -91,6 +87,78 @@ TEST(Expression2, Hash) { EXPECT_EQ(set.size(), 6); } +TEST(Expression2, IsScalarExpression) { + EXPECT_TRUE(literal(true).IsScalarExpression()); + + auto arr = ArrayFromJSON(int8(), "[]"); + EXPECT_FALSE(literal(arr).IsScalarExpression()); + + EXPECT_TRUE(field_ref("a").IsScalarExpression()); + + EXPECT_TRUE(equal(field_ref("a"), literal(1)).IsScalarExpression()); + + EXPECT_FALSE(equal(field_ref("a"), literal(arr)).IsScalarExpression()); + + EXPECT_TRUE(call("is_in", {field_ref("a")}, compute::SetLookupOptions{arr, true}) + .IsScalarExpression()); + + // non scalar function + EXPECT_FALSE(call("take", {field_ref("a"), literal(arr)}).IsScalarExpression()); +} + +TEST(Expression2, IsSatisfiable) { + EXPECT_TRUE(literal(true).IsSatisfiable()); + EXPECT_FALSE(literal(false).IsSatisfiable()); + + auto null = std::make_shared(); + EXPECT_FALSE(literal(null).IsSatisfiable()); + + EXPECT_TRUE(field_ref("a").IsSatisfiable()); + + EXPECT_TRUE(equal(field_ref("a"), literal(1)).IsSatisfiable()); + + // NB: no constant folding here + EXPECT_TRUE(equal(literal(0), literal(1)).IsSatisfiable()); + + // When a top level conjunction contains an Expression2 which is certain to evaluate to + // null, it can only evaluate to null or false. + auto null_or_false = and_(literal(null), field_ref("a")); + // This may appear in satisfiable filters if coalesced + EXPECT_TRUE(call("is_null", {null_or_false}).IsSatisfiable()); + // ... but at the top level it is not satisfiable. + // This special case arises when (for example) an absent column has made + // one member of the conjunction always-null. This is fairly common and + // would be a worthwhile optimization to support. + // EXPECT_FALSE(null_or_false).IsSatisfiable()); +} + +TEST(Expression2, FieldsInExpression) { + auto ExpectFieldsAre = [](Expression2 expr, std::vector expected) { + EXPECT_THAT(FieldsInExpression(expr), testing::ContainerEq(expected)); + }; + + ExpectFieldsAre(literal(true), {}); + + ExpectFieldsAre(field_ref("a"), {"a"}); + + ExpectFieldsAre(equal(field_ref("a"), literal(1)), {"a"}); + + ExpectFieldsAre(equal(field_ref("a"), field_ref("b")), {"a", "b"}); + + ExpectFieldsAre( + or_(equal(field_ref("a"), literal(1)), equal(field_ref("a"), literal(2))), + {"a", "a"}); + + ExpectFieldsAre( + or_(equal(field_ref("a"), literal(1)), equal(field_ref("b"), literal(2))), + {"a", "b"}); + + ExpectFieldsAre(or_(and_(not_(equal(field_ref("a"), literal(1))), + equal(field_ref("b"), literal(2))), + not_(less(field_ref("c"), literal(3)))), + {"a", "b", "c"}); +} + TEST(Expression2, BindLiteral) { for (Datum dat : { Datum(3), @@ -161,6 +229,18 @@ TEST(Expression2, BindCall) { expr.Bind(Schema({field("a", int32()), field("b", int32())}))); } +TEST(Expression2, BindDictionaryTransparent) { + auto expr = call("equal", {field_ref("a"), field_ref("b")}); + EXPECT_FALSE(expr.IsBound()); + + ASSERT_OK_AND_ASSIGN( + std::tie(expr, std::ignore), + expr.Bind(Schema({field("a", utf8()), field("b", dictionary(int32(), utf8()))}))); + + EXPECT_EQ(expr.descr(), ValueDescr::Array(boolean())); + EXPECT_TRUE(expr.IsBound()); +} + TEST(Expression2, BindNestedCall) { auto expr = call("add", {field_ref("a"), @@ -208,35 +288,48 @@ TEST(Expression2, ExecuteFieldRef) { {"a": 0.0}, {"a": -1} ])"), - MakeNullScalar(float64())); + MakeNullScalar(null())); } -Result NaiveExecuteScalarExpression(const Expression2& expr, - ExpressionState* state, const Datum& input) { +Result NaiveExecuteScalarExpression(const Expression2& expr, const Datum& input) { auto call = expr.call(); if (call == nullptr) { // already tested execution of field_ref, execution of literal is trivial - return ExecuteScalarExpression(expr, state, input); + return ExecuteScalarExpression(expr, /*state=*/nullptr, input); } - auto call_state = checked_cast(state); - std::vector arguments(call->arguments.size()); for (size_t i = 0; i < arguments.size(); ++i) { - auto argument_state = call_state->argument_states[i].get(); - ARROW_ASSIGN_OR_RAISE(arguments[i], NaiveExecuteScalarExpression( - call->arguments[i], argument_state, input)); + ARROW_ASSIGN_OR_RAISE(arguments[i], + NaiveExecuteScalarExpression(call->arguments[i], input)); - EXPECT_EQ(call->arguments[i].descr(), arguments[i].descr()); + if (RequriesDictionaryTransparency(*call)) { + RETURN_NOT_OK(EnsureNotDictionary(&arguments[i])); + } } ARROW_ASSIGN_OR_RAISE(auto function, compute::GetFunctionRegistry()->GetFunction(call->function)); - ARROW_ASSIGN_OR_RAISE(auto expected_kernel, - function->DispatchExact(GetDescriptors(call->arguments))); + + auto descrs = GetDescriptors(call->arguments); + for (size_t i = 0; i < arguments.size(); ++i) { + if (RequriesDictionaryTransparency(*call)) { + RETURN_NOT_OK(EnsureNotDictionary(&descrs[i])); + } + EXPECT_EQ(arguments[i].descr(), descrs[i]); + } + + ARROW_ASSIGN_OR_RAISE(auto expected_kernel, function->DispatchExact(descrs)); EXPECT_EQ(call->kernel, expected_kernel); + auto options = call->options; + if (RequriesDictionaryTransparency(*call)) { + auto non_dict_call = *call; + RETURN_NOT_OK(EnsureNotDictionary(&non_dict_call)); + options = non_dict_call.options; + } + compute::ExecContext exec_context; return function->Execute(arguments, call->options.get(), &exec_context); } @@ -247,8 +340,7 @@ void AssertExecute(Expression2 expr, Datum in) { ASSERT_OK_AND_ASSIGN(Datum actual, ExecuteScalarExpression(expr, state.get(), in)); - ASSERT_OK_AND_ASSIGN(Datum expected, - NaiveExecuteScalarExpression(expr, state.get(), in)); + ASSERT_OK_AND_ASSIGN(Datum expected, NaiveExecuteScalarExpression(expr, in)); AssertDatumsEqual(actual, expected, /*verbose=*/true); } @@ -277,7 +369,7 @@ TEST(Expression2, ExecuteCall) { {"a": "12/11/1900"} ])")); - AssertExecute(project({"a + 3.5"}, {call("add", {field_ref("a"), literal(3.5)})}), + AssertExecute(project({call("add", {field_ref("a"), literal(3.5)})}, {"a + 3.5"}), ArrayFromJSON(struct_({field("a", float64())}), R"([ {"a": 6.125}, {"a": 0.0}, @@ -285,18 +377,51 @@ TEST(Expression2, ExecuteCall) { ])")); } +TEST(Expression2, ExecuteDictionaryTransparent) { + AssertExecute( + equal(field_ref("a"), field_ref("b")), + ArrayFromJSON( + struct_({field("a", dictionary(int32(), utf8())), field("b", utf8())}), R"([ + {"a": "hi", "b": "hi"}, + {"a": "", "b": ""}, + {"a": "hi", "b": "hello"} + ])")); + + Datum dict_set = ArrayFromJSON(dictionary(int32(), utf8()), R"(["a"])"); + AssertExecute(call("is_in", {field_ref("a")}, + compute::SetLookupOptions{dict_set, + /*skip_nulls=*/false}), + ArrayFromJSON(struct_({field("a", utf8())}), R"([ + {"a": "a"}, + {"a": "good"}, + {"a": null} + ])")); +} + struct { void operator()(Expression2 expr, Expression2 expected) { - std::shared_ptr state; - ASSERT_OK_AND_ASSIGN(std::tie(expr, state), expr.Bind(kBoringSchema)); - ASSERT_OK_AND_ASSIGN(std::tie(expected, std::ignore), expected.Bind(kBoringSchema)); - ASSERT_OK_AND_ASSIGN(auto actual, FoldConstants(expr, state.get())); - EXPECT_EQ(actual, expected); - if (actual == expr) { + this->operator()(expr, expected, + [](Expression2::BoundWithState, Expression2::BoundWithState, + Expression2::BoundWithState) {}); + } + + template + void operator()(Expression2 expr, Expression2 unbound_expected, + const ExtraExpectations& expect) { + ASSERT_OK_AND_ASSIGN(auto bound, expr.Bind(*kBoringSchema)); + ASSERT_OK_AND_ASSIGN(auto expected, unbound_expected.Bind(*kBoringSchema)); + ASSERT_OK_AND_ASSIGN(auto folded, FoldConstants(bound)); + + EXPECT_EQ(folded.first, expected.first); + + if (folded.first == expr) { // no change -> must be identical - EXPECT_TRUE(Identical(actual, expr)); + EXPECT_TRUE(Identical(folded.first, expr)); } + + expect(bound, folded, expected); } + } ExpectFoldsTo; TEST(Expression2, FoldConstants) { @@ -345,6 +470,21 @@ TEST(Expression2, FoldConstants) { }), literal(2), })); + + compute::SetLookupOptions in_123(ArrayFromJSON(int32(), "[1,2,3]"), true); + ExpectFoldsTo( + call("is_in", + {call("add", {field_ref("i32"), call("multiply", {literal(2), literal(3)})})}, + in_123), + call("is_in", {call("add", {field_ref("i32"), literal(6)})}, in_123), + [](Expression2::BoundWithState bound, Expression2::BoundWithState folded, + Expression2::BoundWithState) { + const compute::KernelState* state = bound.second->Get(bound.first); + const compute::KernelState* folded_state = folded.second->Get(folded.first); + EXPECT_EQ(folded_state, state) << "The kernel state associated with is_in (the " + "hash table for looking up membership) " + "must be associated with the folded is_in call"; + }); } TEST(Expression2, FoldConstantsBoolean) { @@ -356,113 +496,113 @@ TEST(Expression2, FoldConstantsBoolean) { auto true_ = literal(true); auto false_ = literal(false); - ExpectFoldsTo(call("and_kleene", {false_, whatever}), false_); - ExpectFoldsTo(call("and_kleene", {true_, whatever}), whatever); + ExpectFoldsTo(and_(false_, whatever), false_); + ExpectFoldsTo(and_(true_, whatever), whatever); - ExpectFoldsTo(call("or_kleene", {true_, whatever}), true_); - ExpectFoldsTo(call("or_kleene", {false_, whatever}), whatever); + ExpectFoldsTo(or_(true_, whatever), true_); + ExpectFoldsTo(or_(false_, whatever), whatever); } TEST(Expression2, ExtractKnownFieldValues) { struct { void operator()(Expression2 guarantee, std::unordered_map expected) { + ASSERT_OK_AND_ASSIGN(std::tie(guarantee, std::ignore), + guarantee.Bind(*kBoringSchema)); ASSERT_OK_AND_ASSIGN(auto actual, ExtractKnownFieldValues(guarantee)); - EXPECT_THAT(actual, UnorderedElementsAreArray(expected)); + EXPECT_THAT(actual, UnorderedElementsAreArray(expected)) + << " guarantee: " << guarantee.ToString(); } } ExpectKnown; - ExpectKnown(call("equal", {field_ref("a"), literal(3)}), {{"a", Datum(3)}}); + ExpectKnown(equal(field_ref("i32"), literal(3)), {{"i32", Datum(3)}}); - ExpectKnown(call("greater", {field_ref("a"), literal(3)}), {}); + ExpectKnown(greater(field_ref("i32"), literal(3)), {}); // FIXME known null should be expressed with is_null rather than equality auto null_int32 = std::make_shared(); - ExpectKnown(call("equal", {field_ref("a"), literal(null_int32)}), - {{"a", Datum(null_int32)}}); - - ExpectKnown(call("and_kleene", {call("equal", {field_ref("a"), literal(3)}), - call("equal", {literal(1), field_ref("b")})}), - {{"a", Datum(3)}, {"b", Datum(1)}}); + ExpectKnown(equal(field_ref("i32"), literal(null_int32)), {{"i32", Datum(null_int32)}}); - ExpectKnown(call("and_kleene", - {call("equal", {field_ref("a"), literal(3)}), - call("and_kleene", {call("equal", {field_ref("b"), literal(2)}), - call("equal", {literal(1), field_ref("c")})})}), - {{"a", Datum(3)}, {"b", Datum(2)}, {"c", Datum(1)}}); + ExpectKnown( + and_({equal(field_ref("i32"), literal(3)), equal(literal(1.5F), field_ref("f32"))}), + {{"i32", Datum(3)}, {"f32", Datum(1.5F)}}); - ExpectKnown(call("and_kleene", - {call("or_kleene", {call("equal", {field_ref("a"), literal(3)}), - call("equal", {field_ref("a"), literal(4)})}), - call("equal", {literal(1), field_ref("b")})}), - {{"b", Datum(1)}}); + ExpectKnown( + and_({equal(field_ref("i32"), literal(3)), equal(literal(2.F), field_ref("f32")), + equal(literal(1), field_ref("i32_req"))}), + {{"i32", Datum(3)}, {"f32", Datum(2.F)}, {"i32_req", Datum(1)}}); ExpectKnown( - call("and_kleene", - {call("equal", {field_ref("a"), literal(3)}), - call("and_kleene", {call("and_kleene", {field_ref("b"), field_ref("d")}), - call("equal", {literal(1), field_ref("c")})})}), - {{"a", Datum(3)}, {"c", Datum(1)}}); + and_(or_(equal(field_ref("i32"), literal(3)), equal(field_ref("i32"), literal(4))), + equal(literal(2.F), field_ref("f32"))), + {{"f32", Datum(2.F)}}); + + ExpectKnown(and_({equal(field_ref("i32"), literal(3)), + equal(field_ref("f32"), field_ref("f32_req")), + equal(literal(1), field_ref("i32_req"))}), + {{"i32", Datum(3)}, {"i32_req", Datum(1)}}); } TEST(Expression2, ReplaceFieldsWithKnownValues) { - auto ExpectSimplifiesTo = + auto ExpectReplacesTo = [](Expression2 expr, std::unordered_map known_values, - Expression2 expected) { - ASSERT_OK_AND_ASSIGN(auto actual, - ReplaceFieldsWithKnownValues(known_values, expr)); - EXPECT_EQ(actual, expected); + Expression2 unbound_expected) { + ASSERT_OK_AND_ASSIGN(auto bound, expr.Bind(*kBoringSchema)); + ASSERT_OK_AND_ASSIGN(auto expected, unbound_expected.Bind(*kBoringSchema)); + ASSERT_OK_AND_ASSIGN(auto replaced, + ReplaceFieldsWithKnownValues(known_values, bound)); + + EXPECT_EQ(replaced.first, expected.first); - if (actual == expr) { + if (replaced.first == expr) { // no change -> must be identical - EXPECT_TRUE(Identical(actual, expr)); + EXPECT_TRUE(Identical(replaced.first, expr)); } }; - std::unordered_map a_is_3{{"a", Datum(3)}}; - - ExpectSimplifiesTo(literal(1), a_is_3, literal(1)); - - ExpectSimplifiesTo(field_ref("a"), a_is_3, literal(3)); - - ExpectSimplifiesTo(field_ref("b"), a_is_3, field_ref("b")); - - ExpectSimplifiesTo(call("equal", {field_ref("a"), literal(1)}), a_is_3, - call("equal", {literal(3), literal(1)})); - - ExpectSimplifiesTo(call("add", - { - call("subtract", - { - field_ref("a"), - call("multiply", {literal(2), literal(3)}), - }), - literal(2), - }), - a_is_3, - call("add", { - call("subtract", - { - literal(3), - call("multiply", {literal(2), literal(3)}), - }), - literal(2), - })); + std::unordered_map i32_is_3{{"i32", Datum(3)}}; + + ExpectReplacesTo(literal(1), i32_is_3, literal(1)); + + ExpectReplacesTo(field_ref("i32"), i32_is_3, literal(3)); + + ExpectReplacesTo(field_ref("b"), i32_is_3, field_ref("b")); + + ExpectReplacesTo(equal(field_ref("i32"), literal(1)), i32_is_3, + equal(literal(3), literal(1))); + + ExpectReplacesTo(call("add", + { + call("subtract", + { + field_ref("i32"), + call("multiply", {literal(2), literal(3)}), + }), + literal(2), + }), + i32_is_3, + call("add", { + call("subtract", + { + literal(3), + call("multiply", {literal(2), literal(3)}), + }), + literal(2), + })); } struct { - void operator()(Expression2 expr, Expression2 expected) const { - ASSERT_OK_AND_ASSIGN(auto actual, Canonicalize(expr)); - EXPECT_EQ(actual, expected); + void operator()(Expression2 expr, Expression2 unbound_expected) const { + ASSERT_OK_AND_ASSIGN(auto bound, expr.Bind(*kBoringSchema)); + ASSERT_OK_AND_ASSIGN(auto expected, unbound_expected.Bind(*kBoringSchema)); + ASSERT_OK_AND_ASSIGN(auto actual, Canonicalize(bound)); - if (expr.IsBound()) { - EXPECT_TRUE(actual.IsBound()); - } + EXPECT_EQ(actual.first, expected.first); - if (actual == expr) { + if (actual.first == expr) { // no change -> must be identical - EXPECT_TRUE(Identical(actual, expr)); + EXPECT_TRUE(Identical(actual.first, expr)); } } } ExpectCanonicalizesTo; @@ -472,8 +612,8 @@ TEST(Expression2, CanonicalizeTrivial) { ExpectCanonicalizesTo(field_ref("b"), field_ref("b")); - ExpectCanonicalizesTo(call("equal", {field_ref("a"), field_ref("b")}), - call("equal", {field_ref("a"), field_ref("b")})); + ExpectCanonicalizesTo(equal(field_ref("i32"), field_ref("i32_req")), + equal(field_ref("i32"), field_ref("i32_req"))); } TEST(Expression2, CanonicalizeAnd) { @@ -481,95 +621,79 @@ TEST(Expression2, CanonicalizeAnd) { auto true_ = literal(true); auto null_ = literal(std::make_shared()); - auto a = field_ref("a"); + auto b = field_ref("bool"); auto c = call("equal", {literal(1), literal(2)}); - auto and_ = [](Expression2 l, Expression2 r) { return call("and_kleene", {l, r}); }; - // no change possible: - ExpectCanonicalizesTo(and_(a, c), and_(a, c)); + ExpectCanonicalizesTo(and_(b, c), and_(b, c)); // literals are placed innermost - ExpectCanonicalizesTo(and_(a, true_), and_(true_, a)); - ExpectCanonicalizesTo(and_(true_, a), and_(true_, a)); + ExpectCanonicalizesTo(and_(b, true_), and_(true_, b)); + ExpectCanonicalizesTo(and_(true_, b), and_(true_, b)); - ExpectCanonicalizesTo(and_(a, and_(true_, c)), and_(and_(true_, a), c)); - ExpectCanonicalizesTo(and_(a, and_(and_(true_, true_), c)), - and_(and_(and_(true_, true_), a), c)); - ExpectCanonicalizesTo(and_(a, and_(and_(true_, null_), c)), - and_(and_(and_(null_, true_), a), c)); - ExpectCanonicalizesTo(and_(a, and_(and_(true_, null_), and_(c, null_))), - and_(and_(and_(and_(null_, null_), true_), a), c)); + ExpectCanonicalizesTo(and_(b, and_(true_, c)), and_(and_(true_, b), c)); + ExpectCanonicalizesTo(and_(b, and_(and_(true_, true_), c)), + and_(and_(and_(true_, true_), b), c)); + ExpectCanonicalizesTo(and_(b, and_(and_(true_, null_), c)), + and_(and_(and_(null_, true_), b), c)); + ExpectCanonicalizesTo(and_(b, and_(and_(true_, null_), and_(c, null_))), + and_(and_(and_(and_(null_, null_), true_), b), c)); // catches and_kleene even when it's a subexpression - ExpectCanonicalizesTo(call("is_valid", {and_(a, true_)}), - call("is_valid", {and_(true_, a)})); + ExpectCanonicalizesTo(call("is_valid", {and_(b, true_)}), + call("is_valid", {and_(true_, b)})); } TEST(Expression2, CanonicalizeComparison) { - // some aliases for brevity: - auto equal = [](Expression2 l, Expression2 r) { return call("equal", {l, r}); }; - auto less = [](Expression2 l, Expression2 r) { return call("less", {l, r}); }; - auto greater = [](Expression2 l, Expression2 r) { return call("greater", {l, r}); }; + ExpectCanonicalizesTo(equal(literal(1), field_ref("i32")), + equal(field_ref("i32"), literal(1))); - ExpectCanonicalizesTo(equal(literal(1), field_ref("a")), - equal(field_ref("a"), literal(1))); + ExpectCanonicalizesTo(equal(field_ref("i32"), literal(1)), + equal(field_ref("i32"), literal(1))); - ExpectCanonicalizesTo(equal(field_ref("a"), literal(1)), - equal(field_ref("a"), literal(1))); + ExpectCanonicalizesTo(less(literal(1), field_ref("i32")), + greater(field_ref("i32"), literal(1))); - ExpectCanonicalizesTo(less(literal(1), field_ref("a")), - greater(field_ref("a"), literal(1))); - - ExpectCanonicalizesTo(less(field_ref("a"), literal(1)), - less(field_ref("a"), literal(1))); + ExpectCanonicalizesTo(less(field_ref("i32"), literal(1)), + less(field_ref("i32"), literal(1))); } struct Simplify { - Expression2 filter; + Expression2 expr; struct Expectable { - Expression2 filter, guarantee; + Expression2 expr, guarantee; + + void Expect(Expression2 unbound_expected) { + ASSERT_OK_AND_ASSIGN(auto bound, expr.Bind(*kBoringSchema)); + ASSERT_OK_AND_ASSIGN(std::tie(guarantee, std::ignore), + guarantee.Bind(*kBoringSchema)); - void Expect(Expression2 expected) { - std::shared_ptr state; - ASSERT_OK_AND_ASSIGN(std::tie(filter, state), filter.Bind(kBoringSchema)); + ASSERT_OK_AND_ASSIGN(auto simplified, SimplifyWithGuarantee(bound, guarantee)); - ASSERT_OK_AND_ASSIGN(std::tie(expected, std::ignore), expected.Bind(kBoringSchema)); - ASSERT_OK_AND_ASSIGN(auto simplified, - SimplifyWithGuarantee(filter, state.get(), guarantee)); - EXPECT_EQ(simplified, expected) << " original: " << filter.ToString() << "\n" - << " guarantee: " << guarantee.ToString() << "\n" - << (simplified == filter ? " (no change)\n" : ""); + ASSERT_OK_AND_ASSIGN(auto expected, unbound_expected.Bind(*kBoringSchema)); + EXPECT_EQ(simplified.first, expected.first) + << " original: " << expr.ToString() << "\n" + << " guarantee: " << guarantee.ToString() << "\n" + << (simplified.first == bound.first ? " (no change)\n" : ""); - if (simplified == filter) { - EXPECT_TRUE(Identical(simplified, filter)); + if (simplified.first == bound.first) { + EXPECT_TRUE(Identical(simplified.first, bound.first)); } } - void ExpectUnchanged() { Expect(filter); } + void ExpectUnchanged() { Expect(expr); } void Expect(bool constant) { Expect(literal(constant)); } }; - Expectable WithGuarantee(Expression2 guarantee) { return {filter, guarantee}; } + Expectable WithGuarantee(Expression2 guarantee) { return {expr, guarantee}; } }; TEST(Expression2, SingleComparisonGuarantees) { - // some aliases for brevity: - auto equal = [](Expression2 l, Expression2 r) { return call("equal", {l, r}); }; - auto less = [](Expression2 l, Expression2 r) { return call("less", {l, r}); }; - auto greater = [](Expression2 l, Expression2 r) { return call("greater", {l, r}); }; - auto not_equal = [](Expression2 l, Expression2 r) { return call("not_equal", {l, r}); }; - auto less_equal = [](Expression2 l, Expression2 r) { - return call("less_equal", {l, r}); - }; - auto greater_equal = [](Expression2 l, Expression2 r) { - return call("greater_equal", {l, r}); - }; auto i32 = field_ref("i32"); // i32 is guaranteed equal to 3, so the projection can just materialize that constant // and need not incur IO - Simplify{project({"i32 + 1"}, {call("add", {i32, literal(1)})})} + Simplify{project({call("add", {i32, literal(1)})}, {"i32 + 1"})} .WithGuarantee(equal(i32, literal(3))) .Expect(literal( std::make_shared(ScalarVector{std::make_shared(4)}, @@ -658,7 +782,7 @@ TEST(Expression2, SingleComparisonGuarantees) { {"i32"})); std::shared_ptr state; - ASSERT_OK_AND_ASSIGN(std::tie(filter, state), filter.Bind(kBoringSchema)); + ASSERT_OK_AND_ASSIGN(std::tie(filter, state), filter.Bind(*kBoringSchema)); ASSERT_OK_AND_ASSIGN(Datum evaluated, ExecuteScalarExpression(filter, state.get(), input)); @@ -681,48 +805,135 @@ TEST(Expression2, SingleComparisonGuarantees) { TEST(Expression2, SimplifyWithGuarantee) { // drop both members of a conjunctive filter - Simplify{call("and_kleene", - { - call("equal", {field_ref("i32"), literal(2)}), - call("equal", {field_ref("f32"), literal(3.5F)}), - })} - .WithGuarantee(call("and_kleene", - { - call("greater_equal", {field_ref("i32"), literal(0)}), - call("less_equal", {field_ref("i32"), literal(1)}), - })) + Simplify{ + and_(equal(field_ref("i32"), literal(2)), equal(field_ref("f32"), literal(3.5F)))} + .WithGuarantee(and_(greater_equal(field_ref("i32"), literal(0)), + less_equal(field_ref("i32"), literal(1)))) .Expect(false); // drop one member of a conjunctive filter - Simplify{call("and_kleene", - { - call("equal", {field_ref("i32"), literal(0)}), - call("equal", {field_ref("f32"), literal(3.5F)}), - })} - .WithGuarantee(call("equal", {field_ref("i32"), literal(0)})) - .Expect(call("equal", {field_ref("f32"), literal(3.5F)})); + Simplify{ + and_(equal(field_ref("i32"), literal(0)), equal(field_ref("f32"), literal(3.5F)))} + .WithGuarantee(equal(field_ref("i32"), literal(0))) + .Expect(equal(field_ref("f32"), literal(3.5F))); // drop both members of a disjunctive filter - Simplify{call("or_kleene", - { - call("equal", {field_ref("i32"), literal(0)}), - call("equal", {field_ref("f32"), literal(3.5F)}), - })} - .WithGuarantee(call("equal", {field_ref("i32"), literal(0)})) + Simplify{ + or_(equal(field_ref("i32"), literal(0)), equal(field_ref("f32"), literal(3.5F)))} + .WithGuarantee(equal(field_ref("i32"), literal(0))) .Expect(true); // drop one member of a disjunctive filter - Simplify{call("or_kleene", - { - call("equal", {field_ref("i32"), literal(0)}), - call("equal", {field_ref("i32"), literal(3)}), - })} - .WithGuarantee(call("and_kleene", - { - call("greater_equal", {field_ref("i32"), literal(0)}), - call("less_equal", {field_ref("i32"), literal(1)}), - })) - .Expect(call("equal", {field_ref("i32"), literal(0)})); + Simplify{or_(equal(field_ref("i32"), literal(0)), equal(field_ref("i32"), literal(3)))} + .WithGuarantee(and_(greater_equal(field_ref("i32"), literal(0)), + less_equal(field_ref("i32"), literal(1)))) + .Expect(equal(field_ref("i32"), literal(0))); +} + +TEST(Expression2, SimplifyThenExecute) { + auto filter = + and_({equal(field_ref("f32"), literal(0.F)), + call("is_in", {field_ref("i32")}, + compute::SetLookupOptions{ArrayFromJSON(int32(), "[1,2,3]"), true})}); + + ASSERT_OK_AND_ASSIGN(auto bound, filter.Bind(*kBoringSchema)); + ASSERT_OK_AND_ASSIGN(auto guarantee, + equal(field_ref("f32"), literal(0.F)).Bind(*kBoringSchema)); + + ASSERT_OK_AND_ASSIGN(bound, SimplifyWithGuarantee(bound, guarantee.first)); + + auto input = RecordBatchFromJSON(kBoringSchema, R"([ + {"i32": 0, "f32": -0.1}, + {"i32": 0, "f32": 0.3}, + {"i32": 1, "f32": 0.2}, + {"i32": 2, "f32": -0.1}, + {"i32": 0, "f32": 0.1}, + {"i32": 0, "f32": null}, + {"i32": 0, "f32": 1.0} + ])"); + ASSERT_OK_AND_ASSIGN(Datum evaluated, + ExecuteScalarExpression(bound.first, bound.second.get(), input)); +} + +TEST(Expression2, Filter) { + auto ExpectFilter = [](Expression2 filter, std::string batch_json) { + ASSERT_OK_AND_ASSIGN(auto s, kBoringSchema->AddField(0, field("in", boolean()))); + auto batch = RecordBatchFromJSON(s, batch_json); + auto expected_mask = batch->column(0); + + std::shared_ptr state; + ASSERT_OK_AND_ASSIGN(std::tie(filter, state), filter.Bind(*kBoringSchema)); + ASSERT_OK_AND_ASSIGN(Datum mask, ExecuteScalarExpression(filter, state.get(), batch)); + + AssertDatumsEqual(expected_mask, mask); + }; + + ExpectFilter(equal(field_ref("i32"), literal(0)), R"([ + {"i32": 0, "f32": -0.1, "in": 1}, + {"i32": 0, "f32": 0.3, "in": 1}, + {"i32": 1, "f32": 0.2, "in": 0}, + {"i32": 2, "f32": -0.1, "in": 0}, + {"i32": 0, "f32": 0.1, "in": 1}, + {"i32": 0, "f32": null, "in": 1}, + {"i32": 0, "f32": 1.0, "in": 1} + ])"); +} + +TEST(Expression2, SerializationRoundTrips) { + auto ExpectRoundTrips = [](const Expression2& expr) { + ASSERT_OK_AND_ASSIGN(auto serialized, Serialize(expr)); + ASSERT_OK_AND_ASSIGN(Expression2 roundtripped, Deserialize(*serialized)); + EXPECT_EQ(expr, roundtripped); + }; + + ExpectRoundTrips(literal(MakeNullScalar(null()))); + + ExpectRoundTrips(literal(MakeNullScalar(int32()))); + + ExpectRoundTrips( + literal(MakeNullScalar(struct_({field("i", int32()), field("s", utf8())})))); + + ExpectRoundTrips(literal(true)); + + ExpectRoundTrips(literal(false)); + + ExpectRoundTrips(literal(1)); + + ExpectRoundTrips(literal(1.125)); + + ExpectRoundTrips(literal("stringy strings")); + + ExpectRoundTrips(field_ref("field")); + + ExpectRoundTrips(greater(field_ref("a"), literal(0.25))); + + ExpectRoundTrips( + or_({equal(field_ref("a"), literal(1)), not_equal(field_ref("b"), literal("hello")), + equal(field_ref("b"), literal("foo bar"))})); + + ExpectRoundTrips(not_(field_ref("alpha"))); + + compute::CastOptions to_float64; + to_float64.to_type = float64(); + ExpectRoundTrips( + call("is_in", {call("cast", {field_ref("version")}, to_float64)}, + compute::SetLookupOptions{ArrayFromJSON(float64(), "[0.5, 1.0, 2.0]"), true})); + + ExpectRoundTrips(call("is_valid", {field_ref("validity")})); + + ExpectRoundTrips(and_({and_(greater_equal(field_ref("x"), literal(-1.5)), + less(field_ref("x"), literal(0.0))), + and_(greater_equal(field_ref("y"), literal(0.0)), + less(field_ref("y"), literal(1.5))), + and_(greater(field_ref("z"), literal(1.5)), + less_equal(field_ref("z"), literal(3.0)))})); + + ExpectRoundTrips(and_({equal(field_ref("year"), literal(int16_t(1999))), + equal(field_ref("month"), literal(int8_t(12))), + equal(field_ref("day"), literal(int8_t(31))), + equal(field_ref("hour"), literal(int8_t(0))), + equal(field_ref("alpha"), literal(int32_t(0))), + equal(field_ref("beta"), literal(3.25f))})); } } // namespace dataset diff --git a/cpp/src/arrow/dataset/file_base.cc b/cpp/src/arrow/dataset/file_base.cc index ef044ea3a03..f8350a0292c 100644 --- a/cpp/src/arrow/dataset/file_base.cc +++ b/cpp/src/arrow/dataset/file_base.cc @@ -57,16 +57,16 @@ Result> FileSource::Open() const { Result> FileFormat::MakeFragment( FileSource source, std::shared_ptr physical_schema) { - return MakeFragment(std::move(source), scalar(true), std::move(physical_schema)); + return MakeFragment(std::move(source), literal(true), std::move(physical_schema)); } Result> FileFormat::MakeFragment( - FileSource source, std::shared_ptr partition_expression) { + FileSource source, Expression2 partition_expression) { return MakeFragment(std::move(source), std::move(partition_expression), nullptr); } Result> FileFormat::MakeFragment( - FileSource source, std::shared_ptr partition_expression, + FileSource source, Expression2 partition_expression, std::shared_ptr physical_schema) { return std::shared_ptr( new FileFragment(std::move(source), shared_from_this(), @@ -83,7 +83,7 @@ Result FileFragment::Scan(std::shared_ptr options } FileSystemDataset::FileSystemDataset(std::shared_ptr schema, - std::shared_ptr root_partition, + Expression2 root_partition, std::shared_ptr format, std::shared_ptr filesystem, std::vector> fragments) @@ -93,7 +93,7 @@ FileSystemDataset::FileSystemDataset(std::shared_ptr schema, fragments_(std::move(fragments)) {} Result> FileSystemDataset::Make( - std::shared_ptr schema, std::shared_ptr root_partition, + std::shared_ptr schema, Expression2 root_partition, std::shared_ptr format, std::shared_ptr filesystem, std::vector> fragments) { return std::shared_ptr(new FileSystemDataset( @@ -129,20 +129,23 @@ std::string FileSystemDataset::ToString() const { repr += "\n" + fragment->source().path(); const auto& partition = fragment->partition_expression(); - if (!partition->Equals(true)) { - repr += ": " + partition->ToString(); + if (partition != literal(true)) { + repr += ": " + partition.ToString(); } } return repr; } -FragmentIterator FileSystemDataset::GetFragmentsImpl( - std::shared_ptr predicate) { +Result FileSystemDataset::GetFragmentsImpl( + Expression2::BoundWithState predicate) { FragmentVector fragments; for (const auto& fragment : fragments_) { - if (predicate->IsSatisfiableWith(fragment->partition_expression())) { + ARROW_ASSIGN_OR_RAISE( + auto simplified, + SimplifyWithGuarantee(predicate, fragment->partition_expression())); + if (simplified.first.IsSatisfiable()) { fragments.push_back(fragment); } } @@ -273,7 +276,8 @@ Status FileSystemDataset::Write(const FileSystemDatasetWriteOptions& write_optio // // NB: neither of these will have any impact whatsoever on the common case of writing // an in-memory table to disk. - ARROW_ASSIGN_OR_RAISE(FragmentVector fragments, scanner->GetFragments().ToVector()); + ARROW_ASSIGN_OR_RAISE(auto fragment_it, scanner->GetFragments()); + ARROW_ASSIGN_OR_RAISE(FragmentVector fragments, fragment_it.ToVector()); ScanTaskVector scan_tasks; std::vector fragment_for_task; @@ -313,12 +317,14 @@ Status FileSystemDataset::Write(const FileSystemDatasetWriteOptions& write_optio std::unordered_set need_flushed; for (size_t i = 0; i < groups.batches.size(); ++i) { - AndExpression partition_expression(std::move(groups.expressions[i]), - fragment->partition_expression()); + ARROW_ASSIGN_OR_RAISE( + auto partition_expression, + and_(std::move(groups.expressions[i]), fragment->partition_expression()) + .Bind(*scanner->schema())); auto batch = std::move(groups.batches[i]); - ARROW_ASSIGN_OR_RAISE(auto part, - write_options.partitioning->Format(partition_expression)); + ARROW_ASSIGN_OR_RAISE( + auto part, write_options.partitioning->Format(partition_expression.first)); WriteQueue* queue; { diff --git a/cpp/src/arrow/dataset/file_base.h b/cpp/src/arrow/dataset/file_base.h index 192921e7cf0..1d0059ec089 100644 --- a/cpp/src/arrow/dataset/file_base.h +++ b/cpp/src/arrow/dataset/file_base.h @@ -143,11 +143,11 @@ class ARROW_DS_EXPORT FileFormat : public std::enable_shared_from_this> MakeFragment( - FileSource source, std::shared_ptr partition_expression, + FileSource source, Expression2 partition_expression, std::shared_ptr physical_schema); - Result> MakeFragment( - FileSource source, std::shared_ptr partition_expression); + Result> MakeFragment(FileSource source, + Expression2 partition_expression); Result> MakeFragment( FileSource source, std::shared_ptr physical_schema = NULLPTR); @@ -173,8 +173,7 @@ class ARROW_DS_EXPORT FileFragment : public Fragment { protected: FileFragment(FileSource source, std::shared_ptr format, - std::shared_ptr partition_expression, - std::shared_ptr physical_schema) + Expression2 partition_expression, std::shared_ptr physical_schema) : Fragment(std::move(partition_expression), std::move(physical_schema)), source_(std::move(source)), format_(std::move(format)) {} @@ -207,7 +206,7 @@ class ARROW_DS_EXPORT FileSystemDataset : public Dataset { /// /// \return A constructed dataset. static Result> Make( - std::shared_ptr schema, std::shared_ptr root_partition, + std::shared_ptr schema, Expression2 root_partition, std::shared_ptr format, std::shared_ptr filesystem, std::vector> fragments); @@ -234,10 +233,10 @@ class ARROW_DS_EXPORT FileSystemDataset : public Dataset { std::string ToString() const; protected: - FragmentIterator GetFragmentsImpl(std::shared_ptr predicate) override; + Result GetFragmentsImpl( + Expression2::BoundWithState predicate) override; - FileSystemDataset(std::shared_ptr schema, - std::shared_ptr root_partition, + FileSystemDataset(std::shared_ptr schema, Expression2 root_partition, std::shared_ptr format, std::shared_ptr filesystem, std::vector> fragments); diff --git a/cpp/src/arrow/dataset/file_csv.cc b/cpp/src/arrow/dataset/file_csv.cc index f889b12fddd..6f1ca9acbbb 100644 --- a/cpp/src/arrow/dataset/file_csv.cc +++ b/cpp/src/arrow/dataset/file_csv.cc @@ -34,6 +34,7 @@ #include "arrow/result.h" #include "arrow/type.h" #include "arrow/util/iterator.h" +#include "arrow/util/logging.h" namespace arrow { namespace dataset { @@ -90,13 +91,16 @@ static inline Result GetConvertOptions( } // FIXME(bkietz) also acquire types of fields materialized but not projected. - for (auto&& name : FieldsInExpression(scan_options->filter)) { - ARROW_ASSIGN_OR_RAISE(auto match, - FieldRef(name).FindOneOrNone(*scan_options->schema())); - if (match.indices().empty()) { - convert_options.include_columns.push_back(std::move(name)); + // This requires that scan_options include the full dataset schema (not just + // the projected schema). + for (const FieldRef& ref : FieldsInExpression(scan_options->filter2)) { + DCHECK(ref.name()); + ARROW_ASSIGN_OR_RAISE(auto match, ref.FindOneOrNone(*scan_options->schema())); + if (!match) { + convert_options.include_columns.push_back(*ref.name()); } } + return convert_options; } diff --git a/cpp/src/arrow/dataset/file_ipc_test.cc b/cpp/src/arrow/dataset/file_ipc_test.cc index 25bbe559cf4..0fb76803778 100644 --- a/cpp/src/arrow/dataset/file_ipc_test.cc +++ b/cpp/src/arrow/dataset/file_ipc_test.cc @@ -244,7 +244,7 @@ TEST_F(TestIpcFileFormat, ScanRecordBatchReaderProjected) { opts_ = ScanOptions::Make(schema_); opts_->projector = RecordBatchProjector(SchemaFromColumnNames(schema_, {"f64"})); - opts_->filter = equal(field_ref("i32"), scalar(0)); + opts_->filter2 = equal(field_ref("i32"), literal(0)); // NB: projector is applied by the scanner; FileFragment does not evaluate it so // we will not drop "i32" even though it is not in the projector's schema @@ -280,7 +280,7 @@ TEST_F(TestIpcFileFormat, ScanRecordBatchReaderProjectedMissingCols) { schema_ = reader->schema(); opts_ = ScanOptions::Make(schema_); opts_->projector = RecordBatchProjector(SchemaFromColumnNames(schema_, {"f64"})); - opts_->filter = equal(field_ref("i32"), scalar(0)); + opts_->filter2 = equal(field_ref("i32"), literal(0)); auto readers = {reader.get(), reader_without_i32.get(), reader_without_f64.get()}; for (auto reader : readers) { diff --git a/cpp/src/arrow/dataset/file_parquet.cc b/cpp/src/arrow/dataset/file_parquet.cc index e06cee649ef..f1e9ad48bb0 100644 --- a/cpp/src/arrow/dataset/file_parquet.cc +++ b/cpp/src/arrow/dataset/file_parquet.cc @@ -125,7 +125,7 @@ static Result> GetSchemaManifest( return manifest; } -static std::shared_ptr ColumnChunkStatisticsAsExpression( +static util::optional ColumnChunkStatisticsAsExpression( const SchemaField& schema_field, const parquet::RowGroupMetaData& metadata) { // For the remaining of this function, failure to extract/parse statistics // are ignored by returning nullptr. The goal is two fold. First @@ -134,13 +134,13 @@ static std::shared_ptr ColumnChunkStatisticsAsExpression( // For now, only leaf (primitive) types are supported. if (!schema_field.is_leaf()) { - return nullptr; + return util::nullopt; } auto column_metadata = metadata.ColumnChunk(schema_field.column_index); auto statistics = column_metadata->statistics(); if (statistics == nullptr) { - return nullptr; + return util::nullopt; } const auto& field = schema_field.field; @@ -148,12 +148,12 @@ static std::shared_ptr ColumnChunkStatisticsAsExpression( // Optimize for corner case where all values are nulls if (statistics->num_values() == statistics->null_count()) { - return equal(std::move(field_expr), scalar(MakeNullScalar(field->type()))); + return equal(std::move(field_expr), literal(MakeNullScalar(field->type()))); } std::shared_ptr min, max; if (!StatisticsAsScalars(*statistics, &min, &max).ok()) { - return nullptr; + return util::nullopt; } auto maybe_min = min->CastTo(field->type()); @@ -161,11 +161,11 @@ static std::shared_ptr ColumnChunkStatisticsAsExpression( if (maybe_min.ok() && maybe_max.ok()) { min = maybe_min.MoveValueUnsafe(); max = maybe_max.MoveValueUnsafe(); - return and_(greater_equal(field_expr, scalar(min)), - less_equal(field_expr, scalar(max))); + return and_(greater_equal(field_expr, literal(min)), + less_equal(field_expr, literal(max))); } - return nullptr; + return util::nullopt; } static void AddColumnIndices(const SchemaField& schema_field, @@ -291,18 +291,18 @@ Result ParquetFileFormat::ScanFile(std::shared_ptr row_groups; bool pre_filtered = false; - auto empty = [] { return MakeEmptyIterator>(); }; + auto MakeEmpty = [] { return MakeEmptyIterator>(); }; // If RowGroup metadata is cached completely we can pre-filter RowGroups before opening // a FileReader, potentially avoiding IO altogether if all RowGroups are excluded due to // prior statistics knowledge. In the case where a RowGroup doesn't have statistics // metdata, it will not be excluded. if (parquet_fragment->metadata() != nullptr) { - ARROW_ASSIGN_OR_RAISE(row_groups, - parquet_fragment->FilterRowGroups(*options->filter)); + ARROW_ASSIGN_OR_RAISE(row_groups, parquet_fragment->FilterRowGroups( + {options->filter2, context->expression_state})); pre_filtered = true; - if (row_groups.empty()) empty(); + if (row_groups.empty()) MakeEmpty(); } // Open the reader and pay the real IO cost. @@ -314,10 +314,10 @@ Result ParquetFileFormat::ScanFile(std::shared_ptrFilterRowGroups(*options->filter)); + ARROW_ASSIGN_OR_RAISE(row_groups, parquet_fragment->FilterRowGroups( + {options->filter2, context->expression_state})); - if (row_groups.empty()) empty(); + if (row_groups.empty()) MakeEmpty(); } auto column_projection = InferColumnProjection(*reader, *options); @@ -332,7 +332,7 @@ Result ParquetFileFormat::ScanFile(std::shared_ptr> ParquetFileFormat::MakeFragment( - FileSource source, std::shared_ptr partition_expression, + FileSource source, Expression2 partition_expression, std::shared_ptr physical_schema, std::vector row_groups) { return std::shared_ptr(new ParquetFileFragment( std::move(source), shared_from_this(), std::move(partition_expression), @@ -340,7 +340,7 @@ Result> ParquetFileFormat::MakeFragment( } Result> ParquetFileFormat::MakeFragment( - FileSource source, std::shared_ptr partition_expression, + FileSource source, Expression2 partition_expression, std::shared_ptr physical_schema) { return std::shared_ptr(new ParquetFileFragment( std::move(source), shared_from_this(), std::move(partition_expression), @@ -395,7 +395,7 @@ Status ParquetFileWriter::Finish() { return parquet_writer_->Close(); } ParquetFileFragment::ParquetFileFragment(FileSource source, std::shared_ptr format, - std::shared_ptr partition_expression, + Expression2 partition_expression, std::shared_ptr physical_schema, util::optional> row_groups) : FileFragment(std::move(source), std::move(format), std::move(partition_expression), @@ -442,7 +442,7 @@ Status ParquetFileFragment::SetMetadata( metadata_ = std::move(metadata); manifest_ = std::move(manifest); - statistics_expressions_.resize(row_groups_->size(), scalar(true)); + statistics_expressions_.resize(row_groups_->size(), literal(true)); statistics_expressions_complete_.resize(physical_schema_->num_fields(), false); for (int row_group : *row_groups_) { @@ -458,9 +458,9 @@ Status ParquetFileFragment::SetMetadata( } Result ParquetFileFragment::SplitByRowGroup( - const std::shared_ptr& predicate) { + Expression2::BoundWithState predicate) { RETURN_NOT_OK(EnsureCompleteMetadata()); - ARROW_ASSIGN_OR_RAISE(auto row_groups, FilterRowGroups(*predicate)); + ARROW_ASSIGN_OR_RAISE(auto row_groups, FilterRowGroups(predicate)); FragmentVector fragments(row_groups.size()); int i = 0; @@ -477,9 +477,9 @@ Result ParquetFileFragment::SplitByRowGroup( } Result> ParquetFileFragment::Subset( - const std::shared_ptr& predicate) { + Expression2::BoundWithState predicate) { RETURN_NOT_OK(EnsureCompleteMetadata()); - ARROW_ASSIGN_OR_RAISE(auto row_groups, FilterRowGroups(*predicate)); + ARROW_ASSIGN_OR_RAISE(auto row_groups, FilterRowGroups(predicate)); return Subset(std::move(row_groups)); } @@ -494,8 +494,8 @@ Result> ParquetFileFragment::Subset( return new_fragment; } -inline void FoldingAnd(std::shared_ptr* l, std::shared_ptr r) { - if ((*l)->Equals(true)) { +inline void FoldingAnd(Expression2* l, Expression2 r) { + if (*l == literal(true)) { *l = std::move(r); } else { *l = and_(std::move(*l), std::move(r)); @@ -503,13 +503,18 @@ inline void FoldingAnd(std::shared_ptr* l, std::shared_ptr> ParquetFileFragment::FilterRowGroups( - const Expression& predicate) { + Expression2::BoundWithState predicate) { auto lock = physical_schema_mutex_.Lock(); DCHECK_NE(metadata_, nullptr); - RETURN_NOT_OK(predicate.Validate(*physical_schema_)); + ARROW_ASSIGN_OR_RAISE( + predicate, SimplifyWithGuarantee(std::move(predicate), partition_expression_)); + + if (!predicate.first.IsSatisfiable()) { + return std::vector{}; + } - for (FieldRef ref : FieldsInExpression(predicate)) { + for (const FieldRef& ref : FieldsInExpression(predicate.first)) { ARROW_ASSIGN_OR_RAISE(auto path, ref.FindOneOrNone(*physical_schema_)); if (!path) continue; @@ -523,21 +528,20 @@ Result> ParquetFileFragment::FilterRowGroups( if (auto minmax = ColumnChunkStatisticsAsExpression(schema_field, *row_group_metadata)) { - FoldingAnd(&statistics_expressions_[i], std::move(minmax)); + FoldingAnd(&statistics_expressions_[i], std::move(*minmax)); + ARROW_ASSIGN_OR_RAISE(std::tie(statistics_expressions_[i], std::ignore), + statistics_expressions_[i].Bind(*physical_schema_)); } ++i; } } - auto simplified_predicate = predicate.Assume(partition_expression_); - if (!simplified_predicate->IsSatisfiable()) { - return std::vector{}; - } - std::vector row_groups; for (size_t i = 0; i < row_groups_->size(); ++i) { - if (simplified_predicate->IsSatisfiableWith(statistics_expressions_[i])) { + ARROW_ASSIGN_OR_RAISE(auto row_group_predicate, + SimplifyWithGuarantee(predicate, statistics_expressions_[i])); + if (row_group_predicate.first.IsSatisfiable()) { row_groups.push_back(row_groups_->at(i)); } } @@ -661,7 +665,7 @@ ParquetDatasetFactory::CollectParquetFragments(const Partitioning& partitioning) auto partition_expression = partitioning.Parse(StripPrefixAndFilename(path, options_.partition_base_dir)) - .ValueOr(scalar(true)); + .ValueOr(literal(true)); ARROW_ASSIGN_OR_RAISE( auto fragment, @@ -712,7 +716,7 @@ Result> ParquetDatasetFactory::Finish(FinishOptions opt } ARROW_ASSIGN_OR_RAISE(auto fragments, CollectParquetFragments(*partitioning)); - return FileSystemDataset::Make(std::move(schema), scalar(true), format_, filesystem_, + return FileSystemDataset::Make(std::move(schema), literal(true), format_, filesystem_, std::move(fragments)); } diff --git a/cpp/src/arrow/dataset/file_parquet.h b/cpp/src/arrow/dataset/file_parquet.h index b82241a574b..9de19186ad3 100644 --- a/cpp/src/arrow/dataset/file_parquet.h +++ b/cpp/src/arrow/dataset/file_parquet.h @@ -117,12 +117,12 @@ class ARROW_DS_EXPORT ParquetFileFormat : public FileFormat { /// \brief Create a Fragment targeting all RowGroups. Result> MakeFragment( - FileSource source, std::shared_ptr partition_expression, + FileSource source, Expression2 partition_expression, std::shared_ptr physical_schema) override; /// \brief Create a Fragment, restricted to the specified row groups. Result> MakeFragment( - FileSource source, std::shared_ptr partition_expression, + FileSource source, Expression2 partition_expression, std::shared_ptr physical_schema, std::vector row_groups); /// \brief Return a FileReader on the given source. @@ -150,7 +150,7 @@ class ARROW_DS_EXPORT ParquetFileFormat : public FileFormat { /// significant performance boost when scanning high latency file systems. class ARROW_DS_EXPORT ParquetFileFragment : public FileFragment { public: - Result SplitByRowGroup(const std::shared_ptr& predicate); + Result SplitByRowGroup(Expression2::BoundWithState predicate); /// \brief Return the RowGroups selected by this fragment. const std::vector& row_groups() const { @@ -166,12 +166,12 @@ class ARROW_DS_EXPORT ParquetFileFragment : public FileFragment { Status EnsureCompleteMetadata(parquet::arrow::FileReader* reader = NULLPTR); /// \brief Return fragment which selects a filtered subset of this fragment's RowGroups. - Result> Subset(const std::shared_ptr& predicate); + Result> Subset(Expression2::BoundWithState predicate); Result> Subset(std::vector row_group_ids); private: ParquetFileFragment(FileSource source, std::shared_ptr format, - std::shared_ptr partition_expression, + Expression2 partition_expression, std::shared_ptr physical_schema, util::optional> row_groups); @@ -185,7 +185,7 @@ class ARROW_DS_EXPORT ParquetFileFragment : public FileFragment { } // Return a filtered subset of row group indices. - Result> FilterRowGroups(const Expression& predicate); + Result> FilterRowGroups(Expression2::BoundWithState predicate); ParquetFileFormat& parquet_format_; @@ -193,7 +193,7 @@ class ARROW_DS_EXPORT ParquetFileFragment : public FileFragment { // or util::nullopt if all row groups are selected. util::optional> row_groups_; - ExpressionVector statistics_expressions_; + std::vector statistics_expressions_; std::vector statistics_expressions_complete_; std::shared_ptr metadata_; std::shared_ptr manifest_; diff --git a/cpp/src/arrow/dataset/file_parquet_test.cc b/cpp/src/arrow/dataset/file_parquet_test.cc index c2e0d6b632d..5a6bbddaba3 100644 --- a/cpp/src/arrow/dataset/file_parquet_test.cc +++ b/cpp/src/arrow/dataset/file_parquet_test.cc @@ -160,6 +160,11 @@ class TestParquetFileFormat : public ArrowParquetWriterMixin { return Batches(std::move(scan_task_it)); } + void SetFilter(Expression2 filter) { + ASSERT_OK_AND_ASSIGN(std::tie(opts_->filter2, ctx_->expression_state), + filter.Bind(*schema_)); + } + std::shared_ptr SingleBatch(Fragment* fragment) { auto batches = IteratorToVector(Batches(fragment)); EXPECT_EQ(batches.size(), 1); @@ -188,9 +193,12 @@ class TestParquetFileFormat : public ArrowParquetWriterMixin { void CountRowGroupsInFragment(const std::shared_ptr& fragment, std::vector expected_row_groups, - const Expression& filter) { + Expression2 filter) { + schema_ = opts_->schema(); + ASSERT_OK_AND_ASSIGN(auto bound, filter.Bind(*schema_)); + ctx_->expression_state = bound.second; auto parquet_fragment = checked_pointer_cast(fragment); - ASSERT_OK_AND_ASSIGN(auto fragments, parquet_fragment->SplitByRowGroup(filter.Copy())) + ASSERT_OK_AND_ASSIGN(auto fragments, parquet_fragment->SplitByRowGroup(bound)) EXPECT_EQ(fragments.size(), expected_row_groups.size()); for (size_t i = 0; i < fragments.size(); i++) { @@ -214,6 +222,7 @@ TEST_F(TestParquetFileFormat, ScanRecordBatchReader) { auto source = GetFileSource(reader.get()); opts_ = ScanOptions::Make(reader->schema()); + SetFilter(literal(true)); ASSERT_OK_AND_ASSIGN(auto fragment, format_->MakeFragment(*source)); int64_t row_count = 0; @@ -233,6 +242,7 @@ TEST_F(TestParquetFileFormat, ScanRecordBatchReaderDictEncoded) { auto source = GetFileSource(reader.get()); opts_ = ScanOptions::Make(reader->schema()); + SetFilter(literal(true)); format_->reader_options.dict_columns = {"utf8"}; ASSERT_OK_AND_ASSIGN(auto fragment, format_->MakeFragment(*source)); @@ -275,7 +285,7 @@ TEST_F(TestParquetFileFormat, ScanRecordBatchReaderProjected) { opts_ = ScanOptions::Make(schema_); opts_->projector = RecordBatchProjector(SchemaFromColumnNames(schema_, {"f64"})); - opts_->filter = equal(field_ref("i32"), scalar(0)); + SetFilter(equal(field_ref("i32"), scalar(0))); // NB: projector is applied by the scanner; FileFragment does not evaluate it so // we will not drop "i32" even though it is not in the projector's schema @@ -311,7 +321,7 @@ TEST_F(TestParquetFileFormat, ScanRecordBatchReaderProjectedMissingCols) { schema_ = reader->schema(); opts_ = ScanOptions::Make(schema_); opts_->projector = RecordBatchProjector(SchemaFromColumnNames(schema_, {"f64"})); - opts_->filter = equal(field_ref("i32"), scalar(0)); + SetFilter(equal(field_ref("i32"), scalar(0))); auto readers = {reader.get(), reader_without_i32.get(), reader_without_f64.get()}; for (auto reader : readers) { @@ -404,34 +414,35 @@ TEST_F(TestParquetFileFormat, PredicatePushdown) { auto source = GetFileSource(reader.get()); opts_ = ScanOptions::Make(reader->schema()); + schema_ = reader->schema(); ASSERT_OK_AND_ASSIGN(auto fragment, format_->MakeFragment(*source)); - opts_->filter = scalar(true); + SetFilter(literal(true)); CountRowsAndBatchesInScan(fragment, kTotalNumRows, kNumRowGroups); for (int64_t i = 1; i <= kNumRowGroups; i++) { - opts_->filter = ("i64"_ == int64_t(i)).Copy(); + SetFilter(("i64"_ == int64_t(i))); CountRowsAndBatchesInScan(fragment, i, 1); } // Out of bound filters should skip all RowGroups. - opts_->filter = scalar(false); + SetFilter(literal(false)); CountRowsAndBatchesInScan(fragment, 0, 0); - opts_->filter = ("i64"_ == int64_t(kNumRowGroups + 1)).Copy(); + SetFilter(("i64"_ == int64_t(kNumRowGroups + 1))); CountRowsAndBatchesInScan(fragment, 0, 0); - opts_->filter = ("i64"_ == int64_t(-1)).Copy(); + SetFilter(("i64"_ == int64_t(-1))); CountRowsAndBatchesInScan(fragment, 0, 0); // No rows match 1 and 2. - opts_->filter = ("i64"_ == int64_t(1) and "u8"_ == uint8_t(2)).Copy(); + SetFilter(("i64"_ == int64_t(1) and "u8"_ == uint8_t(2))); CountRowsAndBatchesInScan(fragment, 0, 0); - opts_->filter = ("i64"_ == int64_t(2) or "i64"_ == int64_t(4)).Copy(); + SetFilter(("i64"_ == int64_t(2) or "i64"_ == int64_t(4))); CountRowsAndBatchesInScan(fragment, 2 + 4, 2); - opts_->filter = ("i64"_ < int64_t(6)).Copy(); + SetFilter(("i64"_ < int64_t(6))); CountRowsAndBatchesInScan(fragment, 5 * (5 + 1) / 2, 5); - opts_->filter = ("i64"_ >= int64_t(6)).Copy(); + SetFilter(("i64"_ >= int64_t(6))); CountRowsAndBatchesInScan(fragment, kTotalNumRows - (5 * (5 + 1) / 2), kNumRowGroups - 5); } @@ -446,15 +457,17 @@ TEST_F(TestParquetFileFormat, PredicatePushdownRowGroupFragments) { ASSERT_OK_AND_ASSIGN(auto fragment, format_->MakeFragment(*source)); auto all_row_groups = internal::Iota(static_cast(kNumRowGroups)); - CountRowGroupsInFragment(fragment, all_row_groups, *scalar(true)); - CountRowGroupsInFragment(fragment, all_row_groups, "not here"_ == 0); + CountRowGroupsInFragment(fragment, all_row_groups, literal(true)); + + // FIXME this is only meaningful if "not here" is a virtual column + // CountRowGroupsInFragment(fragment, all_row_groups, "not here"_ == 0); for (int i = 0; i < kNumRowGroups; ++i) { CountRowGroupsInFragment(fragment, {i}, "i64"_ == int64_t(i + 1)); } // Out of bound filters should skip all RowGroups. - CountRowGroupsInFragment(fragment, {}, *scalar(false)); + CountRowGroupsInFragment(fragment, {}, literal(false)); CountRowGroupsInFragment(fragment, {}, "i64"_ == int64_t(kNumRowGroups + 1)); CountRowGroupsInFragment(fragment, {}, "i64"_ == int64_t(-1)); @@ -503,10 +516,12 @@ TEST_F(TestParquetFileFormat, ExplicitRowGroupSelection) { auto source = GetFileSource(reader.get()); opts_ = ScanOptions::Make(reader->schema()); + schema_ = reader->schema(); + SetFilter(literal(true)); auto row_groups_fragment = [&](std::vector row_groups) { EXPECT_OK_AND_ASSIGN(auto fragment, - format_->MakeFragment(*source, scalar(true), + format_->MakeFragment(*source, literal(true), /*physical_schema=*/nullptr, row_groups)); return fragment; }; @@ -514,7 +529,7 @@ TEST_F(TestParquetFileFormat, ExplicitRowGroupSelection) { // select all row groups EXPECT_OK_AND_ASSIGN( auto all_row_groups_fragment, - format_->MakeFragment(*source, scalar(true)) + format_->MakeFragment(*source, literal(true)) .Map([](std::shared_ptr f) { return internal::checked_pointer_cast(f); })); @@ -532,17 +547,17 @@ TEST_F(TestParquetFileFormat, ExplicitRowGroupSelection) { for (int i = 0; i < kNumRowGroups; ++i) { // conflicting selection/filter - opts_->filter = ("i64"_ == int64_t(i)).Copy(); + SetFilter(("i64"_ == int64_t(i))); CountRowsAndBatchesInScan(row_groups_fragment({i}), 0, 0); } for (int i = 0; i < kNumRowGroups; ++i) { // identical selection/filter - opts_->filter = ("i64"_ == int64_t(i + 1)).Copy(); + SetFilter(("i64"_ == int64_t(i + 1))); CountRowsAndBatchesInScan(row_groups_fragment({i}), i + 1, 1); } - opts_->filter = ("i64"_ > int64_t(3)).Copy(); + SetFilter(("i64"_ > int64_t(3))); CountRowsAndBatchesInScan(row_groups_fragment({2, 3, 4, 5}), 4 + 5 + 6, 3); EXPECT_RAISES_WITH_MESSAGE_THAT( diff --git a/cpp/src/arrow/dataset/file_test.cc b/cpp/src/arrow/dataset/file_test.cc index 8ee2613576f..c390f777ad4 100644 --- a/cpp/src/arrow/dataset/file_test.cc +++ b/cpp/src/arrow/dataset/file_test.cc @@ -82,15 +82,15 @@ TEST(FileSource, BufferBased) { TEST_F(TestFileSystemDataset, Basic) { MakeDataset({}); - AssertFragmentsAreFromPath(dataset_->GetFragments(), {}); + AssertFragmentsAreFromPath(*dataset_->GetFragments(), {}); MakeDataset({fs::File("a"), fs::File("b"), fs::File("c")}); - AssertFragmentsAreFromPath(dataset_->GetFragments(), {"a", "b", "c"}); + AssertFragmentsAreFromPath(*dataset_->GetFragments(), {"a", "b", "c"}); AssertFilesAre(dataset_, {"a", "b", "c"}); // Should not create fragment from directories. MakeDataset({fs::Dir("A"), fs::Dir("A/B"), fs::File("A/a"), fs::File("A/B/b")}); - AssertFragmentsAreFromPath(dataset_->GetFragments(), {"A/a", "A/B/b"}); + AssertFragmentsAreFromPath(*dataset_->GetFragments(), {"A/a", "A/B/b"}); AssertFilesAre(dataset_, {"A/a", "A/B/b"}); } @@ -98,7 +98,7 @@ TEST_F(TestFileSystemDataset, ReplaceSchema) { auto schm = schema({field("i32", int32()), field("f64", float64())}); auto format = std::make_shared(schm); ASSERT_OK_AND_ASSIGN(auto dataset, - FileSystemDataset::Make(schm, scalar(true), format, nullptr, {})); + FileSystemDataset::Make(schm, literal(true), format, nullptr, {})); // drop field ASSERT_OK(dataset->ReplaceSchema(schema({field("i32", int32())})).status()); @@ -119,45 +119,64 @@ TEST_F(TestFileSystemDataset, ReplaceSchema) { } TEST_F(TestFileSystemDataset, RootPartitionPruning) { - auto root_partition = ("a"_ == 5).Copy(); + auto root_partition = equal(field_ref("i32"), literal(5)); MakeDataset({fs::File("a"), fs::File("b")}, root_partition); + auto GetFragments = [&](Expression2 filter) { + return *dataset_->GetFragments(*filter.Bind(*dataset_->schema())); + }; + // Default filter should always return all data. - AssertFragmentsAreFromPath(dataset_->GetFragments(), {"a", "b"}); + AssertFragmentsAreFromPath(*dataset_->GetFragments(), {"a", "b"}); // filter == partition - AssertFragmentsAreFromPath(dataset_->GetFragments(root_partition), {"a", "b"}); + AssertFragmentsAreFromPath(GetFragments(root_partition), {"a", "b"}); // Same partition key, but non matching filter - AssertFragmentsAreFromPath(dataset_->GetFragments(("a"_ == 6).Copy()), {}); + AssertFragmentsAreFromPath(GetFragments(equal(field_ref("i32"), literal(6))), {}); - AssertFragmentsAreFromPath(dataset_->GetFragments(("a"_ > 1).Copy()), {"a", "b"}); + AssertFragmentsAreFromPath(GetFragments(greater(field_ref("i32"), literal(1))), + {"a", "b"}); // different key shouldn't prune - AssertFragmentsAreFromPath(dataset_->GetFragments(("b"_ == 6).Copy()), {"a", "b"}); + AssertFragmentsAreFromPath(GetFragments(equal(field_ref("f32"), literal(3.F))), + {"a", "b"}); // No partition should match MakeDataset({fs::File("a"), fs::File("b")}); - AssertFragmentsAreFromPath(dataset_->GetFragments(("b"_ == 6).Copy()), {"a", "b"}); + AssertFragmentsAreFromPath(GetFragments(equal(field_ref("f32"), literal(3.F))), + {"a", "b"}); } TEST_F(TestFileSystemDataset, TreePartitionPruning) { - auto root_partition = ("country"_ == "US").Copy(); + auto root_partition = equal(field_ref("country"), literal("US")); + std::vector regions = { fs::Dir("NY"), fs::File("NY/New York"), fs::File("NY/Franklin"), fs::Dir("CA"), fs::File("CA/San Francisco"), fs::File("CA/Franklin"), }; - ExpressionVector partitions = { - ("state"_ == "NY").Copy(), - ("state"_ == "NY" and "city"_ == "New York").Copy(), - ("state"_ == "NY" and "city"_ == "Franklin").Copy(), - ("state"_ == "CA").Copy(), - ("state"_ == "CA" and "city"_ == "San Francisco").Copy(), - ("state"_ == "CA" and "city"_ == "Franklin").Copy(), + std::vector partitions = { + equal(field_ref("state"), literal("NY")), + + equal(field_ref("state"), literal("NY")) and + equal(field_ref("city"), literal("New York")), + + equal(field_ref("state"), literal("NY")) and + equal(field_ref("city"), literal("Franklin")), + + equal(field_ref("state"), literal("CA")), + + equal(field_ref("state"), literal("CA")) and + equal(field_ref("city"), literal("San Francisco")), + + equal(field_ref("state"), literal("CA")) and + equal(field_ref("city"), literal("Franklin")), }; - MakeDataset(regions, root_partition, partitions); + MakeDataset( + regions, root_partition, partitions, + schema({field("country", utf8()), field("state", utf8()), field("city", utf8())})); std::vector all_cities = {"CA/San Francisco", "CA/Franklin", "NY/New York", "NY/Franklin"}; @@ -165,52 +184,67 @@ TEST_F(TestFileSystemDataset, TreePartitionPruning) { std::vector franklins = {"CA/Franklin", "NY/Franklin"}; // Default filter should always return all data. - AssertFragmentsAreFromPath(dataset_->GetFragments(), all_cities); + AssertFragmentsAreFromPath(*dataset_->GetFragments(), all_cities); + + auto GetFragments = [&](Expression2 filter) { + return *dataset_->GetFragments(*filter.Bind(*dataset_->schema())); + }; // Dataset's partitions are respected - AssertFragmentsAreFromPath(dataset_->GetFragments(("country"_ == "US").Copy()), + AssertFragmentsAreFromPath(GetFragments(equal(field_ref("country"), literal("US"))), all_cities); - AssertFragmentsAreFromPath(dataset_->GetFragments(("country"_ == "FR").Copy()), {}); + AssertFragmentsAreFromPath(GetFragments(equal(field_ref("country"), literal("FR"))), + {}); - AssertFragmentsAreFromPath(dataset_->GetFragments(("state"_ == "CA").Copy()), + AssertFragmentsAreFromPath(GetFragments(equal(field_ref("state"), literal("CA"))), ca_cities); // Filter where no decisions can be made on inner nodes when filter don't // apply to inner partitions. - AssertFragmentsAreFromPath(dataset_->GetFragments(("city"_ == "Franklin").Copy()), + AssertFragmentsAreFromPath(GetFragments(equal(field_ref("city"), literal("Franklin"))), franklins); } TEST_F(TestFileSystemDataset, FragmentPartitions) { - auto root_partition = ("country"_ == "US").Copy(); + auto root_partition = equal(field_ref("country"), literal("US")); std::vector regions = { fs::Dir("NY"), fs::File("NY/New York"), fs::File("NY/Franklin"), fs::Dir("CA"), fs::File("CA/San Francisco"), fs::File("CA/Franklin"), }; - ExpressionVector partitions = { - ("state"_ == "NY").Copy(), - ("state"_ == "NY" and "city"_ == "New York").Copy(), - ("state"_ == "NY" and "city"_ == "Franklin").Copy(), - ("state"_ == "CA").Copy(), - ("state"_ == "CA" and "city"_ == "San Francisco").Copy(), - ("state"_ == "CA" and "city"_ == "Franklin").Copy(), - }; + std::vector partitions = { + equal(field_ref("state"), literal("NY")), + + equal(field_ref("state"), literal("NY")) and + equal(field_ref("city"), literal("New York")), - MakeDataset(regions, root_partition, partitions); + equal(field_ref("state"), literal("NY")) and + equal(field_ref("city"), literal("Franklin")), - auto with_root = [&](const Expression& state, const Expression& city) { - return and_(state.Copy(), city.Copy()); + equal(field_ref("state"), literal("CA")), + + equal(field_ref("state"), literal("CA")) and + equal(field_ref("city"), literal("San Francisco")), + + equal(field_ref("state"), literal("CA")) and + equal(field_ref("city"), literal("Franklin")), }; + MakeDataset( + regions, root_partition, partitions, + schema({field("country", utf8()), field("state", utf8()), field("city", utf8())})); + AssertFragmentsHavePartitionExpressions( - dataset_->GetFragments(), - { - with_root("state"_ == "CA", "city"_ == "San Francisco"), - with_root("state"_ == "CA", "city"_ == "Franklin"), - with_root("state"_ == "NY", "city"_ == "New York"), - with_root("state"_ == "NY", "city"_ == "Franklin"), - }); + dataset_, { + and_(equal(field_ref("state"), literal("CA")), + equal(field_ref("city"), literal("San Francisco"))), + and_(equal(field_ref("state"), literal("CA")), + equal(field_ref("city"), literal("Franklin"))), + and_(equal(field_ref("state"), literal("NY")), + equal(field_ref("city"), literal("New York"))), + and_(equal(field_ref("state"), literal("NY")), + equal(field_ref("city"), literal("Franklin"))), + }); } } // namespace dataset diff --git a/cpp/src/arrow/dataset/filter.cc b/cpp/src/arrow/dataset/filter.cc index caab28e3f7c..e231737bc5f 100644 --- a/cpp/src/arrow/dataset/filter.cc +++ b/cpp/src/arrow/dataset/filter.cc @@ -29,6 +29,7 @@ #include "arrow/buffer.h" #include "arrow/compute/api.h" #include "arrow/dataset/dataset.h" +#include "arrow/dataset/expression.h" #include "arrow/io/memory.h" #include "arrow/ipc/reader.h" #include "arrow/ipc/writer.h" @@ -808,38 +809,6 @@ std::shared_ptr ScalarExpression::Copy() const { return std::make_shared(*this); } -std::shared_ptr and_(std::shared_ptr lhs, - std::shared_ptr rhs) { - return std::make_shared(std::move(lhs), std::move(rhs)); -} - -std::shared_ptr and_(const ExpressionVector& subexpressions) { - auto acc = scalar(true); - for (const auto& next : subexpressions) { - if (next->Equals(false)) return next; - acc = acc->Equals(true) ? next : and_(std::move(acc), next); - } - return acc; -} - -std::shared_ptr or_(std::shared_ptr lhs, - std::shared_ptr rhs) { - return std::make_shared(std::move(lhs), std::move(rhs)); -} - -std::shared_ptr or_(const ExpressionVector& subexpressions) { - auto acc = scalar(false); - for (const auto& next : subexpressions) { - if (next->Equals(true)) return next; - acc = acc->Equals(false) ? next : or_(std::move(acc), next); - } - return acc; -} - -std::shared_ptr not_(std::shared_ptr operand) { - return std::make_shared(std::move(operand)); -} - AndExpression operator&&(const Expression& lhs, const Expression& rhs) { return AndExpression(lhs.Copy(), rhs.Copy()); } @@ -848,8 +817,6 @@ OrExpression operator||(const Expression& lhs, const Expression& rhs) { return OrExpression(lhs.Copy(), rhs.Copy()); } -NotExpression operator!(const Expression& rhs) { return NotExpression(rhs.Copy()); } - CastExpression Expression::CastTo(std::shared_ptr type, compute::CastOptions options) const { return CastExpression(Copy(), type, std::move(options)); @@ -890,8 +857,8 @@ Status EnsureNullOrBool(const std::string& msg_prefix, return Status::TypeError(msg_prefix, *type); } -Result> ValidateBoolean(const ExpressionVector& operands, - const Schema& schema) { +Result> ValidateBoolean( + const std::vector>& operands, const Schema& schema) { for (const auto& operand : operands) { ARROW_ASSIGN_OR_RAISE(auto type, operand->Validate(schema)); RETURN_NOT_OK( @@ -1476,7 +1443,7 @@ struct DeserializeImpl { switch (expression_type) { case ExpressionType::FIELD: { ARROW_ASSIGN_OR_RAISE(auto name, GetView(struct_array, 0)); - return field_ref(std::string(name)); + return std::make_shared(std::string(name)); } case ExpressionType::SCALAR: { diff --git a/cpp/src/arrow/dataset/filter.h b/cpp/src/arrow/dataset/filter.h index c6b55b419ff..750c67c048d 100644 --- a/cpp/src/arrow/dataset/filter.h +++ b/cpp/src/arrow/dataset/filter.h @@ -464,24 +464,6 @@ class ARROW_DS_EXPORT CustomExpression : public Expression { CustomExpression() : Expression(ExpressionType::CUSTOM) {} }; -ARROW_DS_EXPORT std::shared_ptr and_(std::shared_ptr lhs, - std::shared_ptr rhs); - -ARROW_DS_EXPORT std::shared_ptr and_(const ExpressionVector& subexpressions); - -ARROW_DS_EXPORT AndExpression operator&&(const Expression& lhs, const Expression& rhs); - -ARROW_DS_EXPORT std::shared_ptr or_(std::shared_ptr lhs, - std::shared_ptr rhs); - -ARROW_DS_EXPORT std::shared_ptr or_(const ExpressionVector& subexpressions); - -ARROW_DS_EXPORT OrExpression operator||(const Expression& lhs, const Expression& rhs); - -ARROW_DS_EXPORT std::shared_ptr not_(std::shared_ptr operand); - -ARROW_DS_EXPORT NotExpression operator!(const Expression& rhs); - inline std::shared_ptr scalar(std::shared_ptr value) { return std::make_shared(std::move(value)); } @@ -491,43 +473,6 @@ auto scalar(T&& value) -> decltype(scalar(MakeScalar(std::forward(value)))) { return scalar(MakeScalar(std::forward(value))); } -#define COMPARISON_FACTORY(NAME, FACTORY_NAME, OP) \ - inline std::shared_ptr FACTORY_NAME( \ - const std::shared_ptr& lhs, const std::shared_ptr& rhs) { \ - return std::make_shared(CompareOperator::NAME, lhs, rhs); \ - } \ - \ - template ::type>::value>::type> \ - ComparisonExpression operator OP(const Expression& lhs, T&& rhs) { \ - return ComparisonExpression(CompareOperator::NAME, lhs.Copy(), \ - scalar(std::forward(rhs))); \ - } \ - \ - inline ComparisonExpression operator OP(const Expression& lhs, \ - const Expression& rhs) { \ - return ComparisonExpression(CompareOperator::NAME, lhs.Copy(), rhs.Copy()); \ - } -COMPARISON_FACTORY(EQUAL, equal, ==) -COMPARISON_FACTORY(NOT_EQUAL, not_equal, !=) -COMPARISON_FACTORY(GREATER, greater, >) -COMPARISON_FACTORY(GREATER_EQUAL, greater_equal, >=) -COMPARISON_FACTORY(LESS, less, <) -COMPARISON_FACTORY(LESS_EQUAL, less_equal, <=) -#undef COMPARISON_FACTORY - -inline std::shared_ptr field_ref(std::string name) { - return std::make_shared(std::move(name)); -} - -inline namespace string_literals { -// clang-format off -inline FieldExpression operator"" _(const char* name, size_t name_length) { - // clang-format on - return FieldExpression({name, name_length}); -} -} // namespace string_literals - template bool Expression::Equals(T&& t) const { if (type_ != ExpressionType::SCALAR) { diff --git a/cpp/src/arrow/dataset/filter_test.cc b/cpp/src/arrow/dataset/filter_test.cc index 7723912eeab..fce2d9cfcd9 100644 --- a/cpp/src/arrow/dataset/filter_test.cc +++ b/cpp/src/arrow/dataset/filter_test.cc @@ -45,30 +45,55 @@ using string_literals::operator"" _; using internal::checked_cast; using internal::checked_pointer_cast; +// A frozen shared_ptr with behavior expected by GTest +struct TestExpression : util::EqualityComparable, + util::ToStringOstreamable { + // NOLINTNEXTLINE runtime/explicit + TestExpression(std::shared_ptr e) : expression(std::move(e)) {} + + // NOLINTNEXTLINE runtime/explicit + TestExpression(const Expression& e) : expression(e.Copy()) {} + + // NOLINTNEXTLINE runtime/explicit + TestExpression(const Expression2& e) : expression(e) {} + + std::shared_ptr expression; + + using util::EqualityComparable::operator==; + bool Equals(const TestExpression& other) const { + return expression->Equals(other.expression); + } + + std::string ToString() const { return expression->ToString(); } + + friend bool operator==(const std::shared_ptr& lhs, + const TestExpression& rhs) { + return TestExpression(lhs) == rhs; + } + + friend void PrintTo(const TestExpression& expr, std::ostream* os) { + *os << expr.ToString(); + } +}; + using E = TestExpression; class ExpressionsTest : public ::testing::Test { public: - void AssertSimplifiesTo(const Expression& expr, const Expression& given, - const Expression& expected) { - ASSERT_OK_AND_ASSIGN(auto expr_type, expr.Validate(*schema_)); - ASSERT_OK_AND_ASSIGN(auto given_type, given.Validate(*schema_)); - ASSERT_OK_AND_ASSIGN(auto expected_type, expected.Validate(*schema_)); + void AssertSimplifiesTo(E expr, E given, E expected) { + ASSERT_OK_AND_ASSIGN(auto expr_type, expr.expression->Validate(*schema_)); + ASSERT_OK_AND_ASSIGN(auto given_type, given.expression->Validate(*schema_)); + ASSERT_OK_AND_ASSIGN(auto expected_type, expected.expression->Validate(*schema_)); EXPECT_TRUE(expr_type->Equals(expected_type)); EXPECT_TRUE(given_type->Equals(boolean())); - auto simplified = expr.Assume(given); - ASSERT_EQ(E{simplified}, E{expected}) + auto simplified = expr.expression->Assume(given.expression); + ASSERT_EQ(simplified, expected) << " simplification of: " << expr.ToString() << std::endl << " given: " << given.ToString() << std::endl; } - void AssertSimplifiesTo(const Expression& expr, const Expression& given, - const std::shared_ptr& expected) { - AssertSimplifiesTo(expr, given, *expected); - } - std::shared_ptr ns = timestamp(TimeUnit::NANO); std::shared_ptr schema_ = schema({field("a", int32()), field("b", int32()), field("f", float64()), @@ -78,16 +103,6 @@ class ExpressionsTest : public ::testing::Test { std::shared_ptr never = scalar(false); }; -TEST_F(ExpressionsTest, StringRepresentation) { - ASSERT_EQ("a"_.ToString(), "a"); - ASSERT_EQ(("a"_ > int32_t(3)).ToString(), "(a > 3:int32)"); - ASSERT_EQ(("a"_ > int32_t(3) and "a"_ < int32_t(4)).ToString(), - "((a > 3:int32) and (a < 4:int32))"); - ASSERT_EQ(("f"_ > double(4)).ToString(), "(f > 4:double)"); - ASSERT_EQ("f"_.CastTo(float64()).ToString(), "(cast f to double)"); - ASSERT_EQ("f"_.CastLike("a"_).ToString(), "(cast f like a)"); -} - TEST_F(ExpressionsTest, Equality) { ASSERT_EQ(E{"a"_}, E{"a"_}); ASSERT_NE(E{"a"_}, E{"b"_}); @@ -114,7 +129,7 @@ TEST_F(ExpressionsTest, SimplificationOfCompoundQuery) { AssertSimplifiesTo("b"_ == 3 or "b"_ == 4, "b"_ > 3, "b"_ == 4); AssertSimplifiesTo("b"_ == 3 or "b"_ == 4, "b"_ >= 3, "b"_ == 3 or "b"_ == 4); - AssertSimplifiesTo("f"_ > 0.5 and "f"_ < 1.5, not("f"_ < 0.0 or "f"_ > 1.0), + AssertSimplifiesTo("f"_ > 0.5 and "f"_ < 1.5, not_("f"_ < 0.0 or "f"_ > 1.0), "f"_ > 0.5); AssertSimplifiesTo("b"_ == 4, "a"_ == 0, "b"_ == 4); @@ -125,7 +140,7 @@ TEST_F(ExpressionsTest, SimplificationOfCompoundQuery) { AssertSimplifiesTo("a"_ == 3 and "b"_.In(set_123), "b"_ == 3, "a"_ == 3); AssertSimplifiesTo("a"_ == 3 and "b"_.In(set_123), "b"_ == 0, *never); - AssertSimplifiesTo("a"_ == 0 or not"b"_.IsValid(), "b"_ == 3, "a"_ == 0); + AssertSimplifiesTo("a"_ == 0 or not_("b"_.IsValid()), "b"_ == 3, "a"_ == 0); } TEST_F(ExpressionsTest, SimplificationAgainstCompoundCondition) { @@ -144,20 +159,6 @@ TEST_F(ExpressionsTest, SimplificationAgainstCompoundCondition) { *always); } -TEST_F(ExpressionsTest, SimplificationToNull) { - auto null = scalar(std::make_shared()); - auto null32 = scalar(std::make_shared()); - - AssertSimplifiesTo(*equal(field_ref("b"), null32), "b"_ == 3, *null); - AssertSimplifiesTo(*not_equal(field_ref("b"), null32), "b"_ == 3, *null); - - // Kleene logic applies here - AssertSimplifiesTo(*not_equal(field_ref("b"), null32) and "b"_ > 3, "b"_ == 3, *never); - AssertSimplifiesTo(*not_equal(field_ref("b"), null32) and "b"_ > 2, "b"_ == 3, *null); - AssertSimplifiesTo(*not_equal(field_ref("b"), null32) or "b"_ > 3, "b"_ == 3, *null); - AssertSimplifiesTo(*not_equal(field_ref("b"), null32) or "b"_ > 2, "b"_ == 3, *always); -} - class FilterTest : public ::testing::Test { public: FilterTest() { evaluator_ = std::make_shared(); } @@ -444,165 +445,6 @@ TEST_F(FilterTest, KleeneTruthTables) { ])"); } -class TakeExpression : public CustomExpression { - public: - TakeExpression(std::shared_ptr operand, std::shared_ptr dictionary) - : operand_(std::move(operand)), dictionary_(std::move(dictionary)) {} - - std::string ToString() const override { - return dictionary_->ToString() + "[" + operand_->ToString() + "]"; - } - - std::shared_ptr Copy() const override { - return std::make_shared(*this); - } - - bool Equals(const Expression& other) const override { - // in a real CustomExpression this would need to be more sophisticated - return other.type() == ExpressionType::CUSTOM && ToString() == other.ToString(); - } - - Result> Validate(const Schema& schema) const override { - ARROW_ASSIGN_OR_RAISE(auto operand_type, operand_->Validate(schema)); - if (!is_integer(operand_type->id())) { - return Status::TypeError("Take indices must be integral, not ", *operand_type); - } - return dictionary_->type(); - } - - class Evaluator : public TreeEvaluator { - public: - using TreeEvaluator::TreeEvaluator; - - using TreeEvaluator::Evaluate; - - Result Evaluate(const Expression& expr, const RecordBatch& batch, - MemoryPool* pool) const override { - if (expr.type() == ExpressionType::CUSTOM) { - const auto& take_expr = checked_cast(expr); - return EvaluateTake(take_expr, batch, pool); - } - return TreeEvaluator::Evaluate(expr, batch, pool); - } - - Result EvaluateTake(const TakeExpression& take_expr, const RecordBatch& batch, - MemoryPool* pool) const { - ARROW_ASSIGN_OR_RAISE(auto indices, Evaluate(*take_expr.operand_, batch, pool)); - - if (indices.kind() == Datum::SCALAR) { - ARROW_ASSIGN_OR_RAISE(auto indices_array, - MakeArrayFromScalar(*indices.scalar(), batch.num_rows(), - default_memory_pool())); - indices = Datum(indices_array->data()); - } - - DCHECK_EQ(indices.kind(), Datum::ARRAY); - compute::ExecContext ctx(pool); - ARROW_ASSIGN_OR_RAISE(Datum out, - compute::Take(take_expr.dictionary_->data(), indices, - compute::TakeOptions(), &ctx)); - return std::move(out); - } - }; - - private: - std::shared_ptr operand_; - std::shared_ptr dictionary_; -}; - -TEST_F(ExpressionsTest, TakeAssumeYieldsNothing) { - auto dict = ArrayFromJSON(float64(), "[0.0, 0.25, 0.5, 0.75, 1.0]"); - auto take_b_is_half = (TakeExpression(field_ref("b"), dict) == 0.5); - - // no special Assume logic was provided for TakeExpression so we should just ignore it - // (logically the below *could* be simplified to false but we haven't implemented that) - AssertSimplifiesTo(take_b_is_half, "b"_ == 3, take_b_is_half); - - // custom expressions will not interfere with simplification of other subexpressions and - // can be dropped if other subexpressions simplify trivially - - // ("b"_ > 5).Assume("b"_ == 3) simplifies to false regardless of take, so the and will - // be false - AssertSimplifiesTo("b"_ > 5 and take_b_is_half, "b"_ == 3, *never); - - // ("b"_ > 5).Assume("b"_ == 6) simplifies to true regardless of take, so it can be - // dropped - AssertSimplifiesTo("b"_ > 5 and take_b_is_half, "b"_ == 6, take_b_is_half); - - // ("b"_ > 5).Assume("b"_ == 6) simplifies to true regardless of take, so the or will be - // true - AssertSimplifiesTo(take_b_is_half or "b"_ > 5, "b"_ == 6, *always); - - // ("b"_ > 5).Assume("b"_ == 3) simplifies to true regardless of take, so it can be - // dropped - AssertSimplifiesTo(take_b_is_half or "b"_ > 5, "b"_ == 3, take_b_is_half); -} - -TEST_F(FilterTest, EvaluateTakeExpression) { - evaluator_ = std::make_shared(); - - auto dict = ArrayFromJSON(float64(), "[0.0, 0.25, 0.5, 0.75, 1.0]"); - - AssertFilter(TakeExpression(field_ref("b"), dict) == 0.5, - {field("b", int32()), field("f", float64())}, R"([ - {"b": 3, "f": -0.1, "in": 0}, - {"b": 2, "f": 0.3, "in": 1}, - {"b": 1, "f": 0.2, "in": 0}, - {"b": 2, "f": -0.1, "in": 1}, - {"b": 4, "f": 0.1, "in": 0}, - {"b": null, "f": 0.0, "in": null}, - {"b": 0, "f": 1.0, "in": 0} - ])"); -} - -void AssertFieldsInExpression(std::shared_ptr expr, - std::vector expected) { - EXPECT_THAT(FieldsInExpression(expr), testing::ContainerEq(expected)); -} - -TEST(FieldsInExpressionTest, Basic) { - AssertFieldsInExpression(scalar(true), {}); - - AssertFieldsInExpression(("a"_).Copy(), {"a"}); - AssertFieldsInExpression(("a"_ == 1).Copy(), {"a"}); - AssertFieldsInExpression(("a"_ == "b"_).Copy(), {"a", "b"}); - - AssertFieldsInExpression(("a"_ == 1 || "a"_ == 2).Copy(), {"a", "a"}); - AssertFieldsInExpression(("a"_ == 1 || "b"_ == 2).Copy(), {"a", "b"}); - AssertFieldsInExpression((not("a"_ == 1) && ("b"_ == 2 || not("c"_ < 3))).Copy(), - {"a", "b", "c"}); -} - -TEST(ExpressionSerializationTest, RoundTrips) { - std::vector exprs{ - scalar(MakeNullScalar(null())), - scalar(MakeNullScalar(int32())), - scalar(MakeNullScalar(struct_({field("i", int32()), field("s", utf8())}))), - scalar(true), - scalar(false), - scalar(1), - scalar(1.125), - scalar("stringy strings"), - "field"_, - "a"_ > 0.25, - "a"_ == 1 or "b"_ != "hello" or "b"_ == "foo bar", - not"alpha"_, - "valid"_ and "a"_.CastLike("b"_) >= "b"_, - "version"_.CastTo(float64()).In(ArrayFromJSON(float64(), "[0.5, 1.0, 2.0]")), - "validity"_.IsValid(), - ("x"_ >= -1.5 and "x"_ < 0.0) and ("y"_ >= 0.0 and "y"_ < 1.5) and - ("z"_ > 1.5 and "z"_ <= 3.0), - "year"_ == int16_t(1999) and "month"_ == int8_t(12) and "day"_ == int8_t(31) and - "hour"_ == int8_t(0) and "alpha"_ == int32_t(0) and "beta"_ == 3.25f, - }; - - for (const auto& expr : exprs) { - ASSERT_OK_AND_ASSIGN(auto serialized, expr.expression->Serialize()); - ASSERT_OK_AND_ASSIGN(E roundtripped, Expression::Deserialize(*serialized)); - ASSERT_EQ(expr, roundtripped); - } -} - void AssertGrouping(const FieldVector& by_fields, const std::string& batch_json, const std::string& expected_json) { FieldVector fields_with_ids = by_fields; diff --git a/cpp/src/arrow/dataset/partition.cc b/cpp/src/arrow/dataset/partition.cc index a39a8adaae0..b6178ddeea6 100644 --- a/cpp/src/arrow/dataset/partition.cc +++ b/cpp/src/arrow/dataset/partition.cc @@ -49,81 +49,47 @@ std::shared_ptr Partitioning::Default() { std::string type_name() const override { return "default"; } - Result> Parse(const std::string& path) const override { - return scalar(true); + Result Parse(const std::string& path) const override { + return literal(true); } - Result Format(const Expression& expr) const override { + Result Format(const Expression2& expr) const override { return Status::NotImplemented("formatting paths from ", type_name(), " Partitioning"); } Result Partition( const std::shared_ptr& batch) const override { - return PartitionedBatches{{batch}, {scalar(true)}}; + return PartitionedBatches{{batch}, {literal(true)}}; } }; return std::make_shared(); } -Status KeyValuePartitioning::VisitKeys( - const Expression& expr, - const std::function& value)>& visitor) { - return VisitConjunctionMembers(expr, [visitor](const Expression& expr) { - if (expr.type() != ExpressionType::COMPARISON) { - return Status::OK(); - } - - const auto& cmp = checked_cast(expr); - if (cmp.op() != compute::CompareOperator::EQUAL) { - return Status::OK(); - } - - auto lhs = cmp.left_operand().get(); - auto rhs = cmp.right_operand().get(); - if (lhs->type() != ExpressionType::FIELD) std::swap(lhs, rhs); - - if (lhs->type() != ExpressionType::FIELD || rhs->type() != ExpressionType::SCALAR) { - return Status::OK(); +Status KeyValuePartitioning::SetDefaultValuesFromKeys(const Expression2& expr, + RecordBatchProjector* projector) { + ARROW_ASSIGN_OR_RAISE(auto known_values, ExtractKnownFieldValues(expr)); + for (const auto& ref_value : known_values) { + if (!ref_value.second.is_scalar()) { + return Status::Invalid("non-scalar partition key ", ref_value.second.ToString()); } - return visitor(checked_cast(lhs)->name(), - checked_cast(rhs)->value()); - }); -} + ARROW_ASSIGN_OR_RAISE(auto match, + ref_value.first.FindOneOrNone(*projector->schema())); -Result>> -KeyValuePartitioning::GetKeys(const Expression& expr) { - std::unordered_map> keys; - RETURN_NOT_OK( - VisitKeys(expr, [&](const std::string& name, const std::shared_ptr& value) { - keys.emplace(name, value); - return Status::OK(); - })); - return keys; -} - -Status KeyValuePartitioning::SetDefaultValuesFromKeys(const Expression& expr, - RecordBatchProjector* projector) { - return KeyValuePartitioning::VisitKeys( - expr, [projector](const std::string& name, const std::shared_ptr& value) { - ARROW_ASSIGN_OR_RAISE(auto match, - FieldRef(name).FindOneOrNone(*projector->schema())); - if (!match) { - return Status::OK(); - } - return projector->SetDefaultValue(match, value); - }); + if (!match) continue; + RETURN_NOT_OK(projector->SetDefaultValue(match, ref_value.second.scalar())); + } + return Status::OK(); } inline std::shared_ptr ConjunctionFromGroupingRow(Scalar* row) { ScalarVector* values = &checked_cast(row)->value; - ExpressionVector equality_expressions(values->size()); + std::vector equality_expressions(values->size()); for (size_t i = 0; i < values->size(); ++i) { const std::string& name = row->type->field(static_cast(i))->name(); - equality_expressions[i] = equal(field_ref(name), scalar(std::move(values->at(i)))); + equality_expressions[i] = equal(field_ref(name), literal(std::move(values->at(i)))); } return and_(std::move(equality_expressions)); } @@ -147,7 +113,7 @@ Result KeyValuePartitioning::Partition( if (by_fields.empty()) { // no fields to group by; return the whole batch - return PartitionedBatches{{batch}, {scalar(true)}}; + return PartitionedBatches{{batch}, {literal(true)}}; } ARROW_ASSIGN_OR_RAISE(auto by, @@ -168,11 +134,10 @@ Result KeyValuePartitioning::Partition( return out; } -Result> KeyValuePartitioning::ConvertKey( - const Key& key) const { +Result KeyValuePartitioning::ConvertKey(const Key& key) const { ARROW_ASSIGN_OR_RAISE(auto match, FieldRef(key.name).FindOneOrNone(*schema_)); if (!match) { - return scalar(true); + return literal(true); } auto field_index = match[0]; @@ -209,41 +174,49 @@ Result> KeyValuePartitioning::ConvertKey( ARROW_ASSIGN_OR_RAISE(converted, Scalar::Parse(field->type(), key.value)); } - return equal(field_ref(field->name()), scalar(std::move(converted))); + return equal(field_ref(field->name()), literal(std::move(converted))); } -Result> KeyValuePartitioning::Parse( - const std::string& path) const { - ExpressionVector expressions; +Result KeyValuePartitioning::Parse(const std::string& path) const { + std::vector expressions; for (const Key& key : ParseKeys(path)) { ARROW_ASSIGN_OR_RAISE(auto expr, ConvertKey(key)); - - if (expr->Equals(true)) continue; - + if (expr == literal(true)) continue; expressions.push_back(std::move(expr)); } - return and_(std::move(expressions)); + auto parsed = and_(std::move(expressions)); + ARROW_ASSIGN_OR_RAISE(std::tie(parsed, std::ignore), parsed.Bind(*schema_)) + return parsed; } -Result KeyValuePartitioning::Format(const Expression& expr) const { +Result KeyValuePartitioning::Format(const Expression2& expr) const { + if (!expr.IsBound()) { + return Status::Invalid("formatted partition expressions must be bound"); + } + std::vector values{static_cast(schema_->num_fields()), nullptr}; - RETURN_NOT_OK(VisitKeys(expr, [&](const std::string& name, - const std::shared_ptr& value) { - ARROW_ASSIGN_OR_RAISE(auto match, FieldRef(name).FindOneOrNone(*schema_)); - if (match) { - const auto& field = schema_->field(match[0]); - if (!value->type->Equals(field->type())) { - return Status::TypeError("scalar ", value->ToString(), " (of type ", *value->type, - ") is invalid for ", field->ToString()); - } + ARROW_ASSIGN_OR_RAISE(auto known_values, ExtractKnownFieldValues(expr)); + for (const auto& ref_value : known_values) { + if (!ref_value.second.is_scalar()) { + return Status::Invalid("non-scalar partition key ", ref_value.second.ToString()); + } + + ARROW_ASSIGN_OR_RAISE(auto match, ref_value.first.FindOneOrNone(*schema_)); + if (!match) continue; - values[match[0]] = value.get(); + const auto& value = ref_value.second.scalar(); + + const auto& field = schema_->field(match[0]); + if (!value->type->Equals(field->type())) { + return Status::TypeError("scalar ", value->ToString(), " (of type ", *value->type, + ") is invalid for ", field->ToString()); } - return Status::OK(); - })); + + values[match[0]] = value.get(); + } return FormatValues(values); } diff --git a/cpp/src/arrow/dataset/partition.h b/cpp/src/arrow/dataset/partition.h index 165fcfb5248..1bad8956b65 100644 --- a/cpp/src/arrow/dataset/partition.h +++ b/cpp/src/arrow/dataset/partition.h @@ -26,7 +26,7 @@ #include #include -#include "arrow/dataset/filter.h" +#include "arrow/dataset/expression.h" #include "arrow/dataset/type_fwd.h" #include "arrow/dataset/visibility.h" #include "arrow/util/optional.h" @@ -63,15 +63,15 @@ class ARROW_DS_EXPORT Partitioning { /// produce sub-batches which satisfy mutually exclusive Expressions. struct PartitionedBatches { RecordBatchVector batches; - ExpressionVector expressions; + std::vector expressions; }; virtual Result Partition( const std::shared_ptr& batch) const = 0; /// \brief Parse a path into a partition expression - virtual Result> Parse(const std::string& path) const = 0; + virtual Result Parse(const std::string& path) const = 0; - virtual Result Format(const Expression& expr) const = 0; + virtual Result Format(const Expression2& expr) const = 0; /// \brief A default Partitioning which always yields scalar(true) static std::shared_ptr Default(); @@ -122,23 +122,15 @@ class ARROW_DS_EXPORT KeyValuePartitioning : public Partitioning { std::string name, value; }; - static Status VisitKeys( - const Expression& expr, - const std::function& value)>& visitor); - - static Result>> GetKeys( - const Expression& expr); - - static Status SetDefaultValuesFromKeys(const Expression& expr, + static Status SetDefaultValuesFromKeys(const Expression2& expr, RecordBatchProjector* projector); Result Partition( const std::shared_ptr& batch) const override; - Result> Parse(const std::string& path) const override; + Result Parse(const std::string& path) const override; - Result Format(const Expression& expr) const override; + Result Format(const Expression2& expr) const override; protected: KeyValuePartitioning(std::shared_ptr schema, ArrayVector dictionaries) @@ -153,7 +145,7 @@ class ARROW_DS_EXPORT KeyValuePartitioning : public Partitioning { virtual Result FormatValues(const std::vector& values) const = 0; /// Convert a Key to a full expression. - Result> ConvertKey(const Key& key) const; + Result ConvertKey(const Key& key) const; ArrayVector dictionaries_; }; @@ -215,10 +207,9 @@ class ARROW_DS_EXPORT HivePartitioning : public KeyValuePartitioning { /// \brief Implementation provided by lambda or other callable class ARROW_DS_EXPORT FunctionPartitioning : public Partitioning { public: - using ParseImpl = - std::function>(const std::string&)>; + using ParseImpl = std::function(const std::string&)>; - using FormatImpl = std::function(const Expression&)>; + using FormatImpl = std::function(const Expression2&)>; FunctionPartitioning(std::shared_ptr schema, ParseImpl parse_impl, FormatImpl format_impl = NULLPTR, std::string name = "function") @@ -229,11 +220,11 @@ class ARROW_DS_EXPORT FunctionPartitioning : public Partitioning { std::string type_name() const override { return name_; } - Result> Parse(const std::string& path) const override { + Result Parse(const std::string& path) const override { return parse_impl_(path); } - Result Format(const Expression& expr) const override { + Result Format(const Expression2& expr) const override { if (format_impl_) { return format_impl_(expr); } diff --git a/cpp/src/arrow/dataset/partition_test.cc b/cpp/src/arrow/dataset/partition_test.cc index b47f5ba9926..37e65cacac4 100644 --- a/cpp/src/arrow/dataset/partition_test.cc +++ b/cpp/src/arrow/dataset/partition_test.cc @@ -37,32 +37,40 @@ using internal::checked_pointer_cast; namespace dataset { -using E = TestExpression; - class TestPartitioning : public ::testing::Test { public: void AssertParseError(const std::string& path) { ASSERT_RAISES(Invalid, partitioning_->Parse(path)); } - void AssertParse(const std::string& path, E expected) { + void AssertParse(const std::string& path, Expression2 expected) { ASSERT_OK_AND_ASSIGN(auto parsed, partitioning_->Parse(path)); - ASSERT_EQ(E{parsed}, expected); + ASSERT_OK_AND_ASSIGN(std::tie(expected, std::ignore), + expected.Bind(*partitioning_->schema())); + ASSERT_EQ(parsed, expected); } template - void AssertFormatError(E expr) { - ASSERT_EQ(partitioning_->Format(*expr.expression).status().code(), code); + void AssertFormatError(Expression2 expr) { + ASSERT_OK_AND_ASSIGN(std::tie(expr, std::ignore), expr.Bind(*written_schema_)); + ASSERT_EQ(partitioning_->Format(expr).status().code(), code); } - void AssertFormat(E expr, const std::string& expected) { - ASSERT_OK_AND_ASSIGN(auto formatted, partitioning_->Format(*expr.expression)); + void AssertFormat(Expression2 expr, const std::string& expected) { + // formatted partition expressions are bound to the schema of the dataset being + // written + ASSERT_OK_AND_ASSIGN(std::tie(expr, std::ignore), expr.Bind(*written_schema_)); + + ASSERT_OK_AND_ASSIGN(auto formatted, partitioning_->Format(expr)); ASSERT_EQ(formatted, expected); // ensure the formatted path round trips the relevant components of the partition // expression: roundtripped should be a subset of expr - ASSERT_OK_AND_ASSIGN(auto roundtripped, partitioning_->Parse(formatted)); - ASSERT_EQ(E{roundtripped->Assume(*expr.expression)}, E{scalar(true)}); + ASSERT_OK_AND_ASSIGN(Expression2 roundtripped, partitioning_->Parse(formatted)); + + ASSERT_OK_AND_ASSIGN(auto bound, roundtripped.Bind(*partitioning_->schema())); + ASSERT_OK_AND_ASSIGN(auto simplified, SimplifyWithGuarantee(bound, expr)); + ASSERT_EQ(simplified.first, literal(true)); } void AssertInspect(const std::vector& paths, @@ -95,6 +103,7 @@ class TestPartitioning : public ::testing::Test { std::shared_ptr partitioning_; std::shared_ptr factory_; + std::shared_ptr written_schema_; }; TEST_F(TestPartitioning, DirectoryPartitioning) { @@ -103,10 +112,10 @@ TEST_F(TestPartitioning, DirectoryPartitioning) { AssertParse("/0/hello", "alpha"_ == int32_t(0) and "beta"_ == "hello"); AssertParse("/3", "alpha"_ == int32_t(3)); - AssertParseError("/world/0"); // reversed order - AssertParseError("/0.0/foo"); // invalid alpha - AssertParseError("/3.25"); // invalid alpha with missing beta - AssertParse("", scalar(true)); // no segments to parse + AssertParseError("/world/0"); // reversed order + AssertParseError("/0.0/foo"); // invalid alpha + AssertParseError("/3.25"); // invalid alpha with missing beta + AssertParse("", literal(true)); // no segments to parse // gotcha someday: AssertParse("/0/dat.parquet", "alpha"_ == int32_t(0) and "beta"_ == "dat.parquet"); @@ -118,15 +127,22 @@ TEST_F(TestPartitioning, DirectoryPartitioningFormat) { partitioning_ = std::make_shared( schema({field("alpha", int32()), field("beta", utf8())})); + written_schema_ = partitioning_->schema(); + AssertFormat("alpha"_ == int32_t(0) and "beta"_ == "hello", "0/hello"); AssertFormat("beta"_ == "hello" and "alpha"_ == int32_t(0), "0/hello"); AssertFormat("alpha"_ == int32_t(0), "0"); AssertFormatError("beta"_ == "hello"); - AssertFormat(scalar(true), ""); + AssertFormat(literal(true), ""); - AssertFormatError("alpha"_ == 0.0 and "beta"_ == "hello"); + ASSERT_OK_AND_ASSIGN(written_schema_, + written_schema_->AddField(0, field("gamma", utf8()))); AssertFormat("gamma"_ == "yo" and "alpha"_ == int32_t(0) and "beta"_ == "hello", "0/hello"); + + // written_schema_ is incompatible with partitioning_'s schema + written_schema_ = schema({field("alpha", utf8()), field("beta", utf8())}); + AssertFormatError("alpha"_ == "0.0" and "beta"_ == "hello"); } TEST_F(TestPartitioning, DirectoryPartitioningWithTemporal) { @@ -216,7 +232,7 @@ TEST_F(TestPartitioning, HivePartitioning) { AssertParse("/beta=3.25/alpha=0", "beta"_ == 3.25f and "alpha"_ == int32_t(0)); AssertParse("/alpha=0", "alpha"_ == int32_t(0)); AssertParse("/beta=3.25", "beta"_ == 3.25f); - AssertParse("", scalar(true)); + AssertParse("", literal(true)); AssertParse("/alpha=0/unexpected/beta=3.25", "alpha"_ == int32_t(0) and "beta"_ == 3.25f); @@ -224,7 +240,7 @@ TEST_F(TestPartitioning, HivePartitioning) { AssertParse("/alpha=0/beta=3.25/ignored=2341", "alpha"_ == int32_t(0) and "beta"_ == 3.25f); - AssertParse("/ignored=2341", scalar(true)); + AssertParse("/ignored=2341", literal(true)); AssertParseError("/alpha=0.0/beta=3.25"); // conversion of "0.0" to int32 fails } @@ -233,15 +249,22 @@ TEST_F(TestPartitioning, HivePartitioningFormat) { partitioning_ = std::make_shared( schema({field("alpha", int32()), field("beta", float32())})); + written_schema_ = partitioning_->schema(); + AssertFormat("alpha"_ == int32_t(0) and "beta"_ == 3.25f, "alpha=0/beta=3.25"); AssertFormat("beta"_ == 3.25f and "alpha"_ == int32_t(0), "alpha=0/beta=3.25"); AssertFormat("alpha"_ == int32_t(0), "alpha=0"); AssertFormat("beta"_ == 3.25f, "alpha/beta=3.25"); - AssertFormat(scalar(true), ""); + AssertFormat(literal(true), ""); - AssertFormatError("alpha"_ == "yo" and "beta"_ == 3.25f); + ASSERT_OK_AND_ASSIGN(written_schema_, + written_schema_->AddField(0, field("gamma", utf8()))); AssertFormat("gamma"_ == "yo" and "alpha"_ == int32_t(0) and "beta"_ == 3.25f, "alpha=0/beta=3.25"); + + // written_schema_ is incompatible with partitioning_'s schema + written_schema_ = schema({field("alpha", utf8()), field("beta", utf8())}); + AssertFormatError("alpha"_ == "0.0" and "beta"_ == "hello"); } TEST_F(TestPartitioning, DiscoverHiveSchema) { @@ -322,10 +345,12 @@ TEST_F(TestPartitioning, EtlThenHive) { FieldVector alphabeta_fields{field("alpha", int32()), field("beta", float32())}; HivePartitioning alphabeta_part(schema(alphabeta_fields)); - partitioning_ = std::make_shared( + auto schm = schema({field("year", int16()), field("month", int8()), field("day", int8()), - field("hour", int8()), field("alpha", int32()), field("beta", float32())}), - [&](const std::string& path) -> Result> { + field("hour", int8()), field("alpha", int32()), field("beta", float32())}); + + partitioning_ = std::make_shared( + schm, [&](const std::string& path) -> Result { auto segments = fs::internal::SplitAbstractPath(path); if (segments.size() < etl_fields.size() + alphabeta_fields.size()) { return Status::Invalid("path ", path, " can't be parsed"); @@ -341,7 +366,9 @@ TEST_F(TestPartitioning, EtlThenHive) { fs::internal::JoinAbstractPath(etl_segments_end, alphabeta_segments_end); ARROW_ASSIGN_OR_RAISE(auto alphabeta_expr, alphabeta_part.Parse(alphabeta_path)); - return and_(etl_expr, alphabeta_expr); + auto expr = and_(etl_expr, alphabeta_expr); + ARROW_ASSIGN_OR_RAISE(std::tie(expr, std::ignore), expr.Bind(*schm)); + return expr; }); AssertParse("/1999/12/31/00/alpha=0/beta=3.25", @@ -350,7 +377,7 @@ TEST_F(TestPartitioning, EtlThenHive) { ("alpha"_ == int32_t(0) and "beta"_ == 3.25f)); AssertParseError("/20X6/03/21/05/alpha=0/beta=3.25"); -} // namespace dataset +} TEST_F(TestPartitioning, Set) { auto ints = [](std::vector ints) { @@ -359,12 +386,13 @@ TEST_F(TestPartitioning, Set) { return out; }; + auto schm = schema({field("x", int32())}); + // An adhoc partitioning which parses segments like "/x in [1 4 5]" // into ("x"_ == 1 or "x"_ == 4 or "x"_ == 5) partitioning_ = std::make_shared( - schema({field("x", int32())}), - [&](const std::string& path) -> Result> { - ExpressionVector subexpressions; + schm, [&](const std::string& path) -> Result { + std::vector subexpressions; for (auto segment : fs::internal::SplitAbstractPath(path)) { std::smatch matches; @@ -380,7 +408,13 @@ TEST_F(TestPartitioning, Set) { set.push_back(checked_cast(*s).value); } - subexpressions.push_back(field_ref(matches[1])->In(ints(set)).Copy()); + auto is_in_expr = call("is_in", {field_ref(matches[1])}, + compute::SetLookupOptions{ints(set), true}); + + ARROW_ASSIGN_OR_RAISE(std::tie(is_in_expr, std::ignore), + is_in_expr.Bind(*schm)); + + subexpressions.push_back(is_in_expr); } return and_(std::move(subexpressions)); }); @@ -398,8 +432,8 @@ class RangePartitioning : public Partitioning { std::string type_name() const override { return "range"; } - Result> Parse(const std::string& path) const override { - ExpressionVector ranges; + Result Parse(const std::string& path) const override { + std::vector ranges; for (auto segment : fs::internal::SplitAbstractPath(path)) { auto key = HivePartitioning::ParseKey(segment); @@ -419,10 +453,13 @@ class RangePartitioning : public Partitioning { ARROW_ASSIGN_OR_RAISE(auto min, Scalar::Parse(type, min_repr)); ARROW_ASSIGN_OR_RAISE(auto max, Scalar::Parse(type, max_repr)); - ranges.push_back(and_(min_cmp(field_ref(key->name), scalar(min)), - max_cmp(field_ref(key->name), scalar(max)))); + ranges.push_back(and_(min_cmp(field_ref(key->name), literal(min)), + max_cmp(field_ref(key->name), literal(max)))); } - return and_(ranges); + + Expression2 expr; + ARROW_ASSIGN_OR_RAISE(std::tie(expr, std::ignore), and_(ranges).Bind(*schema_)); + return expr; } static Status DoRegex(const std::string& segment, std::smatch* matches) { @@ -442,7 +479,7 @@ class RangePartitioning : public Partitioning { return Status::OK(); } - Result Format(const Expression&) const override { return ""; } + Result Format(const Expression2&) const override { return ""; } Result Partition( const std::shared_ptr&) const override { return Status::OK(); diff --git a/cpp/src/arrow/dataset/scanner.cc b/cpp/src/arrow/dataset/scanner.cc index 019416041aa..67df665733b 100644 --- a/cpp/src/arrow/dataset/scanner.cc +++ b/cpp/src/arrow/dataset/scanner.cc @@ -27,6 +27,7 @@ #include "arrow/dataset/scanner_internal.h" #include "arrow/table.h" #include "arrow/util/iterator.h" +#include "arrow/util/logging.h" #include "arrow/util/task_group.h" #include "arrow/util/thread_pool.h" @@ -34,14 +35,12 @@ namespace arrow { namespace dataset { ScanOptions::ScanOptions(std::shared_ptr schema) - : evaluator(ExpressionEvaluator::Null()), - projector(RecordBatchProjector(std::move(schema))) {} + : projector(RecordBatchProjector(std::move(schema))) {} std::shared_ptr ScanOptions::ReplaceSchema( std::shared_ptr schema) const { auto copy = ScanOptions::Make(std::move(schema)); - copy->filter = filter; - copy->evaluator = evaluator; + copy->filter2 = filter2; copy->batch_size = batch_size; return copy; } @@ -53,8 +52,9 @@ std::vector ScanOptions::MaterializedFields() const { fields.push_back(f->name()); } - for (auto&& name : FieldsInExpression(filter)) { - fields.push_back(std::move(name)); + for (const FieldRef& ref : FieldsInExpression(filter2)) { + DCHECK(ref.name()); + fields.push_back(*ref.name()); } return fields; @@ -64,7 +64,7 @@ Result InMemoryScanTask::Execute() { return MakeVectorIterator(record_batches_); } -FragmentIterator Scanner::GetFragments() { +Result Scanner::GetFragments() { if (fragment_ != nullptr) { return MakeVectorIterator(FragmentVector{fragment_}); } @@ -72,14 +72,16 @@ FragmentIterator Scanner::GetFragments() { // Transform Datasets in a flat Iterator. This // iterator is lazily constructed, i.e. Dataset::GetFragments is // not invoked until a Fragment is requested. - return GetFragmentsFromDatasets({dataset_}, scan_options_->filter); + return GetFragmentsFromDatasets( + {dataset_}, {scan_options_->filter2, scan_context_->expression_state}); } Result Scanner::Scan() { // Transforms Iterator into a unified // Iterator. The first Iterator::Next invocation is going to do // all the work of unwinding the chained iterators. - return GetScanTaskIterator(GetFragments(), scan_options_, scan_context_); + ARROW_ASSIGN_OR_RAISE(auto fragment_it, GetFragments()); + return GetScanTaskIterator(std::move(fragment_it), scan_options_, scan_context_); } Result ScanTaskIteratorFromRecordBatch( @@ -95,15 +97,24 @@ ScannerBuilder::ScannerBuilder(std::shared_ptr dataset, : dataset_(std::move(dataset)), fragment_(nullptr), scan_options_(ScanOptions::Make(dataset_->schema())), - scan_context_(std::move(scan_context)) {} + scan_context_(std::move(scan_context)) { + DCHECK_OK(Filter(literal(true))); +} ScannerBuilder::ScannerBuilder(std::shared_ptr schema, std::shared_ptr fragment, std::shared_ptr scan_context) : dataset_(nullptr), fragment_(std::move(fragment)), - scan_options_(ScanOptions::Make(schema)), - scan_context_(std::move(scan_context)) {} + fragment_schema_(schema), + scan_options_(ScanOptions::Make(std::move(schema))), + scan_context_(std::move(scan_context)) { + DCHECK_OK(Filter(literal(true))); +} + +const std::shared_ptr& ScannerBuilder::schema() const { + return fragment_ ? fragment_schema_ : dataset_->schema(); +} Status ScannerBuilder::Project(std::vector columns) { RETURN_NOT_OK(schema()->CanReferenceFieldsByNames(columns)); @@ -112,15 +123,12 @@ Status ScannerBuilder::Project(std::vector columns) { return Status::OK(); } -Status ScannerBuilder::Filter(std::shared_ptr filter) { - RETURN_NOT_OK(schema()->CanReferenceFieldsByNames(FieldsInExpression(*filter))); - RETURN_NOT_OK(filter->Validate(*schema())); - scan_options_->filter = std::move(filter); +Status ScannerBuilder::Filter(const Expression2& filter) { + ARROW_ASSIGN_OR_RAISE(std::tie(scan_options_->filter2, scan_context_->expression_state), + filter.Bind(*schema())); return Status::OK(); } -Status ScannerBuilder::Filter(const Expression& filter) { return Filter(filter.Copy()); } - Status ScannerBuilder::UseThreads(bool use_threads) { scan_context_->use_threads = use_threads; return Status::OK(); @@ -143,10 +151,6 @@ Result> ScannerBuilder::Finish() const { scan_options = std::make_shared(*scan_options_); } - if (!scan_options->filter->Equals(true)) { - scan_options->evaluator = std::make_shared(); - } - if (dataset_ == nullptr) { return std::make_shared(fragment_, std::move(scan_options), scan_context_); } diff --git a/cpp/src/arrow/dataset/scanner.h b/cpp/src/arrow/dataset/scanner.h index 950d416f615..27c2e3e9ed7 100644 --- a/cpp/src/arrow/dataset/scanner.h +++ b/cpp/src/arrow/dataset/scanner.h @@ -25,6 +25,7 @@ #include #include "arrow/dataset/dataset.h" +#include "arrow/dataset/expression.h" #include "arrow/dataset/projector.h" #include "arrow/dataset/type_fwd.h" #include "arrow/dataset/visibility.h" @@ -47,6 +48,8 @@ struct ARROW_DS_EXPORT ScanContext { /// Return a threaded or serial TaskGroup according to use_threads. std::shared_ptr TaskGroup() const; + + std::shared_ptr expression_state; }; class ARROW_DS_EXPORT ScanOptions { @@ -62,10 +65,7 @@ class ARROW_DS_EXPORT ScanOptions { std::shared_ptr ReplaceSchema(std::shared_ptr schema) const; // Filter - std::shared_ptr filter = scalar(true); - - // Evaluator for Filter - std::shared_ptr evaluator; + Expression2 filter2 = literal(true); // Schema to which record batches will be reconciled const std::shared_ptr& schema() const { return projector.schema(); } @@ -172,7 +172,7 @@ class ARROW_DS_EXPORT Scanner { Result> ToTable(); /// \brief GetFragments returns an iterator over all Fragments in this scan. - FragmentIterator GetFragments(); + Result GetFragments(); const std::shared_ptr& schema() const { return scan_options_->schema(); } @@ -222,8 +222,7 @@ class ARROW_DS_EXPORT ScannerBuilder { /// /// \return Failure if any referenced columns does not exist in the dataset's /// Schema. - Status Filter(std::shared_ptr filter); - Status Filter(const Expression& filter); + Status Filter(const Expression2& filter); /// \brief Indicate if the Scanner should make use of the available /// ThreadPool found in ScanContext; @@ -240,11 +239,12 @@ class ARROW_DS_EXPORT ScannerBuilder { /// \brief Return the constructed now-immutable Scanner object Result> Finish() const; - std::shared_ptr schema() const { return scan_options_->schema(); } + const std::shared_ptr& schema() const; private: std::shared_ptr dataset_; std::shared_ptr fragment_; + std::shared_ptr fragment_schema_; std::shared_ptr scan_options_; std::shared_ptr scan_context_; bool has_projection_ = false; diff --git a/cpp/src/arrow/dataset/scanner_internal.h b/cpp/src/arrow/dataset/scanner_internal.h index 94df94470fa..f01ced8396b 100644 --- a/cpp/src/arrow/dataset/scanner_internal.h +++ b/cpp/src/arrow/dataset/scanner_internal.h @@ -20,6 +20,9 @@ #include #include +#include "arrow/array/array_nested.h" +#include "arrow/compute/api_vector.h" +#include "arrow/compute/exec.h" #include "arrow/dataset/dataset_internal.h" #include "arrow/dataset/filter.h" #include "arrow/dataset/partition.h" @@ -29,13 +32,27 @@ namespace arrow { namespace dataset { inline RecordBatchIterator FilterRecordBatch(RecordBatchIterator it, - const ExpressionEvaluator& evaluator, - const Expression& filter, MemoryPool* pool) { + Expression2::BoundWithState filter, + MemoryPool* pool) { return MakeMaybeMapIterator( - [&filter, &evaluator, pool](std::shared_ptr in) { - return evaluator.Evaluate(filter, *in, pool).Map([&](Datum selection) { - return evaluator.Filter(selection, in); - }); + [=](std::shared_ptr in) -> Result> { + compute::ExecContext exec_context{pool}; + ARROW_ASSIGN_OR_RAISE(Datum mask, + ExecuteScalarExpression(filter.first, filter.second.get(), + Datum(in), &exec_context)); + + if (mask.is_scalar()) { + const auto& mask_scalar = mask.scalar_as(); + if (mask_scalar.is_valid && mask_scalar.value) { + return std::move(in); + } + return in->Slice(0, 0); + } + + ARROW_ASSIGN_OR_RAISE( + Datum filtered, + compute::Filter(in, mask, compute::FilterOptions::Defaults(), &exec_context)); + return filtered.record_batch(); }, std::move(it)); } @@ -56,39 +73,41 @@ inline RecordBatchIterator ProjectRecordBatch(RecordBatchIterator it, class FilterAndProjectScanTask : public ScanTask { public: - explicit FilterAndProjectScanTask(std::shared_ptr task, - std::shared_ptr partition) + explicit FilterAndProjectScanTask(std::shared_ptr task, Expression2 partition) : ScanTask(task->options(), task->context()), task_(std::move(task)), partition_(std::move(partition)), - filter_(options()->filter->Assume(partition_)), + filter_(options()->filter2), projector_(options()->projector) {} Result Execute() override { ARROW_ASSIGN_OR_RAISE(auto it, task_->Execute()); - auto filter_it = - FilterRecordBatch(std::move(it), *options_->evaluator, *filter_, context_->pool); + ARROW_ASSIGN_OR_RAISE( + Expression2::BoundWithState simplified_filter, + SimplifyWithGuarantee({filter_, context_->expression_state}, partition_)); + + RecordBatchIterator filter_it = + FilterRecordBatch(std::move(it), simplified_filter, context_->pool); + + RETURN_NOT_OK( + KeyValuePartitioning::SetDefaultValuesFromKeys(partition_, &projector_)); - if (partition_) { - RETURN_NOT_OK( - KeyValuePartitioning::SetDefaultValuesFromKeys(*partition_, &projector_)); - } return ProjectRecordBatch(std::move(filter_it), &projector_, context_->pool); } private: std::shared_ptr task_; - std::shared_ptr partition_; - std::shared_ptr filter_; + Expression2 partition_; + Expression2 filter_; RecordBatchProjector projector_; }; /// \brief GetScanTaskIterator transforms an Iterator in a /// flattened Iterator. -inline ScanTaskIterator GetScanTaskIterator(FragmentIterator fragments, - std::shared_ptr options, - std::shared_ptr context) { +inline Result GetScanTaskIterator( + FragmentIterator fragments, std::shared_ptr options, + std::shared_ptr context) { // Fragment -> ScanTaskIterator auto fn = [options, context](std::shared_ptr fragment) -> Result { @@ -112,42 +131,5 @@ inline ScanTaskIterator GetScanTaskIterator(FragmentIterator fragments, return MakeFlattenIterator(std::move(maybe_scantask_it)); } -struct FragmentRecordBatchReader : RecordBatchReader { - public: - std::shared_ptr schema() const override { return options_->schema(); } - - Status ReadNext(std::shared_ptr* batch) override { - return iterator_.Next().Value(batch); - } - - static Result> Make( - std::shared_ptr fragment, std::shared_ptr schema, - std::shared_ptr context) { - // ensure schema is cached in fragment - auto options = ScanOptions::Make(std::move(schema)); - RETURN_NOT_OK(KeyValuePartitioning::SetDefaultValuesFromKeys( - *fragment->partition_expression(), &options->projector)); - - auto pool = context->pool; - ARROW_ASSIGN_OR_RAISE(auto scan_tasks, fragment->Scan(options, std::move(context))); - - auto reader = std::make_shared(); - reader->options_ = std::move(options); - reader->fragment_ = std::move(fragment); - reader->iterator_ = ProjectRecordBatch( - MakeFlattenIterator(MakeMaybeMapIterator( - [](std::shared_ptr task) { return task->Execute(); }, - std::move(scan_tasks))), - &reader->options_->projector, pool); - - return reader; - } - - private: - std::shared_ptr options_; - std::shared_ptr fragment_; - RecordBatchIterator iterator_; -}; - } // namespace dataset } // namespace arrow diff --git a/cpp/src/arrow/dataset/scanner_test.cc b/cpp/src/arrow/dataset/scanner_test.cc index 36966760346..d79cc49faac 100644 --- a/cpp/src/arrow/dataset/scanner_test.cc +++ b/cpp/src/arrow/dataset/scanner_test.cc @@ -55,7 +55,7 @@ class TestScanner : public DatasetFixtureMixin { // structures of the scanner, i.e. Scanner[Dataset[ScanTask[RecordBatch]]] AssertScannerEquals(expected.get(), &scanner); } -}; // namespace dataset +}; constexpr int64_t TestScanner::kNumberChildDatasets; constexpr int64_t TestScanner::kNumberBatches; @@ -88,8 +88,7 @@ TEST_F(TestScanner, FilteredScan) { value += 1.0; })); - options_->filter = ("f64"_ > 0.0).Copy(); - options_->evaluator = std::make_shared(); + SetFilter(greater(field_ref("f64"), literal(0.0))); auto batch = RecordBatch::Make(schema_, f64->length(), {f64}); @@ -148,7 +147,7 @@ TEST_F(TestScanner, ToTable) { } class TestScannerBuilder : public ::testing::Test { - void SetUp() { + void SetUp() override { DatasetVector sources; schema_ = schema({ @@ -163,7 +162,7 @@ class TestScannerBuilder : public ::testing::Test { } protected: - std::shared_ptr ctx_; + std::shared_ptr ctx_ = std::make_shared(); std::shared_ptr schema_; std::shared_ptr dataset_; }; @@ -184,14 +183,28 @@ TEST_F(TestScannerBuilder, TestProject) { TEST_F(TestScannerBuilder, TestFilter) { ScannerBuilder builder(dataset_, ctx_); - ASSERT_OK(builder.Filter(scalar(true))); - ASSERT_OK(builder.Filter("i64"_ == int64_t(10))); - ASSERT_OK(builder.Filter("i64"_ == int64_t(10) || "b"_ == true)); - - ASSERT_RAISES(TypeError, builder.Filter("i64"_ == int32_t(10))); - ASSERT_RAISES(Invalid, builder.Filter("not_a_column"_ == true)); - ASSERT_RAISES(Invalid, - builder.Filter("i64"_ == int64_t(10) || "not_a_column"_ == true)); + ASSERT_OK(builder.Filter(literal(true))); + ASSERT_OK(builder.Filter(call("equal", {field_ref("i64"), literal(10)}))); + ASSERT_OK(builder.Filter( + call("or_kleene", { + call("equal", {field_ref("i64"), literal(10)}), + call("equal", {field_ref("b"), literal(true)}), + }))); + + ASSERT_RAISES(NotImplemented, + builder.Filter(call("equal", {field_ref("i64"), literal(10)}))); + + ASSERT_RAISES( + NotImplemented, + builder.Filter(call("equal", {field_ref("not_a_column"), literal(10)}))); + + ASSERT_RAISES( + NotImplemented, + builder.Filter( + call("or_kleene", { + call("equal", {field_ref("i64"), literal(10)}), + call("equal", {field_ref("not_a_column"), literal(true)}), + }))); } using testing::ElementsAre; @@ -204,7 +217,7 @@ TEST(ScanOptions, TestMaterializedFields) { auto opts = ScanOptions::Make(schema({})); EXPECT_THAT(opts->MaterializedFields(), IsEmpty()); - opts->filter = ("i32"_ == 10).Copy(); + opts->filter2 = call("equal", {field_ref("i32"), literal(10)}); EXPECT_THAT(opts->MaterializedFields(), ElementsAre("i32")); opts = ScanOptions::Make(schema({i32, i64})); @@ -213,10 +226,10 @@ TEST(ScanOptions, TestMaterializedFields) { opts = opts->ReplaceSchema(schema({i32})); EXPECT_THAT(opts->MaterializedFields(), ElementsAre("i32")); - opts->filter = ("i32"_ == 10).Copy(); + opts->filter2 = call("equal", {field_ref("i32"), literal(10)}); EXPECT_THAT(opts->MaterializedFields(), ElementsAre("i32", "i32")); - opts->filter = ("i64"_ == 10).Copy(); + opts->filter2 = call("equal", {field_ref("i64"), literal(10)}); EXPECT_THAT(opts->MaterializedFields(), ElementsAre("i32", "i64")); } diff --git a/cpp/src/arrow/dataset/test_util.h b/cpp/src/arrow/dataset/test_util.h index 9cfba41a624..70c9b7548e8 100644 --- a/cpp/src/arrow/dataset/test_util.h +++ b/cpp/src/arrow/dataset/test_util.h @@ -50,6 +50,43 @@ namespace arrow { namespace dataset { +const std::shared_ptr kBoringSchema = schema({ + field("i32", int32()), + field("i32_req", int32(), /*nullable=*/false), + field("f32", float32()), + field("f32_req", float32(), /*nullable=*/false), + field("bool", boolean()), +}); + +inline namespace string_literals { +// clang-format off +inline FieldExpression operator"" _(const char* name, size_t name_length) { + // clang-format on + return FieldExpression({name, name_length}); +} +} // namespace string_literals + +#define COMPARISON_FACTORY(NAME, FACTORY_NAME, OP) \ + template ::type>::value>::type> \ + ComparisonExpression operator OP(const Expression& lhs, T&& rhs) { \ + return ComparisonExpression(CompareOperator::NAME, lhs.Copy(), \ + scalar(std::forward(rhs))); \ + } \ + \ + inline ComparisonExpression operator OP(const Expression& lhs, \ + const Expression& rhs) { \ + return ComparisonExpression(CompareOperator::NAME, lhs.Copy(), rhs.Copy()); \ + } + +COMPARISON_FACTORY(EQUAL, equal, ==) +COMPARISON_FACTORY(NOT_EQUAL, not_equal, !=) +COMPARISON_FACTORY(GREATER, greater, >) +COMPARISON_FACTORY(GREATER_EQUAL, greater_equal, >=) +COMPARISON_FACTORY(LESS, less, <) +COMPARISON_FACTORY(LESS_EQUAL, less_equal, <=) +#undef COMPARISON_FACTORY + using fs::internal::GetAbstractPathExtension; using internal::checked_cast; using internal::checked_pointer_cast; @@ -66,7 +103,7 @@ template class GeneratedRecordBatch : public RecordBatchReader { public: GeneratedRecordBatch(std::shared_ptr schema, Gen gen) - : schema_(schema), gen_(gen) {} + : schema_(std::move(schema)), gen_(gen) {} std::shared_ptr schema() const override { return schema_; } @@ -101,8 +138,6 @@ void EnsureRecordBatchReaderDrained(RecordBatchReader* reader) { class DatasetFixtureMixin : public ::testing::Test { public: - DatasetFixtureMixin() : ctx_(std::make_shared()) {} - /// \brief Ensure that record batches found in reader are equals to the /// record batches yielded by the data fragment. void AssertScanTaskEquals(RecordBatchReader* expected, ScanTask* task, @@ -141,7 +176,8 @@ class DatasetFixtureMixin : public ::testing::Test { /// record batches yielded by the data fragments of a dataset. void AssertDatasetFragmentsEqual(RecordBatchReader* expected, Dataset* dataset, bool ensure_drained = true) { - auto it = dataset->GetFragments(options_->filter); + ASSERT_OK_AND_ASSIGN(auto predicate, options_->filter2.Bind(*dataset->schema())); + ASSERT_OK_AND_ASSIGN(auto it, dataset->GetFragments(predicate)); ARROW_EXPECT_OK(it.Visit([&](std::shared_ptr fragment) -> Status { AssertFragmentEquals(expected, fragment.get(), false); @@ -186,11 +222,17 @@ class DatasetFixtureMixin : public ::testing::Test { void SetSchema(std::vector> fields) { schema_ = schema(std::move(fields)); options_ = ScanOptions::Make(schema_); + SetFilter(literal(true)); + } + + void SetFilter(Expression2 filter) { + ASSERT_OK_AND_ASSIGN(std::tie(options_->filter2, ctx_->expression_state), + filter.Bind(*schema_)); } std::shared_ptr schema_; std::shared_ptr options_; - std::shared_ptr ctx_; + std::shared_ptr ctx_ = std::make_shared(); }; /// \brief A dummy FileFormat implementation @@ -321,15 +363,14 @@ struct MakeFileSystemDatasetMixin { } void MakeDataset(const std::vector& infos, - std::shared_ptr root_partition = scalar(true), - ExpressionVector partitions = {}) { + Expression2 root_partition = literal(true), + std::vector partitions = {}, + std::shared_ptr s = kBoringSchema) { auto n_fragments = infos.size(); if (partitions.empty()) { - partitions.resize(n_fragments, scalar(true)); + partitions.resize(n_fragments, literal(true)); } - auto s = schema({}); - MakeFileSystem(infos); auto format = std::make_shared(s); @@ -340,21 +381,17 @@ struct MakeFileSystemDatasetMixin { continue; } + ASSERT_OK_AND_ASSIGN(std::tie(partitions[i], std::ignore), partitions[i].Bind(*s)); ASSERT_OK_AND_ASSIGN(auto fragment, format->MakeFragment({info, fs_}, partitions[i])); fragments.push_back(std::move(fragment)); } + ASSERT_OK_AND_ASSIGN(std::tie(root_partition, std::ignore), root_partition.Bind(*s)); ASSERT_OK_AND_ASSIGN(dataset_, FileSystemDataset::Make(s, root_partition, format, fs_, std::move(fragments))); } - void MakeDatasetFromPathlist(const std::string& pathlist, - std::shared_ptr root_partition = scalar(true), - ExpressionVector partitions = {}) { - MakeDataset(ParsePathList(pathlist), root_partition, partitions); - } - std::shared_ptr fs_; std::shared_ptr dataset_; std::shared_ptr options_ = ScanOptions::Make(schema({})); @@ -387,49 +424,24 @@ void AssertFragmentsAreFromPath(FragmentIterator it, std::vector ex testing::UnorderedElementsAreArray(expected)); } -// A frozen shared_ptr with behavior expected by GTest -struct TestExpression : util::EqualityComparable, - util::ToStringOstreamable { - // NOLINTNEXTLINE runtime/explicit - TestExpression(std::shared_ptr e) : expression(std::move(e)) {} - - // NOLINTNEXTLINE runtime/explicit - TestExpression(const Expression& e) : expression(e.Copy()) {} - - std::shared_ptr expression; - - using util::EqualityComparable::operator==; - bool Equals(const TestExpression& other) const { - return expression->Equals(other.expression); - } - - std::string ToString() const { return expression->ToString(); } - - friend bool operator==(const std::shared_ptr& lhs, - const TestExpression& rhs) { - return TestExpression(lhs) == rhs; - } - - friend void PrintTo(const TestExpression& expr, std::ostream* os) { - *os << expr.ToString(); - } -}; - -static std::vector PartitionExpressionsOf( - const FragmentVector& fragments) { - std::vector partition_expressions; +static std::vector PartitionExpressionsOf(const FragmentVector& fragments) { + std::vector partition_expressions; std::transform(fragments.begin(), fragments.end(), std::back_inserter(partition_expressions), [](const std::shared_ptr& fragment) { - return TestExpression(fragment->partition_expression()); + return fragment->partition_expression(); }); return partition_expressions; } -void AssertFragmentsHavePartitionExpressions(FragmentIterator it, - ExpressionVector expected) { +void AssertFragmentsHavePartitionExpressions(std::shared_ptr dataset, + std::vector expected) { + ASSERT_OK_AND_ASSIGN(auto fragment_it, dataset->GetFragments()); + for (auto& expr : expected) { + ASSERT_OK_AND_ASSIGN(std::tie(expr, std::ignore), expr.Bind(*dataset->schema())); + } // Ordering is not guaranteed. - EXPECT_THAT(PartitionExpressionsOf(IteratorToVector(std::move(it))), + EXPECT_THAT(PartitionExpressionsOf(IteratorToVector(std::move(fragment_it))), testing::UnorderedElementsAreArray(expected)); } @@ -456,7 +468,7 @@ struct ArithmeticDatasetFixture { /// 2}, static std::string JSONRecordFor(int64_t n) { std::stringstream ss; - int32_t n_i32 = static_cast(n); + auto n_i32 = static_cast(n); ss << "{"; ss << "\"i64\": " << n << ", "; @@ -579,6 +591,9 @@ class WriteFileSystemDatasetMixin : public MakeFileSystemDatasetMixin { ASSERT_OK_AND_ASSIGN(dataset_, factory->Finish()); scan_options_ = ScanOptions::Make(source_schema_); + ASSERT_OK_AND_ASSIGN( + std::tie(scan_options_->filter2, scan_context_->expression_state), + literal(true).Bind(*source_schema_)); } void SetWriteOptions(std::shared_ptr file_write_options) { @@ -741,7 +756,8 @@ class WriteFileSystemDatasetMixin : public MakeFileSystemDatasetMixin { } EXPECT_THAT(actual_paths, testing::UnorderedElementsAreArray(expected_paths)); - for (auto maybe_fragment : written_->GetFragments()) { + ASSERT_OK_AND_ASSIGN(auto written_fragments_it, written_->GetFragments()); + for (auto maybe_fragment : written_fragments_it) { ASSERT_OK_AND_ASSIGN(auto fragment, maybe_fragment); ASSERT_OK_AND_ASSIGN(auto actual_physical_schema, fragment->ReadPhysicalSchema()); diff --git a/cpp/src/arrow/dataset/type_fwd.h b/cpp/src/arrow/dataset/type_fwd.h index 522def49818..e4bd322ba10 100644 --- a/cpp/src/arrow/dataset/type_fwd.h +++ b/cpp/src/arrow/dataset/type_fwd.h @@ -69,12 +69,6 @@ class ParquetFileWriter; class ParquetFileWriteOptions; class Expression; -using ExpressionVector = std::vector>; -class ExpressionEvaluator; - -/// forward declared to facilitate scalar(true) as a default for Expression parameters -ARROW_DS_EXPORT -const std::shared_ptr& scalar(bool); struct ExpressionState; class Expression2; diff --git a/cpp/src/arrow/datum.cc b/cpp/src/arrow/datum.cc index c1666adc4c5..0f2fa8ebe90 100644 --- a/cpp/src/arrow/datum.cc +++ b/cpp/src/arrow/datum.cc @@ -70,6 +70,7 @@ Datum::Datum(float value) : value(std::make_shared(value)) {} Datum::Datum(double value) : value(std::make_shared(value)) {} Datum::Datum(std::string value) : value(std::make_shared(std::move(value))) {} +Datum::Datum(const char* value) : value(std::make_shared(value)) {} Datum::Datum(const ChunkedArray& value) : value(std::make_shared(value.chunks(), value.type())) {} @@ -198,6 +199,8 @@ static std::string FormatValueDescr(const ValueDescr& descr) { std::string ValueDescr::ToString() const { return FormatValueDescr(*this); } +void PrintTo(const ValueDescr& descr, std::ostream* os) { *os << descr.ToString(); } + std::string Datum::ToString() const { switch (this->kind()) { case Datum::NONE: diff --git a/cpp/src/arrow/datum.h b/cpp/src/arrow/datum.h index 5f80551f337..9c620279c9e 100644 --- a/cpp/src/arrow/datum.h +++ b/cpp/src/arrow/datum.h @@ -89,6 +89,8 @@ struct ARROW_EXPORT ValueDescr { bool operator!=(const ValueDescr& other) const { return !(*this == other); } std::string ToString() const; + + ARROW_EXPORT friend void PrintTo(const ValueDescr&, std::ostream*); }; /// \brief For use with scalar functions, returns the broadcasted Value::Shape @@ -161,6 +163,7 @@ struct ARROW_EXPORT Datum { explicit Datum(float value); explicit Datum(double value); explicit Datum(std::string value); + explicit Datum(const char* value); Datum::Kind kind() const { switch (this->value.index()) { diff --git a/cpp/src/arrow/type.cc b/cpp/src/arrow/type.cc index 5604c5e2717..20dc0f14ca0 100644 --- a/cpp/src/arrow/type.cc +++ b/cpp/src/arrow/type.cc @@ -1501,7 +1501,7 @@ std::shared_ptr Schema::WithMetadata( return std::make_shared(impl_->fields_, metadata); } -std::shared_ptr Schema::metadata() const { +const std::shared_ptr& Schema::metadata() const { return impl_->metadata_; } diff --git a/cpp/src/arrow/type.h b/cpp/src/arrow/type.h index 3594c248e10..cadb819b9f2 100644 --- a/cpp/src/arrow/type.h +++ b/cpp/src/arrow/type.h @@ -1690,7 +1690,7 @@ class ARROW_EXPORT Schema : public detail::Fingerprintable, /// \brief The custom key-value metadata, if any /// /// \return metadata may be null - std::shared_ptr metadata() const; + const std::shared_ptr& metadata() const; /// \brief Render a string representation of the schema suitable for debugging /// \param[in] show_metadata when true, if KeyValueMetadata is non-empty, diff --git a/cpp/src/arrow/util/variant.h b/cpp/src/arrow/util/variant.h index 4713976d485..89f39ab8917 100644 --- a/cpp/src/arrow/util/variant.h +++ b/cpp/src/arrow/util/variant.h @@ -389,6 +389,22 @@ U& get(Variant& v) { return *v.template get(); } +/// \brief Get a const pointer to the value held by the variant +/// +/// If the type given as template argument doesn't match, a nullptr is returned. +template +const U* get_if(const Variant* v) { + return v->template get(); +} + +/// \brief Get a pointer to the value held by the variant +/// +/// If the type given as template argument doesn't match, a nullptr is returned. +template +U* get_if(Variant* v) { + return v->template get(); +} + namespace detail { template From d9c6ae5750fb753a46b98a56d889990685dcdb15 Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Wed, 9 Dec 2020 13:32:16 -0500 Subject: [PATCH 03/31] remove ExpressionState --- cpp/src/arrow/dataset/dataset.cc | 11 +- cpp/src/arrow/dataset/dataset.h | 9 +- cpp/src/arrow/dataset/dataset_internal.h | 4 +- cpp/src/arrow/dataset/discovery_test.cc | 2 +- cpp/src/arrow/dataset/expression.cc | 170 ++++++-------------- cpp/src/arrow/dataset/expression.h | 37 +++-- cpp/src/arrow/dataset/expression_internal.h | 43 ++--- cpp/src/arrow/dataset/expression_test.cc | 104 +++++------- cpp/src/arrow/dataset/file_base.cc | 9 +- cpp/src/arrow/dataset/file_base.h | 2 +- cpp/src/arrow/dataset/file_parquet.cc | 25 ++- cpp/src/arrow/dataset/file_parquet.h | 6 +- cpp/src/arrow/dataset/file_parquet_test.cc | 4 +- cpp/src/arrow/dataset/partition.cc | 4 +- cpp/src/arrow/dataset/partition_test.cc | 20 +-- cpp/src/arrow/dataset/scanner.cc | 6 +- cpp/src/arrow/dataset/scanner.h | 2 - cpp/src/arrow/dataset/scanner_internal.h | 11 +- cpp/src/arrow/dataset/test_util.h | 12 +- cpp/src/arrow/dataset/type_fwd.h | 2 - 20 files changed, 173 insertions(+), 310 deletions(-) diff --git a/cpp/src/arrow/dataset/dataset.cc b/cpp/src/arrow/dataset/dataset.cc index 56c901f959d..5da90e48ebc 100644 --- a/cpp/src/arrow/dataset/dataset.cc +++ b/cpp/src/arrow/dataset/dataset.cc @@ -111,11 +111,11 @@ Result Dataset::GetFragments() { return GetFragments(std::move(predicate)); } -Result Dataset::GetFragments(Expression2::BoundWithState predicate) { +Result Dataset::GetFragments(Expression2 predicate) { ARROW_ASSIGN_OR_RAISE( predicate, SimplifyWithGuarantee(std::move(predicate), partition_expression_)); - return predicate.first.IsSatisfiable() ? GetFragmentsImpl(std::move(predicate)) - : MakeEmptyIterator>(); + return predicate.IsSatisfiable() ? GetFragmentsImpl(std::move(predicate)) + : MakeEmptyIterator>(); } struct VectorRecordBatchGenerator : InMemoryDataset::RecordBatchGenerator { @@ -155,7 +155,7 @@ Result> InMemoryDataset::ReplaceSchema( return std::make_shared(std::move(schema), get_batches_); } -Result InMemoryDataset::GetFragmentsImpl(Expression2::BoundWithState) { +Result InMemoryDataset::GetFragmentsImpl(Expression2) { auto schema = this->schema(); auto create_fragment = @@ -196,8 +196,7 @@ Result> UnionDataset::ReplaceSchema( new UnionDataset(std::move(schema), std::move(children))); } -Result UnionDataset::GetFragmentsImpl( - Expression2::BoundWithState predicate) { +Result UnionDataset::GetFragmentsImpl(Expression2 predicate) { return GetFragmentsFromDatasets(children_, predicate); } diff --git a/cpp/src/arrow/dataset/dataset.h b/cpp/src/arrow/dataset/dataset.h index 35c91b79fed..88782831670 100644 --- a/cpp/src/arrow/dataset/dataset.h +++ b/cpp/src/arrow/dataset/dataset.h @@ -122,7 +122,7 @@ class ARROW_DS_EXPORT Dataset : public std::enable_shared_from_this { Result> NewScan(); /// \brief GetFragments returns an iterator of Fragments given a predicate. - Result GetFragments(Expression2::BoundWithState predicate); + Result GetFragments(Expression2 predicate); Result GetFragments(); const std::shared_ptr& schema() const { return schema_; } @@ -148,8 +148,7 @@ class ARROW_DS_EXPORT Dataset : public std::enable_shared_from_this { Dataset(std::shared_ptr schema, Expression2 partition_expression); - virtual Result GetFragmentsImpl( - Expression2::BoundWithState predicate) = 0; + virtual Result GetFragmentsImpl(Expression2 predicate) = 0; std::shared_ptr schema_; Expression2 partition_expression_ = literal(true); @@ -182,7 +181,7 @@ class ARROW_DS_EXPORT InMemoryDataset : public Dataset { protected: Result GetFragmentsImpl( - Expression2::BoundWithState predicate) override; + Expression2 predicate) override; std::shared_ptr get_batches_; }; @@ -207,7 +206,7 @@ class ARROW_DS_EXPORT UnionDataset : public Dataset { protected: Result GetFragmentsImpl( - Expression2::BoundWithState predicate) override; + Expression2 predicate) override; explicit UnionDataset(std::shared_ptr schema, DatasetVector children) : Dataset(std::move(schema)), children_(std::move(children)) {} diff --git a/cpp/src/arrow/dataset/dataset_internal.h b/cpp/src/arrow/dataset/dataset_internal.h index a6ee8a117bc..9e070192cd0 100644 --- a/cpp/src/arrow/dataset/dataset_internal.h +++ b/cpp/src/arrow/dataset/dataset_internal.h @@ -35,8 +35,8 @@ namespace dataset { /// \brief GetFragmentsFromDatasets transforms a vector into a /// flattened FragmentIterator. -inline Result GetFragmentsFromDatasets( - const DatasetVector& datasets, Expression2::BoundWithState predicate) { +inline Result GetFragmentsFromDatasets(const DatasetVector& datasets, + Expression2 predicate) { // Iterator auto datasets_it = MakeVectorIterator(datasets); diff --git a/cpp/src/arrow/dataset/discovery_test.cc b/cpp/src/arrow/dataset/discovery_test.cc index cd0ba2bee4d..219ee070fdc 100644 --- a/cpp/src/arrow/dataset/discovery_test.cc +++ b/cpp/src/arrow/dataset/discovery_test.cc @@ -375,7 +375,7 @@ TEST_F(FileSystemDatasetFactoryTest, FilenameNotPartOfPartitions) { ASSERT_OK_AND_ASSIGN(auto fragment_it, dataset->GetFragments()); for (const auto& maybe_fragment : fragment_it) { ASSERT_OK_AND_ASSIGN(auto fragment, maybe_fragment); - EXPECT_EQ(fragment->partition_expression(), expected.first); + EXPECT_EQ(fragment->partition_expression(), expected); } } diff --git a/cpp/src/arrow/dataset/expression.cc b/cpp/src/arrow/dataset/expression.cc index c762cb89e0f..8f9a3ce5663 100644 --- a/cpp/src/arrow/dataset/expression.cc +++ b/cpp/src/arrow/dataset/expression.cc @@ -22,7 +22,6 @@ #include "arrow/chunked_array.h" #include "arrow/compute/exec_internal.h" -#include "arrow/compute/registry.h" #include "arrow/dataset/expression_internal.h" #include "arrow/dataset/filter.h" #include "arrow/io/memory.h" @@ -379,67 +378,29 @@ Result> InitKernelState( return std::move(kernel_state); } -Result> CloneExpressionState( - const ExpressionState& state, compute::ExecContext* exec_context) { - auto clone = std::make_shared(); - clone->kernel_states.reserve(state.kernel_states.size()); - - for (const auto& sub_state : state.kernel_states) { - auto call = CallNotNull(sub_state.first); - - if (KernelStateIsImmutable(call->function)) { - // The kernel's state is immutable so it's safe to just - // share a pointer between threads - clone->kernel_states.insert(sub_state); - continue; - } - - // The kernel's state must be re-initialized. - ARROW_ASSIGN_OR_RAISE(auto kernel_state, InitKernelState(*call, exec_context)); - clone->kernel_states.emplace(sub_state.first, std::move(kernel_state)); - } - - return clone; -} - -Result Expression2::Bind( - ValueDescr in, compute::ExecContext* exec_context) const { +Result Expression2::Bind(ValueDescr in, + compute::ExecContext* exec_context) const { if (exec_context == nullptr) { compute::ExecContext exec_context; return Bind(std::move(in), &exec_context); } - BoundWithState ret{*this, std::make_shared()}; - - if (literal()) return ret; + if (literal()) return *this; if (auto ref = field_ref()) { ARROW_ASSIGN_OR_RAISE(auto field, ref->GetOneOrNone(*in.type)); - ret.first.descr_ = - field ? ValueDescr{field->type(), in.shape} : ValueDescr::Scalar(null()); - return ret; + auto out = *this; + out.descr_ = field ? ValueDescr{field->type(), in.shape} : ValueDescr::Scalar(null()); + return out; } auto bound_call = *CallNotNull(*this); - std::shared_ptr function; - if (bound_call.function == "cast") { - // XXX this special case is strange; why not make "cast" a ScalarFunction? - const auto& to_type = - checked_cast(*bound_call.options).to_type; - ARROW_ASSIGN_OR_RAISE(function, compute::GetCastFunction(to_type)); - } else { - ARROW_ASSIGN_OR_RAISE( - function, exec_context->func_registry()->GetFunction(bound_call.function)); - } + ARROW_ASSIGN_OR_RAISE(auto function, GetFunction(bound_call, exec_context)); bound_call.function_kind = function->kind(); - auto state = std::make_shared(); - for (size_t i = 0; i < bound_call.arguments.size(); ++i) { - std::shared_ptr argument_state; - ARROW_ASSIGN_OR_RAISE(std::tie(bound_call.arguments[i], argument_state), - bound_call.arguments[i].Bind(in, exec_context)); - state->MoveFrom(argument_state.get()); + for (auto&& argument : bound_call.arguments) { + ARROW_ASSIGN_OR_RAISE(argument, argument.Bind(in, exec_context)); } if (RequriesDictionaryTransparency(bound_call)) { @@ -447,7 +408,7 @@ Result Expression2::Bind( } auto descrs = GetDescriptors(bound_call.arguments); - for (auto& descr : descrs) { + for (auto&& descr : descrs) { if (RequriesDictionaryTransparency(bound_call)) { RETURN_NOT_OK(EnsureNotDictionary(&descr)); } @@ -456,29 +417,26 @@ Result Expression2::Bind( ARROW_ASSIGN_OR_RAISE(bound_call.kernel, function->DispatchExact(descrs)); compute::KernelContext kernel_context(exec_context); - ARROW_ASSIGN_OR_RAISE(auto kernel_state, InitKernelState(bound_call, exec_context)); - kernel_context.SetState(kernel_state.get()); + ARROW_ASSIGN_OR_RAISE(bound_call.kernel_state, + InitKernelState(bound_call, exec_context)); + kernel_context.SetState(bound_call.kernel_state.get()); ARROW_ASSIGN_OR_RAISE(auto descr, bound_call.kernel->signature->out_type().Resolve( &kernel_context, descrs)); - Expression2 bound(std::make_shared(std::move(bound_call)), std::move(descr)); - - state->kernel_states.emplace(bound, std::move(kernel_state)); - return BoundWithState{std::move(bound), std::move(state)}; + return Expression2(std::make_shared(std::move(bound_call)), std::move(descr)); } -Result Expression2::Bind( - const Schema& in_schema, compute::ExecContext* exec_context) const { +Result Expression2::Bind(const Schema& in_schema, + compute::ExecContext* exec_context) const { return Bind(ValueDescr::Array(struct_(in_schema.fields())), exec_context); } -Result ExecuteScalarExpression(const Expression2& expr, ExpressionState* state, - const Datum& input, +Result ExecuteScalarExpression(const Expression2& expr, const Datum& input, compute::ExecContext* exec_context) { if (exec_context == nullptr) { compute::ExecContext exec_context; - return ExecuteScalarExpression(expr, state, input, &exec_context); + return ExecuteScalarExpression(expr, input, &exec_context); } if (!expr.IsBound()) { @@ -510,8 +468,8 @@ Result ExecuteScalarExpression(const Expression2& expr, ExpressionState* std::vector arguments(call->arguments.size()); for (size_t i = 0; i < arguments.size(); ++i) { - ARROW_ASSIGN_OR_RAISE(arguments[i], ExecuteScalarExpression(call->arguments[i], state, - input, exec_context)); + ARROW_ASSIGN_OR_RAISE( + arguments[i], ExecuteScalarExpression(call->arguments[i], input, exec_context)); if (RequriesDictionaryTransparency(*call)) { RETURN_NOT_OK(EnsureNotDictionary(&arguments[i])); @@ -521,7 +479,7 @@ Result ExecuteScalarExpression(const Expression2& expr, ExpressionState* auto executor = compute::detail::KernelExecutor::MakeScalar(); compute::KernelContext kernel_context(exec_context); - kernel_context.SetState(state->Get(expr)); + kernel_context.SetState(call->kernel_state.get()); auto kernel = call->kernel; auto descrs = GetDescriptors(arguments); @@ -601,7 +559,7 @@ std::vector FieldsInExpression(const Expression2& expr) { } template -Result Modify(Expression2 expr, ExpressionState* state, const PreVisit& pre, +Result Modify(Expression2 expr, const PreVisit& pre, const PostVisitCall& post_call) { ARROW_ASSIGN_OR_RAISE(expr, Result(pre(std::move(expr)))); @@ -613,7 +571,7 @@ Result Modify(Expression2 expr, ExpressionState* state, const PreVi auto modified_argument = modified_call.arguments.begin(); for (const auto& argument : call->arguments) { - ARROW_ASSIGN_OR_RAISE(*modified_argument, Modify(argument, state, pre, post_call)); + ARROW_ASSIGN_OR_RAISE(*modified_argument, Modify(argument, pre, post_call)); if (!Identical(*modified_argument, argument)) { at_least_one_modified = true; @@ -626,45 +584,24 @@ Result Modify(Expression2 expr, ExpressionState* state, const PreVi auto modified_expr = Expression2( std::make_shared(std::move(modified_call)), expr.descr()); - // if expr had associated kernel state, associate it with modified_expr - state->Replace(expr, modified_expr); return post_call(std::move(modified_expr), &expr); } return post_call(std::move(expr), nullptr); } -template -Result Modify(Expression2::BoundWithState bound, - const PreVisit& pre, - const PostVisitCall& post_call) { - DCHECK(bound.first.IsBound()); - - auto expr = std::move(bound.first); - auto state = bound.second.get(); - - ARROW_ASSIGN_OR_RAISE(expr, Modify(std::move(expr), state, pre, post_call)); - - bound.first = std::move(expr); - - return bound; -} - -Result FoldConstants(Expression2::BoundWithState bound) { - bound.second = std::make_shared(*bound.second); - auto state = bound.second.get(); +Result FoldConstants(Expression2 expr) { return Modify( - std::move(bound), [](Expression2 expr) { return expr; }, - [state](Expression2 expr, ...) -> Result { + std::move(expr), [](Expression2 expr) { return expr; }, + [](Expression2 expr, ...) -> Result { auto call = CallNotNull(expr); if (std::all_of(call->arguments.begin(), call->arguments.end(), [](const Expression2& argument) { return argument.literal(); })) { // all arguments are literal; we can evaluate this subexpression *now* static const Datum ignored_input; ARROW_ASSIGN_OR_RAISE(Datum constant, - ExecuteScalarExpression(expr, state, ignored_input)); + ExecuteScalarExpression(expr, ignored_input)); - state->Drop(expr); return literal(std::move(constant)); } @@ -676,7 +613,6 @@ Result FoldConstants(Expression2::BoundWithState bo // to null *now* if any of their inputs is a null literal for (const auto& argument : call->arguments) { if (argument.IsNullLiteral()) { - state->Drop(expr); return argument; } } @@ -722,12 +658,8 @@ inline std::vector GuaranteeConjunctionMembers( Status ExtractKnownFieldValuesImpl( std::vector* conjunction_members, std::unordered_map* known_values) { - { - auto empty_state = std::make_shared(); - for (auto&& member : *conjunction_members) { - ARROW_ASSIGN_OR_RAISE(std::tie(member, std::ignore), - Canonicalize({std::move(member), empty_state})); - } + for (auto&& member : *conjunction_members) { + ARROW_ASSIGN_OR_RAISE(member, Canonicalize({std::move(member)})); } auto unconsumed_end = @@ -780,12 +712,11 @@ Result> ExtractKnownFieldVal return known_values; } -Result ReplaceFieldsWithKnownValues( +Result ReplaceFieldsWithKnownValues( const std::unordered_map& known_values, - Expression2::BoundWithState bound) { - bound.second = std::make_shared(*bound.second); + Expression2 expr) { return Modify( - std::move(bound), + std::move(expr), [&known_values](Expression2 expr) { if (auto ref = expr.field_ref()) { auto it = known_values.find(*ref); @@ -807,11 +738,10 @@ inline bool IsBinaryAssociativeCommutative(const Expression2::Call& call) { return it != binary_associative_commutative.end(); } -Result Canonicalize(Expression2::BoundWithState bound, - compute::ExecContext* exec_context) { +Result Canonicalize(Expression2 expr, compute::ExecContext* exec_context) { if (exec_context == nullptr) { compute::ExecContext exec_context; - return Canonicalize(std::move(bound), &exec_context); + return Canonicalize(std::move(expr), &exec_context); } // If potentially reconstructing more deeply than a call's immediate arguments @@ -830,7 +760,7 @@ Result Canonicalize(Expression2::BoundWithState bou } AlreadyCanonicalized; return Modify( - std::move(bound), + std::move(expr), [&AlreadyCanonicalized, exec_context](Expression2 expr) -> Result { auto call = expr.call(); if (!call) return expr; @@ -906,10 +836,10 @@ Result Canonicalize(Expression2::BoundWithState bou [](Expression2 expr, ...) { return expr; }); } -Result DirectComparisonSimplification( - Expression2::BoundWithState bound, const Expression2::Call& guarantee) { +Result DirectComparisonSimplification(Expression2 expr, + const Expression2::Call& guarantee) { return Modify( - std::move(bound), [](Expression2 expr) { return expr; }, + std::move(expr), [](Expression2 expr) { return expr; }, [&guarantee](Expression2 expr, ...) -> Result { auto call = expr.call(); if (!call) return expr; @@ -964,8 +894,8 @@ Result DirectComparisonSimplification( }); } -Result SimplifyWithGuarantee( - Expression2::BoundWithState bound, const Expression2& guaranteed_true_predicate) { +Result SimplifyWithGuarantee(Expression2 expr, + const Expression2& guaranteed_true_predicate) { if (!guaranteed_true_predicate.IsBound()) { return Status::Invalid("guaranteed_true_predicate was not bound"); } @@ -975,29 +905,29 @@ Result SimplifyWithGuarantee( std::unordered_map known_values; RETURN_NOT_OK(ExtractKnownFieldValuesImpl(&conjunction_members, &known_values)); - ARROW_ASSIGN_OR_RAISE(bound, - ReplaceFieldsWithKnownValues(known_values, std::move(bound))); + ARROW_ASSIGN_OR_RAISE(expr, + ReplaceFieldsWithKnownValues(known_values, std::move(expr))); - auto CanonicalizeAndFoldConstants = [&bound] { - ARROW_ASSIGN_OR_RAISE(bound, Canonicalize(std::move(bound))); - ARROW_ASSIGN_OR_RAISE(bound, FoldConstants(std::move(bound))); + auto CanonicalizeAndFoldConstants = [&expr] { + ARROW_ASSIGN_OR_RAISE(expr, Canonicalize(std::move(expr))); + ARROW_ASSIGN_OR_RAISE(expr, FoldConstants(std::move(expr))); return Status::OK(); }; RETURN_NOT_OK(CanonicalizeAndFoldConstants()); for (const auto& guarantee : conjunction_members) { if (Comparison::Get(guarantee) && guarantee.call()->arguments[1].literal()) { - ARROW_ASSIGN_OR_RAISE(auto simplified, DirectComparisonSimplification( - bound, *CallNotNull(guarantee))); + ARROW_ASSIGN_OR_RAISE( + auto simplified, DirectComparisonSimplification(expr, *CallNotNull(guarantee))); - if (Identical(simplified.first, bound.first)) continue; + if (Identical(simplified, expr)) continue; - bound = std::move(simplified); + expr = std::move(simplified); RETURN_NOT_OK(CanonicalizeAndFoldConstants()); } } - return bound; + return expr; } Result> Serialize(const Expression2& expr) { diff --git a/cpp/src/arrow/dataset/expression.h b/cpp/src/arrow/dataset/expression.h index f452e54b7c2..16c9e8a96f6 100644 --- a/cpp/src/arrow/dataset/expression.h +++ b/cpp/src/arrow/dataset/expression.h @@ -54,7 +54,9 @@ class ARROW_DS_EXPORT Expression2 { // post-Bind properties: const compute::Kernel* kernel = NULLPTR; - compute::Function::Kind function_kind; + compute::Function::Kind + function_kind; // XXX give Kernel a non-owning pointer to its Function + std::shared_ptr kernel_state; }; std::string ToString() const; @@ -67,10 +69,18 @@ class ARROW_DS_EXPORT Expression2 { /// Bind this expression to the given input type, looking up Kernels and field types. /// Some expression simplification may be performed and implicit casts will be inserted. /// Any state necessary for execution will be initialized and returned. - using BoundWithState = std::pair>; - Result Bind(ValueDescr in, compute::ExecContext* = NULLPTR) const; - Result Bind(const Schema& in_schema, - compute::ExecContext* = NULLPTR) const; + Result Bind(ValueDescr in, compute::ExecContext* = NULLPTR) const; + Result Bind(const Schema& in_schema, + compute::ExecContext* = NULLPTR) const; + + // XXX someday + // Clone all KernelState in this bound expression. If any function referenced by this + // expression has mutable KernelState, it is not safe to execute or apply simplification + // passes to it (or copies of it!) from multiple threads. Cloning state produces new + // KernelStates where necessary to ensure that Expression2s may be manipulated safely + // on multiple threads. + // Result CloneState() const; + // Status SetState(ExpressionState); /// Return true if all an expression's field references have explicit ValueDescr and all /// of its functions' kernels are looked up. @@ -175,27 +185,25 @@ Result> ExtractKnownFieldVal /// equivalent Expressions may result in different canonicalized expressions. /// TODO this could be a strong canonicalization ARROW_DS_EXPORT -Result Canonicalize(Expression2::BoundWithState, - compute::ExecContext* = NULLPTR); +Result Canonicalize(Expression2, compute::ExecContext* = NULLPTR); /// Simplify Expressions based on literal arguments (for example, add(null, x) will always /// be null so replace the call with a null literal). Includes early evaluation of all /// calls whose arguments are entirely literal. ARROW_DS_EXPORT -Result FoldConstants(Expression2::BoundWithState); +Result FoldConstants(Expression2); ARROW_DS_EXPORT -Result ReplaceFieldsWithKnownValues( - const std::unordered_map& known_values, - Expression2::BoundWithState); +Result ReplaceFieldsWithKnownValues( + const std::unordered_map& known_values, Expression2); /// Simplify an expression by replacing subexpressions based on a guarantee: /// a boolean expression which is guaranteed to evaluate to `true`. For example, this is /// used to remove redundant function calls from a filter expression or to replace a /// reference to a constant-value field with a literal. ARROW_DS_EXPORT -Result SimplifyWithGuarantee( - Expression2::BoundWithState, const Expression2& guaranteed_true_predicate); +Result SimplifyWithGuarantee(Expression2, + const Expression2& guaranteed_true_predicate); /// @} @@ -203,8 +211,7 @@ Result SimplifyWithGuarantee( /// Execute a scalar expression against the provided state and input Datum. This /// expression must be bound. -Result ExecuteScalarExpression(const Expression2&, ExpressionState*, - const Datum& input, +Result ExecuteScalarExpression(const Expression2&, const Datum& input, compute::ExecContext* = NULLPTR); // Serialization diff --git a/cpp/src/arrow/dataset/expression_internal.h b/cpp/src/arrow/dataset/expression_internal.h index 548d36711fb..84b802937e0 100644 --- a/cpp/src/arrow/dataset/expression_internal.h +++ b/cpp/src/arrow/dataset/expression_internal.h @@ -22,6 +22,7 @@ #include #include "arrow/compute/api_vector.h" +#include "arrow/compute/registry.h" #include "arrow/record_batch.h" #include "arrow/table.h" #include "arrow/util/logging.h" @@ -71,38 +72,6 @@ inline std::vector GetDescriptors(const std::vector& values) return descrs; } -struct ARROW_DS_EXPORT ExpressionState { - std::unordered_map, - Expression2::Hash> - kernel_states; - - compute::KernelState* Get(const Expression2& expr) const { - auto it = kernel_states.find(expr); - if (it == kernel_states.end()) return nullptr; - return it->second.get(); - } - - void Replace(const Expression2& expr, const Expression2& replacement) { - auto it = kernel_states.find(expr); - if (it == kernel_states.end()) return; - - auto kernel_state = std::move(it->second); - kernel_states.erase(it); - kernel_states.emplace(replacement, std::move(kernel_state)); - } - - void Drop(const Expression2& expr) { - auto it = kernel_states.find(expr); - if (it == kernel_states.end()) return; - kernel_states.erase(it); - } - - void MoveFrom(ExpressionState* other) { - std::move(other->kernel_states.begin(), other->kernel_states.end(), - std::inserter(kernel_states, kernel_states.end())); - } -}; - struct FieldPathGetDatumImpl { template ()))> Result operator()(const std::shared_ptr& ptr) { @@ -418,5 +387,15 @@ struct FlattenedAssociativeChain { } }; +inline Result> GetFunction( + const Expression2::Call& call, compute::ExecContext* exec_context) { + if (call.function != "cast") { + return exec_context->func_registry()->GetFunction(call.function); + } + // XXX this special case is strange; why not make "cast" a ScalarFunction? + const auto& to_type = checked_cast(*call.options).to_type; + return compute::GetCastFunction(to_type); +} + } // namespace dataset } // namespace arrow diff --git a/cpp/src/arrow/dataset/expression_test.cc b/cpp/src/arrow/dataset/expression_test.cc index 6274147d208..8ea70adf01c 100644 --- a/cpp/src/arrow/dataset/expression_test.cc +++ b/cpp/src/arrow/dataset/expression_test.cc @@ -64,6 +64,9 @@ TEST(Expression2, Equality) { EXPECT_EQ(call("add", {literal(3), call("index_in", {field_ref("beta")})}), call("add", {literal(3), call("index_in", {field_ref("beta")})})); + + // FIXME options are not currently compared, so two cast exprs will compare equal + // regardless of differing target type } TEST(Expression2, Hash) { @@ -183,8 +186,7 @@ TEST(Expression2, BindFieldRef) { { auto expr = field_ref("alpha"); // binding a field_ref looks up that field's type in the input Schema - ASSERT_OK_AND_ASSIGN(std::tie(expr, std::ignore), - expr.Bind(Schema({field("alpha", int32())}))); + ASSERT_OK_AND_ASSIGN(expr, expr.Bind(Schema({field("alpha", int32())}))); EXPECT_EQ(expr.descr(), ValueDescr::Array(int32())); EXPECT_TRUE(expr.IsBound()); } @@ -192,7 +194,7 @@ TEST(Expression2, BindFieldRef) { { // if the field is not found, a null scalar will be emitted auto expr = field_ref("alpha"); - ASSERT_OK_AND_ASSIGN(std::tie(expr, std::ignore), expr.Bind(Schema({}))); + ASSERT_OK_AND_ASSIGN(expr, expr.Bind(Schema({}))); EXPECT_EQ(expr.descr(), ValueDescr::Scalar(null())); EXPECT_TRUE(expr.IsBound()); } @@ -208,7 +210,7 @@ TEST(Expression2, BindFieldRef) { { // referencing nested fields is supported auto expr = field_ref("a", "b"); - ASSERT_OK_AND_ASSIGN(std::tie(expr, std::ignore), + ASSERT_OK_AND_ASSIGN(expr, expr.Bind(Schema({field("a", struct_({field("b", int32())}))}))); EXPECT_EQ(expr.descr(), ValueDescr::Array(int32())); EXPECT_TRUE(expr.IsBound()); @@ -219,7 +221,7 @@ TEST(Expression2, BindCall) { auto expr = call("add", {field_ref("a"), field_ref("b")}); EXPECT_FALSE(expr.IsBound()); - ASSERT_OK_AND_ASSIGN(std::tie(expr, std::ignore), + ASSERT_OK_AND_ASSIGN(expr, expr.Bind(Schema({field("a", int32()), field("b", int32())}))); EXPECT_EQ(expr.descr(), ValueDescr::Array(int32())); EXPECT_TRUE(expr.IsBound()); @@ -234,7 +236,7 @@ TEST(Expression2, BindDictionaryTransparent) { EXPECT_FALSE(expr.IsBound()); ASSERT_OK_AND_ASSIGN( - std::tie(expr, std::ignore), + expr, expr.Bind(Schema({field("a", utf8()), field("b", dictionary(int32(), utf8()))}))); EXPECT_EQ(expr.descr(), ValueDescr::Array(boolean())); @@ -248,7 +250,7 @@ TEST(Expression2, BindNestedCall) { field_ref("d")})}); EXPECT_FALSE(expr.IsBound()); - ASSERT_OK_AND_ASSIGN(std::tie(expr, std::ignore), + ASSERT_OK_AND_ASSIGN(expr, expr.Bind(Schema({field("a", int32()), field("b", int32()), field("c", int32()), field("d", int32())}))); EXPECT_EQ(expr.descr(), ValueDescr::Array(int32())); @@ -259,9 +261,8 @@ TEST(Expression2, ExecuteFieldRef) { auto AssertRefIs = [](FieldRef ref, Datum in, Datum expected) { auto expr = field_ref(ref); - ASSERT_OK_AND_ASSIGN(std::tie(expr, std::ignore), expr.Bind(in.descr())); - ASSERT_OK_AND_ASSIGN(Datum actual, - ExecuteScalarExpression(expr, /*state=*/nullptr, in)); + ASSERT_OK_AND_ASSIGN(expr, expr.Bind(in.descr())); + ASSERT_OK_AND_ASSIGN(Datum actual, ExecuteScalarExpression(expr, in)); AssertDatumsEqual(actual, expected, /*verbose=*/true); }; @@ -295,7 +296,7 @@ Result NaiveExecuteScalarExpression(const Expression2& expr, const Datum& auto call = expr.call(); if (call == nullptr) { // already tested execution of field_ref, execution of literal is trivial - return ExecuteScalarExpression(expr, /*state=*/nullptr, input); + return ExecuteScalarExpression(expr, input); } std::vector arguments(call->arguments.size()); @@ -335,10 +336,9 @@ Result NaiveExecuteScalarExpression(const Expression2& expr, const Datum& } void AssertExecute(Expression2 expr, Datum in) { - std::shared_ptr state; - ASSERT_OK_AND_ASSIGN(std::tie(expr, state), expr.Bind(in.descr())); + ASSERT_OK_AND_ASSIGN(expr, expr.Bind(in.descr())); - ASSERT_OK_AND_ASSIGN(Datum actual, ExecuteScalarExpression(expr, state.get(), in)); + ASSERT_OK_AND_ASSIGN(Datum actual, ExecuteScalarExpression(expr, in)); ASSERT_OK_AND_ASSIGN(Datum expected, NaiveExecuteScalarExpression(expr, in)); @@ -400,26 +400,17 @@ TEST(Expression2, ExecuteDictionaryTransparent) { struct { void operator()(Expression2 expr, Expression2 expected) { - this->operator()(expr, expected, - [](Expression2::BoundWithState, Expression2::BoundWithState, - Expression2::BoundWithState) {}); - } + ASSERT_OK_AND_ASSIGN(expr, expr.Bind(*kBoringSchema)); + ASSERT_OK_AND_ASSIGN(expected, expected.Bind(*kBoringSchema)); - template - void operator()(Expression2 expr, Expression2 unbound_expected, - const ExtraExpectations& expect) { - ASSERT_OK_AND_ASSIGN(auto bound, expr.Bind(*kBoringSchema)); - ASSERT_OK_AND_ASSIGN(auto expected, unbound_expected.Bind(*kBoringSchema)); - ASSERT_OK_AND_ASSIGN(auto folded, FoldConstants(bound)); + ASSERT_OK_AND_ASSIGN(auto folded, FoldConstants(expr)); - EXPECT_EQ(folded.first, expected.first); + EXPECT_EQ(folded, expected); - if (folded.first == expr) { + if (folded == expr) { // no change -> must be identical - EXPECT_TRUE(Identical(folded.first, expr)); + EXPECT_TRUE(Identical(folded, expr)); } - - expect(bound, folded, expected); } } ExpectFoldsTo; @@ -476,15 +467,7 @@ TEST(Expression2, FoldConstants) { call("is_in", {call("add", {field_ref("i32"), call("multiply", {literal(2), literal(3)})})}, in_123), - call("is_in", {call("add", {field_ref("i32"), literal(6)})}, in_123), - [](Expression2::BoundWithState bound, Expression2::BoundWithState folded, - Expression2::BoundWithState) { - const compute::KernelState* state = bound.second->Get(bound.first); - const compute::KernelState* folded_state = folded.second->Get(folded.first); - EXPECT_EQ(folded_state, state) << "The kernel state associated with is_in (the " - "hash table for looking up membership) " - "must be associated with the folded is_in call"; - }); + call("is_in", {call("add", {field_ref("i32"), literal(6)})}, in_123)); } TEST(Expression2, FoldConstantsBoolean) { @@ -507,8 +490,7 @@ TEST(Expression2, ExtractKnownFieldValues) { struct { void operator()(Expression2 guarantee, std::unordered_map expected) { - ASSERT_OK_AND_ASSIGN(std::tie(guarantee, std::ignore), - guarantee.Bind(*kBoringSchema)); + ASSERT_OK_AND_ASSIGN(guarantee, guarantee.Bind(*kBoringSchema)); ASSERT_OK_AND_ASSIGN(auto actual, ExtractKnownFieldValues(guarantee)); EXPECT_THAT(actual, UnorderedElementsAreArray(expected)) << " guarantee: " << guarantee.ToString(); @@ -553,11 +535,11 @@ TEST(Expression2, ReplaceFieldsWithKnownValues) { ASSERT_OK_AND_ASSIGN(auto replaced, ReplaceFieldsWithKnownValues(known_values, bound)); - EXPECT_EQ(replaced.first, expected.first); + EXPECT_EQ(replaced, expected); - if (replaced.first == expr) { + if (replaced == expr) { // no change -> must be identical - EXPECT_TRUE(Identical(replaced.first, expr)); + EXPECT_TRUE(Identical(replaced, expr)); } }; @@ -598,11 +580,11 @@ struct { ASSERT_OK_AND_ASSIGN(auto expected, unbound_expected.Bind(*kBoringSchema)); ASSERT_OK_AND_ASSIGN(auto actual, Canonicalize(bound)); - EXPECT_EQ(actual.first, expected.first); + EXPECT_EQ(actual, expected); - if (actual.first == expr) { + if (actual == expr) { // no change -> must be identical - EXPECT_TRUE(Identical(actual.first, expr)); + EXPECT_TRUE(Identical(actual, expr)); } } } ExpectCanonicalizesTo; @@ -666,19 +648,17 @@ struct Simplify { void Expect(Expression2 unbound_expected) { ASSERT_OK_AND_ASSIGN(auto bound, expr.Bind(*kBoringSchema)); - ASSERT_OK_AND_ASSIGN(std::tie(guarantee, std::ignore), - guarantee.Bind(*kBoringSchema)); + ASSERT_OK_AND_ASSIGN(guarantee, guarantee.Bind(*kBoringSchema)); ASSERT_OK_AND_ASSIGN(auto simplified, SimplifyWithGuarantee(bound, guarantee)); ASSERT_OK_AND_ASSIGN(auto expected, unbound_expected.Bind(*kBoringSchema)); - EXPECT_EQ(simplified.first, expected.first) - << " original: " << expr.ToString() << "\n" - << " guarantee: " << guarantee.ToString() << "\n" - << (simplified.first == bound.first ? " (no change)\n" : ""); + EXPECT_EQ(simplified, expected) << " original: " << expr.ToString() << "\n" + << " guarantee: " << guarantee.ToString() << "\n" + << (simplified == bound ? " (no change)\n" : ""); - if (simplified.first == bound.first) { - EXPECT_TRUE(Identical(simplified.first, bound.first)); + if (simplified == bound) { + EXPECT_TRUE(Identical(simplified, bound)); } } void ExpectUnchanged() { Expect(expr); } @@ -781,10 +761,8 @@ TEST(Expression2, SingleComparisonGuarantees) { StructArray::Make({ArrayFromJSON(int32(), satisfying_i32[guarantee_op])}, {"i32"})); - std::shared_ptr state; - ASSERT_OK_AND_ASSIGN(std::tie(filter, state), filter.Bind(*kBoringSchema)); - ASSERT_OK_AND_ASSIGN(Datum evaluated, - ExecuteScalarExpression(filter, state.get(), input)); + ASSERT_OK_AND_ASSIGN(filter, filter.Bind(*kBoringSchema)); + ASSERT_OK_AND_ASSIGN(Datum evaluated, ExecuteScalarExpression(filter, input)); // ensure that the simplified filter is as simplified as it could be // (this is always possible for single comparisons) @@ -840,7 +818,7 @@ TEST(Expression2, SimplifyThenExecute) { ASSERT_OK_AND_ASSIGN(auto guarantee, equal(field_ref("f32"), literal(0.F)).Bind(*kBoringSchema)); - ASSERT_OK_AND_ASSIGN(bound, SimplifyWithGuarantee(bound, guarantee.first)); + ASSERT_OK_AND_ASSIGN(bound, SimplifyWithGuarantee(bound, guarantee)); auto input = RecordBatchFromJSON(kBoringSchema, R"([ {"i32": 0, "f32": -0.1}, @@ -851,8 +829,7 @@ TEST(Expression2, SimplifyThenExecute) { {"i32": 0, "f32": null}, {"i32": 0, "f32": 1.0} ])"); - ASSERT_OK_AND_ASSIGN(Datum evaluated, - ExecuteScalarExpression(bound.first, bound.second.get(), input)); + ASSERT_OK_AND_ASSIGN(Datum evaluated, ExecuteScalarExpression(bound, input)); } TEST(Expression2, Filter) { @@ -861,9 +838,8 @@ TEST(Expression2, Filter) { auto batch = RecordBatchFromJSON(s, batch_json); auto expected_mask = batch->column(0); - std::shared_ptr state; - ASSERT_OK_AND_ASSIGN(std::tie(filter, state), filter.Bind(*kBoringSchema)); - ASSERT_OK_AND_ASSIGN(Datum mask, ExecuteScalarExpression(filter, state.get(), batch)); + ASSERT_OK_AND_ASSIGN(filter, filter.Bind(*kBoringSchema)); + ASSERT_OK_AND_ASSIGN(Datum mask, ExecuteScalarExpression(filter, batch)); AssertDatumsEqual(expected_mask, mask); }; diff --git a/cpp/src/arrow/dataset/file_base.cc b/cpp/src/arrow/dataset/file_base.cc index f8350a0292c..dfc17a25812 100644 --- a/cpp/src/arrow/dataset/file_base.cc +++ b/cpp/src/arrow/dataset/file_base.cc @@ -137,15 +137,14 @@ std::string FileSystemDataset::ToString() const { return repr; } -Result FileSystemDataset::GetFragmentsImpl( - Expression2::BoundWithState predicate) { +Result FileSystemDataset::GetFragmentsImpl(Expression2 predicate) { FragmentVector fragments; for (const auto& fragment : fragments_) { ARROW_ASSIGN_OR_RAISE( auto simplified, SimplifyWithGuarantee(predicate, fragment->partition_expression())); - if (simplified.first.IsSatisfiable()) { + if (simplified.IsSatisfiable()) { fragments.push_back(fragment); } } @@ -323,8 +322,8 @@ Status FileSystemDataset::Write(const FileSystemDatasetWriteOptions& write_optio .Bind(*scanner->schema())); auto batch = std::move(groups.batches[i]); - ARROW_ASSIGN_OR_RAISE( - auto part, write_options.partitioning->Format(partition_expression.first)); + ARROW_ASSIGN_OR_RAISE(auto part, + write_options.partitioning->Format(partition_expression)); WriteQueue* queue; { diff --git a/cpp/src/arrow/dataset/file_base.h b/cpp/src/arrow/dataset/file_base.h index 1d0059ec089..f6def56b7de 100644 --- a/cpp/src/arrow/dataset/file_base.h +++ b/cpp/src/arrow/dataset/file_base.h @@ -234,7 +234,7 @@ class ARROW_DS_EXPORT FileSystemDataset : public Dataset { protected: Result GetFragmentsImpl( - Expression2::BoundWithState predicate) override; + Expression2 predicate) override; FileSystemDataset(std::shared_ptr schema, Expression2 root_partition, std::shared_ptr format, diff --git a/cpp/src/arrow/dataset/file_parquet.cc b/cpp/src/arrow/dataset/file_parquet.cc index f1e9ad48bb0..10c5b4a4bcf 100644 --- a/cpp/src/arrow/dataset/file_parquet.cc +++ b/cpp/src/arrow/dataset/file_parquet.cc @@ -298,8 +298,8 @@ Result ParquetFileFormat::ScanFile(std::shared_ptrmetadata() != nullptr) { - ARROW_ASSIGN_OR_RAISE(row_groups, parquet_fragment->FilterRowGroups( - {options->filter2, context->expression_state})); + ARROW_ASSIGN_OR_RAISE(row_groups, + parquet_fragment->FilterRowGroups(options->filter2)); pre_filtered = true; if (row_groups.empty()) MakeEmpty(); @@ -314,8 +314,8 @@ Result ParquetFileFormat::ScanFile(std::shared_ptrFilterRowGroups( - {options->filter2, context->expression_state})); + ARROW_ASSIGN_OR_RAISE(row_groups, + parquet_fragment->FilterRowGroups(options->filter2)); if (row_groups.empty()) MakeEmpty(); } @@ -457,8 +457,7 @@ Status ParquetFileFragment::SetMetadata( return Status::OK(); } -Result ParquetFileFragment::SplitByRowGroup( - Expression2::BoundWithState predicate) { +Result ParquetFileFragment::SplitByRowGroup(Expression2 predicate) { RETURN_NOT_OK(EnsureCompleteMetadata()); ARROW_ASSIGN_OR_RAISE(auto row_groups, FilterRowGroups(predicate)); @@ -476,8 +475,7 @@ Result ParquetFileFragment::SplitByRowGroup( return fragments; } -Result> ParquetFileFragment::Subset( - Expression2::BoundWithState predicate) { +Result> ParquetFileFragment::Subset(Expression2 predicate) { RETURN_NOT_OK(EnsureCompleteMetadata()); ARROW_ASSIGN_OR_RAISE(auto row_groups, FilterRowGroups(predicate)); return Subset(std::move(row_groups)); @@ -502,19 +500,18 @@ inline void FoldingAnd(Expression2* l, Expression2 r) { } } -Result> ParquetFileFragment::FilterRowGroups( - Expression2::BoundWithState predicate) { +Result> ParquetFileFragment::FilterRowGroups(Expression2 predicate) { auto lock = physical_schema_mutex_.Lock(); DCHECK_NE(metadata_, nullptr); ARROW_ASSIGN_OR_RAISE( predicate, SimplifyWithGuarantee(std::move(predicate), partition_expression_)); - if (!predicate.first.IsSatisfiable()) { + if (!predicate.IsSatisfiable()) { return std::vector{}; } - for (const FieldRef& ref : FieldsInExpression(predicate.first)) { + for (const FieldRef& ref : FieldsInExpression(predicate)) { ARROW_ASSIGN_OR_RAISE(auto path, ref.FindOneOrNone(*physical_schema_)); if (!path) continue; @@ -529,7 +526,7 @@ Result> ParquetFileFragment::FilterRowGroups( if (auto minmax = ColumnChunkStatisticsAsExpression(schema_field, *row_group_metadata)) { FoldingAnd(&statistics_expressions_[i], std::move(*minmax)); - ARROW_ASSIGN_OR_RAISE(std::tie(statistics_expressions_[i], std::ignore), + ARROW_ASSIGN_OR_RAISE(statistics_expressions_[i], statistics_expressions_[i].Bind(*physical_schema_)); } @@ -541,7 +538,7 @@ Result> ParquetFileFragment::FilterRowGroups( for (size_t i = 0; i < row_groups_->size(); ++i) { ARROW_ASSIGN_OR_RAISE(auto row_group_predicate, SimplifyWithGuarantee(predicate, statistics_expressions_[i])); - if (row_group_predicate.first.IsSatisfiable()) { + if (row_group_predicate.IsSatisfiable()) { row_groups.push_back(row_groups_->at(i)); } } diff --git a/cpp/src/arrow/dataset/file_parquet.h b/cpp/src/arrow/dataset/file_parquet.h index 9de19186ad3..9eab3fe9449 100644 --- a/cpp/src/arrow/dataset/file_parquet.h +++ b/cpp/src/arrow/dataset/file_parquet.h @@ -150,7 +150,7 @@ class ARROW_DS_EXPORT ParquetFileFormat : public FileFormat { /// significant performance boost when scanning high latency file systems. class ARROW_DS_EXPORT ParquetFileFragment : public FileFragment { public: - Result SplitByRowGroup(Expression2::BoundWithState predicate); + Result SplitByRowGroup(Expression2 predicate); /// \brief Return the RowGroups selected by this fragment. const std::vector& row_groups() const { @@ -166,7 +166,7 @@ class ARROW_DS_EXPORT ParquetFileFragment : public FileFragment { Status EnsureCompleteMetadata(parquet::arrow::FileReader* reader = NULLPTR); /// \brief Return fragment which selects a filtered subset of this fragment's RowGroups. - Result> Subset(Expression2::BoundWithState predicate); + Result> Subset(Expression2 predicate); Result> Subset(std::vector row_group_ids); private: @@ -185,7 +185,7 @@ class ARROW_DS_EXPORT ParquetFileFragment : public FileFragment { } // Return a filtered subset of row group indices. - Result> FilterRowGroups(Expression2::BoundWithState predicate); + Result> FilterRowGroups(Expression2 predicate); ParquetFileFormat& parquet_format_; diff --git a/cpp/src/arrow/dataset/file_parquet_test.cc b/cpp/src/arrow/dataset/file_parquet_test.cc index 5a6bbddaba3..436826c2971 100644 --- a/cpp/src/arrow/dataset/file_parquet_test.cc +++ b/cpp/src/arrow/dataset/file_parquet_test.cc @@ -161,8 +161,7 @@ class TestParquetFileFormat : public ArrowParquetWriterMixin { } void SetFilter(Expression2 filter) { - ASSERT_OK_AND_ASSIGN(std::tie(opts_->filter2, ctx_->expression_state), - filter.Bind(*schema_)); + ASSERT_OK_AND_ASSIGN(opts_->filter2, filter.Bind(*schema_)); } std::shared_ptr SingleBatch(Fragment* fragment) { @@ -196,7 +195,6 @@ class TestParquetFileFormat : public ArrowParquetWriterMixin { Expression2 filter) { schema_ = opts_->schema(); ASSERT_OK_AND_ASSIGN(auto bound, filter.Bind(*schema_)); - ctx_->expression_state = bound.second; auto parquet_fragment = checked_pointer_cast(fragment); ASSERT_OK_AND_ASSIGN(auto fragments, parquet_fragment->SplitByRowGroup(bound)) diff --git a/cpp/src/arrow/dataset/partition.cc b/cpp/src/arrow/dataset/partition.cc index b6178ddeea6..4e7cc062997 100644 --- a/cpp/src/arrow/dataset/partition.cc +++ b/cpp/src/arrow/dataset/partition.cc @@ -186,9 +186,7 @@ Result KeyValuePartitioning::Parse(const std::string& path) const { expressions.push_back(std::move(expr)); } - auto parsed = and_(std::move(expressions)); - ARROW_ASSIGN_OR_RAISE(std::tie(parsed, std::ignore), parsed.Bind(*schema_)) - return parsed; + return and_(std::move(expressions)).Bind(*schema_); } Result KeyValuePartitioning::Format(const Expression2& expr) const { diff --git a/cpp/src/arrow/dataset/partition_test.cc b/cpp/src/arrow/dataset/partition_test.cc index 37e65cacac4..57a1693b953 100644 --- a/cpp/src/arrow/dataset/partition_test.cc +++ b/cpp/src/arrow/dataset/partition_test.cc @@ -45,21 +45,20 @@ class TestPartitioning : public ::testing::Test { void AssertParse(const std::string& path, Expression2 expected) { ASSERT_OK_AND_ASSIGN(auto parsed, partitioning_->Parse(path)); - ASSERT_OK_AND_ASSIGN(std::tie(expected, std::ignore), - expected.Bind(*partitioning_->schema())); + ASSERT_OK_AND_ASSIGN(expected, expected.Bind(*partitioning_->schema())); ASSERT_EQ(parsed, expected); } template void AssertFormatError(Expression2 expr) { - ASSERT_OK_AND_ASSIGN(std::tie(expr, std::ignore), expr.Bind(*written_schema_)); + ASSERT_OK_AND_ASSIGN(expr, expr.Bind(*written_schema_)); ASSERT_EQ(partitioning_->Format(expr).status().code(), code); } void AssertFormat(Expression2 expr, const std::string& expected) { // formatted partition expressions are bound to the schema of the dataset being // written - ASSERT_OK_AND_ASSIGN(std::tie(expr, std::ignore), expr.Bind(*written_schema_)); + ASSERT_OK_AND_ASSIGN(expr, expr.Bind(*written_schema_)); ASSERT_OK_AND_ASSIGN(auto formatted, partitioning_->Format(expr)); ASSERT_EQ(formatted, expected); @@ -70,7 +69,7 @@ class TestPartitioning : public ::testing::Test { ASSERT_OK_AND_ASSIGN(auto bound, roundtripped.Bind(*partitioning_->schema())); ASSERT_OK_AND_ASSIGN(auto simplified, SimplifyWithGuarantee(bound, expr)); - ASSERT_EQ(simplified.first, literal(true)); + ASSERT_EQ(simplified, literal(true)); } void AssertInspect(const std::vector& paths, @@ -366,9 +365,7 @@ TEST_F(TestPartitioning, EtlThenHive) { fs::internal::JoinAbstractPath(etl_segments_end, alphabeta_segments_end); ARROW_ASSIGN_OR_RAISE(auto alphabeta_expr, alphabeta_part.Parse(alphabeta_path)); - auto expr = and_(etl_expr, alphabeta_expr); - ARROW_ASSIGN_OR_RAISE(std::tie(expr, std::ignore), expr.Bind(*schm)); - return expr; + return and_(etl_expr, alphabeta_expr).Bind(*schm); }); AssertParse("/1999/12/31/00/alpha=0/beta=3.25", @@ -411,8 +408,7 @@ TEST_F(TestPartitioning, Set) { auto is_in_expr = call("is_in", {field_ref(matches[1])}, compute::SetLookupOptions{ints(set), true}); - ARROW_ASSIGN_OR_RAISE(std::tie(is_in_expr, std::ignore), - is_in_expr.Bind(*schm)); + ARROW_ASSIGN_OR_RAISE(is_in_expr, is_in_expr.Bind(*schm)); subexpressions.push_back(is_in_expr); } @@ -457,9 +453,7 @@ class RangePartitioning : public Partitioning { max_cmp(field_ref(key->name), literal(max)))); } - Expression2 expr; - ARROW_ASSIGN_OR_RAISE(std::tie(expr, std::ignore), and_(ranges).Bind(*schema_)); - return expr; + return and_(ranges).Bind(*schema_); } static Status DoRegex(const std::string& segment, std::smatch* matches) { diff --git a/cpp/src/arrow/dataset/scanner.cc b/cpp/src/arrow/dataset/scanner.cc index 67df665733b..e3629bfd8cc 100644 --- a/cpp/src/arrow/dataset/scanner.cc +++ b/cpp/src/arrow/dataset/scanner.cc @@ -72,8 +72,7 @@ Result Scanner::GetFragments() { // Transform Datasets in a flat Iterator. This // iterator is lazily constructed, i.e. Dataset::GetFragments is // not invoked until a Fragment is requested. - return GetFragmentsFromDatasets( - {dataset_}, {scan_options_->filter2, scan_context_->expression_state}); + return GetFragmentsFromDatasets({dataset_}, scan_options_->filter2); } Result Scanner::Scan() { @@ -124,8 +123,7 @@ Status ScannerBuilder::Project(std::vector columns) { } Status ScannerBuilder::Filter(const Expression2& filter) { - ARROW_ASSIGN_OR_RAISE(std::tie(scan_options_->filter2, scan_context_->expression_state), - filter.Bind(*schema())); + ARROW_ASSIGN_OR_RAISE(scan_options_->filter2, filter.Bind(*schema())); return Status::OK(); } diff --git a/cpp/src/arrow/dataset/scanner.h b/cpp/src/arrow/dataset/scanner.h index 27c2e3e9ed7..e2f9892af38 100644 --- a/cpp/src/arrow/dataset/scanner.h +++ b/cpp/src/arrow/dataset/scanner.h @@ -48,8 +48,6 @@ struct ARROW_DS_EXPORT ScanContext { /// Return a threaded or serial TaskGroup according to use_threads. std::shared_ptr TaskGroup() const; - - std::shared_ptr expression_state; }; class ARROW_DS_EXPORT ScanOptions { diff --git a/cpp/src/arrow/dataset/scanner_internal.h b/cpp/src/arrow/dataset/scanner_internal.h index f01ced8396b..686be1e6dfb 100644 --- a/cpp/src/arrow/dataset/scanner_internal.h +++ b/cpp/src/arrow/dataset/scanner_internal.h @@ -31,15 +31,13 @@ namespace arrow { namespace dataset { -inline RecordBatchIterator FilterRecordBatch(RecordBatchIterator it, - Expression2::BoundWithState filter, +inline RecordBatchIterator FilterRecordBatch(RecordBatchIterator it, Expression2 filter, MemoryPool* pool) { return MakeMaybeMapIterator( [=](std::shared_ptr in) -> Result> { compute::ExecContext exec_context{pool}; ARROW_ASSIGN_OR_RAISE(Datum mask, - ExecuteScalarExpression(filter.first, filter.second.get(), - Datum(in), &exec_context)); + ExecuteScalarExpression(filter, Datum(in), &exec_context)); if (mask.is_scalar()) { const auto& mask_scalar = mask.scalar_as(); @@ -83,9 +81,8 @@ class FilterAndProjectScanTask : public ScanTask { Result Execute() override { ARROW_ASSIGN_OR_RAISE(auto it, task_->Execute()); - ARROW_ASSIGN_OR_RAISE( - Expression2::BoundWithState simplified_filter, - SimplifyWithGuarantee({filter_, context_->expression_state}, partition_)); + ARROW_ASSIGN_OR_RAISE(Expression2 simplified_filter, + SimplifyWithGuarantee(filter_, partition_)); RecordBatchIterator filter_it = FilterRecordBatch(std::move(it), simplified_filter, context_->pool); diff --git a/cpp/src/arrow/dataset/test_util.h b/cpp/src/arrow/dataset/test_util.h index 70c9b7548e8..ea39fef3eec 100644 --- a/cpp/src/arrow/dataset/test_util.h +++ b/cpp/src/arrow/dataset/test_util.h @@ -226,8 +226,7 @@ class DatasetFixtureMixin : public ::testing::Test { } void SetFilter(Expression2 filter) { - ASSERT_OK_AND_ASSIGN(std::tie(options_->filter2, ctx_->expression_state), - filter.Bind(*schema_)); + ASSERT_OK_AND_ASSIGN(options_->filter2, filter.Bind(*schema_)); } std::shared_ptr schema_; @@ -381,13 +380,13 @@ struct MakeFileSystemDatasetMixin { continue; } - ASSERT_OK_AND_ASSIGN(std::tie(partitions[i], std::ignore), partitions[i].Bind(*s)); + ASSERT_OK_AND_ASSIGN(partitions[i], partitions[i].Bind(*s)); ASSERT_OK_AND_ASSIGN(auto fragment, format->MakeFragment({info, fs_}, partitions[i])); fragments.push_back(std::move(fragment)); } - ASSERT_OK_AND_ASSIGN(std::tie(root_partition, std::ignore), root_partition.Bind(*s)); + ASSERT_OK_AND_ASSIGN(root_partition, root_partition.Bind(*s)); ASSERT_OK_AND_ASSIGN(dataset_, FileSystemDataset::Make(s, root_partition, format, fs_, std::move(fragments))); } @@ -438,7 +437,7 @@ void AssertFragmentsHavePartitionExpressions(std::shared_ptr dataset, std::vector expected) { ASSERT_OK_AND_ASSIGN(auto fragment_it, dataset->GetFragments()); for (auto& expr : expected) { - ASSERT_OK_AND_ASSIGN(std::tie(expr, std::ignore), expr.Bind(*dataset->schema())); + ASSERT_OK_AND_ASSIGN(expr, expr.Bind(*dataset->schema())); } // Ordering is not guaranteed. EXPECT_THAT(PartitionExpressionsOf(IteratorToVector(std::move(fragment_it))), @@ -591,9 +590,6 @@ class WriteFileSystemDatasetMixin : public MakeFileSystemDatasetMixin { ASSERT_OK_AND_ASSIGN(dataset_, factory->Finish()); scan_options_ = ScanOptions::Make(source_schema_); - ASSERT_OK_AND_ASSIGN( - std::tie(scan_options_->filter2, scan_context_->expression_state), - literal(true).Bind(*source_schema_)); } void SetWriteOptions(std::shared_ptr file_write_options) { diff --git a/cpp/src/arrow/dataset/type_fwd.h b/cpp/src/arrow/dataset/type_fwd.h index e4bd322ba10..81441a18166 100644 --- a/cpp/src/arrow/dataset/type_fwd.h +++ b/cpp/src/arrow/dataset/type_fwd.h @@ -69,8 +69,6 @@ class ParquetFileWriter; class ParquetFileWriteOptions; class Expression; - -struct ExpressionState; class Expression2; class Partitioning; From 36f1bbb6ef12276b35be25fad7424750fa852da6 Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Thu, 10 Dec 2020 12:21:11 -0500 Subject: [PATCH 04/31] repair implicit casts --- cpp/src/arrow/compute/api_scalar.h | 4 +- cpp/src/arrow/compute/cast.h | 24 +- cpp/src/arrow/compute/exec.cc | 3 +- .../compute/kernels/scalar_cast_internal.cc | 31 ++- .../compute/kernels/scalar_cast_numeric.cc | 4 +- .../arrow/compute/kernels/scalar_cast_test.cc | 4 + cpp/src/arrow/dataset/dataset_test.cc | 3 +- cpp/src/arrow/dataset/discovery_test.cc | 2 +- cpp/src/arrow/dataset/expression.cc | 236 +++++++++++++----- cpp/src/arrow/dataset/expression_internal.h | 91 ++++--- cpp/src/arrow/dataset/expression_test.cc | 224 ++++++++++++----- cpp/src/arrow/dataset/filter.cc | 3 +- cpp/src/arrow/dataset/partition.cc | 6 +- cpp/src/arrow/dataset/partition_test.cc | 20 +- cpp/src/arrow/dataset/scanner.cc | 3 + cpp/src/arrow/dataset/scanner_test.cc | 8 +- cpp/src/arrow/dataset/test_util.h | 3 + 17 files changed, 457 insertions(+), 212 deletions(-) diff --git a/cpp/src/arrow/compute/api_scalar.h b/cpp/src/arrow/compute/api_scalar.h index 789ac909ccf..1c6c2ea95b0 100644 --- a/cpp/src/arrow/compute/api_scalar.h +++ b/cpp/src/arrow/compute/api_scalar.h @@ -70,7 +70,7 @@ struct ARROW_EXPORT SplitPatternOptions : public SplitOptions { /// Options for IsIn and IndexIn functions struct ARROW_EXPORT SetLookupOptions : public FunctionOptions { - explicit SetLookupOptions(Datum value_set, bool skip_nulls) + explicit SetLookupOptions(Datum value_set, bool skip_nulls = false) : value_set(std::move(value_set)), skip_nulls(skip_nulls) {} /// The set of values to look up input values into. @@ -86,7 +86,7 @@ struct ARROW_EXPORT SetLookupOptions : public FunctionOptions { struct ARROW_EXPORT StrptimeOptions : public FunctionOptions { explicit StrptimeOptions(std::string format, TimeUnit::type unit) - : format(format), unit(unit) {} + : format(std::move(format)), unit(unit) {} std::string format; TimeUnit::type unit; diff --git a/cpp/src/arrow/compute/cast.h b/cpp/src/arrow/compute/cast.h index 6808d1d86f3..759b7c7665b 100644 --- a/cpp/src/arrow/compute/cast.h +++ b/cpp/src/arrow/compute/cast.h @@ -42,15 +42,7 @@ class ExecContext; /// @{ struct ARROW_EXPORT CastOptions : public FunctionOptions { - CastOptions() - : allow_int_overflow(false), - allow_time_truncate(false), - allow_time_overflow(false), - allow_decimal_truncate(false), - allow_float_truncate(false), - allow_invalid_utf8(false) {} - - explicit CastOptions(bool safe) + explicit CastOptions(bool safe = true) : allow_int_overflow(!safe), allow_time_truncate(!safe), allow_time_overflow(!safe), @@ -58,9 +50,17 @@ struct ARROW_EXPORT CastOptions : public FunctionOptions { allow_float_truncate(!safe), allow_invalid_utf8(!safe) {} - static CastOptions Safe() { return CastOptions(true); } - - static CastOptions Unsafe() { return CastOptions(false); } + static CastOptions Safe(std::shared_ptr to_type = NULLPTR) { + CastOptions safe(true); + safe.to_type = std::move(to_type); + return safe; + } + + static CastOptions Unsafe(std::shared_ptr to_type = NULLPTR) { + CastOptions unsafe(false); + unsafe.to_type = std::move(to_type); + return unsafe; + } // Type being casted to. May be passed separate to eager function // compute::Cast diff --git a/cpp/src/arrow/compute/exec.cc b/cpp/src/arrow/compute/exec.cc index 790ae9b3997..fbd8229f0c8 100644 --- a/cpp/src/arrow/compute/exec.cc +++ b/cpp/src/arrow/compute/exec.cc @@ -648,7 +648,8 @@ class ScalarExecutor : public KernelExecutorImpl { // Decide if we need to preallocate memory for this kernel validity_preallocated_ = (kernel_->null_handling != NullHandling::COMPUTED_NO_PREALLOCATE && - kernel_->null_handling != NullHandling::OUTPUT_NOT_NULL); + kernel_->null_handling != NullHandling::OUTPUT_NOT_NULL && + output_descr_.type->id() != Type::NA); if (kernel_->mem_allocation == MemAllocation::PREALLOCATE) { ComputeDataPreallocate(*output_descr_.type, &data_preallocated_); } diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_internal.cc b/cpp/src/arrow/compute/kernels/scalar_cast_internal.cc index 67f0820402a..7abf1c618da 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_internal.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_internal.cc @@ -149,6 +149,18 @@ void CastNumberToNumberUnsafe(Type::type in_type, Type::type out_type, const Dat // ---------------------------------------------------------------------- void UnpackDictionary(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + auto Finish = [&](Result result) { + if (!result.ok()) { + ctx->SetStatus(result.status()); + return; + } + *out = *result; + }; + + if (out->is_scalar()) { + return Finish(batch[0].scalar_as().GetEncodedValue()); + } + DictionaryArray dict_arr(batch[0].array()); const CastOptions& options = checked_cast(*ctx->state()).options; @@ -160,16 +172,15 @@ void UnpackDictionary(KernelContext* ctx, const ExecBatch& batch, Datum* out) { return; } - Result result = Take(Datum(dict_arr.dictionary()), Datum(dict_arr.indices()), - /*options=*/TakeOptions::Defaults(), ctx->exec_context()); - if (!result.ok()) { - ctx->SetStatus(result.status()); - return; - } - *out = *result; + return Finish(Take(Datum(dict_arr.dictionary()), Datum(dict_arr.indices()), + /*options=*/TakeOptions::Defaults(), ctx->exec_context())); } void OutputAllNull(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + if (out->is_scalar()) { + out->scalar()->is_valid = false; + return; + } ArrayData* output = out->mutable_array(); output->buffers = {nullptr}; output->null_count = batch.length; @@ -261,9 +272,9 @@ void AddCommonCasts(Type::type out_type_id, OutputType out_ty, CastFunction* fun // // XXX: Uses Take and does its own memory allocation for the moment. We can // fix this later. - DCHECK_OK(func->AddKernel( - Type::DICTIONARY, {InputType::Array(Type::DICTIONARY)}, out_ty, UnpackDictionary, - NullHandling::COMPUTED_NO_PREALLOCATE, MemAllocation::NO_PREALLOCATE)); + DCHECK_OK(func->AddKernel(Type::DICTIONARY, {InputType(Type::DICTIONARY)}, out_ty, + UnpackDictionary, NullHandling::COMPUTED_NO_PREALLOCATE, + MemAllocation::NO_PREALLOCATE)); } // From extension type to this type diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc b/cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc index f46cf7e7a75..62665d4ea44 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc @@ -630,8 +630,8 @@ std::vector> GetNumericCasts() { // Make a cast to null that does not do much. Not sure why we need to be able // to cast from dict -> null but there are unit tests for it auto cast_null = std::make_shared("cast_null", Type::NA); - DCHECK_OK(cast_null->AddKernel(Type::DICTIONARY, {InputType::Array(Type::DICTIONARY)}, - null(), OutputAllNull)); + DCHECK_OK(cast_null->AddKernel(Type::DICTIONARY, {InputType(Type::DICTIONARY)}, null(), + OutputAllNull)); functions.push_back(cast_null); functions.push_back(GetCastToInteger("cast_int8")); diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_test.cc b/cpp/src/arrow/compute/kernels/scalar_cast_test.cc index 566a89f4b47..fda8073beb9 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_test.cc @@ -1801,6 +1801,10 @@ TYPED_TEST(TestDictionaryCast, Basic) { // TODO: Should casting dictionary scalars work? this->CheckPass(*dict_arr, *expected, expected->type(), CastOptions::Safe(), /*check_scalar=*/false); + + auto opts = CastOptions::Safe(); + opts.to_type = expected->type(); + CheckScalarUnary("cast", dict_arr, expected, &opts); } } diff --git a/cpp/src/arrow/dataset/dataset_test.cc b/cpp/src/arrow/dataset/dataset_test.cc index 58a4a4a0b3f..cf1e85ea055 100644 --- a/cpp/src/arrow/dataset/dataset_test.cc +++ b/cpp/src/arrow/dataset/dataset_test.cc @@ -786,8 +786,7 @@ TEST(TestDictPartitionColumn, SelectPartitionColumnFilterPhysicalColumn) { // Selects re-ordered virtual column with a filter on a physical column ASSERT_OK_AND_ASSIGN(auto scan_builder, dataset->NewScan()); - ASSERT_OK(scan_builder->Filter("phy_1"_ == 111)); - + ASSERT_OK(scan_builder->Filter(equal(field_ref("phy_1"), literal(111)))); ASSERT_OK(scan_builder->Project({"part"})); ASSERT_OK_AND_ASSIGN(auto scanner, scan_builder->Finish()); diff --git a/cpp/src/arrow/dataset/discovery_test.cc b/cpp/src/arrow/dataset/discovery_test.cc index 219ee070fdc..93d7c3869dd 100644 --- a/cpp/src/arrow/dataset/discovery_test.cc +++ b/cpp/src/arrow/dataset/discovery_test.cc @@ -369,7 +369,7 @@ TEST_F(FileSystemDatasetFactoryTest, FilenameNotPartOfPartitions) { // column. In such case, the filename should not be used. MakeFactory({fs::File("one/file.parquet")}); - ASSERT_OK_AND_ASSIGN(auto expected, equal(field_ref("first"), literal("one")).Bind(*s)); + auto expected = equal(field_ref("first"), literal("one")); ASSERT_OK_AND_ASSIGN(auto dataset, factory_->Finish()); ASSERT_OK_AND_ASSIGN(auto fragment_it, dataset->GetFragments()); diff --git a/cpp/src/arrow/dataset/expression.cc b/cpp/src/arrow/dataset/expression.cc index 8f9a3ce5663..21c067e3d64 100644 --- a/cpp/src/arrow/dataset/expression.cc +++ b/cpp/src/arrow/dataset/expression.cc @@ -213,7 +213,47 @@ std::string Expression2::ToString() const { for (const auto& arg : call->arguments) { out += arg.ToString() + ","; } - out.back() = ')'; + + if (!call->options) { + out.back() = ')'; + return out; + } + + if (call->options) { + out += " {"; + + if (auto options = GetSetLookupOptions(*call)) { + DCHECK_EQ(options->value_set.kind(), Datum::ARRAY); + out += "value_set:" + options->value_set.make_array()->ToString(); + if (options->skip_nulls) { + out += ",skip_nulls"; + } + } + + if (auto options = GetCastOptions(*call)) { + out += "to_type:" + options->to_type->ToString(); + if (options->allow_int_overflow) out += ",allow_int_overflow"; + if (options->allow_time_truncate) out += ",allow_time_truncate"; + if (options->allow_time_overflow) out += ",allow_time_overflow"; + if (options->allow_decimal_truncate) out += ",allow_decimal_truncate"; + if (options->allow_float_truncate) out += ",allow_float_truncate"; + if (options->allow_invalid_utf8) out += ",allow_invalid_utf8"; + } + + if (auto options = GetStructOptions(*call)) { + for (const auto& field_name : options->field_names) { + out += field_name + ","; + } + out.resize(out.size() - 1); + } + + if (auto options = GetStrptimeOptions(*call)) { + out += "format:" + options->format; + out += ",unit:" + internal::ToString(options->unit); + } + + out += "})"; + } return out; } @@ -246,14 +286,49 @@ bool Expression2::Equals(const Expression2& other) const { return false; } - // FIXME compare FunctionOptions for equality for (size_t i = 0; i < call->arguments.size(); ++i) { if (!call->arguments[i].Equals(other_call->arguments[i])) { return false; } } - return true; + if (call->options == other_call->options) return true; + + if (auto options = GetSetLookupOptions(*call)) { + auto other_options = GetSetLookupOptions(*other_call); + return options->value_set == other_options->value_set && + options->skip_nulls == other_options->skip_nulls; + } + + if (auto options = GetCastOptions(*call)) { + auto other_options = GetCastOptions(*other_call); + for (auto safety_opt : { + &compute::CastOptions::allow_int_overflow, + &compute::CastOptions::allow_time_truncate, + &compute::CastOptions::allow_time_overflow, + &compute::CastOptions::allow_decimal_truncate, + &compute::CastOptions::allow_float_truncate, + &compute::CastOptions::allow_invalid_utf8, + }) { + if (options->*safety_opt != other_options->*safety_opt) return false; + } + return options->to_type->Equals(other_options->to_type); + } + + if (auto options = GetStructOptions(*call)) { + auto other_options = GetStructOptions(*other_call); + return options->field_names == other_options->field_names; + } + + if (auto options = GetStrptimeOptions(*call)) { + auto other_options = GetStrptimeOptions(*other_call); + return options->format == other_options->format && + options->unit == other_options->unit; + } + + ARROW_LOG(WARNING) << "comparing unknown FunctionOptions for function " + << call->function; + return false; } size_t Expression2::hash() const { @@ -378,6 +453,81 @@ Result> InitKernelState( return std::move(kernel_state); } +Status MaybeInsertCast(std::shared_ptr to_type, Expression2* expr) { + if (expr->descr().type->Equals(to_type)) { + return Status::OK(); + } + + if (auto lit = expr->literal()) { + ARROW_ASSIGN_OR_RAISE(Datum new_lit, compute::Cast(*lit, to_type)); + *expr = literal(std::move(new_lit)); + return Status::OK(); + } + + // FIXME the resulting cast Call must be bound but this is a hack + auto with_cast = call("cast", {literal(MakeNullScalar(expr->descr().type))}, + compute::CastOptions::Safe(to_type)); + + static ValueDescr ignored_descr; + ARROW_ASSIGN_OR_RAISE(with_cast, with_cast.Bind(ignored_descr)); + + auto call_with_cast = *CallNotNull(with_cast); + call_with_cast.arguments[0] = std::move(*expr); + *expr = Expression2(std::make_shared(std::move(call_with_cast)), + ValueDescr{std::move(to_type), expr->descr().shape}); + + return Status::OK(); +} + +Status InsertImplicitCasts(Expression2::Call* call) { + DCHECK(std::all_of(call->arguments.begin(), call->arguments.end(), + [](const Expression2& argument) { return argument.IsBound(); })); + + if (Comparison::Get(call->function)) { + for (auto&& argument : call->arguments) { + if (auto value_type = GetDictionaryValueType(argument.descr().type)) { + RETURN_NOT_OK(MaybeInsertCast(std::move(value_type), &argument)); + } + } + + if (call->arguments[0].descr().shape == ValueDescr::SCALAR) { + // argument 0 is scalar so casting is cheap + return MaybeInsertCast(call->arguments[1].descr().type, &call->arguments[0]); + } else { + return MaybeInsertCast(call->arguments[0].descr().type, &call->arguments[1]); + } + } + + if (auto options = GetSetLookupOptions(*call)) { + if (auto value_type = GetDictionaryValueType(call->arguments[0].descr().type)) { + // DICTIONARY input is not supported; decode it. + RETURN_NOT_OK(MaybeInsertCast(std::move(value_type), &call->arguments[0])); + } + + if (options->value_set.type()->id() == Type::DICTIONARY) { + // DICTIONARY value_set is not supported; decode it. + auto new_options = std::make_shared(*options); + RETURN_NOT_OK(EnsureNotDictionary(&new_options->value_set)); + options = new_options.get(); + call->options = std::move(new_options); + } + + if (!options->value_set.type()->Equals(call->arguments[0].descr().type)) { + // The value_set is assumed smaller than inputs, casting it should be cheaper. + auto new_options = std::make_shared(*options); + ARROW_ASSIGN_OR_RAISE(new_options->value_set, + compute::Cast(std::move(new_options->value_set), + call->arguments[0].descr().type)); + options = new_options.get(); + call->options = std::move(new_options); + } + + return Status::OK(); + } + + return Status::OK(); +} + Result Expression2::Bind(ValueDescr in, compute::ExecContext* exec_context) const { if (exec_context == nullptr) { @@ -402,18 +552,9 @@ Result Expression2::Bind(ValueDescr in, for (auto&& argument : bound_call.arguments) { ARROW_ASSIGN_OR_RAISE(argument, argument.Bind(in, exec_context)); } - - if (RequriesDictionaryTransparency(bound_call)) { - RETURN_NOT_OK(EnsureNotDictionary(&bound_call)); - } + RETURN_NOT_OK(InsertImplicitCasts(&bound_call)); auto descrs = GetDescriptors(bound_call.arguments); - for (auto&& descr : descrs) { - if (RequriesDictionaryTransparency(bound_call)) { - RETURN_NOT_OK(EnsureNotDictionary(&descr)); - } - } - ARROW_ASSIGN_OR_RAISE(bound_call.kernel, function->DispatchExact(descrs)); compute::KernelContext kernel_context(exec_context); @@ -470,10 +611,6 @@ Result ExecuteScalarExpression(const Expression2& expr, const Datum& inpu for (size_t i = 0; i < arguments.size(); ++i) { ARROW_ASSIGN_OR_RAISE( arguments[i], ExecuteScalarExpression(call->arguments[i], input, exec_context)); - - if (RequriesDictionaryTransparency(*call)) { - RETURN_NOT_OK(EnsureNotDictionary(&arguments[i])); - } } auto executor = compute::detail::KernelExecutor::MakeScalar(); @@ -558,38 +695,6 @@ std::vector FieldsInExpression(const Expression2& expr) { return fields; } -template -Result Modify(Expression2 expr, const PreVisit& pre, - const PostVisitCall& post_call) { - ARROW_ASSIGN_OR_RAISE(expr, Result(pre(std::move(expr)))); - - auto call = expr.call(); - if (!call) return expr; - - bool at_least_one_modified = false; - auto modified_call = *call; - auto modified_argument = modified_call.arguments.begin(); - - for (const auto& argument : call->arguments) { - ARROW_ASSIGN_OR_RAISE(*modified_argument, Modify(argument, pre, post_call)); - - if (!Identical(*modified_argument, argument)) { - at_least_one_modified = true; - } - ++modified_argument; - } - - if (at_least_one_modified) { - // reconstruct the call expression with the modified arguments - auto modified_expr = Expression2( - std::make_shared(std::move(modified_call)), expr.descr()); - - return post_call(std::move(modified_expr), &expr); - } - - return post_call(std::move(expr), nullptr); -} - Result FoldConstants(Expression2 expr) { return Modify( std::move(expr), [](Expression2 expr) { return expr; }, @@ -625,6 +730,9 @@ Result FoldConstants(Expression2 expr) { // false and x == false if (args.first == literal(false)) return args.first; + + // x and x == x + if (args.first == args.second) return args.first; } return expr; } @@ -636,6 +744,9 @@ Result FoldConstants(Expression2 expr) { // true or x == true if (args.first == literal(true)) return args.first; + + // x or x == x + if (args.first == args.second) return args.first; } return expr; } @@ -658,10 +769,6 @@ inline std::vector GuaranteeConjunctionMembers( Status ExtractKnownFieldValuesImpl( std::vector* conjunction_members, std::unordered_map* known_values) { - for (auto&& member : *conjunction_members) { - ARROW_ASSIGN_OR_RAISE(member, Canonicalize({std::move(member)})); - } - auto unconsumed_end = std::partition(conjunction_members->begin(), conjunction_members->end(), [](const Expression2& expr) { @@ -703,9 +810,6 @@ Status ExtractKnownFieldValuesImpl( Result> ExtractKnownFieldValues( const Expression2& guaranteed_true_predicate) { - if (!guaranteed_true_predicate.IsBound()) { - return Status::Invalid("guaranteed_true_predicate was not bound"); - } auto conjunction_members = GuaranteeConjunctionMembers(guaranteed_true_predicate); std::unordered_map known_values; RETURN_NOT_OK(ExtractKnownFieldValuesImpl(&conjunction_members, &known_values)); @@ -715,13 +819,20 @@ Result> ExtractKnownFieldVal Result ReplaceFieldsWithKnownValues( const std::unordered_map& known_values, Expression2 expr) { + if (!expr.IsBound()) { + return Status::Invalid( + "ReplaceFieldsWithKnownValues called on an unbound Expression2"); + } + return Modify( std::move(expr), - [&known_values](Expression2 expr) { + [&known_values](Expression2 expr) -> Result { if (auto ref = expr.field_ref()) { auto it = known_values.find(*ref); if (it != known_values.end()) { - return literal(it->second); + ARROW_ASSIGN_OR_RAISE(Datum lit, + compute::Cast(it->second, expr.descr().type)); + return literal(std::move(lit)); } } return expr; @@ -817,11 +928,6 @@ Result Canonicalize(Expression2 expr, compute::ExecContext* exec_co exec_context->func_registry()->GetFunction(flipped_call.function)); auto descrs = GetDescriptors(flipped_call.arguments); - for (auto& descr : descrs) { - if (RequriesDictionaryTransparency(flipped_call)) { - RETURN_NOT_OK(EnsureNotDictionary(&descr)); - } - } ARROW_ASSIGN_OR_RAISE(flipped_call.kernel, function->DispatchExact(descrs)); std::swap(flipped_call.arguments[0], flipped_call.arguments[1]); @@ -896,10 +1002,6 @@ Result DirectComparisonSimplification(Expression2 expr, Result SimplifyWithGuarantee(Expression2 expr, const Expression2& guaranteed_true_predicate) { - if (!guaranteed_true_predicate.IsBound()) { - return Status::Invalid("guaranteed_true_predicate was not bound"); - } - auto conjunction_members = GuaranteeConjunctionMembers(guaranteed_true_predicate); std::unordered_map known_values; diff --git a/cpp/src/arrow/dataset/expression_internal.h b/cpp/src/arrow/dataset/expression_internal.h index 84b802937e0..0324eaedc77 100644 --- a/cpp/src/arrow/dataset/expression_internal.h +++ b/cpp/src/arrow/dataset/expression_internal.h @@ -140,10 +140,16 @@ struct Comparison { return nullptr; } + // Execute a simple Comparison between scalars, casting the RHS if types disagree static Result Execute(Datum l, Datum r) { if (!l.is_scalar() || !r.is_scalar()) { return Status::Invalid("Cannot Execute Comparison on non-scalars"); } + + if (!l.type()->Equals(r.type())) { + ARROW_ASSIGN_OR_RAISE(r, compute::Cast(r, l.type())); + } + std::vector arguments{std::move(l), std::move(r)}; ARROW_ASSIGN_OR_RAISE(auto equal, compute::CallFunction("equal", arguments)); @@ -201,6 +207,11 @@ struct Comparison { } }; +inline const compute::CastOptions* GetCastOptions(const Expression2::Call& call) { + if (call.function != "cast") return nullptr; + return checked_cast(call.options.get()); +} + inline bool IsSetLookup(const std::string& function) { return function == "is_in" || function == "index_in"; } @@ -211,45 +222,37 @@ inline const compute::SetLookupOptions* GetSetLookupOptions( return checked_cast(call.options.get()); } -inline bool RequriesDictionaryTransparency(const Expression2::Call& call) { - // TODO move this functionality into compute:: - - // Functions which don't provide kernels for dictionary types. Dictionaries will be - // decoded for these functions. - if (Comparison::Get(call.function)) return true; +inline const compute::StructOptions* GetStructOptions(const Expression2::Call& call) { + if (call.function != "struct") return nullptr; + return checked_cast(call.options.get()); +} - if (IsSetLookup(call.function)) return true; +inline const compute::StrptimeOptions* GetStrptimeOptions(const Expression2::Call& call) { + if (call.function != "strptime") return nullptr; + return checked_cast(call.options.get()); +} - return false; +inline const std::shared_ptr& GetDictionaryValueType( + const std::shared_ptr& type) { + if (type && type->id() == Type::DICTIONARY) { + return checked_cast(*type).value_type(); + } + static std::shared_ptr null; + return null; } inline Status EnsureNotDictionary(ValueDescr* descr) { - const auto& type = descr->type; - if (type && type->id() == Type::DICTIONARY) { - descr->type = checked_cast(*type).value_type(); + if (auto value_type = GetDictionaryValueType(descr->type)) { + descr->type = std::move(value_type); } return Status::OK(); } inline Status EnsureNotDictionary(Datum* datum) { - if (datum->type()->id() != Type::DICTIONARY) { - return Status::OK(); - } - - if (datum->is_scalar()) { - ARROW_ASSIGN_OR_RAISE( - *datum, - checked_cast(*datum->scalar()).GetEncodedValue()); - return Status::OK(); + if (datum->type()->id() == Type::DICTIONARY) { + const auto& type = checked_cast(*datum->type()).value_type(); + ARROW_ASSIGN_OR_RAISE(*datum, compute::Cast(*datum, type)); } - - DCHECK_EQ(datum->kind(), Datum::ARRAY); - ArrayData indices = *datum->array(); - indices.type = checked_cast(*datum->type()).index_type(); - auto values = std::move(indices.dictionary); - - ARROW_ASSIGN_OR_RAISE( - *datum, compute::Take(values, indices, compute::TakeOptions::NoBoundsCheck())); return Status::OK(); } @@ -397,5 +400,37 @@ inline Result> GetFunction( return compute::GetCastFunction(to_type); } +template +Result Modify(Expression2 expr, const PreVisit& pre, + const PostVisitCall& post_call) { + ARROW_ASSIGN_OR_RAISE(expr, Result(pre(std::move(expr)))); + + auto call = expr.call(); + if (!call) return expr; + + bool at_least_one_modified = false; + auto modified_call = *call; + auto modified_argument = modified_call.arguments.begin(); + + for (const auto& argument : call->arguments) { + ARROW_ASSIGN_OR_RAISE(*modified_argument, Modify(argument, pre, post_call)); + + if (!Identical(*modified_argument, argument)) { + at_least_one_modified = true; + } + ++modified_argument; + } + + if (at_least_one_modified) { + // reconstruct the call expression with the modified arguments + auto modified_expr = Expression2( + std::make_shared(std::move(modified_call)), expr.descr()); + + return post_call(std::move(modified_expr), &expr); + } + + return post_call(std::move(expr), nullptr); +} + } // namespace dataset } // namespace arrow diff --git a/cpp/src/arrow/dataset/expression_test.cc b/cpp/src/arrow/dataset/expression_test.cc index 8ea70adf01c..44dd322aa63 100644 --- a/cpp/src/arrow/dataset/expression_test.cc +++ b/cpp/src/arrow/dataset/expression_test.cc @@ -42,6 +42,11 @@ namespace dataset { #define EXPECT_OK ARROW_EXPECT_OK +Expression2 cast(Expression2 argument, std::shared_ptr to_type) { + return call("cast", {std::move(argument)}, + compute::CastOptions::Safe(std::move(to_type))); +} + TEST(Expression2, ToString) { EXPECT_EQ(field_ref("alpha").ToString(), "FieldRef(alpha)"); @@ -50,23 +55,59 @@ TEST(Expression2, ToString) { EXPECT_EQ(call("add", {literal(3), field_ref("beta")}).ToString(), "add(3,FieldRef(beta))"); - EXPECT_EQ(call("add", {literal(3), call("index_in", {field_ref("beta")})}).ToString(), - "add(3,index_in(FieldRef(beta)))"); + auto in_12 = call("index_in", {field_ref("beta")}, + compute::SetLookupOptions{ArrayFromJSON(int32(), "[1,2]")}); + + EXPECT_EQ(in_12.ToString(), "index_in(FieldRef(beta), {value_set:[\n 1,\n 2\n]})"); + + EXPECT_EQ(cast(field_ref("a"), int32()).ToString(), + "cast(FieldRef(a), {to_type:int32})"); + + EXPECT_EQ(project( + { + field_ref("a"), + field_ref("a"), + literal(3), + in_12, + }, + { + "a", + "renamed_a", + "three", + "b", + }) + .ToString(), + "struct(FieldRef(a),FieldRef(a),3," + in_12.ToString() + + ", {a,renamed_a,three,b})"); } TEST(Expression2, Equality) { EXPECT_EQ(literal(1), literal(1)); + EXPECT_NE(literal(1), literal(2)); EXPECT_EQ(field_ref("a"), field_ref("a")); + EXPECT_NE(field_ref("a"), field_ref("b")); + EXPECT_NE(field_ref("a"), literal(2)); EXPECT_EQ(call("add", {literal(3), field_ref("a")}), call("add", {literal(3), field_ref("a")})); + EXPECT_NE(call("add", {literal(3), field_ref("a")}), + call("add", {literal(2), field_ref("a")})); + EXPECT_NE(call("add", {field_ref("a"), literal(3)}), + call("add", {literal(3), field_ref("a")})); - EXPECT_EQ(call("add", {literal(3), call("index_in", {field_ref("beta")})}), - call("add", {literal(3), call("index_in", {field_ref("beta")})})); + auto in_123 = compute::SetLookupOptions{ArrayFromJSON(int32(), "[1,2,3]")}; + EXPECT_EQ(call("add", {literal(3), call("index_in", {field_ref("beta")}, in_123)}), + call("add", {literal(3), call("index_in", {field_ref("beta")}, in_123)})); - // FIXME options are not currently compared, so two cast exprs will compare equal - // regardless of differing target type + auto in_12 = compute::SetLookupOptions{ArrayFromJSON(int32(), "[1,2]")}; + EXPECT_NE(call("add", {literal(3), call("index_in", {field_ref("beta")}, in_12)}), + call("add", {literal(3), call("index_in", {field_ref("beta")}, in_123)})); + + EXPECT_EQ(cast(field_ref("a"), int32()), cast(field_ref("a"), int32())); + EXPECT_NE(cast(field_ref("a"), int32()), cast(field_ref("a"), int64())); + EXPECT_NE(cast(field_ref("a"), int32()), + call("cast", {field_ref("a")}, compute::CastOptions::Unsafe(int32()))); } TEST(Expression2, Hash) { @@ -231,13 +272,53 @@ TEST(Expression2, BindCall) { expr.Bind(Schema({field("a", int32()), field("b", int32())}))); } -TEST(Expression2, BindDictionaryTransparent) { - auto expr = call("equal", {field_ref("a"), field_ref("b")}); - EXPECT_FALSE(expr.IsBound()); +TEST(Expression2, BindWithImplicitCasts) { + for (auto cmp : {equal, not_equal, less, less_equal, greater, greater_equal}) { + // cast arguments to same type + ASSERT_OK_AND_ASSIGN(auto expr, + cmp(field_ref("i32"), field_ref("i64")).Bind(*kBoringSchema)); + + // NB: RHS is cast unless LHS is scalar. + ASSERT_OK_AND_ASSIGN( + auto expected, + cmp(field_ref("i32"), cast(field_ref("i64"), int32())).Bind(*kBoringSchema)); + + EXPECT_EQ(expr, expected); + + // cast dictionary to value type + ASSERT_OK_AND_ASSIGN( + expr, cmp(field_ref("dict_str"), field_ref("str")).Bind(*kBoringSchema)); + + ASSERT_OK_AND_ASSIGN( + expected, + cmp(cast(field_ref("dict_str"), utf8()), field_ref("str")).Bind(*kBoringSchema)); + + EXPECT_EQ(expr, expected); + } + // cast value_set to argument type + auto Opts = [](std::shared_ptr type) { + return compute::SetLookupOptions{ArrayFromJSON(type, R"(["a"])")}; + }; + ASSERT_OK_AND_ASSIGN( + auto expr, call("is_in", {field_ref("str")}, Opts(binary())).Bind(*kBoringSchema)); + ASSERT_OK_AND_ASSIGN( + auto expected, + call("is_in", {field_ref("str")}, Opts(utf8())).Bind(*kBoringSchema)); + + EXPECT_EQ(expr, expected); + + // dictionary decode set then cast to argument type + ASSERT_OK_AND_ASSIGN( + expr, call("is_in", {field_ref("str")}, Opts(dictionary(int32(), binary()))) + .Bind(*kBoringSchema)); + + EXPECT_EQ(expr, expected); +} + +TEST(Expression2, BindDictionaryTransparent) { ASSERT_OK_AND_ASSIGN( - expr, - expr.Bind(Schema({field("a", utf8()), field("b", dictionary(int32(), utf8()))}))); + auto expr, equal(field_ref("str"), field_ref("dict_str")).Bind(*kBoringSchema)); EXPECT_EQ(expr.descr(), ValueDescr::Array(boolean())); EXPECT_TRUE(expr.IsBound()); @@ -303,46 +384,34 @@ Result NaiveExecuteScalarExpression(const Expression2& expr, const Datum& for (size_t i = 0; i < arguments.size(); ++i) { ARROW_ASSIGN_OR_RAISE(arguments[i], NaiveExecuteScalarExpression(call->arguments[i], input)); - - if (RequriesDictionaryTransparency(*call)) { - RETURN_NOT_OK(EnsureNotDictionary(&arguments[i])); - } } - ARROW_ASSIGN_OR_RAISE(auto function, - compute::GetFunctionRegistry()->GetFunction(call->function)); + compute::ExecContext exec_context; + ARROW_ASSIGN_OR_RAISE(auto function, GetFunction(*call, &exec_context)); auto descrs = GetDescriptors(call->arguments); - for (size_t i = 0; i < arguments.size(); ++i) { - if (RequriesDictionaryTransparency(*call)) { - RETURN_NOT_OK(EnsureNotDictionary(&descrs[i])); - } - EXPECT_EQ(arguments[i].descr(), descrs[i]); - } - ARROW_ASSIGN_OR_RAISE(auto expected_kernel, function->DispatchExact(descrs)); EXPECT_EQ(call->kernel, expected_kernel); - - auto options = call->options; - if (RequriesDictionaryTransparency(*call)) { - auto non_dict_call = *call; - RETURN_NOT_OK(EnsureNotDictionary(&non_dict_call)); - options = non_dict_call.options; - } - - compute::ExecContext exec_context; return function->Execute(arguments, call->options.get(), &exec_context); } -void AssertExecute(Expression2 expr, Datum in) { - ASSERT_OK_AND_ASSIGN(expr, expr.Bind(in.descr())); +void AssertExecute(Expression2 expr, Datum in, Datum* actual_out = NULLPTR) { + if (in.is_value()) { + ASSERT_OK_AND_ASSIGN(expr, expr.Bind(in.descr())); + } else { + ASSERT_OK_AND_ASSIGN(expr, expr.Bind(*in.record_batch()->schema())); + } ASSERT_OK_AND_ASSIGN(Datum actual, ExecuteScalarExpression(expr, in)); ASSERT_OK_AND_ASSIGN(Datum expected, NaiveExecuteScalarExpression(expr, in)); AssertDatumsEqual(actual, expected, /*verbose=*/true); + + if (actual_out) { + *actual_out = actual; + } } TEST(Expression2, ExecuteCall) { @@ -481,16 +550,17 @@ TEST(Expression2, FoldConstantsBoolean) { ExpectFoldsTo(and_(false_, whatever), false_); ExpectFoldsTo(and_(true_, whatever), whatever); + ExpectFoldsTo(and_(whatever, whatever), whatever); ExpectFoldsTo(or_(true_, whatever), true_); ExpectFoldsTo(or_(false_, whatever), whatever); + ExpectFoldsTo(or_(whatever, whatever), whatever); } TEST(Expression2, ExtractKnownFieldValues) { struct { void operator()(Expression2 guarantee, std::unordered_map expected) { - ASSERT_OK_AND_ASSIGN(guarantee, guarantee.Bind(*kBoringSchema)); ASSERT_OK_AND_ASSIGN(auto actual, ExtractKnownFieldValues(guarantee)); EXPECT_THAT(actual, UnorderedElementsAreArray(expected)) << " guarantee: " << guarantee.ToString(); @@ -506,22 +576,34 @@ TEST(Expression2, ExtractKnownFieldValues) { ExpectKnown(equal(field_ref("i32"), literal(null_int32)), {{"i32", Datum(null_int32)}}); ExpectKnown( - and_({equal(field_ref("i32"), literal(3)), equal(literal(1.5F), field_ref("f32"))}), + and_({equal(field_ref("i32"), literal(3)), equal(field_ref("f32"), literal(1.5F))}), {{"i32", Datum(3)}, {"f32", Datum(1.5F)}}); + // NB: guarantees are *not* automatically canonicalized ExpectKnown( - and_({equal(field_ref("i32"), literal(3)), equal(literal(2.F), field_ref("f32")), - equal(literal(1), field_ref("i32_req"))}), + and_({equal(field_ref("i32"), literal(3)), equal(literal(1.5F), field_ref("f32"))}), + {{"i32", Datum(3)}}); + + // NB: guarantees are *not* automatically simplified + // (the below could be constant folded to a usable guarantee) + ExpectKnown(or_({equal(field_ref("i32"), literal(3)), literal(false)}), {}); + + // NB: guarantees are unbound; applying them may require casts + ExpectKnown(equal(field_ref("i32"), literal("1234324")), {{"i32", Datum("1234324")}}); + + ExpectKnown( + and_({equal(field_ref("i32"), literal(3)), equal(field_ref("f32"), literal(2.F)), + equal(field_ref("i32_req"), literal(1))}), {{"i32", Datum(3)}, {"f32", Datum(2.F)}, {"i32_req", Datum(1)}}); ExpectKnown( and_(or_(equal(field_ref("i32"), literal(3)), equal(field_ref("i32"), literal(4))), - equal(literal(2.F), field_ref("f32"))), + equal(field_ref("f32"), literal(2.F))), {{"f32", Datum(2.F)}}); ExpectKnown(and_({equal(field_ref("i32"), literal(3)), equal(field_ref("f32"), field_ref("f32_req")), - equal(literal(1), field_ref("i32_req"))}), + equal(field_ref("i32_req"), literal(1))}), {{"i32", Datum(3)}, {"i32_req", Datum(1)}}); } @@ -530,10 +612,10 @@ TEST(Expression2, ReplaceFieldsWithKnownValues) { [](Expression2 expr, std::unordered_map known_values, Expression2 unbound_expected) { - ASSERT_OK_AND_ASSIGN(auto bound, expr.Bind(*kBoringSchema)); + ASSERT_OK_AND_ASSIGN(expr, expr.Bind(*kBoringSchema)); ASSERT_OK_AND_ASSIGN(auto expected, unbound_expected.Bind(*kBoringSchema)); ASSERT_OK_AND_ASSIGN(auto replaced, - ReplaceFieldsWithKnownValues(known_values, bound)); + ReplaceFieldsWithKnownValues(known_values, expr)); EXPECT_EQ(replaced, expected); @@ -549,6 +631,9 @@ TEST(Expression2, ReplaceFieldsWithKnownValues) { ExpectReplacesTo(field_ref("i32"), i32_is_3, literal(3)); + // NB: known_values will be cast + ExpectReplacesTo(field_ref("i32"), {{"i32", Datum("3")}}, literal(3)); + ExpectReplacesTo(field_ref("b"), i32_is_3, field_ref("b")); ExpectReplacesTo(equal(field_ref("i32"), literal(1)), i32_is_3, @@ -648,7 +733,6 @@ struct Simplify { void Expect(Expression2 unbound_expected) { ASSERT_OK_AND_ASSIGN(auto bound, expr.Bind(*kBoringSchema)); - ASSERT_OK_AND_ASSIGN(guarantee, guarantee.Bind(*kBoringSchema)); ASSERT_OK_AND_ASSIGN(auto simplified, SimplifyWithGuarantee(bound, guarantee)); @@ -806,30 +890,47 @@ TEST(Expression2, SimplifyWithGuarantee) { .WithGuarantee(and_(greater_equal(field_ref("i32"), literal(0)), less_equal(field_ref("i32"), literal(1)))) .Expect(equal(field_ref("i32"), literal(0))); + Simplify{ + or_(equal(field_ref("f32"), literal("0")), equal(field_ref("i32"), literal(3)))} + .WithGuarantee(greater(field_ref("f32"), literal(0.0))) + .Expect(equal(field_ref("i32"), literal(3))); + + // simplification can see through implicit casts + Simplify{or_({equal(field_ref("f32"), literal("0")), + call("is_in", {field_ref("i64")}, + compute::SetLookupOptions{ + ArrayFromJSON(dictionary(int32(), int32()), "[1,2,3]"), true})})} + .WithGuarantee(greater(field_ref("f32"), literal(0.0))) + .Expect(call("is_in", {field_ref("i64")}, + compute::SetLookupOptions{ArrayFromJSON(int64(), "[1,2,3]"), true})); } TEST(Expression2, SimplifyThenExecute) { auto filter = - and_({equal(field_ref("f32"), literal(0.F)), - call("is_in", {field_ref("i32")}, - compute::SetLookupOptions{ArrayFromJSON(int32(), "[1,2,3]"), true})}); + or_({equal(field_ref("f32"), literal("0")), + call("is_in", {field_ref("i64")}, + compute::SetLookupOptions{ + ArrayFromJSON(dictionary(int32(), int32()), "[1,2,3]"), true})}); - ASSERT_OK_AND_ASSIGN(auto bound, filter.Bind(*kBoringSchema)); - ASSERT_OK_AND_ASSIGN(auto guarantee, - equal(field_ref("f32"), literal(0.F)).Bind(*kBoringSchema)); + ASSERT_OK_AND_ASSIGN(filter, filter.Bind(*kBoringSchema)); + auto guarantee = greater(field_ref("f32"), literal(0.0)); - ASSERT_OK_AND_ASSIGN(bound, SimplifyWithGuarantee(bound, guarantee)); + ASSERT_OK_AND_ASSIGN(auto simplified, SimplifyWithGuarantee(filter, guarantee)); auto input = RecordBatchFromJSON(kBoringSchema, R"([ - {"i32": 0, "f32": -0.1}, - {"i32": 0, "f32": 0.3}, - {"i32": 1, "f32": 0.2}, - {"i32": 2, "f32": -0.1}, - {"i32": 0, "f32": 0.1}, - {"i32": 0, "f32": null}, - {"i32": 0, "f32": 1.0} + {"i64": 0, "f32": 0.1}, + {"i64": 0, "f32": 0.3}, + {"i64": 1, "f32": 0.5}, + {"i64": 2, "f32": 0.1}, + {"i64": 0, "f32": 0.1}, + {"i64": 0, "f32": 0.4}, + {"i64": 0, "f32": 1.0} ])"); - ASSERT_OK_AND_ASSIGN(Datum evaluated, ExecuteScalarExpression(bound, input)); + + Datum evaluated, simplified_evaluated; + AssertExecute(filter, input, &evaluated); + AssertExecute(simplified, input, &simplified_evaluated); + AssertDatumsEqual(evaluated, simplified_evaluated, /*verbose=*/true); } TEST(Expression2, Filter) { @@ -889,10 +990,9 @@ TEST(Expression2, SerializationRoundTrips) { ExpectRoundTrips(not_(field_ref("alpha"))); - compute::CastOptions to_float64; - to_float64.to_type = float64(); ExpectRoundTrips( - call("is_in", {call("cast", {field_ref("version")}, to_float64)}, + call("is_in", + {call("cast", {field_ref("version")}, compute::CastOptions::Safe(float64()))}, compute::SetLookupOptions{ArrayFromJSON(float64(), "[0.5, 1.0, 2.0]"), true})); ExpectRoundTrips(call("is_valid", {field_ref("validity")})); diff --git a/cpp/src/arrow/dataset/filter.cc b/cpp/src/arrow/dataset/filter.cc index e231737bc5f..71253cfde89 100644 --- a/cpp/src/arrow/dataset/filter.cc +++ b/cpp/src/arrow/dataset/filter.cc @@ -983,7 +983,8 @@ struct InsertImplicitCastsImpl { if (!op.type->Equals(set->type())) { // cast the set (which we assume to be small) to match op.type - ARROW_ASSIGN_OR_RAISE(auto encoded_set, CastOrDictionaryEncode(*set, op.type, {})); + ARROW_ASSIGN_OR_RAISE(auto encoded_set, CastOrDictionaryEncode( + *set, op.type, compute::CastOptions{})); set = encoded_set.make_array(); } diff --git a/cpp/src/arrow/dataset/partition.cc b/cpp/src/arrow/dataset/partition.cc index 4e7cc062997..16a49b988a9 100644 --- a/cpp/src/arrow/dataset/partition.cc +++ b/cpp/src/arrow/dataset/partition.cc @@ -186,14 +186,10 @@ Result KeyValuePartitioning::Parse(const std::string& path) const { expressions.push_back(std::move(expr)); } - return and_(std::move(expressions)).Bind(*schema_); + return and_(std::move(expressions)); } Result KeyValuePartitioning::Format(const Expression2& expr) const { - if (!expr.IsBound()) { - return Status::Invalid("formatted partition expressions must be bound"); - } - std::vector values{static_cast(schema_->num_fields()), nullptr}; ARROW_ASSIGN_OR_RAISE(auto known_values, ExtractKnownFieldValues(expr)); diff --git a/cpp/src/arrow/dataset/partition_test.cc b/cpp/src/arrow/dataset/partition_test.cc index 57a1693b953..7bf7efa3761 100644 --- a/cpp/src/arrow/dataset/partition_test.cc +++ b/cpp/src/arrow/dataset/partition_test.cc @@ -45,21 +45,17 @@ class TestPartitioning : public ::testing::Test { void AssertParse(const std::string& path, Expression2 expected) { ASSERT_OK_AND_ASSIGN(auto parsed, partitioning_->Parse(path)); - ASSERT_OK_AND_ASSIGN(expected, expected.Bind(*partitioning_->schema())); ASSERT_EQ(parsed, expected); } template void AssertFormatError(Expression2 expr) { - ASSERT_OK_AND_ASSIGN(expr, expr.Bind(*written_schema_)); ASSERT_EQ(partitioning_->Format(expr).status().code(), code); } void AssertFormat(Expression2 expr, const std::string& expected) { // formatted partition expressions are bound to the schema of the dataset being // written - ASSERT_OK_AND_ASSIGN(expr, expr.Bind(*written_schema_)); - ASSERT_OK_AND_ASSIGN(auto formatted, partitioning_->Format(expr)); ASSERT_EQ(formatted, expected); @@ -67,8 +63,8 @@ class TestPartitioning : public ::testing::Test { // expression: roundtripped should be a subset of expr ASSERT_OK_AND_ASSIGN(Expression2 roundtripped, partitioning_->Parse(formatted)); - ASSERT_OK_AND_ASSIGN(auto bound, roundtripped.Bind(*partitioning_->schema())); - ASSERT_OK_AND_ASSIGN(auto simplified, SimplifyWithGuarantee(bound, expr)); + ASSERT_OK_AND_ASSIGN(roundtripped, roundtripped.Bind(*written_schema_)); + ASSERT_OK_AND_ASSIGN(auto simplified, SimplifyWithGuarantee(roundtripped, expr)); ASSERT_EQ(simplified, literal(true)); } @@ -365,7 +361,7 @@ TEST_F(TestPartitioning, EtlThenHive) { fs::internal::JoinAbstractPath(etl_segments_end, alphabeta_segments_end); ARROW_ASSIGN_OR_RAISE(auto alphabeta_expr, alphabeta_part.Parse(alphabeta_path)); - return and_(etl_expr, alphabeta_expr).Bind(*schm); + return and_(etl_expr, alphabeta_expr); }); AssertParse("/1999/12/31/00/alpha=0/beta=3.25", @@ -405,12 +401,8 @@ TEST_F(TestPartitioning, Set) { set.push_back(checked_cast(*s).value); } - auto is_in_expr = call("is_in", {field_ref(matches[1])}, - compute::SetLookupOptions{ints(set), true}); - - ARROW_ASSIGN_OR_RAISE(is_in_expr, is_in_expr.Bind(*schm)); - - subexpressions.push_back(is_in_expr); + subexpressions.push_back(call("is_in", {field_ref(matches[1])}, + compute::SetLookupOptions{ints(set), true})); } return and_(std::move(subexpressions)); }); @@ -453,7 +445,7 @@ class RangePartitioning : public Partitioning { max_cmp(field_ref(key->name), literal(max)))); } - return and_(ranges).Bind(*schema_); + return and_(ranges); } static Status DoRegex(const std::string& segment, std::smatch* matches) { diff --git a/cpp/src/arrow/dataset/scanner.cc b/cpp/src/arrow/dataset/scanner.cc index e3629bfd8cc..043f6695222 100644 --- a/cpp/src/arrow/dataset/scanner.cc +++ b/cpp/src/arrow/dataset/scanner.cc @@ -123,6 +123,9 @@ Status ScannerBuilder::Project(std::vector columns) { } Status ScannerBuilder::Filter(const Expression2& filter) { + for (const auto& ref : FieldsInExpression(filter)) { + RETURN_NOT_OK(ref.FindOne(*schema())); + } ARROW_ASSIGN_OR_RAISE(scan_options_->filter2, filter.Bind(*schema())); return Status::OK(); } diff --git a/cpp/src/arrow/dataset/scanner_test.cc b/cpp/src/arrow/dataset/scanner_test.cc index d79cc49faac..7261bdb8037 100644 --- a/cpp/src/arrow/dataset/scanner_test.cc +++ b/cpp/src/arrow/dataset/scanner_test.cc @@ -191,15 +191,13 @@ TEST_F(TestScannerBuilder, TestFilter) { call("equal", {field_ref("b"), literal(true)}), }))); - ASSERT_RAISES(NotImplemented, - builder.Filter(call("equal", {field_ref("i64"), literal(10)}))); + ASSERT_OK(builder.Filter(call("equal", {field_ref("i64"), literal(10)}))); ASSERT_RAISES( - NotImplemented, - builder.Filter(call("equal", {field_ref("not_a_column"), literal(10)}))); + Invalid, builder.Filter(call("equal", {field_ref("not_a_column"), literal(true)}))); ASSERT_RAISES( - NotImplemented, + Invalid, builder.Filter( call("or_kleene", { call("equal", {field_ref("i64"), literal(10)}), diff --git a/cpp/src/arrow/dataset/test_util.h b/cpp/src/arrow/dataset/test_util.h index ea39fef3eec..44d211260b6 100644 --- a/cpp/src/arrow/dataset/test_util.h +++ b/cpp/src/arrow/dataset/test_util.h @@ -53,9 +53,12 @@ namespace dataset { const std::shared_ptr kBoringSchema = schema({ field("i32", int32()), field("i32_req", int32(), /*nullable=*/false), + field("i64", int64()), field("f32", float32()), field("f32_req", float32(), /*nullable=*/false), field("bool", boolean()), + field("str", utf8()), + field("dict_str", dictionary(int32(), utf8())), }); inline namespace string_literals { From b954ef8118e330baaa2fd5a571a89933693b563d Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Thu, 10 Dec 2020 13:47:40 -0500 Subject: [PATCH 05/31] delete Expression DSL operators and old Expression class --- .../arrow/dataset-parquet-scan-example.cc | 2 +- cpp/src/arrow/dataset/CMakeLists.txt | 2 - cpp/src/arrow/dataset/api.h | 2 +- cpp/src/arrow/dataset/dataset.cc | 1 - cpp/src/arrow/dataset/dataset_test.cc | 10 +- cpp/src/arrow/dataset/discovery_test.cc | 1 - cpp/src/arrow/dataset/expression.cc | 147 +- cpp/src/arrow/dataset/expression.h | 11 - cpp/src/arrow/dataset/expression_test.cc | 115 +- cpp/src/arrow/dataset/file_base.cc | 1 - cpp/src/arrow/dataset/file_csv.cc | 1 - cpp/src/arrow/dataset/file_csv_test.cc | 3 +- cpp/src/arrow/dataset/file_ipc_test.cc | 1 - cpp/src/arrow/dataset/file_parquet.cc | 1 - cpp/src/arrow/dataset/file_parquet_test.cc | 58 +- cpp/src/arrow/dataset/file_test.cc | 32 +- cpp/src/arrow/dataset/filter.cc | 1724 +---------------- cpp/src/arrow/dataset/filter.h | 600 +----- cpp/src/arrow/dataset/filter_test.cc | 485 +---- cpp/src/arrow/dataset/partition.cc | 192 +- cpp/src/arrow/dataset/partition.h | 17 + cpp/src/arrow/dataset/partition_test.cc | 159 +- cpp/src/arrow/dataset/scanner.cc | 1 - cpp/src/arrow/dataset/scanner_internal.h | 1 - cpp/src/arrow/dataset/test_util.h | 31 +- cpp/src/arrow/dataset/type_fwd.h | 1 - 26 files changed, 440 insertions(+), 3159 deletions(-) diff --git a/cpp/examples/arrow/dataset-parquet-scan-example.cc b/cpp/examples/arrow/dataset-parquet-scan-example.cc index 5b933d3ca62..46778cebaa5 100644 --- a/cpp/examples/arrow/dataset-parquet-scan-example.cc +++ b/cpp/examples/arrow/dataset-parquet-scan-example.cc @@ -18,9 +18,9 @@ #include #include #include +#include #include #include -#include #include #include #include diff --git a/cpp/src/arrow/dataset/CMakeLists.txt b/cpp/src/arrow/dataset/CMakeLists.txt index ae6bb3656dc..aa9cde2dbe9 100644 --- a/cpp/src/arrow/dataset/CMakeLists.txt +++ b/cpp/src/arrow/dataset/CMakeLists.txt @@ -25,7 +25,6 @@ set(ARROW_DATASET_SRCS expression.cc file_base.cc file_ipc.cc - filter.cc partition.cc projector.cc scanner.cc) @@ -110,7 +109,6 @@ add_arrow_dataset_test(discovery_test) add_arrow_dataset_test(expression_test) add_arrow_dataset_test(file_ipc_test) add_arrow_dataset_test(file_test) -add_arrow_dataset_test(filter_test) add_arrow_dataset_test(partition_test) add_arrow_dataset_test(scanner_test) diff --git a/cpp/src/arrow/dataset/api.h b/cpp/src/arrow/dataset/api.h index c4c90088e8c..da9f5ed371e 100644 --- a/cpp/src/arrow/dataset/api.h +++ b/cpp/src/arrow/dataset/api.h @@ -21,9 +21,9 @@ #include "arrow/dataset/dataset.h" #include "arrow/dataset/discovery.h" +#include "arrow/dataset/expression.h" #include "arrow/dataset/file_base.h" #include "arrow/dataset/file_csv.h" #include "arrow/dataset/file_ipc.h" #include "arrow/dataset/file_parquet.h" -#include "arrow/dataset/filter.h" #include "arrow/dataset/scanner.h" diff --git a/cpp/src/arrow/dataset/dataset.cc b/cpp/src/arrow/dataset/dataset.cc index 5da90e48ebc..bedecf1e3a5 100644 --- a/cpp/src/arrow/dataset/dataset.cc +++ b/cpp/src/arrow/dataset/dataset.cc @@ -21,7 +21,6 @@ #include #include "arrow/dataset/dataset_internal.h" -#include "arrow/dataset/filter.h" #include "arrow/dataset/scanner.h" #include "arrow/table.h" #include "arrow/util/bit_util.h" diff --git a/cpp/src/arrow/dataset/dataset_test.cc b/cpp/src/arrow/dataset/dataset_test.cc index cf1e85ea055..82a5a63c2c2 100644 --- a/cpp/src/arrow/dataset/dataset_test.cc +++ b/cpp/src/arrow/dataset/dataset_test.cc @@ -702,8 +702,10 @@ TEST_F(TestSchemaUnification, SelectPhysicalColumnsFilterPartitionColumn) { // when some of the columns may not be materialized ASSERT_OK_AND_ASSIGN(auto scan_builder, dataset_->NewScan()); ASSERT_OK(scan_builder->Project({"phy_2", "phy_3", "phy_4"})); - ASSERT_OK(scan_builder->Filter(("part_df"_ == 1 && "phy_2"_ == 211) || - ("part_ds"_ == 2 && "phy_4"_ != 422))); + ASSERT_OK(scan_builder->Filter(or_(and_(equal(field_ref("part_df"), literal(1)), + equal(field_ref("phy_2"), literal(211))), + and_(equal(field_ref("part_ds"), literal(2)), + not_equal(field_ref("phy_4"), literal(422)))))); using TupleType = std::tuple; std::vector rows = { @@ -734,7 +736,7 @@ TEST_F(TestSchemaUnification, SelectPartitionColumns) { TEST_F(TestSchemaUnification, SelectPartitionColumnsFilterPhysicalColumn) { // Selects re-ordered virtual columns with a filter on a physical columns ASSERT_OK_AND_ASSIGN(auto scan_builder, dataset_->NewScan()); - ASSERT_OK(scan_builder->Filter("phy_1"_ == 111)); + ASSERT_OK(scan_builder->Filter(equal(field_ref("phy_1"), literal(111)))); ASSERT_OK(scan_builder->Project({"part_df", "part_ds"})); using TupleType = std::tuple; @@ -748,7 +750,7 @@ TEST_F(TestSchemaUnification, SelectMixedColumnsAndFilter) { // Selects mix of physical/virtual with a different order and uses a filter on // a physical column not selected. ASSERT_OK_AND_ASSIGN(auto scan_builder, dataset_->NewScan()); - ASSERT_OK(scan_builder->Filter("phy_2"_ >= 212)); + ASSERT_OK(scan_builder->Filter(greater_equal(field_ref("phy_2"), literal(212)))); ASSERT_OK(scan_builder->Project({"part_df", "phy_3", "part_ds", "phy_1"})); using TupleType = std::tuple; diff --git a/cpp/src/arrow/dataset/discovery_test.cc b/cpp/src/arrow/dataset/discovery_test.cc index 93d7c3869dd..a951e827fa4 100644 --- a/cpp/src/arrow/dataset/discovery_test.cc +++ b/cpp/src/arrow/dataset/discovery_test.cc @@ -23,7 +23,6 @@ #include #include -#include "arrow/dataset/filter.h" #include "arrow/dataset/partition.h" #include "arrow/dataset/test_util.h" #include "arrow/filesystem/test_util.h" diff --git a/cpp/src/arrow/dataset/expression.cc b/cpp/src/arrow/dataset/expression.cc index 21c067e3d64..f1b17752bf2 100644 --- a/cpp/src/arrow/dataset/expression.cc +++ b/cpp/src/arrow/dataset/expression.cc @@ -23,7 +23,6 @@ #include "arrow/chunked_array.h" #include "arrow/compute/exec_internal.h" #include "arrow/dataset/expression_internal.h" -#include "arrow/dataset/filter.h" #include "arrow/io/memory.h" #include "arrow/ipc/reader.h" #include "arrow/ipc/writer.h" @@ -50,144 +49,6 @@ const FieldRef* Expression2::field_ref() const { return util::get_if(impl_.get()); } -Expression2::operator std::shared_ptr() const { - if (auto lit = literal()) { - DCHECK(lit->is_scalar()); - return std::make_shared(lit->scalar()); - } - - if (auto ref = field_ref()) { - DCHECK(ref->name()); - return std::make_shared(*ref->name()); - } - - auto call = CallNotNull(*this); - if (call->function == "invert") { - return std::make_shared(call->arguments[0]); - } - - if (call->function == "cast") { - const auto& options = checked_cast(*call->options); - return std::make_shared(call->arguments[0], options.to_type, options); - } - - if (call->function == "and_kleene") { - return std::make_shared(call->arguments[0], call->arguments[1]); - } - - if (call->function == "or_kleene") { - return std::make_shared(call->arguments[0], call->arguments[1]); - } - - if (auto cmp = Comparison::Get(call->function)) { - compute::CompareOperator op = [&] { - switch (*cmp) { - case Comparison::EQUAL: - return compute::EQUAL; - case Comparison::LESS: - return compute::LESS; - case Comparison::GREATER: - return compute::GREATER; - case Comparison::NOT_EQUAL: - return compute::NOT_EQUAL; - case Comparison::LESS_EQUAL: - return compute::LESS_EQUAL; - case Comparison::GREATER_EQUAL: - return compute::GREATER_EQUAL; - default: - break; - } - return static_cast(-1); - }(); - - return std::make_shared(op, call->arguments[0], - call->arguments[1]); - } - - if (call->function == "is_valid") { - return std::make_shared(call->arguments[0]); - } - - if (call->function == "is_in") { - auto set = checked_cast(*call->options) - .value_set.make_array(); - return std::make_shared(call->arguments[0], std::move(set)); - } - - DCHECK(false) << "untranslatable Expression2: " << ToString(); - return nullptr; -} - -Expression2::Expression2(const Expression& expr) { - switch (expr.type()) { - case ExpressionType::FIELD: - *this = - ::arrow::dataset::field_ref(checked_cast(expr).name()); - return; - - case ExpressionType::SCALAR: - *this = - ::arrow::dataset::literal(checked_cast(expr).value()); - return; - - case ExpressionType::NOT: - *this = ::arrow::dataset::call( - "invert", {checked_cast(expr).operand()}); - return; - - case ExpressionType::CAST: { - const auto& cast_expr = checked_cast(expr); - auto options = cast_expr.options(); - options.to_type = cast_expr.to_type(); - *this = ::arrow::dataset::call("cast", {cast_expr.operand()}, std::move(options)); - return; - } - - case ExpressionType::AND: { - const auto& and_expr = checked_cast(expr); - *this = ::arrow::dataset::call("and_kleene", - {and_expr.left_operand(), and_expr.right_operand()}); - return; - } - - case ExpressionType::OR: { - const auto& or_expr = checked_cast(expr); - *this = ::arrow::dataset::call("or_kleene", - {or_expr.left_operand(), or_expr.right_operand()}); - return; - } - - case ExpressionType::COMPARISON: { - const auto& cmp_expr = checked_cast(expr); - static std::array ops = { - "equal", "not_equal", "greater", "greater_equal", "less", "less_equal", - }; - *this = ::arrow::dataset::call(ops[cmp_expr.op()], - {cmp_expr.left_operand(), cmp_expr.right_operand()}); - return; - } - - case ExpressionType::IS_VALID: { - const auto& is_valid_expr = checked_cast(expr); - *this = ::arrow::dataset::call("is_valid", {is_valid_expr.operand()}); - return; - } - - case ExpressionType::IN: { - const auto& in_expr = checked_cast(expr); - *this = ::arrow::dataset::call( - "is_in", {in_expr.operand()}, - compute::SetLookupOptions{in_expr.set(), /*skip_nulls=*/true}); - return; - } - - default: - break; - } - - DCHECK(false) << "untranslatable Expression: " << expr.ToString(); -} - std::string Expression2::ToString() const { if (auto lit = literal()) { if (lit->is_scalar()) { @@ -1032,6 +893,9 @@ Result SimplifyWithGuarantee(Expression2 expr, return expr; } +// Serialization is accomplished by converting expressions to KeyValueMetadata and storing +// this in the schema of a RecordBatch. Embedded arrays and scalars are stored in its +// columns. Finally, the RecordBatch is written to an IPC file. Result> Serialize(const Expression2& expr) { struct { std::shared_ptr metadata_ = std::make_shared(); @@ -1238,10 +1102,5 @@ Expression2 operator||(Expression2 lhs, Expression2 rhs) { return or_(std::move(lhs), std::move(rhs)); } -Result InsertImplicitCasts(Expression2 expr, const Schema& s) { - std::shared_ptr e(expr); - return InsertImplicitCasts(*e, s); -} - } // namespace dataset } // namespace arrow diff --git a/cpp/src/arrow/dataset/expression.h b/cpp/src/arrow/dataset/expression.h index 16c9e8a96f6..e6495433d23 100644 --- a/cpp/src/arrow/dataset/expression.h +++ b/cpp/src/arrow/dataset/expression.h @@ -103,12 +103,6 @@ class ARROW_DS_EXPORT Expression2 { const Datum* literal() const; const FieldRef* field_ref() const; - // FIXME remove these - operator std::shared_ptr() const; // NOLINT runtime/explicit - Expression2(const Expression& expr); // NOLINT runtime/explicit - Expression2(std::shared_ptr expr) // NOLINT runtime/explicit - : Expression2(*expr) {} - const ValueDescr& descr() const { return descr_; } using Impl = util::Variant; @@ -245,10 +239,5 @@ ARROW_DS_EXPORT Expression2 or_(Expression2 lhs, Expression2 rhs); ARROW_DS_EXPORT Expression2 or_(const std::vector&); ARROW_DS_EXPORT Expression2 not_(Expression2 operand); -// FIXME remove these -ARROW_DS_EXPORT Expression2 operator&&(Expression2 lhs, Expression2 rhs); -ARROW_DS_EXPORT Expression2 operator||(Expression2 lhs, Expression2 rhs); -ARROW_DS_EXPORT Result InsertImplicitCasts(Expression2, const Schema&); - } // namespace dataset } // namespace arrow diff --git a/cpp/src/arrow/dataset/expression_test.cc b/cpp/src/arrow/dataset/expression_test.cc index 44dd322aa63..f2b6951a8a3 100644 --- a/cpp/src/arrow/dataset/expression_test.cc +++ b/cpp/src/arrow/dataset/expression_test.cc @@ -216,46 +216,42 @@ TEST(Expression2, BindLiteral) { } } +void ExpectBindsTo(Expression2 expr, Expression2 expected, + Expression2* bound_out = nullptr) { + ASSERT_OK_AND_ASSIGN(auto bound, expr.Bind(*kBoringSchema)); + EXPECT_TRUE(bound.IsBound()); + + ASSERT_OK_AND_ASSIGN(expected, expected.Bind(*kBoringSchema)); + EXPECT_EQ(bound, expected) << " unbound: " << expr.ToString(); + + if (bound_out) { + *bound_out = bound; + } +} + TEST(Expression2, BindFieldRef) { // an unbound field_ref does not have the output ValueDescr set - { - auto expr = field_ref("alpha"); - EXPECT_EQ(expr.descr(), ValueDescr{}); - EXPECT_FALSE(expr.IsBound()); - } + auto expr = field_ref("alpha"); + EXPECT_EQ(expr.descr(), ValueDescr{}); + EXPECT_FALSE(expr.IsBound()); - { - auto expr = field_ref("alpha"); - // binding a field_ref looks up that field's type in the input Schema - ASSERT_OK_AND_ASSIGN(expr, expr.Bind(Schema({field("alpha", int32())}))); - EXPECT_EQ(expr.descr(), ValueDescr::Array(int32())); - EXPECT_TRUE(expr.IsBound()); - } + ExpectBindsTo(field_ref("i32"), field_ref("i32"), &expr); + EXPECT_EQ(expr.descr(), ValueDescr::Array(int32())); - { - // if the field is not found, a null scalar will be emitted - auto expr = field_ref("alpha"); - ASSERT_OK_AND_ASSIGN(expr, expr.Bind(Schema({}))); - EXPECT_EQ(expr.descr(), ValueDescr::Scalar(null())); - EXPECT_TRUE(expr.IsBound()); - } + // if the field is not found, a null scalar will be emitted + ExpectBindsTo(field_ref("no such field"), field_ref("no such field"), &expr); + EXPECT_EQ(expr.descr(), ValueDescr::Scalar(null())); - { - // referencing a field by name is not supported if that name is not unique - // in the input schema - auto expr = field_ref("alpha"); - ASSERT_RAISES( - Invalid, expr.Bind(Schema({field("alpha", int32()), field("alpha", float32())}))); - } + // referencing a field by name is not supported if that name is not unique + // in the input schema + ASSERT_RAISES(Invalid, field_ref("alpha").Bind(Schema( + {field("alpha", int32()), field("alpha", float32())}))); - { - // referencing nested fields is supported - auto expr = field_ref("a", "b"); - ASSERT_OK_AND_ASSIGN(expr, - expr.Bind(Schema({field("a", struct_({field("b", int32())}))}))); - EXPECT_EQ(expr.descr(), ValueDescr::Array(int32())); - EXPECT_TRUE(expr.IsBound()); - } + // referencing nested fields is supported + ASSERT_OK_AND_ASSIGN(expr, field_ref("a", "b").Bind( + Schema({field("a", struct_({field("b", int32())}))}))); + EXPECT_TRUE(expr.IsBound()); + EXPECT_EQ(expr.descr(), ValueDescr::Array(int32())); } TEST(Expression2, BindCall) { @@ -275,53 +271,32 @@ TEST(Expression2, BindCall) { TEST(Expression2, BindWithImplicitCasts) { for (auto cmp : {equal, not_equal, less, less_equal, greater, greater_equal}) { // cast arguments to same type - ASSERT_OK_AND_ASSIGN(auto expr, - cmp(field_ref("i32"), field_ref("i64")).Bind(*kBoringSchema)); - + ExpectBindsTo(cmp(field_ref("i32"), field_ref("i64")), + cmp(field_ref("i32"), cast(field_ref("i64"), int32()))); // NB: RHS is cast unless LHS is scalar. - ASSERT_OK_AND_ASSIGN( - auto expected, - cmp(field_ref("i32"), cast(field_ref("i64"), int32())).Bind(*kBoringSchema)); - - EXPECT_EQ(expr, expected); // cast dictionary to value type - ASSERT_OK_AND_ASSIGN( - expr, cmp(field_ref("dict_str"), field_ref("str")).Bind(*kBoringSchema)); - - ASSERT_OK_AND_ASSIGN( - expected, - cmp(cast(field_ref("dict_str"), utf8()), field_ref("str")).Bind(*kBoringSchema)); - - EXPECT_EQ(expr, expected); + ExpectBindsTo(cmp(field_ref("dict_str"), field_ref("str")), + cmp(cast(field_ref("dict_str"), utf8()), field_ref("str"))); } + // scalars are directly cast when possible: + ExpectBindsTo( + equal(field_ref("ts_ns"), literal("1990-10-23 10:23:33")), + equal(field_ref("ts_ns"), + literal( + *MakeScalar("1990-10-23 10:23:33")->CastTo(timestamp(TimeUnit::NANO))))); + // cast value_set to argument type auto Opts = [](std::shared_ptr type) { return compute::SetLookupOptions{ArrayFromJSON(type, R"(["a"])")}; }; - ASSERT_OK_AND_ASSIGN( - auto expr, call("is_in", {field_ref("str")}, Opts(binary())).Bind(*kBoringSchema)); - ASSERT_OK_AND_ASSIGN( - auto expected, - call("is_in", {field_ref("str")}, Opts(utf8())).Bind(*kBoringSchema)); - - EXPECT_EQ(expr, expected); + ExpectBindsTo(call("is_in", {field_ref("str")}, Opts(binary())), + call("is_in", {field_ref("str")}, Opts(utf8()))); // dictionary decode set then cast to argument type - ASSERT_OK_AND_ASSIGN( - expr, call("is_in", {field_ref("str")}, Opts(dictionary(int32(), binary()))) - .Bind(*kBoringSchema)); - - EXPECT_EQ(expr, expected); -} - -TEST(Expression2, BindDictionaryTransparent) { - ASSERT_OK_AND_ASSIGN( - auto expr, equal(field_ref("str"), field_ref("dict_str")).Bind(*kBoringSchema)); - - EXPECT_EQ(expr.descr(), ValueDescr::Array(boolean())); - EXPECT_TRUE(expr.IsBound()); + ExpectBindsTo(call("is_in", {field_ref("str")}, Opts(dictionary(int32(), binary()))), + call("is_in", {field_ref("str")}, Opts(utf8()))); } TEST(Expression2, BindNestedCall) { diff --git a/cpp/src/arrow/dataset/file_base.cc b/cpp/src/arrow/dataset/file_base.cc index dfc17a25812..86a262a662b 100644 --- a/cpp/src/arrow/dataset/file_base.cc +++ b/cpp/src/arrow/dataset/file_base.cc @@ -24,7 +24,6 @@ #include #include "arrow/dataset/dataset_internal.h" -#include "arrow/dataset/filter.h" #include "arrow/dataset/scanner.h" #include "arrow/dataset/scanner_internal.h" #include "arrow/filesystem/filesystem.h" diff --git a/cpp/src/arrow/dataset/file_csv.cc b/cpp/src/arrow/dataset/file_csv.cc index 6f1ca9acbbb..bc1a69066f7 100644 --- a/cpp/src/arrow/dataset/file_csv.cc +++ b/cpp/src/arrow/dataset/file_csv.cc @@ -28,7 +28,6 @@ #include "arrow/csv/reader.h" #include "arrow/dataset/dataset_internal.h" #include "arrow/dataset/file_base.h" -#include "arrow/dataset/filter.h" #include "arrow/dataset/type_fwd.h" #include "arrow/dataset/visibility.h" #include "arrow/result.h" diff --git a/cpp/src/arrow/dataset/file_csv_test.cc b/cpp/src/arrow/dataset/file_csv_test.cc index 8e73be0ee8f..eb0e1bf9395 100644 --- a/cpp/src/arrow/dataset/file_csv_test.cc +++ b/cpp/src/arrow/dataset/file_csv_test.cc @@ -23,7 +23,6 @@ #include "arrow/dataset/dataset_internal.h" #include "arrow/dataset/file_base.h" -#include "arrow/dataset/filter.h" #include "arrow/dataset/partition.h" #include "arrow/dataset/test_util.h" #include "arrow/io/memory.h" @@ -148,7 +147,7 @@ N/A,bar schema_ = schema({field("f64", utf8()), field("str", utf8())}); ScannerBuilder builder(schema_, fragment, ctx_); // filter expression validated against declared schema - ASSERT_OK(builder.Filter("f64"_ == "str"_)); + ASSERT_OK(builder.Filter(equal(field_ref("f64"), field_ref("str")))); // project only "str" ASSERT_OK(builder.Project({"str"})); ASSERT_OK_AND_ASSIGN(auto scanner, builder.Finish()); diff --git a/cpp/src/arrow/dataset/file_ipc_test.cc b/cpp/src/arrow/dataset/file_ipc_test.cc index 0fb76803778..f49337b362e 100644 --- a/cpp/src/arrow/dataset/file_ipc_test.cc +++ b/cpp/src/arrow/dataset/file_ipc_test.cc @@ -24,7 +24,6 @@ #include "arrow/dataset/dataset_internal.h" #include "arrow/dataset/discovery.h" #include "arrow/dataset/file_base.h" -#include "arrow/dataset/filter.h" #include "arrow/dataset/partition.h" #include "arrow/dataset/test_util.h" #include "arrow/io/memory.h" diff --git a/cpp/src/arrow/dataset/file_parquet.cc b/cpp/src/arrow/dataset/file_parquet.cc index 10c5b4a4bcf..8ba5f67f6fc 100644 --- a/cpp/src/arrow/dataset/file_parquet.cc +++ b/cpp/src/arrow/dataset/file_parquet.cc @@ -24,7 +24,6 @@ #include #include "arrow/dataset/dataset_internal.h" -#include "arrow/dataset/filter.h" #include "arrow/dataset/scanner.h" #include "arrow/filesystem/path_util.h" #include "arrow/table.h" diff --git a/cpp/src/arrow/dataset/file_parquet_test.cc b/cpp/src/arrow/dataset/file_parquet_test.cc index 436826c2971..47a563b911f 100644 --- a/cpp/src/arrow/dataset/file_parquet_test.cc +++ b/cpp/src/arrow/dataset/file_parquet_test.cc @@ -22,7 +22,6 @@ #include #include "arrow/dataset/dataset_internal.h" -#include "arrow/dataset/filter.h" #include "arrow/dataset/test_util.h" #include "arrow/record_batch.h" #include "arrow/table.h" @@ -283,7 +282,7 @@ TEST_F(TestParquetFileFormat, ScanRecordBatchReaderProjected) { opts_ = ScanOptions::Make(schema_); opts_->projector = RecordBatchProjector(SchemaFromColumnNames(schema_, {"f64"})); - SetFilter(equal(field_ref("i32"), scalar(0))); + SetFilter(equal(field_ref("i32"), literal(0))); // NB: projector is applied by the scanner; FileFragment does not evaluate it so // we will not drop "i32" even though it is not in the projector's schema @@ -319,7 +318,7 @@ TEST_F(TestParquetFileFormat, ScanRecordBatchReaderProjectedMissingCols) { schema_ = reader->schema(); opts_ = ScanOptions::Make(schema_); opts_->projector = RecordBatchProjector(SchemaFromColumnNames(schema_, {"f64"})); - SetFilter(equal(field_ref("i32"), scalar(0))); + SetFilter(equal(field_ref("i32"), literal(0))); auto readers = {reader.get(), reader_without_i32.get(), reader_without_f64.get()}; for (auto reader : readers) { @@ -419,28 +418,30 @@ TEST_F(TestParquetFileFormat, PredicatePushdown) { CountRowsAndBatchesInScan(fragment, kTotalNumRows, kNumRowGroups); for (int64_t i = 1; i <= kNumRowGroups; i++) { - SetFilter(("i64"_ == int64_t(i))); + SetFilter(equal(field_ref("i64"), literal(i))); CountRowsAndBatchesInScan(fragment, i, 1); } // Out of bound filters should skip all RowGroups. SetFilter(literal(false)); CountRowsAndBatchesInScan(fragment, 0, 0); - SetFilter(("i64"_ == int64_t(kNumRowGroups + 1))); + SetFilter(equal(field_ref("i64"), literal(kNumRowGroups + 1))); CountRowsAndBatchesInScan(fragment, 0, 0); - SetFilter(("i64"_ == int64_t(-1))); + SetFilter(equal(field_ref("i64"), literal(-1))); CountRowsAndBatchesInScan(fragment, 0, 0); // No rows match 1 and 2. - SetFilter(("i64"_ == int64_t(1) and "u8"_ == uint8_t(2))); + SetFilter(and_(equal(field_ref("i64"), literal(1)), + equal(field_ref("u8"), literal(2)))); CountRowsAndBatchesInScan(fragment, 0, 0); - SetFilter(("i64"_ == int64_t(2) or "i64"_ == int64_t(4))); + SetFilter(or_(equal(field_ref("i64"), literal(2)), + equal(field_ref("i64"), literal(4)))); CountRowsAndBatchesInScan(fragment, 2 + 4, 2); - SetFilter(("i64"_ < int64_t(6))); + SetFilter(less(field_ref("i64"), literal(6))); CountRowsAndBatchesInScan(fragment, 5 * (5 + 1) / 2, 5); - SetFilter(("i64"_ >= int64_t(6))); + SetFilter(greater_equal(field_ref("i64"), literal(6))); CountRowsAndBatchesInScan(fragment, kTotalNumRows - (5 * (5 + 1) / 2), kNumRowGroups - 5); } @@ -461,32 +462,39 @@ TEST_F(TestParquetFileFormat, PredicatePushdownRowGroupFragments) { // CountRowGroupsInFragment(fragment, all_row_groups, "not here"_ == 0); for (int i = 0; i < kNumRowGroups; ++i) { - CountRowGroupsInFragment(fragment, {i}, "i64"_ == int64_t(i + 1)); + CountRowGroupsInFragment(fragment, {i}, equal(field_ref("i64"), literal(i + 1))); } // Out of bound filters should skip all RowGroups. CountRowGroupsInFragment(fragment, {}, literal(false)); - CountRowGroupsInFragment(fragment, {}, "i64"_ == int64_t(kNumRowGroups + 1)); - CountRowGroupsInFragment(fragment, {}, "i64"_ == int64_t(-1)); + CountRowGroupsInFragment(fragment, {}, + equal(field_ref("i64"), literal(kNumRowGroups + 1))); + CountRowGroupsInFragment(fragment, {}, equal(field_ref("i64"), literal(-1))); // No rows match 1 and 2. - CountRowGroupsInFragment(fragment, {}, "i64"_ == int64_t(1) and "u8"_ == uint8_t(2)); - CountRowGroupsInFragment(fragment, {}, "i64"_ == int64_t(2) and "i64"_ == int64_t(4)); + CountRowGroupsInFragment( + fragment, {}, + and_(equal(field_ref("i64"), literal(1)), equal(field_ref("u8"), literal(2)))); + CountRowGroupsInFragment( + fragment, {}, + and_(equal(field_ref("i64"), literal(2)), equal(field_ref("i64"), literal(4)))); - CountRowGroupsInFragment(fragment, {1, 3}, - "i64"_ == int64_t(2) or "i64"_ == int64_t(4)); + CountRowGroupsInFragment( + fragment, {1, 3}, + or_(equal(field_ref("i64"), literal(2)), equal(field_ref("i64"), literal(4)))); // TODO(bkietz): better Assume support for InExpression // auto set = ArrayFromJSON(int64(), "[2, 4]"); - // CountRowGroupsInFragment(fragment, {1, 3}, "i64"_.In(set)); + // CountRowGroupsInFragment(fragment, {1, 3}, field_ref("i64").In(set)); - CountRowGroupsInFragment(fragment, {0, 1, 2, 3, 4}, "i64"_ < int64_t(6)); + CountRowGroupsInFragment(fragment, {0, 1, 2, 3, 4}, less(field_ref("i64"), literal(6))); CountRowGroupsInFragment(fragment, internal::Iota(5, static_cast(kNumRowGroups)), - "i64"_ >= int64_t(6)); + greater_equal(field_ref("i64"), literal(6))); CountRowGroupsInFragment(fragment, {5, 6}, - "i64"_ >= int64_t(6) and "i64"_ < int64_t(8)); + and_(greater_equal(field_ref("i64"), literal(6)), + less(field_ref("i64"), literal(8)))); } TEST_F(TestParquetFileFormat, PredicatePushdownRowGroupFragmentsUsingStringColumn) { @@ -503,7 +511,7 @@ TEST_F(TestParquetFileFormat, PredicatePushdownRowGroupFragmentsUsingStringColum opts_ = ScanOptions::Make(reader.schema()); ASSERT_OK_AND_ASSIGN(auto fragment, format_->MakeFragment(*source)); - CountRowGroupsInFragment(fragment, {0, 3}, "x"_ == "a"); + CountRowGroupsInFragment(fragment, {0, 3}, equal(field_ref("x"), literal("a"))); } TEST_F(TestParquetFileFormat, ExplicitRowGroupSelection) { @@ -545,17 +553,17 @@ TEST_F(TestParquetFileFormat, ExplicitRowGroupSelection) { for (int i = 0; i < kNumRowGroups; ++i) { // conflicting selection/filter - SetFilter(("i64"_ == int64_t(i))); + SetFilter(equal(field_ref("i64"), literal(i))); CountRowsAndBatchesInScan(row_groups_fragment({i}), 0, 0); } for (int i = 0; i < kNumRowGroups; ++i) { // identical selection/filter - SetFilter(("i64"_ == int64_t(i + 1))); + SetFilter(equal(field_ref("i64"), literal(i + 1))); CountRowsAndBatchesInScan(row_groups_fragment({i}), i + 1, 1); } - SetFilter(("i64"_ > int64_t(3))); + SetFilter(greater(field_ref("i64"), literal(3))); CountRowsAndBatchesInScan(row_groups_fragment({2, 3, 4, 5}), 4 + 5 + 6, 3); EXPECT_RAISES_WITH_MESSAGE_THAT( diff --git a/cpp/src/arrow/dataset/file_test.cc b/cpp/src/arrow/dataset/file_test.cc index c390f777ad4..fefd6911ac0 100644 --- a/cpp/src/arrow/dataset/file_test.cc +++ b/cpp/src/arrow/dataset/file_test.cc @@ -159,19 +159,19 @@ TEST_F(TestFileSystemDataset, TreePartitionPruning) { std::vector partitions = { equal(field_ref("state"), literal("NY")), - equal(field_ref("state"), literal("NY")) and - equal(field_ref("city"), literal("New York")), + and_(equal(field_ref("state"), literal("NY")), + equal(field_ref("city"), literal("New York"))), - equal(field_ref("state"), literal("NY")) and - equal(field_ref("city"), literal("Franklin")), + and_(equal(field_ref("state"), literal("NY")), + equal(field_ref("city"), literal("Franklin"))), equal(field_ref("state"), literal("CA")), - equal(field_ref("state"), literal("CA")) and - equal(field_ref("city"), literal("San Francisco")), + and_(equal(field_ref("state"), literal("CA")), + equal(field_ref("city"), literal("San Francisco"))), - equal(field_ref("state"), literal("CA")) and - equal(field_ref("city"), literal("Franklin")), + and_(equal(field_ref("state"), literal("CA")), + equal(field_ref("city"), literal("Franklin"))), }; MakeDataset( @@ -215,19 +215,19 @@ TEST_F(TestFileSystemDataset, FragmentPartitions) { std::vector partitions = { equal(field_ref("state"), literal("NY")), - equal(field_ref("state"), literal("NY")) and - equal(field_ref("city"), literal("New York")), + and_(equal(field_ref("state"), literal("NY")), + equal(field_ref("city"), literal("New York"))), - equal(field_ref("state"), literal("NY")) and - equal(field_ref("city"), literal("Franklin")), + and_(equal(field_ref("state"), literal("NY")), + equal(field_ref("city"), literal("Franklin"))), equal(field_ref("state"), literal("CA")), - equal(field_ref("state"), literal("CA")) and - equal(field_ref("city"), literal("San Francisco")), + and_(equal(field_ref("state"), literal("CA")), + equal(field_ref("city"), literal("San Francisco"))), - equal(field_ref("state"), literal("CA")) and - equal(field_ref("city"), literal("Franklin")), + and_(equal(field_ref("state"), literal("CA")), + equal(field_ref("city"), literal("Franklin"))), }; MakeDataset( diff --git a/cpp/src/arrow/dataset/filter.cc b/cpp/src/arrow/dataset/filter.cc index 71253cfde89..9852f8c0808 100644 --- a/cpp/src/arrow/dataset/filter.cc +++ b/cpp/src/arrow/dataset/filter.cc @@ -15,1726 +15,4 @@ // specific language governing permissions and limitations // under the License. -#include "arrow/dataset/filter.h" - -#include -#include -#include -#include -#include -#include -#include - -#include "arrow/array/builder_primitive.h" -#include "arrow/buffer.h" -#include "arrow/compute/api.h" -#include "arrow/dataset/dataset.h" -#include "arrow/dataset/expression.h" -#include "arrow/io/memory.h" -#include "arrow/ipc/reader.h" -#include "arrow/ipc/writer.h" -#include "arrow/record_batch.h" -#include "arrow/result.h" -#include "arrow/scalar.h" -#include "arrow/type_fwd.h" -#include "arrow/util/checked_cast.h" -#include "arrow/util/int_util_internal.h" -#include "arrow/util/iterator.h" -#include "arrow/util/logging.h" -#include "arrow/util/string.h" -#include "arrow/visitor_inline.h" - -namespace arrow { - -using compute::CompareOperator; -using compute::ExecContext; - -namespace dataset { - -using arrow::internal::checked_cast; -using arrow::internal::checked_pointer_cast; - -inline std::shared_ptr NullExpression() { - return std::make_shared(std::make_shared()); -} - -inline Datum NullDatum() { return Datum(std::make_shared()); } - -bool IsNullDatum(const Datum& datum) { - if (datum.is_scalar()) { - auto scalar = datum.scalar(); - return !scalar->is_valid; - } - - auto array_data = datum.array(); - return array_data->GetNullCount() == array_data->length; -} - -struct Comparison { - enum type { - LESS, - EQUAL, - GREATER, - NULL_, - }; -}; - -Result> EnsureNotDictionary( - const std::shared_ptr& scalar) { - if (scalar->type->id() == Type::DICTIONARY) { - return checked_cast(*scalar).GetEncodedValue(); - } - return scalar; -} - -Result Compare(const Scalar& lhs, const Scalar& rhs); - -struct CompareVisitor { - template - using ScalarType = typename TypeTraits::ScalarType; - - Status Visit(const NullType&) { - result_ = Comparison::NULL_; - return Status::OK(); - } - - Status Visit(const BooleanType&) { return CompareValues(); } - - template - enable_if_physical_floating_point Visit(const T&) { - return CompareValues(); - } - - template - enable_if_physical_signed_integer Visit(const T&) { - return CompareValues(); - } - - template - enable_if_physical_unsigned_integer Visit(const T&) { - return CompareValues(); - } - - template - enable_if_nested Visit(const T&) { - return Status::NotImplemented("comparison of scalars of type ", *lhs_.type); - } - - template - enable_if_binary_like Visit(const T&) { - auto lhs = checked_cast&>(lhs_).value; - auto rhs = checked_cast&>(rhs_).value; - auto cmp = std::memcmp(lhs->data(), rhs->data(), std::min(lhs->size(), rhs->size())); - if (cmp == 0) { - return CompareValues(lhs->size(), rhs->size()); - } - return CompareValues(cmp, 0); - } - - template - enable_if_string_like Visit(const T&) { - auto lhs = checked_cast&>(lhs_).value; - auto rhs = checked_cast&>(rhs_).value; - auto cmp = std::memcmp(lhs->data(), rhs->data(), std::min(lhs->size(), rhs->size())); - if (cmp == 0) { - return CompareValues(lhs->size(), rhs->size()); - } - return CompareValues(cmp, 0); - } - - Status Visit(const Decimal128Type&) { return CompareValues(); } - Status Visit(const Decimal256Type&) { return CompareValues(); } - - // Explicit because it falls under `physical_unsigned_integer`. - // TODO(bkietz) whenever we vendor a float16, this can be implemented - Status Visit(const HalfFloatType&) { - return Status::NotImplemented("comparison of scalars of type ", *lhs_.type); - } - - Status Visit(const ExtensionType&) { - return Status::NotImplemented("comparison of scalars of type ", *lhs_.type); - } - - Status Visit(const DictionaryType&) { - return Status::NotImplemented("comparison of scalars of type ", *lhs_.type); - } - - // defer comparison to ScalarType::value - template - Status CompareValues() { - auto lhs = checked_cast&>(lhs_).value; - auto rhs = checked_cast&>(rhs_).value; - return CompareValues(lhs, rhs); - } - - // defer comparison to explicit values - template - Status CompareValues(Value lhs, Value rhs) { - result_ = lhs < rhs ? Comparison::LESS - : lhs == rhs ? Comparison::EQUAL : Comparison::GREATER; - return Status::OK(); - } - - Comparison::type result_; - const Scalar& lhs_; - const Scalar& rhs_; -}; - -// Compare two scalars -// if either is null, return is null -// TODO(bkietz) extract this to the scalar comparison kernels -Result Compare(const Scalar& lhs, const Scalar& rhs) { - if (!lhs.type->Equals(*rhs.type)) { - return Status::TypeError("Cannot compare scalars of differing type: ", *lhs.type, - " vs ", *rhs.type); - } - if (!lhs.is_valid || !rhs.is_valid) { - return Comparison::NULL_; - } - CompareVisitor vis{Comparison::NULL_, lhs, rhs}; - RETURN_NOT_OK(VisitTypeInline(*lhs.type, &vis)); - return vis.result_; -} - -CompareOperator InvertCompareOperator(CompareOperator op) { - switch (op) { - case CompareOperator::EQUAL: - return CompareOperator::NOT_EQUAL; - - case CompareOperator::NOT_EQUAL: - return CompareOperator::EQUAL; - - case CompareOperator::GREATER: - return CompareOperator::LESS_EQUAL; - - case CompareOperator::GREATER_EQUAL: - return CompareOperator::LESS; - - case CompareOperator::LESS: - return CompareOperator::GREATER_EQUAL; - - case CompareOperator::LESS_EQUAL: - return CompareOperator::GREATER; - - default: - break; - } - - DCHECK(false); - return CompareOperator::EQUAL; -} - -template -std::shared_ptr InvertBoolean(const Boolean& expr) { - auto lhs = Invert(*expr.left_operand()); - auto rhs = Invert(*expr.right_operand()); - - if (std::is_same::value) { - return std::make_shared(std::move(lhs), std::move(rhs)); - } - - if (std::is_same::value) { - return std::make_shared(std::move(lhs), std::move(rhs)); - } - - return nullptr; -} - -std::shared_ptr Invert(const Expression& expr) { - switch (expr.type()) { - case ExpressionType::NOT: - return checked_cast(expr).operand(); - - case ExpressionType::AND: - return InvertBoolean(checked_cast(expr)); - - case ExpressionType::OR: - return InvertBoolean(checked_cast(expr)); - - case ExpressionType::COMPARISON: { - const auto& comparison = checked_cast(expr); - auto inverted_op = InvertCompareOperator(comparison.op()); - return std::make_shared( - inverted_op, comparison.left_operand(), comparison.right_operand()); - } - - default: - break; - } - return nullptr; -} - -std::shared_ptr Expression::Assume(const Expression& given) const { - std::shared_ptr out; - - DCHECK_OK(VisitConjunctionMembers(given, [&](const Expression& given) { - if (out != nullptr) { - return Status::OK(); - } - - if (given.type() != ExpressionType::COMPARISON) { - return Status::OK(); - } - - const auto& given_cmp = checked_cast(given); - if (given_cmp.op() != CompareOperator::EQUAL) { - return Status::OK(); - } - - if (this->Equals(given_cmp.left_operand())) { - out = given_cmp.right_operand(); - return Status::OK(); - } - - if (this->Equals(given_cmp.right_operand())) { - out = given_cmp.left_operand(); - return Status::OK(); - } - - return Status::OK(); - })); - - return out ? out : Copy(); -} - -std::shared_ptr ComparisonExpression::Assume(const Expression& given) const { - switch (given.type()) { - case ExpressionType::COMPARISON: { - return AssumeGivenComparison(checked_cast(given)); - } - - case ExpressionType::NOT: { - const auto& given_not = checked_cast(given); - if (auto inverted = Invert(*given_not.operand())) { - return Assume(*inverted); - } - return Copy(); - } - - case ExpressionType::OR: { - const auto& given_or = checked_cast(given); - - auto left_simplified = Assume(*given_or.left_operand()); - auto right_simplified = Assume(*given_or.right_operand()); - - // The result of simplification against the operands of an OrExpression - // cannot be used unless they are identical - if (left_simplified->Equals(right_simplified)) { - return left_simplified; - } - - return Copy(); - } - - case ExpressionType::AND: { - const auto& given_and = checked_cast(given); - - auto simplified = Copy(); - simplified = simplified->Assume(*given_and.left_operand()); - simplified = simplified->Assume(*given_and.right_operand()); - return simplified; - } - - // TODO(bkietz) we should be able to use ExpressionType::IN here - - default: - break; - } - - return Copy(); -} - -// Try to simplify one comparison against another comparison. -// For example, -// ("x"_ > 3) is a subset of ("x"_ > 2), so ("x"_ > 2).Assume("x"_ > 3) == (true) -// ("x"_ < 0) is disjoint with ("x"_ > 2), so ("x"_ > 2).Assume("x"_ < 0) == (false) -// If simplification to (true) or (false) is not possible, pass e through unchanged. -std::shared_ptr ComparisonExpression::AssumeGivenComparison( - const ComparisonExpression& given) const { - if (!left_operand_->Equals(given.left_operand_)) { - return Copy(); - } - - for (auto rhs : {right_operand_, given.right_operand_}) { - if (rhs->type() != ExpressionType::SCALAR) { - return Copy(); - } - } - - auto this_rhs = - EnsureNotDictionary(checked_cast(*right_operand_).value()) - .ValueOr(nullptr); - auto given_rhs = - EnsureNotDictionary( - checked_cast(*given.right_operand_).value()) - .ValueOr(nullptr); - - if (!this_rhs || !given_rhs) { - return Copy(); - } - - auto cmp = Compare(*this_rhs, *given_rhs).ValueOrDie(); - - if (cmp == Comparison::NULL_) { - // the RHS of e or given was null - return NullExpression(); - } - - static auto always = scalar(true); - static auto never = scalar(false); - - if (cmp == Comparison::GREATER) { - // the rhs of e is greater than that of given - switch (op()) { - case CompareOperator::EQUAL: - case CompareOperator::GREATER: - case CompareOperator::GREATER_EQUAL: - switch (given.op()) { - case CompareOperator::EQUAL: - case CompareOperator::LESS: - case CompareOperator::LESS_EQUAL: - return never; - default: - return Copy(); - } - case CompareOperator::NOT_EQUAL: - case CompareOperator::LESS: - case CompareOperator::LESS_EQUAL: - switch (given.op()) { - case CompareOperator::EQUAL: - case CompareOperator::LESS: - case CompareOperator::LESS_EQUAL: - return always; - default: - return Copy(); - } - default: - return Copy(); - } - } - - if (cmp == Comparison::LESS) { - // the rhs of e is less than that of given - switch (op()) { - case CompareOperator::EQUAL: - case CompareOperator::LESS: - case CompareOperator::LESS_EQUAL: - switch (given.op()) { - case CompareOperator::EQUAL: - case CompareOperator::GREATER: - case CompareOperator::GREATER_EQUAL: - return never; - default: - return Copy(); - } - case CompareOperator::NOT_EQUAL: - case CompareOperator::GREATER: - case CompareOperator::GREATER_EQUAL: - switch (given.op()) { - case CompareOperator::EQUAL: - case CompareOperator::GREATER: - case CompareOperator::GREATER_EQUAL: - return always; - default: - return Copy(); - } - default: - return Copy(); - } - } - - DCHECK_EQ(cmp, Comparison::EQUAL); - - // the rhs of the comparisons are equal - switch (op_) { - case CompareOperator::EQUAL: - switch (given.op()) { - case CompareOperator::NOT_EQUAL: - case CompareOperator::GREATER: - case CompareOperator::LESS: - return never; - case CompareOperator::EQUAL: - return always; - default: - return Copy(); - } - case CompareOperator::NOT_EQUAL: - switch (given.op()) { - case CompareOperator::EQUAL: - return never; - case CompareOperator::NOT_EQUAL: - case CompareOperator::GREATER: - case CompareOperator::LESS: - return always; - default: - return Copy(); - } - case CompareOperator::GREATER: - switch (given.op()) { - case CompareOperator::EQUAL: - case CompareOperator::LESS_EQUAL: - case CompareOperator::LESS: - return never; - case CompareOperator::GREATER: - return always; - default: - return Copy(); - } - case CompareOperator::GREATER_EQUAL: - switch (given.op()) { - case CompareOperator::LESS: - return never; - case CompareOperator::EQUAL: - case CompareOperator::GREATER: - case CompareOperator::GREATER_EQUAL: - return always; - default: - return Copy(); - } - case CompareOperator::LESS: - switch (given.op()) { - case CompareOperator::EQUAL: - case CompareOperator::GREATER: - case CompareOperator::GREATER_EQUAL: - return never; - case CompareOperator::LESS: - return always; - default: - return Copy(); - } - case CompareOperator::LESS_EQUAL: - switch (given.op()) { - case CompareOperator::GREATER: - return never; - case CompareOperator::EQUAL: - case CompareOperator::LESS: - case CompareOperator::LESS_EQUAL: - return always; - default: - return Copy(); - } - default: - return Copy(); - } - return Copy(); -} - -std::shared_ptr AndExpression::Assume(const Expression& given) const { - auto left_operand = left_operand_->Assume(given); - auto right_operand = right_operand_->Assume(given); - - // if either of the operands is trivially false then so is this AND - if (left_operand->Equals(false) || right_operand->Equals(false)) { - return scalar(false); - } - - // if either operand is trivially null then so is this AND - if (left_operand->IsNull() || right_operand->IsNull()) { - return NullExpression(); - } - - // if one of the operands is trivially true then drop it - if (left_operand->Equals(true)) { - return right_operand; - } - if (right_operand->Equals(true)) { - return left_operand; - } - - // if neither of the operands is trivial, simply construct a new AND - return and_(std::move(left_operand), std::move(right_operand)); -} - -std::shared_ptr OrExpression::Assume(const Expression& given) const { - auto left_operand = left_operand_->Assume(given); - auto right_operand = right_operand_->Assume(given); - - // if either of the operands is trivially true then so is this OR - if (left_operand->Equals(true) || right_operand->Equals(true)) { - return scalar(true); - } - - // if either operand is trivially null then so is this OR - if (left_operand->IsNull() || right_operand->IsNull()) { - return NullExpression(); - } - - // if one of the operands is trivially false then drop it - if (left_operand->Equals(false)) { - return right_operand; - } - if (right_operand->Equals(false)) { - return left_operand; - } - - // if neither of the operands is trivial, simply construct a new OR - return or_(std::move(left_operand), std::move(right_operand)); -} - -std::shared_ptr NotExpression::Assume(const Expression& given) const { - auto operand = operand_->Assume(given); - - if (operand->IsNull()) { - return NullExpression(); - } - if (operand->Equals(true)) { - return scalar(false); - } - if (operand->Equals(false)) { - return scalar(true); - } - - return Copy(); -} - -std::shared_ptr InExpression::Assume(const Expression& given) const { - auto operand = operand_->Assume(given); - if (operand->type() != ExpressionType::SCALAR) { - return std::make_shared(std::move(operand), set_); - } - - if (operand->IsNull()) { - return scalar(set_->null_count() > 0); - } - - Datum set, value; - if (set_->type_id() == Type::DICTIONARY) { - const auto& dict_set = checked_cast(*set_); - auto maybe_decoded = compute::Take(dict_set.dictionary(), dict_set.indices()); - auto maybe_value = checked_cast( - *checked_cast(*operand).value()) - .GetEncodedValue(); - if (!maybe_decoded.ok() || !maybe_value.ok()) { - return std::make_shared(std::move(operand), set_); - } - set = *maybe_decoded; - value = *maybe_value; - } else { - set = set_; - value = checked_cast(*operand).value(); - } - - compute::CompareOptions eq(CompareOperator::EQUAL); - Result maybe_out = compute::Compare(set, value, eq); - if (!maybe_out.ok()) { - return std::make_shared(std::move(operand), set_); - } - - Datum out = maybe_out.ValueOrDie(); - - DCHECK(out.is_array()); - DCHECK_EQ(out.type()->id(), Type::BOOL); - auto out_array = checked_pointer_cast(out.make_array()); - - for (int64_t i = 0; i < out_array->length(); ++i) { - if (out_array->IsValid(i) && out_array->Value(i)) { - return scalar(true); - } - } - return scalar(false); -} - -std::shared_ptr IsValidExpression::Assume(const Expression& given) const { - auto operand = operand_->Assume(given); - if (operand->type() == ExpressionType::SCALAR) { - return scalar(!operand->IsNull()); - } - - return std::make_shared(std::move(operand)); -} - -std::shared_ptr CastExpression::Assume(const Expression& given) const { - auto operand = operand_->Assume(given); - if (arrow::util::holds_alternative>(to_)) { - auto to_type = arrow::util::get>(to_); - return std::make_shared(std::move(operand), std::move(to_type), - options_); - } - auto like = arrow::util::get>(to_)->Assume(given); - return std::make_shared(std::move(operand), std::move(like), options_); -} - -const std::shared_ptr& CastExpression::to_type() const { - if (arrow::util::holds_alternative>(to_)) { - return arrow::util::get>(to_); - } - static std::shared_ptr null; - return null; -} - -const std::shared_ptr& CastExpression::like_expr() const { - if (arrow::util::holds_alternative>(to_)) { - return arrow::util::get>(to_); - } - static std::shared_ptr null; - return null; -} - -std::string FieldExpression::ToString() const { return name_; } - -std::string OperatorName(compute::CompareOperator op) { - switch (op) { - case CompareOperator::EQUAL: - return "=="; - case CompareOperator::NOT_EQUAL: - return "!="; - case CompareOperator::LESS: - return "<"; - case CompareOperator::LESS_EQUAL: - return "<="; - case CompareOperator::GREATER: - return ">"; - case CompareOperator::GREATER_EQUAL: - return ">="; - default: - DCHECK(false); - } - return ""; -} - -std::string ScalarExpression::ToString() const { - auto type_repr = value_->type->ToString(); - if (!value_->is_valid) { - return "null:" + type_repr; - } - - return value_->ToString() + ":" + type_repr; -} - -using arrow::internal::JoinStrings; - -std::string AndExpression::ToString() const { - return JoinStrings( - {"(", left_operand_->ToString(), " and ", right_operand_->ToString(), ")"}, ""); -} - -std::string OrExpression::ToString() const { - return JoinStrings( - {"(", left_operand_->ToString(), " or ", right_operand_->ToString(), ")"}, ""); -} - -std::string NotExpression::ToString() const { - if (operand_->type() == ExpressionType::IS_VALID) { - const auto& is_valid = checked_cast(*operand_); - return JoinStrings({"(", is_valid.operand()->ToString(), " is null)"}, ""); - } - return JoinStrings({"(not ", operand_->ToString(), ")"}, ""); -} - -std::string IsValidExpression::ToString() const { - return JoinStrings({"(", operand_->ToString(), " is not null)"}, ""); -} - -std::string InExpression::ToString() const { - return JoinStrings({"(", operand_->ToString(), " is in ", set_->ToString(), ")"}, ""); -} - -std::string CastExpression::ToString() const { - std::string to; - if (arrow::util::holds_alternative>(to_)) { - auto to_type = arrow::util::get>(to_); - to = " to " + to_type->ToString(); - } else { - auto like = arrow::util::get>(to_); - to = " like " + like->ToString(); - } - return JoinStrings({"(cast ", operand_->ToString(), std::move(to), ")"}, ""); -} - -std::string ComparisonExpression::ToString() const { - return JoinStrings({"(", left_operand_->ToString(), " ", OperatorName(op()), " ", - right_operand_->ToString(), ")"}, - ""); -} - -bool UnaryExpression::Equals(const Expression& other) const { - return type_ == other.type() && - operand_->Equals(checked_cast(other).operand_); -} - -bool BinaryExpression::Equals(const Expression& other) const { - return type_ == other.type() && - left_operand_->Equals( - checked_cast(other).left_operand_) && - right_operand_->Equals( - checked_cast(other).right_operand_); -} - -bool ComparisonExpression::Equals(const Expression& other) const { - return BinaryExpression::Equals(other) && - op_ == checked_cast(other).op_; -} - -bool ScalarExpression::Equals(const Expression& other) const { - return other.type() == ExpressionType::SCALAR && - value_->Equals(*checked_cast(other).value_); -} - -bool FieldExpression::Equals(const Expression& other) const { - return other.type() == ExpressionType::FIELD && - name_ == checked_cast(other).name_; -} - -bool Expression::Equals(const std::shared_ptr& other) const { - if (other == nullptr) { - return false; - } - return Equals(*other); -} - -bool Expression::IsNull() const { - if (type_ != ExpressionType::SCALAR) { - return false; - } - - const auto& scalar = checked_cast(*this).value(); - if (!scalar->is_valid) { - return true; - } - - return false; -} - -InExpression Expression::In(std::shared_ptr set) const { - return InExpression(Copy(), std::move(set)); -} - -IsValidExpression Expression::IsValid() const { return IsValidExpression(Copy()); } - -std::shared_ptr FieldExpression::Copy() const { - return std::make_shared(*this); -} - -std::shared_ptr ScalarExpression::Copy() const { - return std::make_shared(*this); -} - -AndExpression operator&&(const Expression& lhs, const Expression& rhs) { - return AndExpression(lhs.Copy(), rhs.Copy()); -} - -OrExpression operator||(const Expression& lhs, const Expression& rhs) { - return OrExpression(lhs.Copy(), rhs.Copy()); -} - -CastExpression Expression::CastTo(std::shared_ptr type, - compute::CastOptions options) const { - return CastExpression(Copy(), type, std::move(options)); -} - -CastExpression Expression::CastLike(std::shared_ptr expr, - compute::CastOptions options) const { - return CastExpression(Copy(), std::move(expr), std::move(options)); -} - -CastExpression Expression::CastLike(const Expression& expr, - compute::CastOptions options) const { - return CastLike(expr.Copy(), std::move(options)); -} - -Result> ComparisonExpression::Validate( - const Schema& schema) const { - ARROW_ASSIGN_OR_RAISE(auto lhs_type, left_operand_->Validate(schema)); - ARROW_ASSIGN_OR_RAISE(auto rhs_type, right_operand_->Validate(schema)); - - if (lhs_type->id() == Type::NA || rhs_type->id() == Type::NA) { - return boolean(); - } - - if (!lhs_type->Equals(rhs_type)) { - return Status::TypeError("cannot compare expressions of differing type, ", *lhs_type, - " vs ", *rhs_type); - } - - return boolean(); -} - -Status EnsureNullOrBool(const std::string& msg_prefix, - const std::shared_ptr& type) { - if (type->id() == Type::BOOL || type->id() == Type::NA) { - return Status::OK(); - } - return Status::TypeError(msg_prefix, *type); -} - -Result> ValidateBoolean( - const std::vector>& operands, const Schema& schema) { - for (const auto& operand : operands) { - ARROW_ASSIGN_OR_RAISE(auto type, operand->Validate(schema)); - RETURN_NOT_OK( - EnsureNullOrBool("cannot combine expressions including one of type ", type)); - } - return boolean(); -} - -Result> AndExpression::Validate(const Schema& schema) const { - return ValidateBoolean({left_operand_, right_operand_}, schema); -} - -Result> OrExpression::Validate(const Schema& schema) const { - return ValidateBoolean({left_operand_, right_operand_}, schema); -} - -Result> NotExpression::Validate(const Schema& schema) const { - return ValidateBoolean({operand_}, schema); -} - -Result> InExpression::Validate(const Schema& schema) const { - ARROW_ASSIGN_OR_RAISE(auto operand_type, operand_->Validate(schema)); - if (operand_type->id() == Type::NA || set_->type()->id() == Type::NA) { - return boolean(); - } - - if (!operand_type->Equals(set_->type())) { - return Status::TypeError("mismatch: set type ", *set_->type(), " vs operand type ", - *operand_type); - } - // TODO(bkietz) check if IsIn supports operand_type - return boolean(); -} - -Result> IsValidExpression::Validate( - const Schema& schema) const { - ARROW_ASSIGN_OR_RAISE(std::ignore, operand_->Validate(schema)); - return boolean(); -} - -Result> CastExpression::Validate(const Schema& schema) const { - ARROW_ASSIGN_OR_RAISE(auto operand_type, operand_->Validate(schema)); - std::shared_ptr to_type; - if (arrow::util::holds_alternative>(to_)) { - to_type = arrow::util::get>(to_); - } else { - auto like = arrow::util::get>(to_); - ARROW_ASSIGN_OR_RAISE(to_type, like->Validate(schema)); - } - - // Until expressions carry a shape, detect scalar and try to cast it. Works - // if the operand is a scalar leaf. - if (operand_->type() == ExpressionType::SCALAR) { - auto scalar_expr = checked_pointer_cast(operand_); - ARROW_ASSIGN_OR_RAISE(std::ignore, scalar_expr->value()->CastTo(to_type)); - return to_type; - } - - if (!compute::CanCast(*operand_type, *to_type)) { - return Status::Invalid("Cannot cast to ", to_type->ToString()); - } - - return to_type; -} - -Result> ScalarExpression::Validate(const Schema& schema) const { - return value_->type; -} - -Result> FieldExpression::Validate(const Schema& schema) const { - ARROW_ASSIGN_OR_RAISE(auto field, FieldRef(name_).GetOneOrNone(schema)); - if (field != nullptr) { - return field->type(); - } - return null(); -} - -Result CastOrDictionaryEncode(const Datum& arr, - const std::shared_ptr& type, - const compute::CastOptions opts) { - if (type->id() == Type::DICTIONARY) { - const auto& dict_type = checked_cast(*type); - if (dict_type.index_type()->id() != Type::INT32) { - return Status::TypeError("cannot DictionaryEncode to index type ", - *dict_type.index_type()); - } - ARROW_ASSIGN_OR_RAISE(auto dense, compute::Cast(arr, dict_type.value_type(), opts)); - return compute::DictionaryEncode(dense); - } - - return compute::Cast(arr, type, opts); -} - -struct InsertImplicitCastsImpl { - struct ValidatedAndCast { - std::shared_ptr expr; - std::shared_ptr type; - }; - - Result InsertCastsAndValidate(const Expression& expr) { - ValidatedAndCast out; - ARROW_ASSIGN_OR_RAISE(out.expr, InsertImplicitCasts(expr, schema_)); - ARROW_ASSIGN_OR_RAISE(out.type, out.expr->Validate(schema_)); - return std::move(out); - } - - Result> Cast(std::shared_ptr type, - const Expression& expr) { - if (expr.type() != ExpressionType::SCALAR) { - return expr.CastTo(type).Copy(); - } - - // cast the scalar directly - const auto& value = checked_cast(expr).value(); - ARROW_ASSIGN_OR_RAISE(auto cast_value, value->CastTo(std::move(type))); - return scalar(cast_value); - } - - Result> operator()(const InExpression& expr) { - ARROW_ASSIGN_OR_RAISE(auto op, InsertCastsAndValidate(*expr.operand())); - auto set = expr.set(); - - if (!op.type->Equals(set->type())) { - // cast the set (which we assume to be small) to match op.type - ARROW_ASSIGN_OR_RAISE(auto encoded_set, CastOrDictionaryEncode( - *set, op.type, compute::CastOptions{})); - set = encoded_set.make_array(); - } - - return std::make_shared(std::move(op.expr), std::move(set)); - } - - Result> operator()(const NotExpression& expr) { - ARROW_ASSIGN_OR_RAISE(auto op, InsertCastsAndValidate(*expr.operand())); - - if (op.type->id() != Type::BOOL) { - ARROW_ASSIGN_OR_RAISE(op.expr, Cast(boolean(), *op.expr)); - } - return not_(std::move(op.expr)); - } - - Result> operator()(const AndExpression& expr) { - ARROW_ASSIGN_OR_RAISE(auto lhs, InsertCastsAndValidate(*expr.left_operand())); - ARROW_ASSIGN_OR_RAISE(auto rhs, InsertCastsAndValidate(*expr.right_operand())); - - if (lhs.type->id() != Type::BOOL) { - ARROW_ASSIGN_OR_RAISE(lhs.expr, Cast(boolean(), *lhs.expr)); - } - if (rhs.type->id() != Type::BOOL) { - ARROW_ASSIGN_OR_RAISE(rhs.expr, Cast(boolean(), *rhs.expr)); - } - return and_(std::move(lhs.expr), std::move(rhs.expr)); - } - - Result> operator()(const OrExpression& expr) { - ARROW_ASSIGN_OR_RAISE(auto lhs, InsertCastsAndValidate(*expr.left_operand())); - ARROW_ASSIGN_OR_RAISE(auto rhs, InsertCastsAndValidate(*expr.right_operand())); - - if (lhs.type->id() != Type::BOOL) { - ARROW_ASSIGN_OR_RAISE(lhs.expr, Cast(boolean(), *lhs.expr)); - } - if (rhs.type->id() != Type::BOOL) { - ARROW_ASSIGN_OR_RAISE(rhs.expr, Cast(boolean(), *rhs.expr)); - } - return or_(std::move(lhs.expr), std::move(rhs.expr)); - } - - Result> operator()(const ComparisonExpression& expr) { - ARROW_ASSIGN_OR_RAISE(auto lhs, InsertCastsAndValidate(*expr.left_operand())); - ARROW_ASSIGN_OR_RAISE(auto rhs, InsertCastsAndValidate(*expr.right_operand())); - - if (lhs.type->Equals(rhs.type)) { - return expr.Copy(); - } - - if (lhs.expr->type() == ExpressionType::SCALAR) { - ARROW_ASSIGN_OR_RAISE(lhs.expr, Cast(rhs.type, *lhs.expr)); - } else { - ARROW_ASSIGN_OR_RAISE(rhs.expr, Cast(lhs.type, *rhs.expr)); - } - return std::make_shared(expr.op(), std::move(lhs.expr), - std::move(rhs.expr)); - } - - Result> operator()(const Expression& expr) const { - return expr.Copy(); - } - - const Schema& schema_; -}; - -Result> InsertImplicitCasts(const Expression& expr, - const Schema& schema) { - RETURN_NOT_OK(schema.CanReferenceFieldsByNames(FieldsInExpression(expr))); - return VisitExpression(expr, InsertImplicitCastsImpl{schema}); -} - -Status VisitConjunctionMembers(const Expression& expr, - const std::function& visitor) { - if (expr.type() == ExpressionType::AND) { - const auto& and_ = checked_cast(expr); - RETURN_NOT_OK(VisitConjunctionMembers(*and_.left_operand(), visitor)); - RETURN_NOT_OK(VisitConjunctionMembers(*and_.right_operand(), visitor)); - return Status::OK(); - } - - return visitor(expr); -} - -std::vector FieldsInExpression(const Expression& expr) { - struct { - void operator()(const FieldExpression& expr) { fields.push_back(expr.name()); } - - void operator()(const UnaryExpression& expr) { - VisitExpression(*expr.operand(), *this); - } - - void operator()(const BinaryExpression& expr) { - VisitExpression(*expr.left_operand(), *this); - VisitExpression(*expr.right_operand(), *this); - } - - void operator()(const Expression&) const {} - - std::vector fields; - } visitor; - - VisitExpression(expr, visitor); - return std::move(visitor.fields); -} - -std::vector FieldsInExpression(const std::shared_ptr& expr) { - DCHECK_NE(expr, nullptr); - if (expr == nullptr) { - return {}; - } - - return FieldsInExpression(*expr); -} - -RecordBatchIterator ExpressionEvaluator::FilterBatches(RecordBatchIterator unfiltered, - std::shared_ptr filter, - MemoryPool* pool) { - auto filter_batches = [filter, pool, this](std::shared_ptr unfiltered) { - auto filtered = Evaluate(*filter, *unfiltered, pool).Map([&](Datum selection) { - return Filter(selection, unfiltered, pool); - }); - - if (filtered.ok() && (*filtered)->num_rows() == 0) { - // drop empty batches - return FilterIterator::Reject>(); - } - - return FilterIterator::MaybeAccept(std::move(filtered)); - }; - - return MakeFilterIterator(std::move(filter_batches), std::move(unfiltered)); -} - -std::shared_ptr ExpressionEvaluator::Null() { - struct Impl : ExpressionEvaluator { - Result Evaluate(const Expression& expr, const RecordBatch& batch, - MemoryPool* pool) const override { - ARROW_ASSIGN_OR_RAISE(auto type, expr.Validate(*batch.schema())); - return Datum(MakeNullScalar(type)); - } - - Result> Filter(const Datum& selection, - const std::shared_ptr& batch, - MemoryPool* pool) const override { - return batch; - } - }; - - return std::make_shared(); -} - -struct TreeEvaluator::Impl { - Result operator()(const ScalarExpression& expr) const { - return Datum(expr.value()); - } - - Result operator()(const FieldExpression& expr) const { - if (auto column = batch_.GetColumnByName(expr.name())) { - return std::move(column); - } - return NullDatum(); - } - - Result operator()(const AndExpression& expr) const { - return EvaluateBoolean(expr, compute::KleeneAnd); - } - - Result operator()(const OrExpression& expr) const { - return EvaluateBoolean(expr, compute::KleeneOr); - } - - Result EvaluateBoolean(const BinaryExpression& expr, - Result kernel(const Datum& left, - const Datum& right, - ExecContext* ctx)) const { - ARROW_ASSIGN_OR_RAISE(Datum lhs, Evaluate(*expr.left_operand())); - ARROW_ASSIGN_OR_RAISE(Datum rhs, Evaluate(*expr.right_operand())); - - if (lhs.is_scalar()) { - ARROW_ASSIGN_OR_RAISE( - auto lhs_array, - MakeArrayFromScalar(*lhs.scalar(), batch_.num_rows(), ctx_.memory_pool())); - lhs = Datum(std::move(lhs_array)); - } - - if (rhs.is_scalar()) { - ARROW_ASSIGN_OR_RAISE( - auto rhs_array, - MakeArrayFromScalar(*rhs.scalar(), batch_.num_rows(), ctx_.memory_pool())); - rhs = Datum(std::move(rhs_array)); - } - - return kernel(lhs, rhs, &ctx_); - } - - Result operator()(const NotExpression& expr) const { - ARROW_ASSIGN_OR_RAISE(Datum to_invert, Evaluate(*expr.operand())); - if (IsNullDatum(to_invert)) { - return NullDatum(); - } - - if (to_invert.is_scalar()) { - bool trivial_condition = - checked_cast(*to_invert.scalar()).value; - return Datum(std::make_shared(!trivial_condition)); - } - return compute::Invert(to_invert, &ctx_); - } - - Result operator()(const InExpression& expr) const { - ARROW_ASSIGN_OR_RAISE(Datum operand_values, Evaluate(*expr.operand())); - if (IsNullDatum(operand_values)) { - return Datum(expr.set()->null_count() != 0); - } - - DCHECK(operand_values.is_array()); - return compute::IsIn(operand_values, expr.set(), &ctx_); - } - - Result operator()(const IsValidExpression& expr) const { - ARROW_ASSIGN_OR_RAISE(Datum operand_values, Evaluate(*expr.operand())); - if (IsNullDatum(operand_values)) { - return Datum(false); - } - - if (operand_values.is_scalar()) { - return Datum(true); - } - - DCHECK(operand_values.is_array()); - if (operand_values.array()->GetNullCount() == 0) { - return Datum(true); - } - - return Datum(std::make_shared(operand_values.array()->length, - operand_values.array()->buffers[0])); - } - - Result operator()(const CastExpression& expr) const { - ARROW_ASSIGN_OR_RAISE(auto to_type, expr.Validate(*batch_.schema())); - - ARROW_ASSIGN_OR_RAISE(auto to_cast, Evaluate(*expr.operand())); - if (to_cast.is_scalar()) { - return to_cast.scalar()->CastTo(to_type); - } - - DCHECK(to_cast.is_array()); - return CastOrDictionaryEncode(to_cast, to_type, expr.options()); - } - - Result operator()(const ComparisonExpression& expr) const { - ARROW_ASSIGN_OR_RAISE(Datum lhs, Evaluate(*expr.left_operand())); - ARROW_ASSIGN_OR_RAISE(Datum rhs, Evaluate(*expr.right_operand())); - - if (IsNullDatum(lhs) || IsNullDatum(rhs)) { - return Datum(std::make_shared()); - } - - if (lhs.type()->id() == Type::DICTIONARY && rhs.type()->id() == Type::DICTIONARY) { - if (lhs.is_array() && rhs.is_array()) { - // decode dictionary arrays - for (Datum* arg : {&lhs, &rhs}) { - auto dict = checked_pointer_cast(arg->make_array()); - ARROW_ASSIGN_OR_RAISE(*arg, compute::Take(dict->dictionary(), dict->indices(), - compute::TakeOptions::Defaults())); - } - } else if (lhs.is_array() || rhs.is_array()) { - auto dict = checked_pointer_cast( - (lhs.is_array() ? lhs : rhs).make_array()); - - ARROW_ASSIGN_OR_RAISE(auto scalar, checked_cast( - *(lhs.is_scalar() ? lhs : rhs).scalar()) - .GetEncodedValue()); - if (lhs.is_array()) { - lhs = dict->dictionary(); - rhs = std::move(scalar); - } else { - lhs = std::move(scalar); - rhs = dict->dictionary(); - } - ARROW_ASSIGN_OR_RAISE( - Datum out_dict, - compute::Compare(lhs, rhs, compute::CompareOptions(expr.op()), &ctx_)); - - return compute::Take(out_dict, dict->indices(), compute::TakeOptions::Defaults()); - } - } - - return compute::Compare(lhs, rhs, compute::CompareOptions(expr.op()), &ctx_); - } - - Result operator()(const Expression& expr) const { - return Status::NotImplemented("evaluation of ", expr.ToString()); - } - - Result Evaluate(const Expression& expr) const { - return this_->Evaluate(expr, batch_, ctx_.memory_pool()); - } - - const TreeEvaluator* this_; - const RecordBatch& batch_; - mutable compute::ExecContext ctx_; -}; - -Result TreeEvaluator::Evaluate(const Expression& expr, const RecordBatch& batch, - MemoryPool* pool) const { - return VisitExpression(expr, Impl{this, batch, compute::ExecContext{pool}}); -} - -Result> TreeEvaluator::Filter( - const Datum& selection, const std::shared_ptr& batch, - MemoryPool* pool) const { - if (selection.is_array()) { - auto selection_array = selection.make_array(); - compute::ExecContext ctx(pool); - ARROW_ASSIGN_OR_RAISE(Datum filtered, - compute::Filter(batch, selection_array, - compute::FilterOptions::Defaults(), &ctx)); - return filtered.record_batch(); - } - - if (!selection.is_scalar() || selection.type()->id() != Type::BOOL) { - return Status::NotImplemented("Filtering batches against DatumKind::", - selection.kind(), " of type ", *selection.type()); - } - - if (BooleanScalar(true).Equals(*selection.scalar())) { - return batch; - } - - return batch->Slice(0, 0); -} - -const std::shared_ptr& scalar(bool value) { - static auto true_ = scalar(MakeScalar(true)); - static auto false_ = scalar(MakeScalar(false)); - return value ? true_ : false_; -} - -// Serialization is accomplished by converting expressions to single element StructArrays -// then writing that to an IPC file. The last field is always an int32 column containing -// ExpressionType, the rest store the Expression's members. -struct SerializeImpl { - Result> ToArray(const Expression& expr) const { - return VisitExpression(expr, *this); - } - - Result> TaggedWithChildren(const Expression& expr, - ArrayVector children) const { - children.emplace_back(); - ARROW_ASSIGN_OR_RAISE(children.back(), - MakeArrayFromScalar(Int32Scalar(expr.type()), 1)); - - return StructArray::Make(children, std::vector(children.size(), "")); - } - - Result> operator()(const FieldExpression& expr) const { - // store the field's name in a StringArray - ARROW_ASSIGN_OR_RAISE(auto name, MakeArrayFromScalar(StringScalar(expr.name()), 1)); - return TaggedWithChildren(expr, {name}); - } - - Result> operator()(const ScalarExpression& expr) const { - // store the scalar's value in a single element Array - ARROW_ASSIGN_OR_RAISE(auto value, MakeArrayFromScalar(*expr.value(), 1)); - return TaggedWithChildren(expr, {value}); - } - - Result> operator()(const UnaryExpression& expr) const { - // recurse to store the operand in a single element StructArray - ARROW_ASSIGN_OR_RAISE(auto operand, ToArray(*expr.operand())); - return TaggedWithChildren(expr, {operand}); - } - - Result> operator()(const CastExpression& expr) const { - // recurse to store the operand in a single element StructArray - ARROW_ASSIGN_OR_RAISE(auto operand, ToArray(*expr.operand())); - - // store the cast target and a discriminant - std::shared_ptr is_like_expr, to; - if (const auto& to_type = expr.to_type()) { - ARROW_ASSIGN_OR_RAISE(is_like_expr, MakeArrayFromScalar(BooleanScalar(false), 1)); - ARROW_ASSIGN_OR_RAISE(to, MakeArrayOfNull(to_type, 1)); - } - if (const auto& like_expr = expr.like_expr()) { - ARROW_ASSIGN_OR_RAISE(is_like_expr, MakeArrayFromScalar(BooleanScalar(true), 1)); - ARROW_ASSIGN_OR_RAISE(to, ToArray(*like_expr)); - } - - return TaggedWithChildren(expr, {operand, is_like_expr, to}); - } - - Result> operator()(const BinaryExpression& expr) const { - // recurse to store the operands in single element StructArrays - ARROW_ASSIGN_OR_RAISE(auto left_operand, ToArray(*expr.left_operand())); - ARROW_ASSIGN_OR_RAISE(auto right_operand, ToArray(*expr.right_operand())); - return TaggedWithChildren(expr, {left_operand, right_operand}); - } - - Result> operator()( - const ComparisonExpression& expr) const { - // recurse to store the operands in single element StructArrays - ARROW_ASSIGN_OR_RAISE(auto left_operand, ToArray(*expr.left_operand())); - ARROW_ASSIGN_OR_RAISE(auto right_operand, ToArray(*expr.right_operand())); - // store the CompareOperator in a single element Int32Array - ARROW_ASSIGN_OR_RAISE(auto op, MakeArrayFromScalar(Int32Scalar(expr.op()), 1)); - return TaggedWithChildren(expr, {left_operand, right_operand, op}); - } - - Result> operator()(const InExpression& expr) const { - // recurse to store the operand in a single element StructArray - ARROW_ASSIGN_OR_RAISE(auto operand, ToArray(*expr.operand())); - - // store the set as a single element ListArray - auto set_type = list(expr.set()->type()); - - ARROW_ASSIGN_OR_RAISE(auto set_offsets, AllocateBuffer(sizeof(int32_t) * 2)); - reinterpret_cast(set_offsets->mutable_data())[0] = 0; - reinterpret_cast(set_offsets->mutable_data())[1] = - static_cast(expr.set()->length()); - - auto set_values = expr.set(); - - auto set = std::make_shared(std::move(set_type), 1, std::move(set_offsets), - std::move(set_values)); - return TaggedWithChildren(expr, {operand, set}); - } - - Result> operator()(const Expression& expr) const { - return Status::NotImplemented("serialization of ", expr.ToString()); - } - - Result> ToBuffer(const Expression& expr) const { - ARROW_ASSIGN_OR_RAISE(auto array, SerializeImpl{}.ToArray(expr)); - ARROW_ASSIGN_OR_RAISE(auto batch, RecordBatch::FromStructArray(array)); - ARROW_ASSIGN_OR_RAISE(auto stream, io::BufferOutputStream::Create()); - ARROW_ASSIGN_OR_RAISE(auto writer, ipc::MakeFileWriter(stream, batch->schema())); - RETURN_NOT_OK(writer->WriteRecordBatch(*batch)); - RETURN_NOT_OK(writer->Close()); - return stream->Finish(); - } -}; - -Result> Expression::Serialize() const { - return SerializeImpl{}.ToBuffer(*this); -} - -struct DeserializeImpl { - Result> FromArray(const Array& array) const { - if (array.type_id() != Type::STRUCT || array.length() != 1) { - return Status::Invalid("can only deserialize expressions from unit-length", - " StructArray, got ", array); - } - const auto& struct_array = checked_cast(array); - - ARROW_ASSIGN_OR_RAISE(auto expression_type, GetExpressionType(struct_array)); - switch (expression_type) { - case ExpressionType::FIELD: { - ARROW_ASSIGN_OR_RAISE(auto name, GetView(struct_array, 0)); - return std::make_shared(std::string(name)); - } - - case ExpressionType::SCALAR: { - ARROW_ASSIGN_OR_RAISE(auto value, struct_array.field(0)->GetScalar(0)); - return scalar(std::move(value)); - } - - case ExpressionType::NOT: { - ARROW_ASSIGN_OR_RAISE(auto operand, FromArray(*struct_array.field(0))); - return not_(std::move(operand)); - } - - case ExpressionType::CAST: { - ARROW_ASSIGN_OR_RAISE(auto operand, FromArray(*struct_array.field(0))); - ARROW_ASSIGN_OR_RAISE(auto is_like_expr, GetView(struct_array, 1)); - if (is_like_expr) { - ARROW_ASSIGN_OR_RAISE(auto like_expr, FromArray(*struct_array.field(2))); - return operand->CastLike(std::move(like_expr)).Copy(); - } - return operand->CastTo(struct_array.field(2)->type()).Copy(); - } - - case ExpressionType::AND: { - ARROW_ASSIGN_OR_RAISE(auto left_operand, FromArray(*struct_array.field(0))); - ARROW_ASSIGN_OR_RAISE(auto right_operand, FromArray(*struct_array.field(1))); - return and_(std::move(left_operand), std::move(right_operand)); - } - - case ExpressionType::OR: { - ARROW_ASSIGN_OR_RAISE(auto left_operand, FromArray(*struct_array.field(0))); - ARROW_ASSIGN_OR_RAISE(auto right_operand, FromArray(*struct_array.field(1))); - return or_(std::move(left_operand), std::move(right_operand)); - } - - case ExpressionType::COMPARISON: { - ARROW_ASSIGN_OR_RAISE(auto left_operand, FromArray(*struct_array.field(0))); - ARROW_ASSIGN_OR_RAISE(auto right_operand, FromArray(*struct_array.field(1))); - ARROW_ASSIGN_OR_RAISE(auto op, GetView(struct_array, 2)); - return std::make_shared(static_cast(op), - std::move(left_operand), - std::move(right_operand)); - } - - case ExpressionType::IS_VALID: { - ARROW_ASSIGN_OR_RAISE(auto operand, FromArray(*struct_array.field(0))); - return std::make_shared(std::move(operand)); - } - - case ExpressionType::IN: { - ARROW_ASSIGN_OR_RAISE(auto operand, FromArray(*struct_array.field(0))); - if (struct_array.field(1)->type_id() != Type::LIST) { - return Status::TypeError("expected field 1 of ", struct_array, - " to have list type"); - } - auto set = checked_cast(*struct_array.field(1)).values(); - return std::make_shared(std::move(operand), std::move(set)); - } - - default: - break; - } - - return Status::Invalid("non-deserializable ExpressionType ", expression_type); - } - - template ::ArrayType> - static Result().GetView(0))> GetView(const StructArray& array, - int index) { - if (index >= array.num_fields()) { - return Status::IndexError("expected ", array, " to have a child at index ", index); - } - - const auto& child = *array.field(index); - if (child.type_id() != T::type_id) { - return Status::TypeError("expected child ", index, " of ", array, " to have type ", - T::type_id); - } - - return checked_cast(child).GetView(0); - } - - static Result GetExpressionType(const StructArray& array) { - if (array.struct_type()->num_fields() < 1) { - return Status::Invalid("StructArray didn't contain ExpressionType member"); - } - - ARROW_ASSIGN_OR_RAISE(auto expression_type, - GetView(array, array.num_fields() - 1)); - return static_cast(expression_type); - } - - Result> FromBuffer(const Buffer& serialized) { - io::BufferReader stream(serialized); - ARROW_ASSIGN_OR_RAISE(auto reader, ipc::RecordBatchFileReader::Open(&stream)); - ARROW_ASSIGN_OR_RAISE(auto batch, reader->ReadRecordBatch(0)); - ARROW_ASSIGN_OR_RAISE(auto array, batch->ToStructArray()); - return FromArray(*array); - } -}; - -Result> Expression::Deserialize(const Buffer& serialized) { - return DeserializeImpl{}.FromBuffer(serialized); -} - -// Transform an array of counts to offsets which will divide a ListArray -// into an equal number of slices with corresponding lengths. -inline Result> CountsToOffsets( - std::shared_ptr counts) { - Int32Builder offset_builder; - RETURN_NOT_OK(offset_builder.Resize(counts->length() + 1)); - offset_builder.UnsafeAppend(0); - - for (int64_t i = 0; i < counts->length(); ++i) { - DCHECK_NE(counts->Value(i), 0); - auto next_offset = static_cast(offset_builder[i] + counts->Value(i)); - offset_builder.UnsafeAppend(next_offset); - } - - std::shared_ptr offsets; - RETURN_NOT_OK(offset_builder.Finish(&offsets)); - return offsets; -} - -// Helper for simultaneous dictionary encoding of multiple arrays. -// -// The fused dictionary is the Cartesian product of the individual dictionaries. -// For example given two arrays A, B where A has unique values ["ex", "why"] -// and B has unique values [0, 1] the fused dictionary is the set of tuples -// [["ex", 0], ["ex", 1], ["why", 0], ["ex", 1]]. -// -// TODO(bkietz) this capability belongs in an Action of the hash kernels, where -// it can be used to group aggregates without materializing a grouped batch. -// For the purposes of writing we need the materialized grouped batch anyway -// since no Writers accept a selection vector. -class StructDictionary { - public: - struct Encoded { - std::shared_ptr indices; - std::shared_ptr dictionary; - }; - - static Result Encode(const ArrayVector& columns) { - Encoded out{nullptr, std::make_shared()}; - - for (const auto& column : columns) { - if (column->null_count() != 0) { - return Status::NotImplemented("Grouping on a field with nulls"); - } - - RETURN_NOT_OK(out.dictionary->AddOne(column, &out.indices)); - } - - return out; - } - - Result> Decode(std::shared_ptr fused_indices, - FieldVector fields) { - std::vector builders(dictionaries_.size()); - for (Int32Builder& b : builders) { - RETURN_NOT_OK(b.Resize(fused_indices->length())); - } - - std::vector codes(dictionaries_.size()); - for (int64_t i = 0; i < fused_indices->length(); ++i) { - Expand(fused_indices->Value(i), codes.data()); - - auto builder_it = builders.begin(); - for (int32_t index : codes) { - builder_it++->UnsafeAppend(index); - } - } - - ArrayVector columns(dictionaries_.size()); - for (size_t i = 0; i < dictionaries_.size(); ++i) { - std::shared_ptr indices; - RETURN_NOT_OK(builders[i].FinishInternal(&indices)); - - ARROW_ASSIGN_OR_RAISE(Datum column, compute::Take(dictionaries_[i], indices)); - columns[i] = column.make_array(); - } - - return StructArray::Make(std::move(columns), std::move(fields)); - } - - private: - Status AddOne(Datum column, std::shared_ptr* fused_indices) { - ArrayData* encoded; - if (column.type()->id() != Type::DICTIONARY) { - ARROW_ASSIGN_OR_RAISE(column, compute::DictionaryEncode(column)); - } - encoded = column.mutable_array(); - - auto indices = - std::make_shared(encoded->length, std::move(encoded->buffers[1])); - - dictionaries_.push_back(MakeArray(std::move(encoded->dictionary))); - auto dictionary_size = static_cast(dictionaries_.back()->length()); - - if (*fused_indices == nullptr) { - *fused_indices = std::move(indices); - size_ = dictionary_size; - return Status::OK(); - } - - // It's useful to think about the case where each of dictionaries_ has size 10. - // In this case the decimal digit in the ones place is the code in dictionaries_[0], - // the tens place corresponds to dictionaries_[1], etc. - // The incumbent indices must be shifted to the hundreds place so as not to collide. - ARROW_ASSIGN_OR_RAISE(Datum new_fused_indices, - compute::Multiply(indices, MakeScalar(size_))); - - ARROW_ASSIGN_OR_RAISE(new_fused_indices, - compute::Add(new_fused_indices, *fused_indices)); - - *fused_indices = checked_pointer_cast(new_fused_indices.make_array()); - - // XXX should probably cap this at 2**15 or so - ARROW_CHECK(!internal::MultiplyWithOverflow(size_, dictionary_size, &size_)); - return Status::OK(); - } - - // expand a fused code into component dict codes, order is in order of addition - void Expand(int32_t fused_code, int32_t* codes) { - for (size_t i = 0; i < dictionaries_.size(); ++i) { - auto dictionary_size = static_cast(dictionaries_[i]->length()); - codes[i] = fused_code % dictionary_size; - fused_code /= dictionary_size; - } - } - - int32_t size_; - ArrayVector dictionaries_; -}; - -Result> MakeGroupings(const StructArray& by) { - if (by.num_fields() == 0) { - return Status::NotImplemented("Grouping with no criteria"); - } - - ARROW_ASSIGN_OR_RAISE(auto fused, StructDictionary::Encode(by.fields())); - - ARROW_ASSIGN_OR_RAISE(auto sort_indices, compute::SortIndices(*fused.indices)); - ARROW_ASSIGN_OR_RAISE(Datum sorted, compute::Take(fused.indices, *sort_indices)); - fused.indices = checked_pointer_cast(sorted.make_array()); - - ARROW_ASSIGN_OR_RAISE(auto fused_counts_and_values, - compute::ValueCounts(fused.indices)); - fused.indices.reset(); - - auto unique_fused_indices = - checked_pointer_cast(fused_counts_and_values->GetFieldByName("values")); - ARROW_ASSIGN_OR_RAISE( - auto unique_rows, - fused.dictionary->Decode(std::move(unique_fused_indices), by.type()->fields())); - - auto counts = - checked_pointer_cast(fused_counts_and_values->GetFieldByName("counts")); - ARROW_ASSIGN_OR_RAISE(auto offsets, CountsToOffsets(std::move(counts))); - - ARROW_ASSIGN_OR_RAISE(auto grouped_sort_indices, - ListArray::FromArrays(*offsets, *sort_indices)); - - return StructArray::Make( - ArrayVector{std::move(unique_rows), std::move(grouped_sort_indices)}, - std::vector{"values", "groupings"}); -} - -Result> ApplyGroupings(const ListArray& groupings, - const Array& array) { - ARROW_ASSIGN_OR_RAISE(Datum sorted, - compute::Take(array, groupings.data()->child_data[0])); - - return std::make_shared(list(array.type()), groupings.length(), - groupings.value_offsets(), sorted.make_array()); -} - -Result ApplyGroupings(const ListArray& groupings, - const std::shared_ptr& batch) { - ARROW_ASSIGN_OR_RAISE(Datum sorted, - compute::Take(batch, groupings.data()->child_data[0])); - - const auto& sorted_batch = *sorted.record_batch(); - - RecordBatchVector out(static_cast(groupings.length())); - for (size_t i = 0; i < out.size(); ++i) { - out[i] = sorted_batch.Slice(groupings.value_offset(i), groupings.value_length(i)); - } - - return out; -} - -} // namespace dataset -} // namespace arrow +// FIXME remove diff --git a/cpp/src/arrow/dataset/filter.h b/cpp/src/arrow/dataset/filter.h index 750c67c048d..6dd39e25e1e 100644 --- a/cpp/src/arrow/dataset/filter.h +++ b/cpp/src/arrow/dataset/filter.h @@ -15,603 +15,5 @@ // specific language governing permissions and limitations // under the License. -// This API is EXPERIMENTAL. +// FIXME remove -#pragma once - -#include -#include -#include -#include -#include - -#include "arrow/compute/api_scalar.h" -#include "arrow/compute/cast.h" -#include "arrow/dataset/type_fwd.h" -#include "arrow/dataset/visibility.h" -#include "arrow/datum.h" -#include "arrow/result.h" -#include "arrow/scalar.h" -#include "arrow/type_fwd.h" -#include "arrow/util/checked_cast.h" -#include "arrow/util/variant.h" - -namespace arrow { -namespace dataset { - -using compute::CastOptions; -using compute::CompareOperator; - -struct ExpressionType { - enum type { - /// a reference to a column within a record batch, will evaluate to an array - FIELD, - - /// a literal singular value encapsulated in a Scalar - SCALAR, - - /// a literal Array - // TODO(bkietz) ARRAY, - - /// an inversion of another expression - NOT, - - /// cast an expression to a given DataType - CAST, - - /// a conjunction of multiple expressions (true if all operands are true) - AND, - - /// a disjunction of multiple expressions (true if any operand is true) - OR, - - /// a comparison of two other expressions - COMPARISON, - - /// replace nulls with other expressions - /// currently only boolean expressions may be coalesced - // TODO(bkietz) COALESCE, - - /// extract validity as a boolean expression - IS_VALID, - - /// check each element for membership in a set - IN, - - /// custom user defined expression - CUSTOM, - }; -}; - -class InExpression; -class CastExpression; -class IsValidExpression; - -/// Represents an expression tree -class ARROW_DS_EXPORT Expression { - public: - explicit Expression(ExpressionType::type type) : type_(type) {} - - virtual ~Expression() = default; - - /// Returns true iff the expressions are identical; does not check for equivalence. - /// For example, (A and B) is not equal to (B and A) nor is (A and not A) equal to - /// (false). - virtual bool Equals(const Expression& other) const = 0; - - bool Equals(const std::shared_ptr& other) const; - - /// Overload for the common case of checking for equality to a specific scalar. - template ()))> - bool Equals(T&& t) const; - - /// If true, this Expression is a ScalarExpression wrapping a null scalar. - bool IsNull() const; - - /// Validate this expression for execution against a schema. This will check that all - /// reference fields are present (fields not in the schema will be replaced with null) - /// and all subexpressions are executable. Returns the type to which this expression - /// will evaluate. - virtual Result> Validate(const Schema& schema) const = 0; - - /// \brief Simplify to an equivalent Expression given assumed constraints on input. - /// This can be used to do less filtering work using predicate push down. - /// - /// Both expressions must pass validation against a schema before Assume may be used. - /// - /// Two expressions can be considered equivalent for a given subset of possible inputs - /// if they yield identical results. Formally, if given.Evaluate(input).Equals(input) - /// then Assume guarantees that: - /// expr.Assume(given).Evaluate(input).Equals(expr.Evaluate(input)) - /// - /// For example if we are given that all inputs will - /// satisfy ("a"_ == 1) then the expression ("a"_ > 0 and "b"_ > 0) is equivalent to - /// ("b"_ > 0). It is impossible that the comparison ("a"_ > 0) will evaluate false - /// given ("a"_ == 1), so both expressions will yield identical results. Thus we can - /// write: - /// ("a"_ > 0 and "b"_ > 0).Assume("a"_ == 1).Equals("b"_ > 0) - /// - /// filter.Assume(partition) is trivial if filter and partition are disjoint or if - /// partition is a subset of filter. FIXME(bkietz) write this better - /// - If the two are disjoint, then (false) may be substituted for filter. - /// - If partition is a subset of filter then (true) may be substituted for filter. - /// - /// filter.Assume(partition) is straightforward if both filter and partition are simple - /// comparisons. - /// - filter may be a superset of partition, in which case the filter is - /// satisfied by all inputs: - /// ("a"_ > 0).Assume("a"_ == 1).Equals(true) - /// - filter may be disjoint with partition, in which case there are no inputs which - /// satisfy filter: - /// ("a"_ < 0).Assume("a"_ == 1).Equals(false) - /// - If neither of these is the case, partition provides no information which can - /// simplify filter: - /// ("a"_ == 1).Assume("a"_ > 0).Equals("a"_ == 1) - /// ("a"_ == 1).Assume("b"_ == 1).Equals("a"_ == 1) - /// - /// If filter is compound, Assume can be distributed across the boolean operator. To - /// prove this is valid, we again demonstrate that the simplified expression will yield - /// identical results. For conjunction of filters lhs and rhs: - /// (lhs.Assume(p) and rhs.Assume(p)).Evaluate(input) - /// == Intersection(lhs.Assume(p).Evaluate(input), rhs.Assume(p).Evaluate(input)) - /// == Intersection(lhs.Evaluate(input), rhs.Evaluate(input)) - /// == (lhs and rhs).Evaluate(input) - /// - The proof for disjunction is symmetric; just replace Intersection with Union. Thus - /// we can write: - /// (lhs and rhs).Assume(p).Equals(lhs.Assume(p) and rhs.Assume(p)) - /// (lhs or rhs).Assume(p).Equals(lhs.Assume(p) or rhs.Assume(p)) - /// - For negation: - /// (not e.Assume(p)).Evaluate(input) - /// == Difference(input, e.Assume(p).Evaluate(input)) - /// == Difference(input, e.Evaluate(input)) - /// == (not e).Evaluate(input) - /// - Thus we can write: - /// (not e).Assume(p).Equals(not e.Assume(p)) - /// - /// If the partition expression is a conjunction then each of its subexpressions is - /// true for all input and can be used independently: - /// filter.Assume(lhs).Assume(rhs).Evaluate(input) - /// == filter.Assume(lhs).Evaluate(input) - /// == filter.Evaluate(input) - /// - Thus we can write: - /// filter.Assume(lhs and rhs).Equals(filter.Assume(lhs).Assume(rhs)) - /// - /// FIXME(bkietz) disjunction proof - /// filter.Assume(lhs or rhs).Equals(filter.Assume(lhs) and filter.Assume(rhs)) - /// - This may not result in a simpler expression so it is only used when - /// filter.Assume(lhs).Equals(filter.Assume(rhs)) - /// - /// If the partition expression is a negation then we can use the above relations by - /// replacing comparisons with their complements and using the properties: - /// (not (a and b)).Equals(not a or not b) - /// (not (a or b)).Equals(not a and not b) - virtual std::shared_ptr Assume(const Expression& given) const; - - std::shared_ptr Assume(const std::shared_ptr& given) const { - return Assume(*given); - } - - /// Indicates if the expression is satisfiable. - /// - /// This is a shortcut to check if the expression is neither null nor false. - bool IsSatisfiable() const { return !IsNull() && !Equals(false); } - - /// Indicates if the expression is satisfiable given an other expression. - /// - /// This behaves like IsSatisfiable, but it simplifies the current expression - /// with the given `other` information. - bool IsSatisfiableWith(const Expression& other) const { - return Assume(other)->IsSatisfiable(); - } - - bool IsSatisfiableWith(const std::shared_ptr& other) const { - return Assume(other)->IsSatisfiable(); - } - - /// returns a debug string representing this expression - virtual std::string ToString() const = 0; - - /// serialize/deserialize an Expression. - Result> Serialize() const; - static Result> Deserialize(const Buffer&); - - /// \brief Return the expression's type identifier - ExpressionType::type type() const { return type_; } - - /// Copy this expression into a shared pointer. - virtual std::shared_ptr Copy() const = 0; - - InExpression In(std::shared_ptr set) const; - - IsValidExpression IsValid() const; - - CastExpression CastTo(std::shared_ptr type, - CastOptions options = CastOptions()) const; - - CastExpression CastLike(const Expression& expr, - CastOptions options = CastOptions()) const; - - CastExpression CastLike(std::shared_ptr expr, - CastOptions options = CastOptions()) const; - - protected: - ExpressionType::type type_; -}; - -/// Helper class which implements Copy and forwards construction -template -class ExpressionImpl : public Base { - public: - static constexpr ExpressionType::type expression_type = E; - - template - explicit ExpressionImpl(A0&& arg0, A&&... args) - : Base(expression_type, std::forward(arg0), std::forward(args)...) {} - - std::shared_ptr Copy() const override { - return std::make_shared(internal::checked_cast(*this)); - } -}; - -/// Base class for an expression with exactly one operand -class ARROW_DS_EXPORT UnaryExpression : public Expression { - public: - const std::shared_ptr& operand() const { return operand_; } - - bool Equals(const Expression& other) const override; - - protected: - UnaryExpression(ExpressionType::type type, std::shared_ptr operand) - : Expression(type), operand_(std::move(operand)) {} - - std::shared_ptr operand_; -}; - -/// Base class for an expression with exactly two operands -class ARROW_DS_EXPORT BinaryExpression : public Expression { - public: - const std::shared_ptr& left_operand() const { return left_operand_; } - - const std::shared_ptr& right_operand() const { return right_operand_; } - - bool Equals(const Expression& other) const override; - - protected: - BinaryExpression(ExpressionType::type type, std::shared_ptr left_operand, - std::shared_ptr right_operand) - : Expression(type), - left_operand_(std::move(left_operand)), - right_operand_(std::move(right_operand)) {} - - std::shared_ptr left_operand_, right_operand_; -}; - -class ARROW_DS_EXPORT ComparisonExpression final - : public ExpressionImpl { - public: - ComparisonExpression(CompareOperator op, std::shared_ptr left_operand, - std::shared_ptr right_operand) - : ExpressionImpl(std::move(left_operand), std::move(right_operand)), op_(op) {} - - std::string ToString() const override; - - bool Equals(const Expression& other) const override; - - std::shared_ptr Assume(const Expression& given) const override; - - CompareOperator op() const { return op_; } - - Result> Validate(const Schema& schema) const override; - - private: - std::shared_ptr AssumeGivenComparison( - const ComparisonExpression& given) const; - - CompareOperator op_; -}; - -class ARROW_DS_EXPORT AndExpression final - : public ExpressionImpl { - public: - using ExpressionImpl::ExpressionImpl; - - std::string ToString() const override; - - std::shared_ptr Assume(const Expression& given) const override; - - Result> Validate(const Schema& schema) const override; -}; - -class ARROW_DS_EXPORT OrExpression final - : public ExpressionImpl { - public: - using ExpressionImpl::ExpressionImpl; - - std::string ToString() const override; - - std::shared_ptr Assume(const Expression& given) const override; - - Result> Validate(const Schema& schema) const override; -}; - -class ARROW_DS_EXPORT NotExpression final - : public ExpressionImpl { - public: - using ExpressionImpl::ExpressionImpl; - - std::string ToString() const override; - - std::shared_ptr Assume(const Expression& given) const override; - - Result> Validate(const Schema& schema) const override; -}; - -class ARROW_DS_EXPORT IsValidExpression final - : public ExpressionImpl { - public: - using ExpressionImpl::ExpressionImpl; - - std::string ToString() const override; - - Result> Validate(const Schema& schema) const override; - - std::shared_ptr Assume(const Expression& given) const override; -}; - -class ARROW_DS_EXPORT InExpression final - : public ExpressionImpl { - public: - InExpression(std::shared_ptr operand, std::shared_ptr set) - : ExpressionImpl(std::move(operand)), set_(std::move(set)) {} - - std::string ToString() const override; - - Result> Validate(const Schema& schema) const override; - - std::shared_ptr Assume(const Expression& given) const override; - - /// The set against which the operand will be compared - const std::shared_ptr& set() const { return set_; } - - private: - std::shared_ptr set_; -}; - -/// Explicitly cast an expression to a different type -class ARROW_DS_EXPORT CastExpression final - : public ExpressionImpl { - public: - CastExpression(std::shared_ptr operand, std::shared_ptr to, - CastOptions options) - : ExpressionImpl(std::move(operand)), - to_(std::move(to)), - options_(std::move(options)) {} - - /// The operand will be cast to whatever type `like` would evaluate to, given the same - /// schema. - CastExpression(std::shared_ptr operand, std::shared_ptr like, - CastOptions options) - : ExpressionImpl(std::move(operand)), - to_(std::move(like)), - options_(std::move(options)) {} - - std::string ToString() const override; - - std::shared_ptr Assume(const Expression& given) const override; - - Result> Validate(const Schema& schema) const override; - - const CastOptions& options() const { return options_; } - - /// Return the target type of this CastTo expression, or nullptr if this is a - /// CastLike expression. - const std::shared_ptr& to_type() const; - - /// Return the target expression of this CastLike expression, or nullptr if - /// this is a CastTo expression. - const std::shared_ptr& like_expr() const; - - private: - util::Variant, std::shared_ptr> to_; - CastOptions options_; -}; - -/// Represents a scalar value; thin wrapper around arrow::Scalar -class ARROW_DS_EXPORT ScalarExpression final : public Expression { - public: - explicit ScalarExpression(const std::shared_ptr& value) - : Expression(ExpressionType::SCALAR), value_(std::move(value)) {} - - const std::shared_ptr& value() const { return value_; } - - std::string ToString() const override; - - bool Equals(const Expression& other) const override; - - Result> Validate(const Schema& schema) const override; - - std::shared_ptr Copy() const override; - - private: - std::shared_ptr value_; -}; - -/// Represents a reference to a field. Stores only the field's name (type and other -/// information is known only when a Schema is provided) -class ARROW_DS_EXPORT FieldExpression final : public Expression { - public: - explicit FieldExpression(std::string name) - : Expression(ExpressionType::FIELD), name_(std::move(name)) {} - - std::string name() const { return name_; } - - std::string ToString() const override; - - bool Equals(const Expression& other) const override; - - Result> Validate(const Schema& schema) const override; - - std::shared_ptr Copy() const override; - - private: - std::string name_; -}; - -class ARROW_DS_EXPORT CustomExpression : public Expression { - protected: - CustomExpression() : Expression(ExpressionType::CUSTOM) {} -}; - -inline std::shared_ptr scalar(std::shared_ptr value) { - return std::make_shared(std::move(value)); -} - -template -auto scalar(T&& value) -> decltype(scalar(MakeScalar(std::forward(value)))) { - return scalar(MakeScalar(std::forward(value))); -} - -template -bool Expression::Equals(T&& t) const { - if (type_ != ExpressionType::SCALAR) { - return false; - } - auto s = MakeScalar(std::forward(t)); - return internal::checked_cast(*this).value()->Equals(*s); -} - -template -auto VisitExpression(const Expression& expr, Visitor&& visitor) - -> decltype(visitor(expr)) { - switch (expr.type()) { - case ExpressionType::FIELD: - return visitor(internal::checked_cast(expr)); - - case ExpressionType::SCALAR: - return visitor(internal::checked_cast(expr)); - - case ExpressionType::IN: - return visitor(internal::checked_cast(expr)); - - case ExpressionType::IS_VALID: - return visitor(internal::checked_cast(expr)); - - case ExpressionType::AND: - return visitor(internal::checked_cast(expr)); - - case ExpressionType::OR: - return visitor(internal::checked_cast(expr)); - - case ExpressionType::NOT: - return visitor(internal::checked_cast(expr)); - - case ExpressionType::CAST: - return visitor(internal::checked_cast(expr)); - - case ExpressionType::COMPARISON: - return visitor(internal::checked_cast(expr)); - - case ExpressionType::CUSTOM: - default: - break; - } - return visitor(internal::checked_cast(expr)); -} - -/// \brief Visit each subexpression of an arbitrarily nested conjunction. -/// -/// | given | visit | -/// |--------------------------------|---------------------------------------------| -/// | a and b | visit(a), visit(b) | -/// | c | visit(c) | -/// | (a and b) and ((c or d) and e) | visit(a), visit(b), visit(c or d), visit(e) | -ARROW_DS_EXPORT Status VisitConjunctionMembers( - const Expression& expr, const std::function& visitor); - -/// \brief Insert CastExpressions where necessary to make a valid expression. -ARROW_DS_EXPORT Result> InsertImplicitCasts( - const Expression& expr, const Schema& schema); - -/// \brief Returns field names referenced in the expression. -ARROW_DS_EXPORT std::vector FieldsInExpression(const Expression& expr); - -ARROW_DS_EXPORT std::vector FieldsInExpression( - const std::shared_ptr& expr); - -/// Interface for evaluation of expressions against record batches. -class ARROW_DS_EXPORT ExpressionEvaluator { - public: - virtual ~ExpressionEvaluator() = default; - - /// Evaluate expr against each row of a RecordBatch. - /// Returned Datum will be of either SCALAR or ARRAY kind. - /// A return value of ARRAY kind will have length == batch.num_rows() - /// An return value of SCALAR kind is equivalent to an array of the same type whose - /// slots contain a single repeated value. - /// - /// expr must be validated against the schema of batch before calling this method. - virtual Result Evaluate(const Expression& expr, const RecordBatch& batch, - MemoryPool* pool) const = 0; - - Result Evaluate(const Expression& expr, const RecordBatch& batch) const { - return Evaluate(expr, batch, default_memory_pool()); - } - - virtual Result> Filter( - const Datum& selection, const std::shared_ptr& batch, - MemoryPool* pool) const = 0; - - Result> Filter( - const Datum& selection, const std::shared_ptr& batch) const { - return Filter(selection, batch, default_memory_pool()); - } - - /// \brief Wrap an iterator of record batches with a filter expression. The resulting - /// iterator will yield record batches filtered by the given expression. - /// - /// \note The ExpressionEvaluator must outlive the returned iterator. - RecordBatchIterator FilterBatches(RecordBatchIterator unfiltered, - std::shared_ptr filter, - MemoryPool* pool = default_memory_pool()); - - /// construct an Evaluator which evaluates all expressions to null and does no - /// filtering - static std::shared_ptr Null(); -}; - -/// construct an Evaluator which uses compute kernels to evaluate expressions and -/// filter record batches in depth first order -class ARROW_DS_EXPORT TreeEvaluator : public ExpressionEvaluator { - public: - Result Evaluate(const Expression& expr, const RecordBatch& batch, - MemoryPool* pool) const override; - - Result> Filter(const Datum& selection, - const std::shared_ptr& batch, - MemoryPool* pool) const override; - - protected: - struct Impl; -}; - -/// \brief Assemble lists of indices of identical rows. -/// -/// \param[in] by A StructArray whose columns will be used as grouping criteria. -/// \return A StructArray mapping unique rows (in field "values", represented as a -/// StructArray with the same fields as `by`) to lists of indices where -/// that row appears (in field "groupings"). -ARROW_DS_EXPORT -Result> MakeGroupings(const StructArray& by); - -/// \brief Produce slices of an Array which correspond to the provided groupings. -ARROW_DS_EXPORT -Result> ApplyGroupings(const ListArray& groupings, - const Array& array); -ARROW_DS_EXPORT -Result ApplyGroupings(const ListArray& groupings, - const std::shared_ptr& batch); - -} // namespace dataset -} // namespace arrow diff --git a/cpp/src/arrow/dataset/filter_test.cc b/cpp/src/arrow/dataset/filter_test.cc index fce2d9cfcd9..ef32b8a7bb6 100644 --- a/cpp/src/arrow/dataset/filter_test.cc +++ b/cpp/src/arrow/dataset/filter_test.cc @@ -15,487 +15,4 @@ // specific language governing permissions and limitations // under the License. -#include "arrow/dataset/filter.h" - -#include -#include -#include -#include - -#include -#include - -#include "arrow/compute/api.h" -#include "arrow/dataset/test_util.h" -#include "arrow/record_batch.h" -#include "arrow/status.h" -#include "arrow/testing/gtest_util.h" -#include "arrow/type.h" -#include "arrow/type_fwd.h" -#include "arrow/util/checked_cast.h" -#include "arrow/util/logging.h" - -namespace arrow { -namespace dataset { - -// clang-format off -using string_literals::operator"" _; -// clang-format on - -using internal::checked_cast; -using internal::checked_pointer_cast; - -// A frozen shared_ptr with behavior expected by GTest -struct TestExpression : util::EqualityComparable, - util::ToStringOstreamable { - // NOLINTNEXTLINE runtime/explicit - TestExpression(std::shared_ptr e) : expression(std::move(e)) {} - - // NOLINTNEXTLINE runtime/explicit - TestExpression(const Expression& e) : expression(e.Copy()) {} - - // NOLINTNEXTLINE runtime/explicit - TestExpression(const Expression2& e) : expression(e) {} - - std::shared_ptr expression; - - using util::EqualityComparable::operator==; - bool Equals(const TestExpression& other) const { - return expression->Equals(other.expression); - } - - std::string ToString() const { return expression->ToString(); } - - friend bool operator==(const std::shared_ptr& lhs, - const TestExpression& rhs) { - return TestExpression(lhs) == rhs; - } - - friend void PrintTo(const TestExpression& expr, std::ostream* os) { - *os << expr.ToString(); - } -}; - -using E = TestExpression; - -class ExpressionsTest : public ::testing::Test { - public: - void AssertSimplifiesTo(E expr, E given, E expected) { - ASSERT_OK_AND_ASSIGN(auto expr_type, expr.expression->Validate(*schema_)); - ASSERT_OK_AND_ASSIGN(auto given_type, given.expression->Validate(*schema_)); - ASSERT_OK_AND_ASSIGN(auto expected_type, expected.expression->Validate(*schema_)); - - EXPECT_TRUE(expr_type->Equals(expected_type)); - EXPECT_TRUE(given_type->Equals(boolean())); - - auto simplified = expr.expression->Assume(given.expression); - ASSERT_EQ(simplified, expected) - << " simplification of: " << expr.ToString() << std::endl - << " given: " << given.ToString() << std::endl; - } - - std::shared_ptr ns = timestamp(TimeUnit::NANO); - std::shared_ptr schema_ = - schema({field("a", int32()), field("b", int32()), field("f", float64()), - field("s", utf8()), field("ts", ns), - field("dict_b", dictionary(int32(), int32()))}); - std::shared_ptr always = scalar(true); - std::shared_ptr never = scalar(false); -}; - -TEST_F(ExpressionsTest, Equality) { - ASSERT_EQ(E{"a"_}, E{"a"_}); - ASSERT_NE(E{"a"_}, E{"b"_}); - - ASSERT_EQ(E{"b"_ == 3}, E{"b"_ == 3}); - ASSERT_NE(E{"b"_ == 3}, E{"b"_ < 3}); - ASSERT_NE(E{"b"_ == 3}, E{"b"_}); - - // ordering matters - ASSERT_EQ(E{"b"_ == 3}, E{"b"_ == 3}); - ASSERT_NE(E{"b"_ == 3}, E{"b"_ < 3}); - ASSERT_NE(E{"b"_ == 3}, E{"b"_}); - - ASSERT_EQ(E("b"_ > 2 and "b"_ < 3), E("b"_ > 2 and "b"_ < 3)); - ASSERT_NE(E("b"_ > 2 and "b"_ < 3), E("b"_ < 3 and "b"_ > 2)); -} - -TEST_F(ExpressionsTest, SimplificationOfCompoundQuery) { - AssertSimplifiesTo("b"_ > 5 and "b"_ < 10, "b"_ == 3, *never); - AssertSimplifiesTo("b"_ > 5 and "b"_ < 10, "b"_ == 6, *always); - - AssertSimplifiesTo("b"_ == 3 or "b"_ == 4, "b"_ > 6, *never); - AssertSimplifiesTo("b"_ == 3 or "b"_ == 4, "b"_ == 3, *always); - AssertSimplifiesTo("b"_ == 3 or "b"_ == 4, "b"_ > 3, "b"_ == 4); - AssertSimplifiesTo("b"_ == 3 or "b"_ == 4, "b"_ >= 3, "b"_ == 3 or "b"_ == 4); - - AssertSimplifiesTo("f"_ > 0.5 and "f"_ < 1.5, not_("f"_ < 0.0 or "f"_ > 1.0), - "f"_ > 0.5); - - AssertSimplifiesTo("b"_ == 4, "a"_ == 0, "b"_ == 4); - - AssertSimplifiesTo("a"_ == 3 or "b"_ == 4, "a"_ == 0, "b"_ == 4); - - auto set_123 = ArrayFromJSON(int32(), R"([1, 2, 3])"); - AssertSimplifiesTo("a"_ == 3 and "b"_.In(set_123), "b"_ == 3, "a"_ == 3); - AssertSimplifiesTo("a"_ == 3 and "b"_.In(set_123), "b"_ == 0, *never); - - AssertSimplifiesTo("a"_ == 0 or not_("b"_.IsValid()), "b"_ == 3, "a"_ == 0); -} - -TEST_F(ExpressionsTest, SimplificationAgainstCompoundCondition) { - AssertSimplifiesTo("b"_ > 5, "b"_ == 3 or "b"_ == 6, "b"_ > 5); - AssertSimplifiesTo("b"_ > 7, "b"_ == 3 or "b"_ == 6, *never); - AssertSimplifiesTo("b"_ > 5 and "b"_ < 10, "b"_ > 6 and "b"_ < 13, "b"_ < 10); - - auto set_123 = ArrayFromJSON(int32(), R"([1, 2, 3])"); - AssertSimplifiesTo("b"_.In(set_123), "a"_ == 3 and "b"_ == 3, *always); - AssertSimplifiesTo("b"_.In(set_123), "a"_ == 3 and "b"_ == 5, *never); - - auto dict_set_123 = - DictArrayFromJSON(dictionary(int32(), int32()), R"([1,2,0])", R"([1,2,3])"); - ASSERT_OK_AND_ASSIGN(auto b_dict, dict_set_123->GetScalar(0)); - AssertSimplifiesTo("b_dict"_.In(dict_set_123), "a"_ == 3 and "b_dict"_ == b_dict, - *always); -} - -class FilterTest : public ::testing::Test { - public: - FilterTest() { evaluator_ = std::make_shared(); } - - Result DoFilter(const Expression& expr, - std::vector> fields, - std::string batch_json, - std::shared_ptr* expected_mask = nullptr) { - // expected filter result is in the "in" field - fields.push_back(field("in", boolean())); - auto batch = RecordBatchFromJSON(schema(fields), batch_json); - if (expected_mask) { - *expected_mask = checked_pointer_cast(batch->GetColumnByName("in")); - } - - ARROW_ASSIGN_OR_RAISE(auto expr_type, expr.Validate(*batch->schema())); - EXPECT_TRUE(expr_type->Equals(boolean())); - - return evaluator_->Evaluate(expr, *batch); - } - - void AssertFilter(const std::shared_ptr& expr, - std::vector> fields, - const std::string& batch_json) { - AssertFilter(*expr, std::move(fields), batch_json); - } - - void AssertFilter(const Expression& expr, std::vector> fields, - const std::string& batch_json) { - std::shared_ptr expected_mask; - - ASSERT_OK_AND_ASSIGN(Datum mask, DoFilter(expr, std::move(fields), - std::move(batch_json), &expected_mask)); - ASSERT_TRUE(mask.type()->Equals(null()) || mask.type()->Equals(boolean())); - - if (mask.is_array()) { - AssertArraysEqual(*expected_mask, *mask.make_array(), /*verbose=*/true); - return; - } - - ASSERT_TRUE(mask.is_scalar()); - auto mask_scalar = mask.scalar(); - if (!mask_scalar->is_valid) { - ASSERT_EQ(expected_mask->null_count(), expected_mask->length()); - return; - } - - TypedBufferBuilder builder; - ASSERT_OK(builder.Append(expected_mask->length(), - checked_cast(*mask_scalar).value)); - - std::shared_ptr values; - ASSERT_OK(builder.Finish(&values)); - - ASSERT_ARRAYS_EQUAL(*expected_mask, BooleanArray(expected_mask->length(), values)); - } - - std::shared_ptr evaluator_; -}; - -TEST_F(FilterTest, Trivial) { - // Note that we should expect these trivial expressions will never be evaluated against - // record batches; since they're trivial, evaluation is not necessary. - AssertFilter(scalar(true), {field("a", int32()), field("b", float64())}, R"([ - {"a": 0, "b": -0.1, "in": 1}, - {"a": 0, "b": 0.3, "in": 1}, - {"a": 1, "b": 0.2, "in": 1}, - {"a": 2, "b": -0.1, "in": 1}, - {"a": 0, "b": 0.1, "in": 1}, - {"a": 0, "b": null, "in": 1}, - {"a": 0, "b": 1.0, "in": 1} - ])"); - - AssertFilter(scalar(false), {field("a", int32()), field("b", float64())}, R"([ - {"a": 0, "b": -0.1, "in": 0}, - {"a": 0, "b": 0.3, "in": 0}, - {"a": 1, "b": 0.2, "in": 0}, - {"a": 2, "b": -0.1, "in": 0}, - {"a": 0, "b": 0.1, "in": 0}, - {"a": 0, "b": null, "in": 0}, - {"a": 0, "b": 1.0, "in": 0} - ])"); - - AssertFilter(*scalar(std::shared_ptr(new BooleanScalar)), - {field("a", int32()), field("b", float64())}, R"([ - {"a": 0, "b": -0.1, "in": null}, - {"a": 0, "b": 0.3, "in": null}, - {"a": 1, "b": 0.2, "in": null}, - {"a": 2, "b": -0.1, "in": null}, - {"a": 0, "b": 0.1, "in": null}, - {"a": 0, "b": null, "in": null}, - {"a": 0, "b": 1.0, "in": null} - ])"); -} - -TEST_F(FilterTest, Basics) { - AssertFilter("a"_ == 0 and "b"_ > 0.0 and "b"_ < 1.0, - {field("a", int32()), field("b", float64())}, R"([ - {"a": 0, "b": -0.1, "in": 0}, - {"a": 0, "b": 0.3, "in": 1}, - {"a": 1, "b": 0.2, "in": 0}, - {"a": 2, "b": -0.1, "in": 0}, - {"a": 0, "b": 0.1, "in": 1}, - {"a": 0, "b": null, "in": null}, - {"a": 0, "b": 1.0, "in": 0} - ])"); - - AssertFilter("a"_ != 0 and "b"_ > 0.1, {field("a", int32()), field("b", float64())}, - R"([ - {"a": 0, "b": -0.1, "in": 0}, - {"a": 0, "b": 0.3, "in": 0}, - {"a": 1, "b": 0.2, "in": 1}, - {"a": 2, "b": -0.1, "in": 0}, - {"a": 0, "b": 0.1, "in": 0}, - {"a": 0, "b": null, "in": 0}, - {"a": 0, "b": 1.0, "in": 0} - ])"); -} - -TEST_F(FilterTest, InExpression) { - auto hello_world = ArrayFromJSON(utf8(), R"(["hello", "world"])"); - - AssertFilter("s"_.In(hello_world), {field("s", utf8())}, R"([ - {"s": "hello", "in": 1}, - {"s": "world", "in": 1}, - {"s": "", "in": 0}, - {"s": null, "in": null}, - {"s": "foo", "in": 0}, - {"s": "hello", "in": 1}, - {"s": "bar", "in": 0} - ])"); -} - -TEST_F(FilterTest, IsValidExpression) { - AssertFilter("s"_.IsValid(), {field("s", utf8())}, R"([ - {"s": "hello", "in": 1}, - {"s": null, "in": 0}, - {"s": "", "in": 1}, - {"s": null, "in": 0}, - {"s": "foo", "in": 1}, - {"s": "hello", "in": 1}, - {"s": null, "in": 0} - ])"); -} - -TEST_F(FilterTest, Cast) { - ASSERT_RAISES(TypeError, ("a"_ == double(1.0)).Validate(Schema({field("a", int32())}))); - - AssertFilter("a"_.CastTo(float64()) == double(1.0), - {field("a", int32()), field("b", float64())}, R"([ - {"a": 0, "b": -0.1, "in": 0}, - {"a": 0, "b": 0.3, "in": 0}, - {"a": 1, "b": 0.2, "in": 1}, - {"a": 2, "b": -0.1, "in": 0}, - {"a": 0, "b": 0.1, "in": 0}, - {"a": 0, "b": null, "in": 0}, - {"a": 1, "b": 1.0, "in": 1} - ])"); - - AssertFilter("a"_ == scalar(0.6)->CastLike("a"_), - {field("a", int32()), field("b", float64())}, R"([ - {"a": 0, "b": -0.1, "in": 1}, - {"a": 0, "b": 0.3, "in": 1}, - {"a": 1, "b": 0.2, "in": 0}, - {"a": 2, "b": -0.1, "in": 0}, - {"a": 0, "b": 0.1, "in": 1}, - {"a": 0, "b": null, "in": 1}, - {"a": 1, "b": 1.0, "in": 0} - ])"); - - AssertFilter("a"_.CastLike("b"_) == "b"_, {field("a", int32()), field("b", float64())}, - R"([ - {"a": 0, "b": -0.1, "in": 0}, - {"a": 0, "b": 0.0, "in": 1}, - {"a": 1, "b": 1.0, "in": 1}, - {"a": 2, "b": -0.1, "in": 0}, - {"a": 0, "b": 0.1, "in": 0}, - {"a": 2, "b": null, "in": null}, - {"a": 1, "b": 1.0, "in": 1} - ])"); -} - -TEST_F(ExpressionsTest, ImplicitCast) { - ASSERT_OK_AND_ASSIGN(auto filter, InsertImplicitCasts("a"_ == 0.0, *schema_)); - ASSERT_EQ(E{filter}, E{"a"_ == 0}); - - auto ns = timestamp(TimeUnit::NANO); - auto date = "1990-10-23 10:23:33"; - ASSERT_OK_AND_ASSIGN(filter, InsertImplicitCasts("ts"_ == date, *schema_)); - ASSERT_EQ(E{filter}, E{"ts"_ == *MakeScalar(date)->CastTo(ns)}); - - ASSERT_OK_AND_ASSIGN(filter, - InsertImplicitCasts("ts"_ == date and "b"_ == "3", *schema_)); - ASSERT_EQ(E{filter}, E{"ts"_ == *MakeScalar(date)->CastTo(ns) and "b"_ == 3}); - AssertSimplifiesTo(*filter, "b"_ == 2, *never); - AssertSimplifiesTo(*filter, "b"_ == 3, "ts"_ == *MakeScalar(date)->CastTo(ns)); - - // set is double but "a"_ is int32 - auto set_double = ArrayFromJSON(float64(), R"([1, 2, 3])"); - ASSERT_OK_AND_ASSIGN(filter, InsertImplicitCasts("a"_.In(set_double), *schema_)); - auto set_int32 = ArrayFromJSON(int32(), R"([1, 2, 3])"); - ASSERT_EQ(E{filter}, E{"a"_.In(set_int32)}); - - EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, - testing::HasSubstr("Field named 'nope' not found"), - InsertImplicitCasts("nope"_ == 0.0, *schema_)); -} - -TEST_F(ExpressionsTest, ImplicitCastToDict) { - auto dict_type = dictionary(int32(), float64()); - ASSERT_OK_AND_ASSIGN(auto filter, - InsertImplicitCasts("a"_ == 1.5, Schema({field("a", dict_type)}))); - - auto encoded_scalar = std::make_shared( - DictionaryScalar::ValueType{MakeScalar(0), - ArrayFromJSON(float64(), "[1.5]")}, - dict_type); - - ASSERT_EQ(E{filter}, E{"a"_ == encoded_scalar}); - - for (int32_t i = 0; i < 5; ++i) { - auto partition_scalar = std::make_shared( - DictionaryScalar::ValueType{ - MakeScalar(i), ArrayFromJSON(float64(), "[0.0, 0.5, 1.0, 1.5, 2.0]")}, - dict_type); - ASSERT_EQ(E{filter->Assume("a"_ == partition_scalar)}, E{scalar(i == 3)}); - } - - auto set_f64 = ArrayFromJSON(float64(), "[0.0, 0.5, 1.0, 1.5, 2.0]"); - ASSERT_OK_AND_ASSIGN( - filter, InsertImplicitCasts("a"_.In(set_f64), Schema({field("a", dict_type)}))); -} - -TEST_F(FilterTest, ImplicitCast) { - ASSERT_OK_AND_ASSIGN(auto filter, - InsertImplicitCasts("a"_ >= "1", Schema({field("a", int32())}))); - - AssertFilter(*filter, {field("a", int32()), field("b", float64())}, - R"([ - {"a": 0, "b": -0.1, "in": 0}, - {"a": 0, "b": 0.0, "in": 0}, - {"a": 1, "b": 1.0, "in": 1}, - {"a": 2, "b": -0.1, "in": 1}, - {"a": 0, "b": 0.1, "in": 0}, - {"a": 2, "b": null, "in": 1}, - {"a": 1, "b": 1.0, "in": 1} - ])"); -} - -TEST_F(FilterTest, ConditionOnAbsentColumn) { - AssertFilter("a"_ == 0 and "b"_ > 0.0 and "b"_ < 1.0 and "absent"_ == 0, - {field("a", int32()), field("b", float64())}, R"([ - {"a": 0, "b": -0.1, "in": false}, - {"a": 0, "b": 0.3, "in": null}, - {"a": 1, "b": 0.2, "in": false}, - {"a": 2, "b": -0.1, "in": false}, - {"a": 0, "b": 0.1, "in": null}, - {"a": 0, "b": null, "in": null}, - {"a": 0, "b": 1.0, "in": false} - ])"); -} - -TEST_F(FilterTest, KleeneTruthTables) { - AssertFilter("a"_ and "b"_, {field("a", boolean()), field("b", boolean())}, R"([ - {"a":null, "b":null, "in":null}, - {"a":null, "b":true, "in":null}, - {"a":null, "b":false, "in":false}, - - {"a":true, "b":true, "in":true}, - {"a":true, "b":false, "in":false}, - - {"a":false, "b":false, "in":false} - ])"); - - AssertFilter("a"_ or "b"_, {field("a", boolean()), field("b", boolean())}, R"([ - {"a":null, "b":null, "in":null}, - {"a":null, "b":true, "in":true}, - {"a":null, "b":false, "in":null}, - - {"a":true, "b":true, "in":true}, - {"a":true, "b":false, "in":true}, - - {"a":false, "b":false, "in":false} - ])"); -} - -void AssertGrouping(const FieldVector& by_fields, const std::string& batch_json, - const std::string& expected_json) { - FieldVector fields_with_ids = by_fields; - fields_with_ids.push_back(field("ids", list(int32()))); - auto expected = ArrayFromJSON(struct_(fields_with_ids), expected_json); - - FieldVector fields_with_id = by_fields; - fields_with_id.push_back(field("id", int32())); - auto batch = RecordBatchFromJSON(schema(fields_with_id), batch_json); - - ASSERT_OK_AND_ASSIGN(auto by, batch->RemoveColumn(batch->num_columns() - 1) - .Map([](std::shared_ptr by) { - return by->ToStructArray(); - })); - - ASSERT_OK_AND_ASSIGN(auto groupings_and_values, MakeGroupings(*by)); - - auto groupings = - checked_pointer_cast(groupings_and_values->GetFieldByName("groupings")); - - ASSERT_OK_AND_ASSIGN(std::shared_ptr grouped_ids, - ApplyGroupings(*groupings, *batch->GetColumnByName("id"))); - - ArrayVector columns = - checked_cast(*groupings_and_values->GetFieldByName("values")) - .fields(); - columns.push_back(grouped_ids); - - ASSERT_OK_AND_ASSIGN(auto actual, StructArray::Make(columns, fields_with_ids)); - - AssertArraysEqual(*expected, *actual, /*verbose=*/true); -} - -TEST(GroupTest, Basics) { - AssertGrouping({field("a", utf8()), field("b", int32())}, R"([ - {"a": "ex", "b": 0, "id": 0}, - {"a": "ex", "b": 0, "id": 1}, - {"a": "why", "b": 0, "id": 2}, - {"a": "ex", "b": 1, "id": 3}, - {"a": "why", "b": 0, "id": 4}, - {"a": "ex", "b": 1, "id": 5}, - {"a": "ex", "b": 0, "id": 6}, - {"a": "why", "b": 1, "id": 7} - ])", - R"([ - {"a": "ex", "b": 0, "ids": [0, 1, 6]}, - {"a": "why", "b": 0, "ids": [2, 4]}, - {"a": "ex", "b": 1, "ids": [3, 5]}, - {"a": "why", "b": 1, "ids": [7]} - ])"); -} - -} // namespace dataset -} // namespace arrow +// FIXME delete this diff --git a/cpp/src/arrow/dataset/partition.cc b/cpp/src/arrow/dataset/partition.cc index 16a49b988a9..6fd61096392 100644 --- a/cpp/src/arrow/dataset/partition.cc +++ b/cpp/src/arrow/dataset/partition.cc @@ -26,10 +26,11 @@ #include "arrow/array/array_nested.h" #include "arrow/array/builder_dict.h" #include "arrow/compute/api_scalar.h" +#include "arrow/compute/api_vector.h" #include "arrow/dataset/dataset_internal.h" -#include "arrow/dataset/filter.h" #include "arrow/filesystem/path_util.h" #include "arrow/scalar.h" +#include "arrow/util/int_util_internal.h" #include "arrow/util/logging.h" #include "arrow/util/make_unique.h" #include "arrow/util/string_view.h" @@ -84,7 +85,7 @@ Status KeyValuePartitioning::SetDefaultValuesFromKeys(const Expression2& expr, return Status::OK(); } -inline std::shared_ptr ConjunctionFromGroupingRow(Scalar* row) { +inline Expression2 ConjunctionFromGroupingRow(Scalar* row) { ScalarVector* values = &checked_cast(row)->value; std::vector equality_expressions(values->size()); for (size_t i = 0; i < values->size(); ++i) { @@ -529,5 +530,192 @@ Result> PartitioningOrFactory::GetOrInferSchema( return factory()->Inspect(paths); } +// Transform an array of counts to offsets which will divide a ListArray +// into an equal number of slices with corresponding lengths. +inline Result> CountsToOffsets( + std::shared_ptr counts) { + Int32Builder offset_builder; + RETURN_NOT_OK(offset_builder.Resize(counts->length() + 1)); + offset_builder.UnsafeAppend(0); + + for (int64_t i = 0; i < counts->length(); ++i) { + DCHECK_NE(counts->Value(i), 0); + auto next_offset = static_cast(offset_builder[i] + counts->Value(i)); + offset_builder.UnsafeAppend(next_offset); + } + + std::shared_ptr offsets; + RETURN_NOT_OK(offset_builder.Finish(&offsets)); + return offsets; +} + +// Helper for simultaneous dictionary encoding of multiple arrays. +// +// The fused dictionary is the Cartesian product of the individual dictionaries. +// For example given two arrays A, B where A has unique values ["ex", "why"] +// and B has unique values [0, 1] the fused dictionary is the set of tuples +// [["ex", 0], ["ex", 1], ["why", 0], ["ex", 1]]. +// +// TODO(bkietz) this capability belongs in an Action of the hash kernels, where +// it can be used to group aggregates without materializing a grouped batch. +// For the purposes of writing we need the materialized grouped batch anyway +// since no Writers accept a selection vector. +class StructDictionary { + public: + struct Encoded { + std::shared_ptr indices; + std::shared_ptr dictionary; + }; + + static Result Encode(const ArrayVector& columns) { + Encoded out{nullptr, std::make_shared()}; + + for (const auto& column : columns) { + if (column->null_count() != 0) { + return Status::NotImplemented("Grouping on a field with nulls"); + } + + RETURN_NOT_OK(out.dictionary->AddOne(column, &out.indices)); + } + + return out; + } + + Result> Decode(std::shared_ptr fused_indices, + FieldVector fields) { + std::vector builders(dictionaries_.size()); + for (Int32Builder& b : builders) { + RETURN_NOT_OK(b.Resize(fused_indices->length())); + } + + std::vector codes(dictionaries_.size()); + for (int64_t i = 0; i < fused_indices->length(); ++i) { + Expand(fused_indices->Value(i), codes.data()); + + auto builder_it = builders.begin(); + for (int32_t index : codes) { + builder_it++->UnsafeAppend(index); + } + } + + ArrayVector columns(dictionaries_.size()); + for (size_t i = 0; i < dictionaries_.size(); ++i) { + std::shared_ptr indices; + RETURN_NOT_OK(builders[i].FinishInternal(&indices)); + + ARROW_ASSIGN_OR_RAISE(Datum column, compute::Take(dictionaries_[i], indices)); + columns[i] = column.make_array(); + } + + return StructArray::Make(std::move(columns), std::move(fields)); + } + + private: + Status AddOne(Datum column, std::shared_ptr* fused_indices) { + ArrayData* encoded; + if (column.type()->id() != Type::DICTIONARY) { + ARROW_ASSIGN_OR_RAISE(column, compute::DictionaryEncode(column)); + } + encoded = column.mutable_array(); + + auto indices = + std::make_shared(encoded->length, std::move(encoded->buffers[1])); + + dictionaries_.push_back(MakeArray(std::move(encoded->dictionary))); + auto dictionary_size = static_cast(dictionaries_.back()->length()); + + if (*fused_indices == nullptr) { + *fused_indices = std::move(indices); + size_ = dictionary_size; + return Status::OK(); + } + + // It's useful to think about the case where each of dictionaries_ has size 10. + // In this case the decimal digit in the ones place is the code in dictionaries_[0], + // the tens place corresponds to dictionaries_[1], etc. + // The incumbent indices must be shifted to the hundreds place so as not to collide. + ARROW_ASSIGN_OR_RAISE(Datum new_fused_indices, + compute::Multiply(indices, MakeScalar(size_))); + + ARROW_ASSIGN_OR_RAISE(new_fused_indices, + compute::Add(new_fused_indices, *fused_indices)); + + *fused_indices = checked_pointer_cast(new_fused_indices.make_array()); + + // XXX should probably cap this at 2**15 or so + ARROW_CHECK(!internal::MultiplyWithOverflow(size_, dictionary_size, &size_)); + return Status::OK(); + } + + // expand a fused code into component dict codes, order is in order of addition + void Expand(int32_t fused_code, int32_t* codes) { + for (size_t i = 0; i < dictionaries_.size(); ++i) { + auto dictionary_size = static_cast(dictionaries_[i]->length()); + codes[i] = fused_code % dictionary_size; + fused_code /= dictionary_size; + } + } + + int32_t size_; + ArrayVector dictionaries_; +}; + +Result> MakeGroupings(const StructArray& by) { + if (by.num_fields() == 0) { + return Status::NotImplemented("Grouping with no criteria"); + } + + ARROW_ASSIGN_OR_RAISE(auto fused, StructDictionary::Encode(by.fields())); + + ARROW_ASSIGN_OR_RAISE(auto sort_indices, compute::SortIndices(*fused.indices)); + ARROW_ASSIGN_OR_RAISE(Datum sorted, compute::Take(fused.indices, *sort_indices)); + fused.indices = checked_pointer_cast(sorted.make_array()); + + ARROW_ASSIGN_OR_RAISE(auto fused_counts_and_values, + compute::ValueCounts(fused.indices)); + fused.indices.reset(); + + auto unique_fused_indices = + checked_pointer_cast(fused_counts_and_values->GetFieldByName("values")); + ARROW_ASSIGN_OR_RAISE( + auto unique_rows, + fused.dictionary->Decode(std::move(unique_fused_indices), by.type()->fields())); + + auto counts = + checked_pointer_cast(fused_counts_and_values->GetFieldByName("counts")); + ARROW_ASSIGN_OR_RAISE(auto offsets, CountsToOffsets(std::move(counts))); + + ARROW_ASSIGN_OR_RAISE(auto grouped_sort_indices, + ListArray::FromArrays(*offsets, *sort_indices)); + + return StructArray::Make( + ArrayVector{std::move(unique_rows), std::move(grouped_sort_indices)}, + std::vector{"values", "groupings"}); +} + +Result> ApplyGroupings(const ListArray& groupings, + const Array& array) { + ARROW_ASSIGN_OR_RAISE(Datum sorted, + compute::Take(array, groupings.data()->child_data[0])); + + return std::make_shared(list(array.type()), groupings.length(), + groupings.value_offsets(), sorted.make_array()); +} + +Result ApplyGroupings(const ListArray& groupings, + const std::shared_ptr& batch) { + ARROW_ASSIGN_OR_RAISE(Datum sorted, + compute::Take(batch, groupings.data()->child_data[0])); + + const auto& sorted_batch = *sorted.record_batch(); + + RecordBatchVector out(static_cast(groupings.length())); + for (size_t i = 0; i < out.size(); ++i) { + out[i] = sorted_batch.Slice(groupings.value_offset(i), groupings.value_length(i)); + } + + return out; +} + } // namespace dataset } // namespace arrow diff --git a/cpp/src/arrow/dataset/partition.h b/cpp/src/arrow/dataset/partition.h index 1bad8956b65..a9f8c87bfc9 100644 --- a/cpp/src/arrow/dataset/partition.h +++ b/cpp/src/arrow/dataset/partition.h @@ -285,5 +285,22 @@ class ARROW_DS_EXPORT PartitioningOrFactory { std::shared_ptr partitioning_; }; +/// \brief Assemble lists of indices of identical rows. +/// +/// \param[in] by A StructArray whose columns will be used as grouping criteria. +/// \return A StructArray mapping unique rows (in field "values", represented as a +/// StructArray with the same fields as `by`) to lists of indices where +/// that row appears (in field "groupings"). +ARROW_DS_EXPORT +Result> MakeGroupings(const StructArray& by); + +/// \brief Produce slices of an Array which correspond to the provided groupings. +ARROW_DS_EXPORT +Result> ApplyGroupings(const ListArray& groupings, + const Array& array); +ARROW_DS_EXPORT +Result ApplyGroupings(const ListArray& groupings, + const std::shared_ptr& batch); + } // namespace dataset } // namespace arrow diff --git a/cpp/src/arrow/dataset/partition_test.cc b/cpp/src/arrow/dataset/partition_test.cc index 7bf7efa3761..0cdb5eb50b6 100644 --- a/cpp/src/arrow/dataset/partition_test.cc +++ b/cpp/src/arrow/dataset/partition_test.cc @@ -105,17 +105,20 @@ TEST_F(TestPartitioning, DirectoryPartitioning) { partitioning_ = std::make_shared( schema({field("alpha", int32()), field("beta", utf8())})); - AssertParse("/0/hello", "alpha"_ == int32_t(0) and "beta"_ == "hello"); - AssertParse("/3", "alpha"_ == int32_t(3)); + AssertParse("/0/hello", and_(equal(field_ref("alpha"), literal(0)), + equal(field_ref("beta"), literal("hello")))); + AssertParse("/3", equal(field_ref("alpha"), literal(3))); AssertParseError("/world/0"); // reversed order AssertParseError("/0.0/foo"); // invalid alpha AssertParseError("/3.25"); // invalid alpha with missing beta AssertParse("", literal(true)); // no segments to parse // gotcha someday: - AssertParse("/0/dat.parquet", "alpha"_ == int32_t(0) and "beta"_ == "dat.parquet"); + AssertParse("/0/dat.parquet", and_(equal(field_ref("alpha"), literal(0)), + equal(field_ref("beta"), literal("dat.parquet")))); - AssertParse("/0/foo/ignored=2341", "alpha"_ == int32_t(0) and "beta"_ == "foo"); + AssertParse("/0/foo/ignored=2341", and_(equal(field_ref("alpha"), literal(0)), + equal(field_ref("beta"), literal("foo")))); } TEST_F(TestPartitioning, DirectoryPartitioningFormat) { @@ -124,20 +127,28 @@ TEST_F(TestPartitioning, DirectoryPartitioningFormat) { written_schema_ = partitioning_->schema(); - AssertFormat("alpha"_ == int32_t(0) and "beta"_ == "hello", "0/hello"); - AssertFormat("beta"_ == "hello" and "alpha"_ == int32_t(0), "0/hello"); - AssertFormat("alpha"_ == int32_t(0), "0"); - AssertFormatError("beta"_ == "hello"); + AssertFormat(and_(equal(field_ref("alpha"), literal(0)), + equal(field_ref("beta"), literal("hello"))), + "0/hello"); + AssertFormat(and_(equal(field_ref("beta"), literal("hello")), + equal(field_ref("alpha"), literal(0))), + "0/hello"); + AssertFormat(equal(field_ref("alpha"), literal(0)), "0"); + AssertFormatError(equal(field_ref("beta"), literal("hello"))); AssertFormat(literal(true), ""); ASSERT_OK_AND_ASSIGN(written_schema_, written_schema_->AddField(0, field("gamma", utf8()))); - AssertFormat("gamma"_ == "yo" and "alpha"_ == int32_t(0) and "beta"_ == "hello", + AssertFormat(and_({equal(field_ref("gamma"), literal("yo")), + equal(field_ref("alpha"), literal(0)), + equal(field_ref("beta"), literal("hello"))}), "0/hello"); // written_schema_ is incompatible with partitioning_'s schema written_schema_ = schema({field("alpha", utf8()), field("beta", utf8())}); - AssertFormatError("alpha"_ == "0.0" and "beta"_ == "hello"); + AssertFormatError( + and_(equal(field_ref("alpha"), literal("0.0")), + equal(field_ref("beta"), literal("hello")))); } TEST_F(TestPartitioning, DirectoryPartitioningWithTemporal) { @@ -147,7 +158,9 @@ TEST_F(TestPartitioning, DirectoryPartitioningWithTemporal) { ASSERT_OK_AND_ASSIGN(auto day, StringScalar("2020-06-08").CastTo(temporal)); AssertParse("/2020/06/2020-06-08", - "year"_ == int32_t(2020) and "month"_ == int8_t(6) and "day"_ == day); + and_({equal(field_ref("year"), literal(2020)), + equal(field_ref("month"), literal(6)), + equal(field_ref("day"), literal(day))})); } } @@ -207,7 +220,7 @@ TEST_F(TestPartitioning, DictionaryHasUniqueValues) { std::make_shared(index_and_dictionary, alpha->type()); auto path = "/" + expected_dictionary->GetString(i); - AssertParse(path, "alpha"_ == dictionary_scalar); + AssertParse(path, equal(field_ref("alpha"), literal(dictionary_scalar))); } AssertParseError("/yosemite"); // not in inspected dictionary @@ -223,17 +236,21 @@ TEST_F(TestPartitioning, HivePartitioning) { partitioning_ = std::make_shared( schema({field("alpha", int32()), field("beta", float32())})); - AssertParse("/alpha=0/beta=3.25", "alpha"_ == int32_t(0) and "beta"_ == 3.25f); - AssertParse("/beta=3.25/alpha=0", "beta"_ == 3.25f and "alpha"_ == int32_t(0)); - AssertParse("/alpha=0", "alpha"_ == int32_t(0)); - AssertParse("/beta=3.25", "beta"_ == 3.25f); + AssertParse("/alpha=0/beta=3.25", and_(equal(field_ref("alpha"), literal(0)), + equal(field_ref("beta"), literal(3.25f)))); + AssertParse("/beta=3.25/alpha=0", and_(equal(field_ref("beta"), literal(3.25f)), + equal(field_ref("alpha"), literal(0)))); + AssertParse("/alpha=0", equal(field_ref("alpha"), literal(0))); + AssertParse("/beta=3.25", equal(field_ref("beta"), literal(3.25f))); AssertParse("", literal(true)); AssertParse("/alpha=0/unexpected/beta=3.25", - "alpha"_ == int32_t(0) and "beta"_ == 3.25f); + and_(equal(field_ref("alpha"), literal(0)), + equal(field_ref("beta"), literal(3.25f)))); AssertParse("/alpha=0/beta=3.25/ignored=2341", - "alpha"_ == int32_t(0) and "beta"_ == 3.25f); + and_(equal(field_ref("alpha"), literal(0)), + equal(field_ref("beta"), literal(3.25f)))); AssertParse("/ignored=2341", literal(true)); @@ -246,20 +263,28 @@ TEST_F(TestPartitioning, HivePartitioningFormat) { written_schema_ = partitioning_->schema(); - AssertFormat("alpha"_ == int32_t(0) and "beta"_ == 3.25f, "alpha=0/beta=3.25"); - AssertFormat("beta"_ == 3.25f and "alpha"_ == int32_t(0), "alpha=0/beta=3.25"); - AssertFormat("alpha"_ == int32_t(0), "alpha=0"); - AssertFormat("beta"_ == 3.25f, "alpha/beta=3.25"); + AssertFormat(and_(equal(field_ref("alpha"), literal(0)), + equal(field_ref("beta"), literal(3.25f))), + "alpha=0/beta=3.25"); + AssertFormat(and_(equal(field_ref("beta"), literal(3.25f)), + equal(field_ref("alpha"), literal(0))), + "alpha=0/beta=3.25"); + AssertFormat(equal(field_ref("alpha"), literal(0)), "alpha=0"); + AssertFormat(equal(field_ref("beta"), literal(3.25f)), "alpha/beta=3.25"); AssertFormat(literal(true), ""); ASSERT_OK_AND_ASSIGN(written_schema_, written_schema_->AddField(0, field("gamma", utf8()))); - AssertFormat("gamma"_ == "yo" and "alpha"_ == int32_t(0) and "beta"_ == 3.25f, + AssertFormat(and_({equal(field_ref("gamma"), literal("yo")), + equal(field_ref("alpha"), literal(0)), + equal(field_ref("beta"), literal(3.25f))}), "alpha=0/beta=3.25"); // written_schema_ is incompatible with partitioning_'s schema written_schema_ = schema({field("alpha", utf8()), field("beta", utf8())}); - AssertFormatError("alpha"_ == "0.0" and "beta"_ == "hello"); + AssertFormatError( + and_(equal(field_ref("alpha"), literal("0.0")), + equal(field_ref("beta"), literal("hello")))); } TEST_F(TestPartitioning, DiscoverHiveSchema) { @@ -326,7 +351,7 @@ TEST_F(TestPartitioning, HiveDictionaryHasUniqueValues) { std::make_shared(index_and_dictionary, alpha->type()); auto path = "/alpha=" + expected_dictionary->GetString(i); - AssertParse(path, "alpha"_ == dictionary_scalar); + AssertParse(path, equal(field_ref("alpha"), literal(dictionary_scalar))); } AssertParseError("/alpha=yosemite"); // not in inspected dictionary @@ -365,9 +390,12 @@ TEST_F(TestPartitioning, EtlThenHive) { }); AssertParse("/1999/12/31/00/alpha=0/beta=3.25", - "year"_ == int16_t(1999) and "month"_ == int8_t(12) and - "day"_ == int8_t(31) and "hour"_ == int8_t(0) and - ("alpha"_ == int32_t(0) and "beta"_ == 3.25f)); + and_({equal(field_ref("year"), literal(1999)), + equal(field_ref("month"), literal(12)), + equal(field_ref("day"), literal(31)), + equal(field_ref("hour"), literal(0)), + and_(equal(field_ref("alpha"), literal(0)), + equal(field_ref("beta"), literal(3.25f)))})); AssertParseError("/20X6/03/21/05/alpha=0/beta=3.25"); } @@ -382,7 +410,7 @@ TEST_F(TestPartitioning, Set) { auto schm = schema({field("x", int32())}); // An adhoc partitioning which parses segments like "/x in [1 4 5]" - // into ("x"_ == 1 or "x"_ == 4 or "x"_ == 5) + // into (field_ref("x") == 1 or field_ref("x") == 4 or field_ref("x") == 5) partitioning_ = std::make_shared( schm, [&](const std::string& path) -> Result { std::vector subexpressions; @@ -402,18 +430,21 @@ TEST_F(TestPartitioning, Set) { } subexpressions.push_back(call("is_in", {field_ref(matches[1])}, - compute::SetLookupOptions{ints(set), true})); + compute::SetLookupOptions{ints(set)})); } return and_(std::move(subexpressions)); }); - AssertParse("/x in [1]", "x"_.In(ints({1}))); - AssertParse("/x in [1 4 5]", "x"_.In(ints({1, 4, 5}))); - AssertParse("/x in []", "x"_.In(ints({}))); + auto x_in = [&](std::vector set) { + return call("is_in", {field_ref("x")}, compute::SetLookupOptions{ints(set)}); + }; + AssertParse("/x in [1]", x_in({1})); + AssertParse("/x in [1 4 5]", x_in({1, 4, 5})); + AssertParse("/x in []", x_in({})); } // An adhoc partitioning which parses segments like "/x=[-3.25, 0.0)" -// into ("x"_ >= -3.25 and "x" < 0.0) +// into (field_ref("x") >= -3.25 and "x" < 0.0) class RangePartitioning : public Partitioning { public: explicit RangePartitioning(std::shared_ptr s) : Partitioning(std::move(s)) {} @@ -477,8 +508,12 @@ TEST_F(TestPartitioning, Range) { schema({field("x", float64()), field("y", float64()), field("z", float64())})); AssertParse("/x=[-1.5 0.0)/y=[0.0 1.5)/z=(1.5 3.0]", - ("x"_ >= -1.5 and "x"_ < 0.0) and ("y"_ >= 0.0 and "y"_ < 1.5) and - ("z"_ > 1.5 and "z"_ <= 3.0)); + and_({and_(greater_equal(field_ref("x"), literal(-1.5)), + less(field_ref("x"), literal(0.0))), + and_(greater_equal(field_ref("y"), literal(0.0)), + less(field_ref("y"), literal(1.5))), + and_(greater(field_ref("z"), literal(1.5)), + less_equal(field_ref("z"), literal(3.0)))})); } TEST(TestStripPrefixAndFilename, Basic) { @@ -497,5 +532,57 @@ TEST(TestStripPrefixAndFilename, Basic) { "year=2019/month=12/day=01")); } +void AssertGrouping(const FieldVector& by_fields, const std::string& batch_json, + const std::string& expected_json) { + FieldVector fields_with_ids = by_fields; + fields_with_ids.push_back(field("ids", list(int32()))); + auto expected = ArrayFromJSON(struct_(fields_with_ids), expected_json); + + FieldVector fields_with_id = by_fields; + fields_with_id.push_back(field("id", int32())); + auto batch = RecordBatchFromJSON(schema(fields_with_id), batch_json); + + ASSERT_OK_AND_ASSIGN(auto by, batch->RemoveColumn(batch->num_columns() - 1) + .Map([](std::shared_ptr by) { + return by->ToStructArray(); + })); + + ASSERT_OK_AND_ASSIGN(auto groupings_and_values, MakeGroupings(*by)); + + auto groupings = + checked_pointer_cast(groupings_and_values->GetFieldByName("groupings")); + + ASSERT_OK_AND_ASSIGN(std::shared_ptr grouped_ids, + ApplyGroupings(*groupings, *batch->GetColumnByName("id"))); + + ArrayVector columns = + checked_cast(*groupings_and_values->GetFieldByName("values")) + .fields(); + columns.push_back(grouped_ids); + + ASSERT_OK_AND_ASSIGN(auto actual, StructArray::Make(columns, fields_with_ids)); + + AssertArraysEqual(*expected, *actual, /*verbose=*/true); +} + +TEST(GroupTest, Basics) { + AssertGrouping({field("a", utf8()), field("b", int32())}, R"([ + {"a": "ex", "b": 0, "id": 0}, + {"a": "ex", "b": 0, "id": 1}, + {"a": "why", "b": 0, "id": 2}, + {"a": "ex", "b": 1, "id": 3}, + {"a": "why", "b": 0, "id": 4}, + {"a": "ex", "b": 1, "id": 5}, + {"a": "ex", "b": 0, "id": 6}, + {"a": "why", "b": 1, "id": 7} + ])", + R"([ + {"a": "ex", "b": 0, "ids": [0, 1, 6]}, + {"a": "why", "b": 0, "ids": [2, 4]}, + {"a": "ex", "b": 1, "ids": [3, 5]}, + {"a": "why", "b": 1, "ids": [7]} + ])"); +} + } // namespace dataset } // namespace arrow diff --git a/cpp/src/arrow/dataset/scanner.cc b/cpp/src/arrow/dataset/scanner.cc index 043f6695222..809d83b7e2b 100644 --- a/cpp/src/arrow/dataset/scanner.cc +++ b/cpp/src/arrow/dataset/scanner.cc @@ -23,7 +23,6 @@ #include "arrow/dataset/dataset.h" #include "arrow/dataset/dataset_internal.h" -#include "arrow/dataset/filter.h" #include "arrow/dataset/scanner_internal.h" #include "arrow/table.h" #include "arrow/util/iterator.h" diff --git a/cpp/src/arrow/dataset/scanner_internal.h b/cpp/src/arrow/dataset/scanner_internal.h index 686be1e6dfb..eb11f326570 100644 --- a/cpp/src/arrow/dataset/scanner_internal.h +++ b/cpp/src/arrow/dataset/scanner_internal.h @@ -24,7 +24,6 @@ #include "arrow/compute/api_vector.h" #include "arrow/compute/exec.h" #include "arrow/dataset/dataset_internal.h" -#include "arrow/dataset/filter.h" #include "arrow/dataset/partition.h" #include "arrow/dataset/scanner.h" diff --git a/cpp/src/arrow/dataset/test_util.h b/cpp/src/arrow/dataset/test_util.h index 44d211260b6..76015fa365c 100644 --- a/cpp/src/arrow/dataset/test_util.h +++ b/cpp/src/arrow/dataset/test_util.h @@ -33,7 +33,6 @@ #include "arrow/dataset/dataset_internal.h" #include "arrow/dataset/discovery.h" #include "arrow/dataset/file_base.h" -#include "arrow/dataset/filter.h" #include "arrow/filesystem/localfs.h" #include "arrow/filesystem/mockfs.h" #include "arrow/filesystem/path_util.h" @@ -59,37 +58,9 @@ const std::shared_ptr kBoringSchema = schema({ field("bool", boolean()), field("str", utf8()), field("dict_str", dictionary(int32(), utf8())), + field("ts_ns", timestamp(TimeUnit::NANO)), }); -inline namespace string_literals { -// clang-format off -inline FieldExpression operator"" _(const char* name, size_t name_length) { - // clang-format on - return FieldExpression({name, name_length}); -} -} // namespace string_literals - -#define COMPARISON_FACTORY(NAME, FACTORY_NAME, OP) \ - template ::type>::value>::type> \ - ComparisonExpression operator OP(const Expression& lhs, T&& rhs) { \ - return ComparisonExpression(CompareOperator::NAME, lhs.Copy(), \ - scalar(std::forward(rhs))); \ - } \ - \ - inline ComparisonExpression operator OP(const Expression& lhs, \ - const Expression& rhs) { \ - return ComparisonExpression(CompareOperator::NAME, lhs.Copy(), rhs.Copy()); \ - } - -COMPARISON_FACTORY(EQUAL, equal, ==) -COMPARISON_FACTORY(NOT_EQUAL, not_equal, !=) -COMPARISON_FACTORY(GREATER, greater, >) -COMPARISON_FACTORY(GREATER_EQUAL, greater_equal, >=) -COMPARISON_FACTORY(LESS, less, <) -COMPARISON_FACTORY(LESS_EQUAL, less_equal, <=) -#undef COMPARISON_FACTORY - using fs::internal::GetAbstractPathExtension; using internal::checked_cast; using internal::checked_pointer_cast; diff --git a/cpp/src/arrow/dataset/type_fwd.h b/cpp/src/arrow/dataset/type_fwd.h index 81441a18166..a80b333f0c8 100644 --- a/cpp/src/arrow/dataset/type_fwd.h +++ b/cpp/src/arrow/dataset/type_fwd.h @@ -68,7 +68,6 @@ class ParquetFileFragment; class ParquetFileWriter; class ParquetFileWriteOptions; -class Expression; class Expression2; class Partitioning; From e179f4275875da32bc6b016b3e778a5af5c88e07 Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Thu, 10 Dec 2020 14:05:32 -0500 Subject: [PATCH 06/31] rename Expression2 -> Expression --- .../arrow/dataset-parquet-scan-example.cc | 4 +- cpp/src/arrow/dataset/dataset.cc | 14 +- cpp/src/arrow/dataset/dataset.h | 24 +-- cpp/src/arrow/dataset/dataset_internal.h | 2 +- cpp/src/arrow/dataset/discovery.h | 6 +- cpp/src/arrow/dataset/expression.cc | 180 +++++++++--------- cpp/src/arrow/dataset/expression.h | 94 ++++----- cpp/src/arrow/dataset/expression_internal.h | 42 ++-- cpp/src/arrow/dataset/expression_test.cc | 92 ++++----- cpp/src/arrow/dataset/file_base.cc | 10 +- cpp/src/arrow/dataset/file_base.h | 12 +- cpp/src/arrow/dataset/file_parquet.cc | 16 +- cpp/src/arrow/dataset/file_parquet.h | 14 +- cpp/src/arrow/dataset/file_parquet_test.cc | 4 +- cpp/src/arrow/dataset/file_test.cc | 8 +- cpp/src/arrow/dataset/partition.cc | 18 +- cpp/src/arrow/dataset/partition.h | 22 +-- cpp/src/arrow/dataset/partition_test.cc | 20 +- cpp/src/arrow/dataset/scanner.cc | 2 +- cpp/src/arrow/dataset/scanner.h | 4 +- cpp/src/arrow/dataset/scanner_internal.h | 10 +- cpp/src/arrow/dataset/test_util.h | 12 +- cpp/src/arrow/dataset/type_fwd.h | 2 +- 23 files changed, 306 insertions(+), 306 deletions(-) diff --git a/cpp/examples/arrow/dataset-parquet-scan-example.cc b/cpp/examples/arrow/dataset-parquet-scan-example.cc index 46778cebaa5..197ca5aa4c6 100644 --- a/cpp/examples/arrow/dataset-parquet-scan-example.cc +++ b/cpp/examples/arrow/dataset-parquet-scan-example.cc @@ -60,7 +60,7 @@ struct Configuration { // Indicates the filter by which rows will be filtered. This optimization can // make use of partition information and/or file metadata if possible. - ds::Expression2 filter = + ds::Expression filter = ds::greater(ds::field_ref("total_amount"), ds::literal(1000.0f)); ds::InspectOptions inspect_options{}; @@ -146,7 +146,7 @@ std::shared_ptr GetDatasetFromPath( std::shared_ptr GetScannerFromDataset(std::shared_ptr dataset, std::vector columns, - ds::Expression2 filter, + ds::Expression filter, bool use_threads) { auto scanner_builder = dataset->NewScan().ValueOrDie(); diff --git a/cpp/src/arrow/dataset/dataset.cc b/cpp/src/arrow/dataset/dataset.cc index bedecf1e3a5..e2386dddec7 100644 --- a/cpp/src/arrow/dataset/dataset.cc +++ b/cpp/src/arrow/dataset/dataset.cc @@ -31,7 +31,7 @@ namespace arrow { namespace dataset { -Fragment::Fragment(Expression2 partition_expression, +Fragment::Fragment(Expression partition_expression, std::shared_ptr physical_schema) : partition_expression_(std::move(partition_expression)), physical_schema_(std::move(physical_schema)) {} @@ -58,14 +58,14 @@ Result> InMemoryFragment::ReadPhysicalSchemaImpl() { InMemoryFragment::InMemoryFragment(std::shared_ptr schema, RecordBatchVector record_batches, - Expression2 partition_expression) + Expression partition_expression) : Fragment(std::move(partition_expression), std::move(schema)), record_batches_(std::move(record_batches)) { DCHECK_NE(physical_schema_, nullptr); } InMemoryFragment::InMemoryFragment(RecordBatchVector record_batches, - Expression2 partition_expression) + Expression partition_expression) : InMemoryFragment(record_batches.empty() ? schema({}) : record_batches[0]->schema(), std::move(record_batches), std::move(partition_expression)) {} @@ -92,7 +92,7 @@ Result InMemoryFragment::Scan(std::shared_ptr opt return MakeMapIterator(fn, std::move(batches_it)); } -Dataset::Dataset(std::shared_ptr schema, Expression2 partition_expression) +Dataset::Dataset(std::shared_ptr schema, Expression partition_expression) : schema_(std::move(schema)), partition_expression_(std::move(partition_expression)) {} @@ -110,7 +110,7 @@ Result Dataset::GetFragments() { return GetFragments(std::move(predicate)); } -Result Dataset::GetFragments(Expression2 predicate) { +Result Dataset::GetFragments(Expression predicate) { ARROW_ASSIGN_OR_RAISE( predicate, SimplifyWithGuarantee(std::move(predicate), partition_expression_)); return predicate.IsSatisfiable() ? GetFragmentsImpl(std::move(predicate)) @@ -154,7 +154,7 @@ Result> InMemoryDataset::ReplaceSchema( return std::make_shared(std::move(schema), get_batches_); } -Result InMemoryDataset::GetFragmentsImpl(Expression2) { +Result InMemoryDataset::GetFragmentsImpl(Expression) { auto schema = this->schema(); auto create_fragment = @@ -195,7 +195,7 @@ Result> UnionDataset::ReplaceSchema( new UnionDataset(std::move(schema), std::move(children))); } -Result UnionDataset::GetFragmentsImpl(Expression2 predicate) { +Result UnionDataset::GetFragmentsImpl(Expression predicate) { return GetFragmentsFromDatasets(children_, predicate); } diff --git a/cpp/src/arrow/dataset/dataset.h b/cpp/src/arrow/dataset/dataset.h index 88782831670..4ad6ecbe74a 100644 --- a/cpp/src/arrow/dataset/dataset.h +++ b/cpp/src/arrow/dataset/dataset.h @@ -72,19 +72,19 @@ class ARROW_DS_EXPORT Fragment { /// \brief An expression which evaluates to true for all data viewed by this /// Fragment. - const Expression2& partition_expression() const { return partition_expression_; } + const Expression& partition_expression() const { return partition_expression_; } virtual ~Fragment() = default; protected: Fragment() = default; - explicit Fragment(Expression2 partition_expression, + explicit Fragment(Expression partition_expression, std::shared_ptr physical_schema); virtual Result> ReadPhysicalSchemaImpl() = 0; util::Mutex physical_schema_mutex_; - Expression2 partition_expression_ = literal(true); + Expression partition_expression_ = literal(true); std::shared_ptr physical_schema_; }; @@ -93,9 +93,9 @@ class ARROW_DS_EXPORT Fragment { class ARROW_DS_EXPORT InMemoryFragment : public Fragment { public: InMemoryFragment(std::shared_ptr schema, RecordBatchVector record_batches, - Expression2 = literal(true)); + Expression = literal(true)); explicit InMemoryFragment(RecordBatchVector record_batches, - Expression2 = literal(true)); + Expression = literal(true)); Result Scan(std::shared_ptr options, std::shared_ptr context) override; @@ -122,14 +122,14 @@ class ARROW_DS_EXPORT Dataset : public std::enable_shared_from_this { Result> NewScan(); /// \brief GetFragments returns an iterator of Fragments given a predicate. - Result GetFragments(Expression2 predicate); + Result GetFragments(Expression predicate); Result GetFragments(); const std::shared_ptr& schema() const { return schema_; } /// \brief An expression which evaluates to true for all data viewed by this Dataset. /// May be null, which indicates no information is available. - const Expression2& partition_expression() const { return partition_expression_; } + const Expression& partition_expression() const { return partition_expression_; } /// \brief The name identifying the kind of Dataset virtual std::string type_name() const = 0; @@ -146,12 +146,12 @@ class ARROW_DS_EXPORT Dataset : public std::enable_shared_from_this { protected: explicit Dataset(std::shared_ptr schema) : schema_(std::move(schema)) {} - Dataset(std::shared_ptr schema, Expression2 partition_expression); + Dataset(std::shared_ptr schema, Expression partition_expression); - virtual Result GetFragmentsImpl(Expression2 predicate) = 0; + virtual Result GetFragmentsImpl(Expression predicate) = 0; std::shared_ptr schema_; - Expression2 partition_expression_ = literal(true); + Expression partition_expression_ = literal(true); }; /// \brief A Source which yields fragments wrapping a stream of record batches. @@ -181,7 +181,7 @@ class ARROW_DS_EXPORT InMemoryDataset : public Dataset { protected: Result GetFragmentsImpl( - Expression2 predicate) override; + Expression predicate) override; std::shared_ptr get_batches_; }; @@ -206,7 +206,7 @@ class ARROW_DS_EXPORT UnionDataset : public Dataset { protected: Result GetFragmentsImpl( - Expression2 predicate) override; + Expression predicate) override; explicit UnionDataset(std::shared_ptr schema, DatasetVector children) : Dataset(std::move(schema)), children_(std::move(children)) {} diff --git a/cpp/src/arrow/dataset/dataset_internal.h b/cpp/src/arrow/dataset/dataset_internal.h index 9e070192cd0..cb6d406fb70 100644 --- a/cpp/src/arrow/dataset/dataset_internal.h +++ b/cpp/src/arrow/dataset/dataset_internal.h @@ -36,7 +36,7 @@ namespace dataset { /// \brief GetFragmentsFromDatasets transforms a vector into a /// flattened FragmentIterator. inline Result GetFragmentsFromDatasets(const DatasetVector& datasets, - Expression2 predicate) { + Expression predicate) { // Iterator auto datasets_it = MakeVectorIterator(datasets); diff --git a/cpp/src/arrow/dataset/discovery.h b/cpp/src/arrow/dataset/discovery.h index 0b1c94f5adc..b7786cd305f 100644 --- a/cpp/src/arrow/dataset/discovery.h +++ b/cpp/src/arrow/dataset/discovery.h @@ -90,8 +90,8 @@ class ARROW_DS_EXPORT DatasetFactory { virtual Result> Finish(FinishOptions options) = 0; /// \brief Optional root partition for the resulting Dataset. - const Expression2& root_partition() const { return root_partition_; } - Status SetRootPartition(Expression2 partition) { + const Expression& root_partition() const { return root_partition_; } + Status SetRootPartition(Expression partition) { root_partition_ = std::move(partition); return Status::OK(); } @@ -101,7 +101,7 @@ class ARROW_DS_EXPORT DatasetFactory { protected: DatasetFactory(); - Expression2 root_partition_; + Expression root_partition_; }; /// \brief DatasetFactory provides a way to inspect/discover a Dataset's diff --git a/cpp/src/arrow/dataset/expression.cc b/cpp/src/arrow/dataset/expression.cc index f1b17752bf2..6ff53d4190c 100644 --- a/cpp/src/arrow/dataset/expression.cc +++ b/cpp/src/arrow/dataset/expression.cc @@ -39,17 +39,17 @@ using internal::checked_pointer_cast; namespace dataset { -const Expression2::Call* Expression2::call() const { +const Expression::Call* Expression::call() const { return util::get_if(impl_.get()); } -const Datum* Expression2::literal() const { return util::get_if(impl_.get()); } +const Datum* Expression::literal() const { return util::get_if(impl_.get()); } -const FieldRef* Expression2::field_ref() const { +const FieldRef* Expression::field_ref() const { return util::get_if(impl_.get()); } -std::string Expression2::ToString() const { +std::string Expression::ToString() const { if (auto lit = literal()) { if (lit->is_scalar()) { return lit->scalar()->ToString(); @@ -118,14 +118,14 @@ std::string Expression2::ToString() const { return out; } -void PrintTo(const Expression2& expr, std::ostream* os) { +void PrintTo(const Expression& expr, std::ostream* os) { *os << expr.ToString(); if (expr.IsBound()) { *os << "[bound]"; } } -bool Expression2::Equals(const Expression2& other) const { +bool Expression::Equals(const Expression& other) const { if (Identical(*this, other)) return true; if (impl_->index() != other.impl_->index()) { @@ -192,7 +192,7 @@ bool Expression2::Equals(const Expression2& other) const { return false; } -size_t Expression2::hash() const { +size_t Expression::hash() const { if (auto lit = literal()) { if (lit->is_scalar()) { return Scalar::Hash::hash(*lit->scalar()); @@ -213,7 +213,7 @@ size_t Expression2::hash() const { return out; } -bool Expression2::IsBound() const { +bool Expression::IsBound() const { if (descr_.type == nullptr) return false; if (auto lit = literal()) return true; @@ -222,14 +222,14 @@ bool Expression2::IsBound() const { auto call = CallNotNull(*this); - for (const Expression2& arg : call->arguments) { + for (const Expression& arg : call->arguments) { if (!arg.IsBound()) return false; } return call->kernel != nullptr; } -bool Expression2::IsScalarExpression() const { +bool Expression::IsScalarExpression() const { if (auto lit = literal()) { return lit->is_scalar(); } @@ -239,7 +239,7 @@ bool Expression2::IsScalarExpression() const { auto call = CallNotNull(*this); - for (const Expression2& arg : call->arguments) { + for (const Expression& arg : call->arguments) { if (!arg.IsScalarExpression()) return false; } @@ -259,7 +259,7 @@ bool Expression2::IsScalarExpression() const { return call->function_kind == compute::Function::SCALAR; } -bool Expression2::IsNullLiteral() const { +bool Expression::IsNullLiteral() const { if (auto lit = literal()) { if (lit->null_count() == lit->length()) { return true; @@ -269,7 +269,7 @@ bool Expression2::IsNullLiteral() const { return false; } -bool Expression2::IsSatisfiable() const { +bool Expression::IsSatisfiable() const { if (descr_.type && descr_.type->id() == Type::NA) { return false; } @@ -303,7 +303,7 @@ inline bool KernelStateIsImmutable(const std::string& function) { } Result> InitKernelState( - const Expression2::Call& call, compute::ExecContext* exec_context) { + const Expression::Call& call, compute::ExecContext* exec_context) { if (!call.kernel->init) return nullptr; compute::KernelContext kernel_context(exec_context); @@ -314,7 +314,7 @@ Result> InitKernelState( return std::move(kernel_state); } -Status MaybeInsertCast(std::shared_ptr to_type, Expression2* expr) { +Status MaybeInsertCast(std::shared_ptr to_type, Expression* expr) { if (expr->descr().type->Equals(to_type)) { return Status::OK(); } @@ -334,15 +334,15 @@ Status MaybeInsertCast(std::shared_ptr to_type, Expression2* expr) { auto call_with_cast = *CallNotNull(with_cast); call_with_cast.arguments[0] = std::move(*expr); - *expr = Expression2(std::make_shared(std::move(call_with_cast)), + *expr = Expression(std::make_shared(std::move(call_with_cast)), ValueDescr{std::move(to_type), expr->descr().shape}); return Status::OK(); } -Status InsertImplicitCasts(Expression2::Call* call) { +Status InsertImplicitCasts(Expression::Call* call) { DCHECK(std::all_of(call->arguments.begin(), call->arguments.end(), - [](const Expression2& argument) { return argument.IsBound(); })); + [](const Expression& argument) { return argument.IsBound(); })); if (Comparison::Get(call->function)) { for (auto&& argument : call->arguments) { @@ -389,7 +389,7 @@ Status InsertImplicitCasts(Expression2::Call* call) { return Status::OK(); } -Result Expression2::Bind(ValueDescr in, +Result Expression::Bind(ValueDescr in, compute::ExecContext* exec_context) const { if (exec_context == nullptr) { compute::ExecContext exec_context; @@ -426,15 +426,15 @@ Result Expression2::Bind(ValueDescr in, ARROW_ASSIGN_OR_RAISE(auto descr, bound_call.kernel->signature->out_type().Resolve( &kernel_context, descrs)); - return Expression2(std::make_shared(std::move(bound_call)), std::move(descr)); + return Expression(std::make_shared(std::move(bound_call)), std::move(descr)); } -Result Expression2::Bind(const Schema& in_schema, +Result Expression::Bind(const Schema& in_schema, compute::ExecContext* exec_context) const { return Bind(ValueDescr::Array(struct_(in_schema.fields())), exec_context); } -Result ExecuteScalarExpression(const Expression2& expr, const Datum& input, +Result ExecuteScalarExpression(const Expression& expr, const Datum& input, compute::ExecContext* exec_context) { if (exec_context == nullptr) { compute::ExecContext exec_context; @@ -489,12 +489,12 @@ Result ExecuteScalarExpression(const Expression2& expr, const Datum& inpu return executor->WrapResults(arguments, listener->values()); } -std::array, 2> -ArgumentsAndFlippedArguments(const Expression2::Call& call) { +std::array, 2> +ArgumentsAndFlippedArguments(const Expression::Call& call) { DCHECK_EQ(call.arguments.size(), 2); - return {std::pair{call.arguments[0], + return {std::pair{call.arguments[0], call.arguments[1]}, - std::pair{call.arguments[1], + std::pair{call.arguments[1], call.arguments[0]}}; } @@ -511,14 +511,14 @@ util::optional FoldLeft(It begin, It end, const BinOp& bin_op) { } util::optional GetNullHandling( - const Expression2::Call& call) { + const Expression::Call& call) { if (call.function_kind == compute::Function::SCALAR) { return static_cast(call.kernel)->null_handling; } return util::nullopt; } -bool DefinitelyNotNull(const Expression2& expr) { +bool DefinitelyNotNull(const Expression& expr) { DCHECK(expr.IsBound()); if (expr.literal()) { @@ -541,7 +541,7 @@ bool DefinitelyNotNull(const Expression2& expr) { return false; } -std::vector FieldsInExpression(const Expression2& expr) { +std::vector FieldsInExpression(const Expression& expr) { if (auto lit = expr.literal()) return {}; if (auto ref = expr.field_ref()) { @@ -549,20 +549,20 @@ std::vector FieldsInExpression(const Expression2& expr) { } std::vector fields; - for (const Expression2& arg : CallNotNull(expr)->arguments) { + for (const Expression& arg : CallNotNull(expr)->arguments) { auto argument_fields = FieldsInExpression(arg); std::move(argument_fields.begin(), argument_fields.end(), std::back_inserter(fields)); } return fields; } -Result FoldConstants(Expression2 expr) { +Result FoldConstants(Expression expr) { return Modify( - std::move(expr), [](Expression2 expr) { return expr; }, - [](Expression2 expr, ...) -> Result { + std::move(expr), [](Expression expr) { return expr; }, + [](Expression expr, ...) -> Result { auto call = CallNotNull(expr); if (std::all_of(call->arguments.begin(), call->arguments.end(), - [](const Expression2& argument) { return argument.literal(); })) { + [](const Expression& argument) { return argument.literal(); })) { // all arguments are literal; we can evaluate this subexpression *now* static const Datum ignored_input; ARROW_ASSIGN_OR_RAISE(Datum constant, @@ -616,8 +616,8 @@ Result FoldConstants(Expression2 expr) { }); } -inline std::vector GuaranteeConjunctionMembers( - const Expression2& guaranteed_true_predicate) { +inline std::vector GuaranteeConjunctionMembers( + const Expression& guaranteed_true_predicate) { auto guarantee = guaranteed_true_predicate.call(); if (!guarantee || guarantee->function != "and_kleene") { return {guaranteed_true_predicate}; @@ -628,11 +628,11 @@ inline std::vector GuaranteeConjunctionMembers( // Conjunction members which are represented in known_values are erased from // conjunction_members Status ExtractKnownFieldValuesImpl( - std::vector* conjunction_members, + std::vector* conjunction_members, std::unordered_map* known_values) { auto unconsumed_end = std::partition(conjunction_members->begin(), conjunction_members->end(), - [](const Expression2& expr) { + [](const Expression& expr) { // search for an equality conditions between a field and a literal auto call = expr.call(); if (!call) return true; @@ -670,24 +670,24 @@ Status ExtractKnownFieldValuesImpl( } Result> ExtractKnownFieldValues( - const Expression2& guaranteed_true_predicate) { + const Expression& guaranteed_true_predicate) { auto conjunction_members = GuaranteeConjunctionMembers(guaranteed_true_predicate); std::unordered_map known_values; RETURN_NOT_OK(ExtractKnownFieldValuesImpl(&conjunction_members, &known_values)); return known_values; } -Result ReplaceFieldsWithKnownValues( +Result ReplaceFieldsWithKnownValues( const std::unordered_map& known_values, - Expression2 expr) { + Expression expr) { if (!expr.IsBound()) { return Status::Invalid( - "ReplaceFieldsWithKnownValues called on an unbound Expression2"); + "ReplaceFieldsWithKnownValues called on an unbound Expression"); } return Modify( std::move(expr), - [&known_values](Expression2 expr) -> Result { + [&known_values](Expression expr) -> Result { if (auto ref = expr.field_ref()) { auto it = known_values.find(*ref); if (it != known_values.end()) { @@ -698,10 +698,10 @@ Result ReplaceFieldsWithKnownValues( } return expr; }, - [](Expression2 expr, ...) { return expr; }); + [](Expression expr, ...) { return expr; }); } -inline bool IsBinaryAssociativeCommutative(const Expression2::Call& call) { +inline bool IsBinaryAssociativeCommutative(const Expression::Call& call) { static std::unordered_set binary_associative_commutative{ "and", "or", "and_kleene", "or_kleene", "xor", "multiply", "add", "multiply_checked", "add_checked"}; @@ -710,7 +710,7 @@ inline bool IsBinaryAssociativeCommutative(const Expression2::Call& call) { return it != binary_associative_commutative.end(); } -Result Canonicalize(Expression2 expr, compute::ExecContext* exec_context) { +Result Canonicalize(Expression expr, compute::ExecContext* exec_context) { if (exec_context == nullptr) { compute::ExecContext exec_context; return Canonicalize(std::move(expr), &exec_context); @@ -720,20 +720,20 @@ Result Canonicalize(Expression2 expr, compute::ExecContext* exec_co // (for example, when reorganizing an associative chain), add expressions to this set to // avoid unnecessary work struct { - std::unordered_set set_; + std::unordered_set set_; - bool operator()(const Expression2& expr) const { + bool operator()(const Expression& expr) const { return set_.find(expr) != set_.end(); } - void Add(std::vector exprs) { + void Add(std::vector exprs) { std::move(exprs.begin(), exprs.end(), std::inserter(set_, set_.end())); } } AlreadyCanonicalized; return Modify( std::move(expr), - [&AlreadyCanonicalized, exec_context](Expression2 expr) -> Result { + [&AlreadyCanonicalized, exec_context](Expression expr) -> Result { auto call = expr.call(); if (!call) return expr; @@ -741,13 +741,13 @@ Result Canonicalize(Expression2 expr, compute::ExecContext* exec_co if (IsBinaryAssociativeCommutative(*call)) { struct { - int Priority(const Expression2& operand) const { + int Priority(const Expression& operand) const { // order literals first, starting with nulls if (operand.IsNullLiteral()) return 0; if (operand.literal()) return 1; return 2; } - bool operator()(const Expression2& l, const Expression2& r) const { + bool operator()(const Expression& l, const Expression& r) const { return Priority(l) < Priority(r); } } CanonicalOrdering; @@ -766,10 +766,10 @@ Result Canonicalize(Expression2 expr, compute::ExecContext* exec_co const auto& descr = expr.descr(); auto folded = FoldLeft( chain.fringe.begin(), chain.fringe.end(), - [call, &descr, &AlreadyCanonicalized](Expression2 l, Expression2 r) { + [call, &descr, &AlreadyCanonicalized](Expression l, Expression r) { auto ret = *call; ret.arguments = {std::move(l), std::move(r)}; - Expression2 expr(std::make_shared(std::move(ret)), + Expression expr(std::make_shared(std::move(ret)), descr); AlreadyCanonicalized.Add({expr}); return expr; @@ -792,22 +792,22 @@ Result Canonicalize(Expression2 expr, compute::ExecContext* exec_co ARROW_ASSIGN_OR_RAISE(flipped_call.kernel, function->DispatchExact(descrs)); std::swap(flipped_call.arguments[0], flipped_call.arguments[1]); - return Expression2( - std::make_shared(std::move(flipped_call)), + return Expression( + std::make_shared(std::move(flipped_call)), expr.descr()); } } return expr; }, - [](Expression2 expr, ...) { return expr; }); + [](Expression expr, ...) { return expr; }); } -Result DirectComparisonSimplification(Expression2 expr, - const Expression2::Call& guarantee) { +Result DirectComparisonSimplification(Expression expr, + const Expression::Call& guarantee) { return Modify( - std::move(expr), [](Expression2 expr) { return expr; }, - [&guarantee](Expression2 expr, ...) -> Result { + std::move(expr), [](Expression expr) { return expr; }, + [&guarantee](Expression expr, ...) -> Result { auto call = expr.call(); if (!call) return expr; @@ -861,8 +861,8 @@ Result DirectComparisonSimplification(Expression2 expr, }); } -Result SimplifyWithGuarantee(Expression2 expr, - const Expression2& guaranteed_true_predicate) { +Result SimplifyWithGuarantee(Expression expr, + const Expression& guaranteed_true_predicate) { auto conjunction_members = GuaranteeConjunctionMembers(guaranteed_true_predicate); std::unordered_map known_values; @@ -896,7 +896,7 @@ Result SimplifyWithGuarantee(Expression2 expr, // Serialization is accomplished by converting expressions to KeyValueMetadata and storing // this in the schema of a RecordBatch. Embedded arrays and scalars are stored in its // columns. Finally, the RecordBatch is written to an IPC file. -Result> Serialize(const Expression2& expr) { +Result> Serialize(const Expression& expr) { struct { std::shared_ptr metadata_ = std::make_shared(); ArrayVector columns_; @@ -908,7 +908,7 @@ Result> Serialize(const Expression2& expr) { return std::to_string(ret); } - Status Visit(const Expression2& expr) { + Status Visit(const Expression& expr) { if (auto lit = expr.literal()) { if (!lit->is_scalar()) { return Status::NotImplemented("Serialization of non-scalar literals"); @@ -943,7 +943,7 @@ Result> Serialize(const Expression2& expr) { return Status::OK(); } - Result> operator()(const Expression2& expr) { + Result> operator()(const Expression& expr) { RETURN_NOT_OK(Visit(expr)); FieldVector fields(columns_.size()); for (size_t i = 0; i < fields.size(); ++i) { @@ -962,16 +962,16 @@ Result> Serialize(const Expression2& expr) { return stream->Finish(); } -Result Deserialize(const Buffer& buffer) { +Result Deserialize(const Buffer& buffer) { io::BufferReader stream(buffer); ARROW_ASSIGN_OR_RAISE(auto reader, ipc::RecordBatchFileReader::Open(&stream)); ARROW_ASSIGN_OR_RAISE(auto batch, reader->ReadRecordBatch(0)); if (batch->schema()->metadata() == nullptr) { - return Status::Invalid("serialized Expression2's batch repr had null metadata"); + return Status::Invalid("serialized Expression's batch repr had null metadata"); } if (batch->num_rows() != 1) { return Status::Invalid( - "serialized Expression2's batch repr was not a single row - had ", + "serialized Expression's batch repr was not a single row - had ", batch->num_rows()); } @@ -992,9 +992,9 @@ Result Deserialize(const Buffer& buffer) { return batch_.column(column_index)->GetScalar(0); } - Result GetOne() { + Result GetOne() { if (index_ >= metadata().size()) { - return Status::Invalid("unterminated serialized Expression2"); + return Status::Invalid("unterminated serialized Expression"); } const std::string& key = metadata().key(index_); @@ -1011,17 +1011,17 @@ Result Deserialize(const Buffer& buffer) { } if (key != "call") { - return Status::Invalid("Unrecognized serialized Expression2 key ", key); + return Status::Invalid("Unrecognized serialized Expression key ", key); } - std::vector arguments; + std::vector arguments; while (metadata().key(index_) != "end") { if (metadata().key(index_) == "options") { ARROW_ASSIGN_OR_RAISE(auto options_scalar, GetScalar(metadata().value(index_))); auto expr = call(value, std::move(arguments)); RETURN_NOT_OK(FunctionOptionsFromStructScalar( checked_cast(options_scalar.get()), - const_cast(expr.call()))); + const_cast(expr.call()))); index_ += 2; return expr; } @@ -1038,40 +1038,40 @@ Result Deserialize(const Buffer& buffer) { return FromRecordBatch{*batch, 0}.GetOne(); } -Expression2 project(std::vector values, std::vector names) { +Expression project(std::vector values, std::vector names) { return call("struct", std::move(values), compute::StructOptions{std::move(names)}); } -Expression2 equal(Expression2 lhs, Expression2 rhs) { +Expression equal(Expression lhs, Expression rhs) { return call("equal", {std::move(lhs), std::move(rhs)}); } -Expression2 not_equal(Expression2 lhs, Expression2 rhs) { +Expression not_equal(Expression lhs, Expression rhs) { return call("not_equal", {std::move(lhs), std::move(rhs)}); } -Expression2 less(Expression2 lhs, Expression2 rhs) { +Expression less(Expression lhs, Expression rhs) { return call("less", {std::move(lhs), std::move(rhs)}); } -Expression2 less_equal(Expression2 lhs, Expression2 rhs) { +Expression less_equal(Expression lhs, Expression rhs) { return call("less_equal", {std::move(lhs), std::move(rhs)}); } -Expression2 greater(Expression2 lhs, Expression2 rhs) { +Expression greater(Expression lhs, Expression rhs) { return call("greater", {std::move(lhs), std::move(rhs)}); } -Expression2 greater_equal(Expression2 lhs, Expression2 rhs) { +Expression greater_equal(Expression lhs, Expression rhs) { return call("greater_equal", {std::move(lhs), std::move(rhs)}); } -Expression2 and_(Expression2 lhs, Expression2 rhs) { +Expression and_(Expression lhs, Expression rhs) { return call("and_kleene", {std::move(lhs), std::move(rhs)}); } -Expression2 and_(const std::vector& operands) { - auto folded = FoldLeft(operands.begin(), +Expression and_(const std::vector& operands) { + auto folded = FoldLeft(operands.begin(), operands.end(), and_); if (folded) { return std::move(*folded); @@ -1079,12 +1079,12 @@ Expression2 and_(const std::vector& operands) { return literal(true); } -Expression2 or_(Expression2 lhs, Expression2 rhs) { +Expression or_(Expression lhs, Expression rhs) { return call("or_kleene", {std::move(lhs), std::move(rhs)}); } -Expression2 or_(const std::vector& operands) { - auto folded = FoldLeft(operands.begin(), +Expression or_(const std::vector& operands) { + auto folded = FoldLeft(operands.begin(), operands.end(), or_); if (folded) { return std::move(*folded); @@ -1092,13 +1092,13 @@ Expression2 or_(const std::vector& operands) { return literal(false); } -Expression2 not_(Expression2 operand) { return call("invert", {std::move(operand)}); } +Expression not_(Expression operand) { return call("invert", {std::move(operand)}); } -Expression2 operator&&(Expression2 lhs, Expression2 rhs) { +Expression operator&&(Expression lhs, Expression rhs) { return and_(std::move(lhs), std::move(rhs)); } -Expression2 operator||(Expression2 lhs, Expression2 rhs) { +Expression operator||(Expression lhs, Expression rhs) { return or_(std::move(lhs), std::move(rhs)); } diff --git a/cpp/src/arrow/dataset/expression.h b/cpp/src/arrow/dataset/expression.h index e6495433d23..a9ec5403068 100644 --- a/cpp/src/arrow/dataset/expression.h +++ b/cpp/src/arrow/dataset/expression.h @@ -45,11 +45,11 @@ namespace dataset { /// - A literal Datum. /// - A reference to a single (potentially nested) field of the input Datum. /// - A call to a compute function, with arguments specified by other Expressions. -class ARROW_DS_EXPORT Expression2 { +class ARROW_DS_EXPORT Expression { public: struct Call { std::string function; - std::vector arguments; + std::vector arguments; std::shared_ptr options; // post-Bind properties: @@ -60,24 +60,24 @@ class ARROW_DS_EXPORT Expression2 { }; std::string ToString() const; - bool Equals(const Expression2& other) const; + bool Equals(const Expression& other) const; size_t hash() const; struct Hash { - size_t operator()(const Expression2& expr) const { return expr.hash(); } + size_t operator()(const Expression& expr) const { return expr.hash(); } }; /// Bind this expression to the given input type, looking up Kernels and field types. /// Some expression simplification may be performed and implicit casts will be inserted. /// Any state necessary for execution will be initialized and returned. - Result Bind(ValueDescr in, compute::ExecContext* = NULLPTR) const; - Result Bind(const Schema& in_schema, + Result Bind(ValueDescr in, compute::ExecContext* = NULLPTR) const; + Result Bind(const Schema& in_schema, compute::ExecContext* = NULLPTR) const; // XXX someday // Clone all KernelState in this bound expression. If any function referenced by this // expression has mutable KernelState, it is not safe to execute or apply simplification // passes to it (or copies of it!) from multiple threads. Cloning state produces new - // KernelStates where necessary to ensure that Expression2s may be manipulated safely + // KernelStates where necessary to ensure that Expressions may be manipulated safely // on multiple threads. // Result CloneState() const; // Status SetState(ExpressionState); @@ -107,10 +107,10 @@ class ARROW_DS_EXPORT Expression2 { using Impl = util::Variant; - explicit Expression2(std::shared_ptr impl, ValueDescr descr = {}) + explicit Expression(std::shared_ptr impl, ValueDescr descr = {}) : impl_(std::move(impl)), descr_(std::move(descr)) {} - Expression2() = default; + Expression() = default; private: std::shared_ptr impl_; @@ -118,86 +118,86 @@ class ARROW_DS_EXPORT Expression2 { // XXX someday // NullGeneralization::type evaluates_to_null_; - ARROW_EXPORT friend bool Identical(const Expression2& l, const Expression2& r); + ARROW_EXPORT friend bool Identical(const Expression& l, const Expression& r); - ARROW_EXPORT friend void PrintTo(const Expression2&, std::ostream*); + ARROW_EXPORT friend void PrintTo(const Expression&, std::ostream*); }; -inline bool operator==(const Expression2& l, const Expression2& r) { return l.Equals(r); } -inline bool operator!=(const Expression2& l, const Expression2& r) { +inline bool operator==(const Expression& l, const Expression& r) { return l.Equals(r); } +inline bool operator!=(const Expression& l, const Expression& r) { return !l.Equals(r); } // Factories -inline Expression2 call(std::string function, std::vector arguments, +inline Expression call(std::string function, std::vector arguments, std::shared_ptr options = NULLPTR) { - Expression2::Call call; + Expression::Call call; call.function = std::move(function); call.arguments = std::move(arguments); call.options = std::move(options); - return Expression2(std::make_shared(std::move(call))); + return Expression(std::make_shared(std::move(call))); } template ::value>::type> -Expression2 call(std::string function, std::vector arguments, +Expression call(std::string function, std::vector arguments, Options options) { return call(std::move(function), std::move(arguments), std::make_shared(std::move(options))); } template -Expression2 field_ref(Args&&... args) { - return Expression2( - std::make_shared(FieldRef(std::forward(args)...))); +Expression field_ref(Args&&... args) { + return Expression( + std::make_shared(FieldRef(std::forward(args)...))); } template -Expression2 literal(Arg&& arg) { +Expression literal(Arg&& arg) { Datum lit(std::forward(arg)); ValueDescr descr = lit.descr(); - return Expression2(std::make_shared(std::move(lit)), + return Expression(std::make_shared(std::move(lit)), std::move(descr)); } ARROW_DS_EXPORT -std::vector FieldsInExpression(const Expression2&); +std::vector FieldsInExpression(const Expression&); ARROW_DS_EXPORT Result> ExtractKnownFieldValues( - const Expression2& guaranteed_true_predicate); + const Expression& guaranteed_true_predicate); -/// \defgroup expression-passes Functions for modification of Expression2s +/// \defgroup expression-passes Functions for modification of Expressions /// /// @{ /// /// These operate on a bound expression and its bound state simultaneously, -/// ensuring that Call Expression2s' KernelState can be utilized or reassociated. +/// ensuring that Call Expressions' KernelState can be utilized or reassociated. /// Weak canonicalization which establishes guarantees for subsequent passes. Even /// equivalent Expressions may result in different canonicalized expressions. /// TODO this could be a strong canonicalization ARROW_DS_EXPORT -Result Canonicalize(Expression2, compute::ExecContext* = NULLPTR); +Result Canonicalize(Expression, compute::ExecContext* = NULLPTR); /// Simplify Expressions based on literal arguments (for example, add(null, x) will always /// be null so replace the call with a null literal). Includes early evaluation of all /// calls whose arguments are entirely literal. ARROW_DS_EXPORT -Result FoldConstants(Expression2); +Result FoldConstants(Expression); ARROW_DS_EXPORT -Result ReplaceFieldsWithKnownValues( - const std::unordered_map& known_values, Expression2); +Result ReplaceFieldsWithKnownValues( + const std::unordered_map& known_values, Expression); /// Simplify an expression by replacing subexpressions based on a guarantee: /// a boolean expression which is guaranteed to evaluate to `true`. For example, this is /// used to remove redundant function calls from a filter expression or to replace a /// reference to a constant-value field with a literal. ARROW_DS_EXPORT -Result SimplifyWithGuarantee(Expression2, - const Expression2& guaranteed_true_predicate); +Result SimplifyWithGuarantee(Expression, + const Expression& guaranteed_true_predicate); /// @} @@ -205,39 +205,39 @@ Result SimplifyWithGuarantee(Expression2, /// Execute a scalar expression against the provided state and input Datum. This /// expression must be bound. -Result ExecuteScalarExpression(const Expression2&, const Datum& input, +Result ExecuteScalarExpression(const Expression&, const Datum& input, compute::ExecContext* = NULLPTR); // Serialization ARROW_DS_EXPORT -Result> Serialize(const Expression2&); +Result> Serialize(const Expression&); ARROW_DS_EXPORT -Result Deserialize(const Buffer&); +Result Deserialize(const Buffer&); // Convenience aliases for factories -ARROW_DS_EXPORT Expression2 project(std::vector values, +ARROW_DS_EXPORT Expression project(std::vector values, std::vector names); -ARROW_DS_EXPORT Expression2 equal(Expression2 lhs, Expression2 rhs); +ARROW_DS_EXPORT Expression equal(Expression lhs, Expression rhs); -ARROW_DS_EXPORT Expression2 not_equal(Expression2 lhs, Expression2 rhs); +ARROW_DS_EXPORT Expression not_equal(Expression lhs, Expression rhs); -ARROW_DS_EXPORT Expression2 less(Expression2 lhs, Expression2 rhs); +ARROW_DS_EXPORT Expression less(Expression lhs, Expression rhs); -ARROW_DS_EXPORT Expression2 less_equal(Expression2 lhs, Expression2 rhs); +ARROW_DS_EXPORT Expression less_equal(Expression lhs, Expression rhs); -ARROW_DS_EXPORT Expression2 greater(Expression2 lhs, Expression2 rhs); +ARROW_DS_EXPORT Expression greater(Expression lhs, Expression rhs); -ARROW_DS_EXPORT Expression2 greater_equal(Expression2 lhs, Expression2 rhs); +ARROW_DS_EXPORT Expression greater_equal(Expression lhs, Expression rhs); -ARROW_DS_EXPORT Expression2 and_(Expression2 lhs, Expression2 rhs); -ARROW_DS_EXPORT Expression2 and_(const std::vector&); -ARROW_DS_EXPORT Expression2 or_(Expression2 lhs, Expression2 rhs); -ARROW_DS_EXPORT Expression2 or_(const std::vector&); -ARROW_DS_EXPORT Expression2 not_(Expression2 operand); +ARROW_DS_EXPORT Expression and_(Expression lhs, Expression rhs); +ARROW_DS_EXPORT Expression and_(const std::vector&); +ARROW_DS_EXPORT Expression or_(Expression lhs, Expression rhs); +ARROW_DS_EXPORT Expression or_(const std::vector&); +ARROW_DS_EXPORT Expression not_(Expression operand); } // namespace dataset } // namespace arrow diff --git a/cpp/src/arrow/dataset/expression_internal.h b/cpp/src/arrow/dataset/expression_internal.h index 0324eaedc77..8c26e037f16 100644 --- a/cpp/src/arrow/dataset/expression_internal.h +++ b/cpp/src/arrow/dataset/expression_internal.h @@ -33,15 +33,15 @@ using internal::checked_cast; namespace dataset { -bool Identical(const Expression2& l, const Expression2& r) { return l.impl_ == r.impl_; } +bool Identical(const Expression& l, const Expression& r) { return l.impl_ == r.impl_; } -const Expression2::Call* CallNotNull(const Expression2& expr) { +const Expression::Call* CallNotNull(const Expression& expr) { auto call = expr.call(); DCHECK_NE(call, nullptr); return call; } -inline void GetAllFieldRefs(const Expression2& expr, +inline void GetAllFieldRefs(const Expression& expr, std::unordered_set* refs) { if (auto lit = expr.literal()) return; @@ -50,12 +50,12 @@ inline void GetAllFieldRefs(const Expression2& expr, return; } - for (const Expression2& arg : CallNotNull(expr)->arguments) { + for (const Expression& arg : CallNotNull(expr)->arguments) { GetAllFieldRefs(arg, refs); } } -inline std::vector GetDescriptors(const std::vector& exprs) { +inline std::vector GetDescriptors(const std::vector& exprs) { std::vector descrs(exprs.size()); for (size_t i = 0; i < exprs.size(); ++i) { DCHECK(exprs[i].IsBound()); @@ -133,7 +133,7 @@ struct Comparison { return it != flipped_comparisons.end() ? &it->second : nullptr; } - static const type* Get(const Expression2& expr) { + static const type* Get(const Expression& expr) { if (auto call = expr.call()) { return Comparison::Get(call->function); } @@ -207,7 +207,7 @@ struct Comparison { } }; -inline const compute::CastOptions* GetCastOptions(const Expression2::Call& call) { +inline const compute::CastOptions* GetCastOptions(const Expression::Call& call) { if (call.function != "cast") return nullptr; return checked_cast(call.options.get()); } @@ -217,17 +217,17 @@ inline bool IsSetLookup(const std::string& function) { } inline const compute::SetLookupOptions* GetSetLookupOptions( - const Expression2::Call& call) { + const Expression::Call& call) { if (!IsSetLookup(call.function)) return nullptr; return checked_cast(call.options.get()); } -inline const compute::StructOptions* GetStructOptions(const Expression2::Call& call) { +inline const compute::StructOptions* GetStructOptions(const Expression::Call& call) { if (call.function != "struct") return nullptr; return checked_cast(call.options.get()); } -inline const compute::StrptimeOptions* GetStrptimeOptions(const Expression2::Call& call) { +inline const compute::StrptimeOptions* GetStrptimeOptions(const Expression::Call& call) { if (call.function != "strptime") return nullptr; return checked_cast(call.options.get()); } @@ -256,7 +256,7 @@ inline Status EnsureNotDictionary(Datum* datum) { return Status::OK(); } -inline Status EnsureNotDictionary(Expression2::Call* call) { +inline Status EnsureNotDictionary(Expression::Call* call) { if (auto options = GetSetLookupOptions(*call)) { auto new_options = *options; RETURN_NOT_OK(EnsureNotDictionary(&new_options.value_set)); @@ -266,7 +266,7 @@ inline Status EnsureNotDictionary(Expression2::Call* call) { } inline Result> FunctionOptionsToStructScalar( - const Expression2::Call& call) { + const Expression::Call& call) { if (call.options == nullptr) { return nullptr; } @@ -318,7 +318,7 @@ inline Result> FunctionOptionsToStructScalar( } inline Status FunctionOptionsFromStructScalar(const StructScalar* repr, - Expression2::Call* call) { + Expression::Call* call) { if (repr == nullptr) { call->options = nullptr; return Status::OK(); @@ -359,9 +359,9 @@ inline Status FunctionOptionsFromStructScalar(const StructScalar* repr, struct FlattenedAssociativeChain { bool was_left_folded = true; - std::vector exprs, fringe; + std::vector exprs, fringe; - explicit FlattenedAssociativeChain(Expression2 expr) : exprs{std::move(expr)} { + explicit FlattenedAssociativeChain(Expression expr) : exprs{std::move(expr)} { auto call = CallNotNull(exprs.back()); fringe = call->arguments; @@ -384,14 +384,14 @@ struct FlattenedAssociativeChain { // NB: no increment so we hit sub_call's first argument next iteration } - DCHECK(std::all_of(exprs.begin(), exprs.end(), [](const Expression2& expr) { + DCHECK(std::all_of(exprs.begin(), exprs.end(), [](const Expression& expr) { return CallNotNull(expr)->options == nullptr; })); } }; inline Result> GetFunction( - const Expression2::Call& call, compute::ExecContext* exec_context) { + const Expression::Call& call, compute::ExecContext* exec_context) { if (call.function != "cast") { return exec_context->func_registry()->GetFunction(call.function); } @@ -401,9 +401,9 @@ inline Result> GetFunction( } template -Result Modify(Expression2 expr, const PreVisit& pre, +Result Modify(Expression expr, const PreVisit& pre, const PostVisitCall& post_call) { - ARROW_ASSIGN_OR_RAISE(expr, Result(pre(std::move(expr)))); + ARROW_ASSIGN_OR_RAISE(expr, Result(pre(std::move(expr)))); auto call = expr.call(); if (!call) return expr; @@ -423,8 +423,8 @@ Result Modify(Expression2 expr, const PreVisit& pre, if (at_least_one_modified) { // reconstruct the call expression with the modified arguments - auto modified_expr = Expression2( - std::make_shared(std::move(modified_call)), expr.descr()); + auto modified_expr = Expression( + std::make_shared(std::move(modified_call)), expr.descr()); return post_call(std::move(modified_expr), &expr); } diff --git a/cpp/src/arrow/dataset/expression_test.cc b/cpp/src/arrow/dataset/expression_test.cc index f2b6951a8a3..4812d06b8c8 100644 --- a/cpp/src/arrow/dataset/expression_test.cc +++ b/cpp/src/arrow/dataset/expression_test.cc @@ -42,12 +42,12 @@ namespace dataset { #define EXPECT_OK ARROW_EXPECT_OK -Expression2 cast(Expression2 argument, std::shared_ptr to_type) { +Expression cast(Expression argument, std::shared_ptr to_type) { return call("cast", {std::move(argument)}, compute::CastOptions::Safe(std::move(to_type))); } -TEST(Expression2, ToString) { +TEST(Expression, ToString) { EXPECT_EQ(field_ref("alpha").ToString(), "FieldRef(alpha)"); EXPECT_EQ(literal(3).ToString(), "3"); @@ -81,7 +81,7 @@ TEST(Expression2, ToString) { ", {a,renamed_a,three,b})"); } -TEST(Expression2, Equality) { +TEST(Expression, Equality) { EXPECT_EQ(literal(1), literal(1)); EXPECT_NE(literal(1), literal(2)); @@ -110,8 +110,8 @@ TEST(Expression2, Equality) { call("cast", {field_ref("a")}, compute::CastOptions::Unsafe(int32()))); } -TEST(Expression2, Hash) { - std::unordered_set set; +TEST(Expression, Hash) { + std::unordered_set set; EXPECT_TRUE(set.emplace(field_ref("alpha")).second); EXPECT_TRUE(set.emplace(field_ref("beta")).second); @@ -131,7 +131,7 @@ TEST(Expression2, Hash) { EXPECT_EQ(set.size(), 6); } -TEST(Expression2, IsScalarExpression) { +TEST(Expression, IsScalarExpression) { EXPECT_TRUE(literal(true).IsScalarExpression()); auto arr = ArrayFromJSON(int8(), "[]"); @@ -150,7 +150,7 @@ TEST(Expression2, IsScalarExpression) { EXPECT_FALSE(call("take", {field_ref("a"), literal(arr)}).IsScalarExpression()); } -TEST(Expression2, IsSatisfiable) { +TEST(Expression, IsSatisfiable) { EXPECT_TRUE(literal(true).IsSatisfiable()); EXPECT_FALSE(literal(false).IsSatisfiable()); @@ -164,7 +164,7 @@ TEST(Expression2, IsSatisfiable) { // NB: no constant folding here EXPECT_TRUE(equal(literal(0), literal(1)).IsSatisfiable()); - // When a top level conjunction contains an Expression2 which is certain to evaluate to + // When a top level conjunction contains an Expression which is certain to evaluate to // null, it can only evaluate to null or false. auto null_or_false = and_(literal(null), field_ref("a")); // This may appear in satisfiable filters if coalesced @@ -176,8 +176,8 @@ TEST(Expression2, IsSatisfiable) { // EXPECT_FALSE(null_or_false).IsSatisfiable()); } -TEST(Expression2, FieldsInExpression) { - auto ExpectFieldsAre = [](Expression2 expr, std::vector expected) { +TEST(Expression, FieldsInExpression) { + auto ExpectFieldsAre = [](Expression expr, std::vector expected) { EXPECT_THAT(FieldsInExpression(expr), testing::ContainerEq(expected)); }; @@ -203,7 +203,7 @@ TEST(Expression2, FieldsInExpression) { {"a", "b", "c"}); } -TEST(Expression2, BindLiteral) { +TEST(Expression, BindLiteral) { for (Datum dat : { Datum(3), Datum(3.5), @@ -216,8 +216,8 @@ TEST(Expression2, BindLiteral) { } } -void ExpectBindsTo(Expression2 expr, Expression2 expected, - Expression2* bound_out = nullptr) { +void ExpectBindsTo(Expression expr, Expression expected, + Expression* bound_out = nullptr) { ASSERT_OK_AND_ASSIGN(auto bound, expr.Bind(*kBoringSchema)); EXPECT_TRUE(bound.IsBound()); @@ -229,7 +229,7 @@ void ExpectBindsTo(Expression2 expr, Expression2 expected, } } -TEST(Expression2, BindFieldRef) { +TEST(Expression, BindFieldRef) { // an unbound field_ref does not have the output ValueDescr set auto expr = field_ref("alpha"); EXPECT_EQ(expr.descr(), ValueDescr{}); @@ -254,7 +254,7 @@ TEST(Expression2, BindFieldRef) { EXPECT_EQ(expr.descr(), ValueDescr::Array(int32())); } -TEST(Expression2, BindCall) { +TEST(Expression, BindCall) { auto expr = call("add", {field_ref("a"), field_ref("b")}); EXPECT_FALSE(expr.IsBound()); @@ -268,7 +268,7 @@ TEST(Expression2, BindCall) { expr.Bind(Schema({field("a", int32()), field("b", int32())}))); } -TEST(Expression2, BindWithImplicitCasts) { +TEST(Expression, BindWithImplicitCasts) { for (auto cmp : {equal, not_equal, less, less_equal, greater, greater_equal}) { // cast arguments to same type ExpectBindsTo(cmp(field_ref("i32"), field_ref("i64")), @@ -299,7 +299,7 @@ TEST(Expression2, BindWithImplicitCasts) { call("is_in", {field_ref("str")}, Opts(utf8()))); } -TEST(Expression2, BindNestedCall) { +TEST(Expression, BindNestedCall) { auto expr = call("add", {field_ref("a"), call("subtract", {call("multiply", {field_ref("b"), field_ref("c")}), @@ -313,7 +313,7 @@ TEST(Expression2, BindNestedCall) { EXPECT_TRUE(expr.IsBound()); } -TEST(Expression2, ExecuteFieldRef) { +TEST(Expression, ExecuteFieldRef) { auto AssertRefIs = [](FieldRef ref, Datum in, Datum expected) { auto expr = field_ref(ref); @@ -348,7 +348,7 @@ TEST(Expression2, ExecuteFieldRef) { MakeNullScalar(null())); } -Result NaiveExecuteScalarExpression(const Expression2& expr, const Datum& input) { +Result NaiveExecuteScalarExpression(const Expression& expr, const Datum& input) { auto call = expr.call(); if (call == nullptr) { // already tested execution of field_ref, execution of literal is trivial @@ -371,7 +371,7 @@ Result NaiveExecuteScalarExpression(const Expression2& expr, const Datum& return function->Execute(arguments, call->options.get(), &exec_context); } -void AssertExecute(Expression2 expr, Datum in, Datum* actual_out = NULLPTR) { +void AssertExecute(Expression expr, Datum in, Datum* actual_out = NULLPTR) { if (in.is_value()) { ASSERT_OK_AND_ASSIGN(expr, expr.Bind(in.descr())); } else { @@ -389,7 +389,7 @@ void AssertExecute(Expression2 expr, Datum in, Datum* actual_out = NULLPTR) { } } -TEST(Expression2, ExecuteCall) { +TEST(Expression, ExecuteCall) { AssertExecute(call("add", {field_ref("a"), literal(3.5)}), ArrayFromJSON(struct_({field("a", float64())}), R"([ {"a": 6.125}, @@ -421,7 +421,7 @@ TEST(Expression2, ExecuteCall) { ])")); } -TEST(Expression2, ExecuteDictionaryTransparent) { +TEST(Expression, ExecuteDictionaryTransparent) { AssertExecute( equal(field_ref("a"), field_ref("b")), ArrayFromJSON( @@ -443,7 +443,7 @@ TEST(Expression2, ExecuteDictionaryTransparent) { } struct { - void operator()(Expression2 expr, Expression2 expected) { + void operator()(Expression expr, Expression expected) { ASSERT_OK_AND_ASSIGN(expr, expr.Bind(*kBoringSchema)); ASSERT_OK_AND_ASSIGN(expected, expected.Bind(*kBoringSchema)); @@ -459,7 +459,7 @@ struct { } ExpectFoldsTo; -TEST(Expression2, FoldConstants) { +TEST(Expression, FoldConstants) { // literals are unchanged ExpectFoldsTo(literal(3), literal(3)); @@ -514,7 +514,7 @@ TEST(Expression2, FoldConstants) { call("is_in", {call("add", {field_ref("i32"), literal(6)})}, in_123)); } -TEST(Expression2, FoldConstantsBoolean) { +TEST(Expression, FoldConstantsBoolean) { // test and_kleene/or_kleene-specific optimizations auto one = literal(1); auto two = literal(2); @@ -532,9 +532,9 @@ TEST(Expression2, FoldConstantsBoolean) { ExpectFoldsTo(or_(whatever, whatever), whatever); } -TEST(Expression2, ExtractKnownFieldValues) { +TEST(Expression, ExtractKnownFieldValues) { struct { - void operator()(Expression2 guarantee, + void operator()(Expression guarantee, std::unordered_map expected) { ASSERT_OK_AND_ASSIGN(auto actual, ExtractKnownFieldValues(guarantee)); EXPECT_THAT(actual, UnorderedElementsAreArray(expected)) @@ -582,11 +582,11 @@ TEST(Expression2, ExtractKnownFieldValues) { {{"i32", Datum(3)}, {"i32_req", Datum(1)}}); } -TEST(Expression2, ReplaceFieldsWithKnownValues) { +TEST(Expression, ReplaceFieldsWithKnownValues) { auto ExpectReplacesTo = - [](Expression2 expr, + [](Expression expr, std::unordered_map known_values, - Expression2 unbound_expected) { + Expression unbound_expected) { ASSERT_OK_AND_ASSIGN(expr, expr.Bind(*kBoringSchema)); ASSERT_OK_AND_ASSIGN(auto expected, unbound_expected.Bind(*kBoringSchema)); ASSERT_OK_AND_ASSIGN(auto replaced, @@ -635,7 +635,7 @@ TEST(Expression2, ReplaceFieldsWithKnownValues) { } struct { - void operator()(Expression2 expr, Expression2 unbound_expected) const { + void operator()(Expression expr, Expression unbound_expected) const { ASSERT_OK_AND_ASSIGN(auto bound, expr.Bind(*kBoringSchema)); ASSERT_OK_AND_ASSIGN(auto expected, unbound_expected.Bind(*kBoringSchema)); ASSERT_OK_AND_ASSIGN(auto actual, Canonicalize(bound)); @@ -649,7 +649,7 @@ struct { } } ExpectCanonicalizesTo; -TEST(Expression2, CanonicalizeTrivial) { +TEST(Expression, CanonicalizeTrivial) { ExpectCanonicalizesTo(literal(1), literal(1)); ExpectCanonicalizesTo(field_ref("b"), field_ref("b")); @@ -658,7 +658,7 @@ TEST(Expression2, CanonicalizeTrivial) { equal(field_ref("i32"), field_ref("i32_req"))); } -TEST(Expression2, CanonicalizeAnd) { +TEST(Expression, CanonicalizeAnd) { // some aliases for brevity: auto true_ = literal(true); auto null_ = literal(std::make_shared()); @@ -686,7 +686,7 @@ TEST(Expression2, CanonicalizeAnd) { call("is_valid", {and_(true_, b)})); } -TEST(Expression2, CanonicalizeComparison) { +TEST(Expression, CanonicalizeComparison) { ExpectCanonicalizesTo(equal(literal(1), field_ref("i32")), equal(field_ref("i32"), literal(1))); @@ -701,12 +701,12 @@ TEST(Expression2, CanonicalizeComparison) { } struct Simplify { - Expression2 expr; + Expression expr; struct Expectable { - Expression2 expr, guarantee; + Expression expr, guarantee; - void Expect(Expression2 unbound_expected) { + void Expect(Expression unbound_expected) { ASSERT_OK_AND_ASSIGN(auto bound, expr.Bind(*kBoringSchema)); ASSERT_OK_AND_ASSIGN(auto simplified, SimplifyWithGuarantee(bound, guarantee)); @@ -724,10 +724,10 @@ struct Simplify { void Expect(bool constant) { Expect(literal(constant)); } }; - Expectable WithGuarantee(Expression2 guarantee) { return {expr, guarantee}; } + Expectable WithGuarantee(Expression guarantee) { return {expr, guarantee}; } }; -TEST(Expression2, SingleComparisonGuarantees) { +TEST(Expression, SingleComparisonGuarantees) { auto i32 = field_ref("i32"); // i32 is guaranteed equal to 3, so the projection can just materialize that constant @@ -840,7 +840,7 @@ TEST(Expression2, SingleComparisonGuarantees) { } } -TEST(Expression2, SimplifyWithGuarantee) { +TEST(Expression, SimplifyWithGuarantee) { // drop both members of a conjunctive filter Simplify{ and_(equal(field_ref("i32"), literal(2)), equal(field_ref("f32"), literal(3.5F)))} @@ -880,7 +880,7 @@ TEST(Expression2, SimplifyWithGuarantee) { compute::SetLookupOptions{ArrayFromJSON(int64(), "[1,2,3]"), true})); } -TEST(Expression2, SimplifyThenExecute) { +TEST(Expression, SimplifyThenExecute) { auto filter = or_({equal(field_ref("f32"), literal("0")), call("is_in", {field_ref("i64")}, @@ -908,8 +908,8 @@ TEST(Expression2, SimplifyThenExecute) { AssertDatumsEqual(evaluated, simplified_evaluated, /*verbose=*/true); } -TEST(Expression2, Filter) { - auto ExpectFilter = [](Expression2 filter, std::string batch_json) { +TEST(Expression, Filter) { + auto ExpectFilter = [](Expression filter, std::string batch_json) { ASSERT_OK_AND_ASSIGN(auto s, kBoringSchema->AddField(0, field("in", boolean()))); auto batch = RecordBatchFromJSON(s, batch_json); auto expected_mask = batch->column(0); @@ -931,10 +931,10 @@ TEST(Expression2, Filter) { ])"); } -TEST(Expression2, SerializationRoundTrips) { - auto ExpectRoundTrips = [](const Expression2& expr) { +TEST(Expression, SerializationRoundTrips) { + auto ExpectRoundTrips = [](const Expression& expr) { ASSERT_OK_AND_ASSIGN(auto serialized, Serialize(expr)); - ASSERT_OK_AND_ASSIGN(Expression2 roundtripped, Deserialize(*serialized)); + ASSERT_OK_AND_ASSIGN(Expression roundtripped, Deserialize(*serialized)); EXPECT_EQ(expr, roundtripped); }; diff --git a/cpp/src/arrow/dataset/file_base.cc b/cpp/src/arrow/dataset/file_base.cc index 86a262a662b..2c437ce8eec 100644 --- a/cpp/src/arrow/dataset/file_base.cc +++ b/cpp/src/arrow/dataset/file_base.cc @@ -60,12 +60,12 @@ Result> FileFormat::MakeFragment( } Result> FileFormat::MakeFragment( - FileSource source, Expression2 partition_expression) { + FileSource source, Expression partition_expression) { return MakeFragment(std::move(source), std::move(partition_expression), nullptr); } Result> FileFormat::MakeFragment( - FileSource source, Expression2 partition_expression, + FileSource source, Expression partition_expression, std::shared_ptr physical_schema) { return std::shared_ptr( new FileFragment(std::move(source), shared_from_this(), @@ -82,7 +82,7 @@ Result FileFragment::Scan(std::shared_ptr options } FileSystemDataset::FileSystemDataset(std::shared_ptr schema, - Expression2 root_partition, + Expression root_partition, std::shared_ptr format, std::shared_ptr filesystem, std::vector> fragments) @@ -92,7 +92,7 @@ FileSystemDataset::FileSystemDataset(std::shared_ptr schema, fragments_(std::move(fragments)) {} Result> FileSystemDataset::Make( - std::shared_ptr schema, Expression2 root_partition, + std::shared_ptr schema, Expression root_partition, std::shared_ptr format, std::shared_ptr filesystem, std::vector> fragments) { return std::shared_ptr(new FileSystemDataset( @@ -136,7 +136,7 @@ std::string FileSystemDataset::ToString() const { return repr; } -Result FileSystemDataset::GetFragmentsImpl(Expression2 predicate) { +Result FileSystemDataset::GetFragmentsImpl(Expression predicate) { FragmentVector fragments; for (const auto& fragment : fragments_) { diff --git a/cpp/src/arrow/dataset/file_base.h b/cpp/src/arrow/dataset/file_base.h index f6def56b7de..7b67e6f9bf5 100644 --- a/cpp/src/arrow/dataset/file_base.h +++ b/cpp/src/arrow/dataset/file_base.h @@ -143,11 +143,11 @@ class ARROW_DS_EXPORT FileFormat : public std::enable_shared_from_this> MakeFragment( - FileSource source, Expression2 partition_expression, + FileSource source, Expression partition_expression, std::shared_ptr physical_schema); Result> MakeFragment(FileSource source, - Expression2 partition_expression); + Expression partition_expression); Result> MakeFragment( FileSource source, std::shared_ptr physical_schema = NULLPTR); @@ -173,7 +173,7 @@ class ARROW_DS_EXPORT FileFragment : public Fragment { protected: FileFragment(FileSource source, std::shared_ptr format, - Expression2 partition_expression, std::shared_ptr physical_schema) + Expression partition_expression, std::shared_ptr physical_schema) : Fragment(std::move(partition_expression), std::move(physical_schema)), source_(std::move(source)), format_(std::move(format)) {} @@ -206,7 +206,7 @@ class ARROW_DS_EXPORT FileSystemDataset : public Dataset { /// /// \return A constructed dataset. static Result> Make( - std::shared_ptr schema, Expression2 root_partition, + std::shared_ptr schema, Expression root_partition, std::shared_ptr format, std::shared_ptr filesystem, std::vector> fragments); @@ -234,9 +234,9 @@ class ARROW_DS_EXPORT FileSystemDataset : public Dataset { protected: Result GetFragmentsImpl( - Expression2 predicate) override; + Expression predicate) override; - FileSystemDataset(std::shared_ptr schema, Expression2 root_partition, + FileSystemDataset(std::shared_ptr schema, Expression root_partition, std::shared_ptr format, std::shared_ptr filesystem, std::vector> fragments); diff --git a/cpp/src/arrow/dataset/file_parquet.cc b/cpp/src/arrow/dataset/file_parquet.cc index 8ba5f67f6fc..94d2115358f 100644 --- a/cpp/src/arrow/dataset/file_parquet.cc +++ b/cpp/src/arrow/dataset/file_parquet.cc @@ -124,7 +124,7 @@ static Result> GetSchemaManifest( return manifest; } -static util::optional ColumnChunkStatisticsAsExpression( +static util::optional ColumnChunkStatisticsAsExpression( const SchemaField& schema_field, const parquet::RowGroupMetaData& metadata) { // For the remaining of this function, failure to extract/parse statistics // are ignored by returning nullptr. The goal is two fold. First @@ -331,7 +331,7 @@ Result ParquetFileFormat::ScanFile(std::shared_ptr> ParquetFileFormat::MakeFragment( - FileSource source, Expression2 partition_expression, + FileSource source, Expression partition_expression, std::shared_ptr physical_schema, std::vector row_groups) { return std::shared_ptr(new ParquetFileFragment( std::move(source), shared_from_this(), std::move(partition_expression), @@ -339,7 +339,7 @@ Result> ParquetFileFormat::MakeFragment( } Result> ParquetFileFormat::MakeFragment( - FileSource source, Expression2 partition_expression, + FileSource source, Expression partition_expression, std::shared_ptr physical_schema) { return std::shared_ptr(new ParquetFileFragment( std::move(source), shared_from_this(), std::move(partition_expression), @@ -394,7 +394,7 @@ Status ParquetFileWriter::Finish() { return parquet_writer_->Close(); } ParquetFileFragment::ParquetFileFragment(FileSource source, std::shared_ptr format, - Expression2 partition_expression, + Expression partition_expression, std::shared_ptr physical_schema, util::optional> row_groups) : FileFragment(std::move(source), std::move(format), std::move(partition_expression), @@ -456,7 +456,7 @@ Status ParquetFileFragment::SetMetadata( return Status::OK(); } -Result ParquetFileFragment::SplitByRowGroup(Expression2 predicate) { +Result ParquetFileFragment::SplitByRowGroup(Expression predicate) { RETURN_NOT_OK(EnsureCompleteMetadata()); ARROW_ASSIGN_OR_RAISE(auto row_groups, FilterRowGroups(predicate)); @@ -474,7 +474,7 @@ Result ParquetFileFragment::SplitByRowGroup(Expression2 predicat return fragments; } -Result> ParquetFileFragment::Subset(Expression2 predicate) { +Result> ParquetFileFragment::Subset(Expression predicate) { RETURN_NOT_OK(EnsureCompleteMetadata()); ARROW_ASSIGN_OR_RAISE(auto row_groups, FilterRowGroups(predicate)); return Subset(std::move(row_groups)); @@ -491,7 +491,7 @@ Result> ParquetFileFragment::Subset( return new_fragment; } -inline void FoldingAnd(Expression2* l, Expression2 r) { +inline void FoldingAnd(Expression* l, Expression r) { if (*l == literal(true)) { *l = std::move(r); } else { @@ -499,7 +499,7 @@ inline void FoldingAnd(Expression2* l, Expression2 r) { } } -Result> ParquetFileFragment::FilterRowGroups(Expression2 predicate) { +Result> ParquetFileFragment::FilterRowGroups(Expression predicate) { auto lock = physical_schema_mutex_.Lock(); DCHECK_NE(metadata_, nullptr); diff --git a/cpp/src/arrow/dataset/file_parquet.h b/cpp/src/arrow/dataset/file_parquet.h index 9eab3fe9449..ae0337994a0 100644 --- a/cpp/src/arrow/dataset/file_parquet.h +++ b/cpp/src/arrow/dataset/file_parquet.h @@ -117,12 +117,12 @@ class ARROW_DS_EXPORT ParquetFileFormat : public FileFormat { /// \brief Create a Fragment targeting all RowGroups. Result> MakeFragment( - FileSource source, Expression2 partition_expression, + FileSource source, Expression partition_expression, std::shared_ptr physical_schema) override; /// \brief Create a Fragment, restricted to the specified row groups. Result> MakeFragment( - FileSource source, Expression2 partition_expression, + FileSource source, Expression partition_expression, std::shared_ptr physical_schema, std::vector row_groups); /// \brief Return a FileReader on the given source. @@ -150,7 +150,7 @@ class ARROW_DS_EXPORT ParquetFileFormat : public FileFormat { /// significant performance boost when scanning high latency file systems. class ARROW_DS_EXPORT ParquetFileFragment : public FileFragment { public: - Result SplitByRowGroup(Expression2 predicate); + Result SplitByRowGroup(Expression predicate); /// \brief Return the RowGroups selected by this fragment. const std::vector& row_groups() const { @@ -166,12 +166,12 @@ class ARROW_DS_EXPORT ParquetFileFragment : public FileFragment { Status EnsureCompleteMetadata(parquet::arrow::FileReader* reader = NULLPTR); /// \brief Return fragment which selects a filtered subset of this fragment's RowGroups. - Result> Subset(Expression2 predicate); + Result> Subset(Expression predicate); Result> Subset(std::vector row_group_ids); private: ParquetFileFragment(FileSource source, std::shared_ptr format, - Expression2 partition_expression, + Expression partition_expression, std::shared_ptr physical_schema, util::optional> row_groups); @@ -185,7 +185,7 @@ class ARROW_DS_EXPORT ParquetFileFragment : public FileFragment { } // Return a filtered subset of row group indices. - Result> FilterRowGroups(Expression2 predicate); + Result> FilterRowGroups(Expression predicate); ParquetFileFormat& parquet_format_; @@ -193,7 +193,7 @@ class ARROW_DS_EXPORT ParquetFileFragment : public FileFragment { // or util::nullopt if all row groups are selected. util::optional> row_groups_; - std::vector statistics_expressions_; + std::vector statistics_expressions_; std::vector statistics_expressions_complete_; std::shared_ptr metadata_; std::shared_ptr manifest_; diff --git a/cpp/src/arrow/dataset/file_parquet_test.cc b/cpp/src/arrow/dataset/file_parquet_test.cc index 47a563b911f..e9182ddf696 100644 --- a/cpp/src/arrow/dataset/file_parquet_test.cc +++ b/cpp/src/arrow/dataset/file_parquet_test.cc @@ -159,7 +159,7 @@ class TestParquetFileFormat : public ArrowParquetWriterMixin { return Batches(std::move(scan_task_it)); } - void SetFilter(Expression2 filter) { + void SetFilter(Expression filter) { ASSERT_OK_AND_ASSIGN(opts_->filter2, filter.Bind(*schema_)); } @@ -191,7 +191,7 @@ class TestParquetFileFormat : public ArrowParquetWriterMixin { void CountRowGroupsInFragment(const std::shared_ptr& fragment, std::vector expected_row_groups, - Expression2 filter) { + Expression filter) { schema_ = opts_->schema(); ASSERT_OK_AND_ASSIGN(auto bound, filter.Bind(*schema_)); auto parquet_fragment = checked_pointer_cast(fragment); diff --git a/cpp/src/arrow/dataset/file_test.cc b/cpp/src/arrow/dataset/file_test.cc index fefd6911ac0..f0799e07a3a 100644 --- a/cpp/src/arrow/dataset/file_test.cc +++ b/cpp/src/arrow/dataset/file_test.cc @@ -122,7 +122,7 @@ TEST_F(TestFileSystemDataset, RootPartitionPruning) { auto root_partition = equal(field_ref("i32"), literal(5)); MakeDataset({fs::File("a"), fs::File("b")}, root_partition); - auto GetFragments = [&](Expression2 filter) { + auto GetFragments = [&](Expression filter) { return *dataset_->GetFragments(*filter.Bind(*dataset_->schema())); }; @@ -156,7 +156,7 @@ TEST_F(TestFileSystemDataset, TreePartitionPruning) { fs::Dir("CA"), fs::File("CA/San Francisco"), fs::File("CA/Franklin"), }; - std::vector partitions = { + std::vector partitions = { equal(field_ref("state"), literal("NY")), and_(equal(field_ref("state"), literal("NY")), @@ -186,7 +186,7 @@ TEST_F(TestFileSystemDataset, TreePartitionPruning) { // Default filter should always return all data. AssertFragmentsAreFromPath(*dataset_->GetFragments(), all_cities); - auto GetFragments = [&](Expression2 filter) { + auto GetFragments = [&](Expression filter) { return *dataset_->GetFragments(*filter.Bind(*dataset_->schema())); }; @@ -212,7 +212,7 @@ TEST_F(TestFileSystemDataset, FragmentPartitions) { fs::Dir("CA"), fs::File("CA/San Francisco"), fs::File("CA/Franklin"), }; - std::vector partitions = { + std::vector partitions = { equal(field_ref("state"), literal("NY")), and_(equal(field_ref("state"), literal("NY")), diff --git a/cpp/src/arrow/dataset/partition.cc b/cpp/src/arrow/dataset/partition.cc index 6fd61096392..ba4f0c46b8f 100644 --- a/cpp/src/arrow/dataset/partition.cc +++ b/cpp/src/arrow/dataset/partition.cc @@ -50,11 +50,11 @@ std::shared_ptr Partitioning::Default() { std::string type_name() const override { return "default"; } - Result Parse(const std::string& path) const override { + Result Parse(const std::string& path) const override { return literal(true); } - Result Format(const Expression2& expr) const override { + Result Format(const Expression& expr) const override { return Status::NotImplemented("formatting paths from ", type_name(), " Partitioning"); } @@ -68,7 +68,7 @@ std::shared_ptr Partitioning::Default() { return std::make_shared(); } -Status KeyValuePartitioning::SetDefaultValuesFromKeys(const Expression2& expr, +Status KeyValuePartitioning::SetDefaultValuesFromKeys(const Expression& expr, RecordBatchProjector* projector) { ARROW_ASSIGN_OR_RAISE(auto known_values, ExtractKnownFieldValues(expr)); for (const auto& ref_value : known_values) { @@ -85,9 +85,9 @@ Status KeyValuePartitioning::SetDefaultValuesFromKeys(const Expression2& expr, return Status::OK(); } -inline Expression2 ConjunctionFromGroupingRow(Scalar* row) { +inline Expression ConjunctionFromGroupingRow(Scalar* row) { ScalarVector* values = &checked_cast(row)->value; - std::vector equality_expressions(values->size()); + std::vector equality_expressions(values->size()); for (size_t i = 0; i < values->size(); ++i) { const std::string& name = row->type->field(static_cast(i))->name(); equality_expressions[i] = equal(field_ref(name), literal(std::move(values->at(i)))); @@ -135,7 +135,7 @@ Result KeyValuePartitioning::Partition( return out; } -Result KeyValuePartitioning::ConvertKey(const Key& key) const { +Result KeyValuePartitioning::ConvertKey(const Key& key) const { ARROW_ASSIGN_OR_RAISE(auto match, FieldRef(key.name).FindOneOrNone(*schema_)); if (!match) { return literal(true); @@ -178,8 +178,8 @@ Result KeyValuePartitioning::ConvertKey(const Key& key) const { return equal(field_ref(field->name()), literal(std::move(converted))); } -Result KeyValuePartitioning::Parse(const std::string& path) const { - std::vector expressions; +Result KeyValuePartitioning::Parse(const std::string& path) const { + std::vector expressions; for (const Key& key : ParseKeys(path)) { ARROW_ASSIGN_OR_RAISE(auto expr, ConvertKey(key)); @@ -190,7 +190,7 @@ Result KeyValuePartitioning::Parse(const std::string& path) const { return and_(std::move(expressions)); } -Result KeyValuePartitioning::Format(const Expression2& expr) const { +Result KeyValuePartitioning::Format(const Expression& expr) const { std::vector values{static_cast(schema_->num_fields()), nullptr}; ARROW_ASSIGN_OR_RAISE(auto known_values, ExtractKnownFieldValues(expr)); diff --git a/cpp/src/arrow/dataset/partition.h b/cpp/src/arrow/dataset/partition.h index a9f8c87bfc9..8975f565b19 100644 --- a/cpp/src/arrow/dataset/partition.h +++ b/cpp/src/arrow/dataset/partition.h @@ -63,15 +63,15 @@ class ARROW_DS_EXPORT Partitioning { /// produce sub-batches which satisfy mutually exclusive Expressions. struct PartitionedBatches { RecordBatchVector batches; - std::vector expressions; + std::vector expressions; }; virtual Result Partition( const std::shared_ptr& batch) const = 0; /// \brief Parse a path into a partition expression - virtual Result Parse(const std::string& path) const = 0; + virtual Result Parse(const std::string& path) const = 0; - virtual Result Format(const Expression2& expr) const = 0; + virtual Result Format(const Expression& expr) const = 0; /// \brief A default Partitioning which always yields scalar(true) static std::shared_ptr Default(); @@ -122,15 +122,15 @@ class ARROW_DS_EXPORT KeyValuePartitioning : public Partitioning { std::string name, value; }; - static Status SetDefaultValuesFromKeys(const Expression2& expr, + static Status SetDefaultValuesFromKeys(const Expression& expr, RecordBatchProjector* projector); Result Partition( const std::shared_ptr& batch) const override; - Result Parse(const std::string& path) const override; + Result Parse(const std::string& path) const override; - Result Format(const Expression2& expr) const override; + Result Format(const Expression& expr) const override; protected: KeyValuePartitioning(std::shared_ptr schema, ArrayVector dictionaries) @@ -145,7 +145,7 @@ class ARROW_DS_EXPORT KeyValuePartitioning : public Partitioning { virtual Result FormatValues(const std::vector& values) const = 0; /// Convert a Key to a full expression. - Result ConvertKey(const Key& key) const; + Result ConvertKey(const Key& key) const; ArrayVector dictionaries_; }; @@ -207,9 +207,9 @@ class ARROW_DS_EXPORT HivePartitioning : public KeyValuePartitioning { /// \brief Implementation provided by lambda or other callable class ARROW_DS_EXPORT FunctionPartitioning : public Partitioning { public: - using ParseImpl = std::function(const std::string&)>; + using ParseImpl = std::function(const std::string&)>; - using FormatImpl = std::function(const Expression2&)>; + using FormatImpl = std::function(const Expression&)>; FunctionPartitioning(std::shared_ptr schema, ParseImpl parse_impl, FormatImpl format_impl = NULLPTR, std::string name = "function") @@ -220,11 +220,11 @@ class ARROW_DS_EXPORT FunctionPartitioning : public Partitioning { std::string type_name() const override { return name_; } - Result Parse(const std::string& path) const override { + Result Parse(const std::string& path) const override { return parse_impl_(path); } - Result Format(const Expression2& expr) const override { + Result Format(const Expression& expr) const override { if (format_impl_) { return format_impl_(expr); } diff --git a/cpp/src/arrow/dataset/partition_test.cc b/cpp/src/arrow/dataset/partition_test.cc index 0cdb5eb50b6..840e2f0303f 100644 --- a/cpp/src/arrow/dataset/partition_test.cc +++ b/cpp/src/arrow/dataset/partition_test.cc @@ -43,17 +43,17 @@ class TestPartitioning : public ::testing::Test { ASSERT_RAISES(Invalid, partitioning_->Parse(path)); } - void AssertParse(const std::string& path, Expression2 expected) { + void AssertParse(const std::string& path, Expression expected) { ASSERT_OK_AND_ASSIGN(auto parsed, partitioning_->Parse(path)); ASSERT_EQ(parsed, expected); } template - void AssertFormatError(Expression2 expr) { + void AssertFormatError(Expression expr) { ASSERT_EQ(partitioning_->Format(expr).status().code(), code); } - void AssertFormat(Expression2 expr, const std::string& expected) { + void AssertFormat(Expression expr, const std::string& expected) { // formatted partition expressions are bound to the schema of the dataset being // written ASSERT_OK_AND_ASSIGN(auto formatted, partitioning_->Format(expr)); @@ -61,7 +61,7 @@ class TestPartitioning : public ::testing::Test { // ensure the formatted path round trips the relevant components of the partition // expression: roundtripped should be a subset of expr - ASSERT_OK_AND_ASSIGN(Expression2 roundtripped, partitioning_->Parse(formatted)); + ASSERT_OK_AND_ASSIGN(Expression roundtripped, partitioning_->Parse(formatted)); ASSERT_OK_AND_ASSIGN(roundtripped, roundtripped.Bind(*written_schema_)); ASSERT_OK_AND_ASSIGN(auto simplified, SimplifyWithGuarantee(roundtripped, expr)); @@ -370,7 +370,7 @@ TEST_F(TestPartitioning, EtlThenHive) { field("hour", int8()), field("alpha", int32()), field("beta", float32())}); partitioning_ = std::make_shared( - schm, [&](const std::string& path) -> Result { + schm, [&](const std::string& path) -> Result { auto segments = fs::internal::SplitAbstractPath(path); if (segments.size() < etl_fields.size() + alphabeta_fields.size()) { return Status::Invalid("path ", path, " can't be parsed"); @@ -412,8 +412,8 @@ TEST_F(TestPartitioning, Set) { // An adhoc partitioning which parses segments like "/x in [1 4 5]" // into (field_ref("x") == 1 or field_ref("x") == 4 or field_ref("x") == 5) partitioning_ = std::make_shared( - schm, [&](const std::string& path) -> Result { - std::vector subexpressions; + schm, [&](const std::string& path) -> Result { + std::vector subexpressions; for (auto segment : fs::internal::SplitAbstractPath(path)) { std::smatch matches; @@ -451,8 +451,8 @@ class RangePartitioning : public Partitioning { std::string type_name() const override { return "range"; } - Result Parse(const std::string& path) const override { - std::vector ranges; + Result Parse(const std::string& path) const override { + std::vector ranges; for (auto segment : fs::internal::SplitAbstractPath(path)) { auto key = HivePartitioning::ParseKey(segment); @@ -496,7 +496,7 @@ class RangePartitioning : public Partitioning { return Status::OK(); } - Result Format(const Expression2&) const override { return ""; } + Result Format(const Expression&) const override { return ""; } Result Partition( const std::shared_ptr&) const override { return Status::OK(); diff --git a/cpp/src/arrow/dataset/scanner.cc b/cpp/src/arrow/dataset/scanner.cc index 809d83b7e2b..0c501c9f5b3 100644 --- a/cpp/src/arrow/dataset/scanner.cc +++ b/cpp/src/arrow/dataset/scanner.cc @@ -121,7 +121,7 @@ Status ScannerBuilder::Project(std::vector columns) { return Status::OK(); } -Status ScannerBuilder::Filter(const Expression2& filter) { +Status ScannerBuilder::Filter(const Expression& filter) { for (const auto& ref : FieldsInExpression(filter)) { RETURN_NOT_OK(ref.FindOne(*schema())); } diff --git a/cpp/src/arrow/dataset/scanner.h b/cpp/src/arrow/dataset/scanner.h index e2f9892af38..5902f759ec3 100644 --- a/cpp/src/arrow/dataset/scanner.h +++ b/cpp/src/arrow/dataset/scanner.h @@ -63,7 +63,7 @@ class ARROW_DS_EXPORT ScanOptions { std::shared_ptr ReplaceSchema(std::shared_ptr schema) const; // Filter - Expression2 filter2 = literal(true); + Expression filter2 = literal(true); // Schema to which record batches will be reconciled const std::shared_ptr& schema() const { return projector.schema(); } @@ -220,7 +220,7 @@ class ARROW_DS_EXPORT ScannerBuilder { /// /// \return Failure if any referenced columns does not exist in the dataset's /// Schema. - Status Filter(const Expression2& filter); + Status Filter(const Expression& filter); /// \brief Indicate if the Scanner should make use of the available /// ThreadPool found in ScanContext; diff --git a/cpp/src/arrow/dataset/scanner_internal.h b/cpp/src/arrow/dataset/scanner_internal.h index eb11f326570..cd8fffd2a71 100644 --- a/cpp/src/arrow/dataset/scanner_internal.h +++ b/cpp/src/arrow/dataset/scanner_internal.h @@ -30,7 +30,7 @@ namespace arrow { namespace dataset { -inline RecordBatchIterator FilterRecordBatch(RecordBatchIterator it, Expression2 filter, +inline RecordBatchIterator FilterRecordBatch(RecordBatchIterator it, Expression filter, MemoryPool* pool) { return MakeMaybeMapIterator( [=](std::shared_ptr in) -> Result> { @@ -70,7 +70,7 @@ inline RecordBatchIterator ProjectRecordBatch(RecordBatchIterator it, class FilterAndProjectScanTask : public ScanTask { public: - explicit FilterAndProjectScanTask(std::shared_ptr task, Expression2 partition) + explicit FilterAndProjectScanTask(std::shared_ptr task, Expression partition) : ScanTask(task->options(), task->context()), task_(std::move(task)), partition_(std::move(partition)), @@ -80,7 +80,7 @@ class FilterAndProjectScanTask : public ScanTask { Result Execute() override { ARROW_ASSIGN_OR_RAISE(auto it, task_->Execute()); - ARROW_ASSIGN_OR_RAISE(Expression2 simplified_filter, + ARROW_ASSIGN_OR_RAISE(Expression simplified_filter, SimplifyWithGuarantee(filter_, partition_)); RecordBatchIterator filter_it = @@ -94,8 +94,8 @@ class FilterAndProjectScanTask : public ScanTask { private: std::shared_ptr task_; - Expression2 partition_; - Expression2 filter_; + Expression partition_; + Expression filter_; RecordBatchProjector projector_; }; diff --git a/cpp/src/arrow/dataset/test_util.h b/cpp/src/arrow/dataset/test_util.h index 76015fa365c..c5904b45e06 100644 --- a/cpp/src/arrow/dataset/test_util.h +++ b/cpp/src/arrow/dataset/test_util.h @@ -199,7 +199,7 @@ class DatasetFixtureMixin : public ::testing::Test { SetFilter(literal(true)); } - void SetFilter(Expression2 filter) { + void SetFilter(Expression filter) { ASSERT_OK_AND_ASSIGN(options_->filter2, filter.Bind(*schema_)); } @@ -336,8 +336,8 @@ struct MakeFileSystemDatasetMixin { } void MakeDataset(const std::vector& infos, - Expression2 root_partition = literal(true), - std::vector partitions = {}, + Expression root_partition = literal(true), + std::vector partitions = {}, std::shared_ptr s = kBoringSchema) { auto n_fragments = infos.size(); if (partitions.empty()) { @@ -397,8 +397,8 @@ void AssertFragmentsAreFromPath(FragmentIterator it, std::vector ex testing::UnorderedElementsAreArray(expected)); } -static std::vector PartitionExpressionsOf(const FragmentVector& fragments) { - std::vector partition_expressions; +static std::vector PartitionExpressionsOf(const FragmentVector& fragments) { + std::vector partition_expressions; std::transform(fragments.begin(), fragments.end(), std::back_inserter(partition_expressions), [](const std::shared_ptr& fragment) { @@ -408,7 +408,7 @@ static std::vector PartitionExpressionsOf(const FragmentVector& fra } void AssertFragmentsHavePartitionExpressions(std::shared_ptr dataset, - std::vector expected) { + std::vector expected) { ASSERT_OK_AND_ASSIGN(auto fragment_it, dataset->GetFragments()); for (auto& expr : expected) { ASSERT_OK_AND_ASSIGN(expr, expr.Bind(*dataset->schema())); diff --git a/cpp/src/arrow/dataset/type_fwd.h b/cpp/src/arrow/dataset/type_fwd.h index a80b333f0c8..66fed352d0f 100644 --- a/cpp/src/arrow/dataset/type_fwd.h +++ b/cpp/src/arrow/dataset/type_fwd.h @@ -68,7 +68,7 @@ class ParquetFileFragment; class ParquetFileWriter; class ParquetFileWriteOptions; -class Expression2; +class Expression; class Partitioning; class PartitioningFactory; From c3e6be30e645a5b2db57ff16edbf12d63fd67752 Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Thu, 10 Dec 2020 17:29:35 -0500 Subject: [PATCH 07/31] first pass at repairing bindings --- cpp/src/arrow/dataset/expression.h | 25 +++++++++--------- cpp/src/arrow/dataset/expression_test.cc | 3 +++ cpp/src/arrow/datum.h | 4 +-- cpp/src/arrow/util/variant.h | 32 ++++++++++++++---------- cpp/src/arrow/util/variant_test.cc | 13 ++++++++++ r/src/compute.cpp | 2 +- r/src/dataset.cpp | 5 +--- r/src/expression.cpp | 31 +++++++++++++---------- 8 files changed, 69 insertions(+), 46 deletions(-) diff --git a/cpp/src/arrow/dataset/expression.h b/cpp/src/arrow/dataset/expression.h index a9ec5403068..0ab98327619 100644 --- a/cpp/src/arrow/dataset/expression.h +++ b/cpp/src/arrow/dataset/expression.h @@ -70,8 +70,7 @@ class ARROW_DS_EXPORT Expression { /// Some expression simplification may be performed and implicit casts will be inserted. /// Any state necessary for execution will be initialized and returned. Result Bind(ValueDescr in, compute::ExecContext* = NULLPTR) const; - Result Bind(const Schema& in_schema, - compute::ExecContext* = NULLPTR) const; + Result Bind(const Schema& in_schema, compute::ExecContext* = NULLPTR) const; // XXX someday // Clone all KernelState in this bound expression. If any function referenced by this @@ -124,14 +123,12 @@ class ARROW_DS_EXPORT Expression { }; inline bool operator==(const Expression& l, const Expression& r) { return l.Equals(r); } -inline bool operator!=(const Expression& l, const Expression& r) { - return !l.Equals(r); -} +inline bool operator!=(const Expression& l, const Expression& r) { return !l.Equals(r); } // Factories inline Expression call(std::string function, std::vector arguments, - std::shared_ptr options = NULLPTR) { + std::shared_ptr options = NULLPTR) { Expression::Call call; call.function = std::move(function); call.arguments = std::move(arguments); @@ -142,7 +139,7 @@ inline Expression call(std::string function, std::vector arguments, template ::value>::type> Expression call(std::string function, std::vector arguments, - Options options) { + Options options) { return call(std::move(function), std::move(arguments), std::make_shared(std::move(options))); } @@ -157,8 +154,11 @@ template Expression literal(Arg&& arg) { Datum lit(std::forward(arg)); ValueDescr descr = lit.descr(); - return Expression(std::make_shared(std::move(lit)), - std::move(descr)); + return Expression(std::make_shared(std::move(lit)), std::move(descr)); +} + +inline Expression literal(std::shared_ptr scalar) { + return literal(Datum(std::move(scalar))); } ARROW_DS_EXPORT @@ -172,8 +172,7 @@ Result> ExtractKnownFieldVal /// /// @{ /// -/// These operate on a bound expression and its bound state simultaneously, -/// ensuring that Call Expressions' KernelState can be utilized or reassociated. +/// These operate on bound expressions. /// Weak canonicalization which establishes guarantees for subsequent passes. Even /// equivalent Expressions may result in different canonicalized expressions. @@ -197,7 +196,7 @@ Result ReplaceFieldsWithKnownValues( /// reference to a constant-value field with a literal. ARROW_DS_EXPORT Result SimplifyWithGuarantee(Expression, - const Expression& guaranteed_true_predicate); + const Expression& guaranteed_true_predicate); /// @} @@ -219,7 +218,7 @@ Result Deserialize(const Buffer&); // Convenience aliases for factories ARROW_DS_EXPORT Expression project(std::vector values, - std::vector names); + std::vector names); ARROW_DS_EXPORT Expression equal(Expression lhs, Expression rhs); diff --git a/cpp/src/arrow/dataset/expression_test.cc b/cpp/src/arrow/dataset/expression_test.cc index 4812d06b8c8..d9738798259 100644 --- a/cpp/src/arrow/dataset/expression_test.cc +++ b/cpp/src/arrow/dataset/expression_test.cc @@ -52,6 +52,9 @@ TEST(Expression, ToString) { EXPECT_EQ(literal(3).ToString(), "3"); + auto ts = *MakeScalar("1990-10-23 10:23:33")->CastTo(timestamp(TimeUnit::NANO)); + EXPECT_EQ(literal(ts).ToString(), "656677413000000000"); + EXPECT_EQ(call("add", {literal(3), field_ref("beta")}).ToString(), "add(3,FieldRef(beta))"); diff --git a/cpp/src/arrow/datum.h b/cpp/src/arrow/datum.h index 9c620279c9e..3d1f545e5f7 100644 --- a/cpp/src/arrow/datum.h +++ b/cpp/src/arrow/datum.h @@ -118,8 +118,8 @@ struct ARROW_EXPORT Datum { /// \brief Empty datum, to be populated elsewhere Datum() = default; - Datum(const Datum& other) noexcept = default; - Datum& operator=(const Datum& other) noexcept = default; + Datum(const Datum& other) = default; + Datum& operator=(const Datum& other) = default; Datum(Datum&& other) noexcept = default; Datum& operator=(Datum&& other) noexcept = default; diff --git a/cpp/src/arrow/util/variant.h b/cpp/src/arrow/util/variant.h index 89f39ab8917..04ac5e2bb45 100644 --- a/cpp/src/arrow/util/variant.h +++ b/cpp/src/arrow/util/variant.h @@ -96,7 +96,7 @@ template struct all : conditional_t, std::false_type> {}; struct delete_copy_constructor { - template + template struct type { type() = default; type(const type& other) = delete; @@ -105,11 +105,13 @@ struct delete_copy_constructor { }; struct explicit_copy_constructor { - template + template struct type { type() = default; - type(const type& other) { static_cast(other).copy_to(this); } - type& operator=(const type& other) { + type(const type& other) noexcept(NoexceptCopyable) { + static_cast(other).copy_to(this); + } + type& operator=(const type& other) noexcept(NoexceptCopyable) { static_cast(this)->destroy(); static_cast(other).copy_to(this); return *this; @@ -117,11 +119,18 @@ struct explicit_copy_constructor { }; }; +template +using CopyConstructionImpl = + typename conditional_t::value && + std::is_copy_assignable::value)...>::value, + explicit_copy_constructor, delete_copy_constructor>:: + template type::value...>::value>; + template struct VariantStorage { VariantStorage() = default; - VariantStorage(const VariantStorage&) {} - VariantStorage& operator=(const VariantStorage&) { return *this; } + VariantStorage(const VariantStorage&) noexcept {} + VariantStorage& operator=(const VariantStorage&) noexcept { return *this; } VariantStorage(VariantStorage&&) noexcept {} VariantStorage& operator=(VariantStorage&&) noexcept { return *this; } ~VariantStorage() { @@ -192,7 +201,8 @@ struct VariantImpl, H, T...> : VariantImpl, T...> { // Templated to avoid instantiation in case H is not copy constructible template - void copy_to(Void* generic_target) const { + void copy_to(Void* generic_target) const + noexcept(noexcept(H(std::declval()))) { const auto target = static_cast(generic_target); try { if (this->index_ == kIndex) { @@ -243,11 +253,7 @@ struct VariantImpl, H, T...> : VariantImpl, T...> { template class Variant : detail::VariantImpl, T...>, - detail::conditional_t< - detail::all<(std::is_copy_constructible::value && - std::is_copy_assignable::value)...>::value, - detail::explicit_copy_constructor, - detail::delete_copy_constructor>::template type> { + detail::CopyConstructionImpl, T...> { template static constexpr uint8_t index_of() { return Impl::index_of(detail::type_constant{}); @@ -338,7 +344,7 @@ class Variant : detail::VariantImpl, T...>, this->index_ = 0; } - template + template friend struct detail::explicit_copy_constructor::type; template diff --git a/cpp/src/arrow/util/variant_test.cc b/cpp/src/arrow/util/variant_test.cc index 9e36f2eb9cf..5d83e00185f 100644 --- a/cpp/src/arrow/util/variant_test.cc +++ b/cpp/src/arrow/util/variant_test.cc @@ -125,6 +125,19 @@ TEST(Variant, CopyConstruction) { EXPECT_NO_THROW(AssertCopyConstruction(CopyAssignThrows{})); } +TEST(Variant, Noexcept) { + struct CopyThrows { + CopyThrows() = default; + CopyThrows(const CopyThrows&) { throw 42; } + CopyThrows& operator=(const CopyThrows&) { throw 42; } + }; + static_assert(!std::is_nothrow_copy_constructible::value, ""); + static_assert(std::is_nothrow_copy_constructible::value, ""); + static_assert(std::is_nothrow_copy_constructible>::value, ""); + static_assert( + !std::is_nothrow_copy_constructible>::value, ""); +} + TEST(Variant, Emplace) { using variant_type = Variant, int>; variant_type v; diff --git a/r/src/compute.cpp b/r/src/compute.cpp index a456ec4711b..12b23dad798 100644 --- a/r/src/compute.cpp +++ b/r/src/compute.cpp @@ -191,7 +191,7 @@ SEXP compute__CallFunction(std::string func_name, cpp11::list args, cpp11::list auto datum_args = arrow::r::from_r_list(args); auto out = ValueOrStop( arrow::compute::CallFunction(func_name, datum_args, opts.get(), gc_context())); - return from_datum(out); + return from_datum(std::move(out)); } #endif diff --git a/r/src/dataset.cpp b/r/src/dataset.cpp index 2ad59677eb0..8d8aa9cff6d 100644 --- a/r/src/dataset.cpp +++ b/r/src/dataset.cpp @@ -312,10 +312,7 @@ void dataset___ScannerBuilder__Project(const std::shared_ptr // [[arrow::export]] void dataset___ScannerBuilder__Filter(const std::shared_ptr& sb, const std::shared_ptr& expr) { - // Expressions converted from R's expressions are typed with R's native type, - // i.e. double, int64_t and bool. - auto cast_filter = ValueOrStop(InsertImplicitCasts(*expr, *sb->schema())); - StopIfNotOk(sb->Filter(cast_filter)); + StopIfNotOk(sb->Filter(*expr)); } // [[arrow::export]] diff --git a/r/src/expression.cpp b/r/src/expression.cpp index 575dddc7da7..b1116d27a2f 100644 --- a/r/src/expression.cpp +++ b/r/src/expression.cpp @@ -19,93 +19,98 @@ #if defined(ARROW_R_WITH_ARROW) +#include #include namespace ds = ::arrow::dataset; +std::shared_ptr Share(ds::Expression expr) { + return std::make_shared(std::move(expr)); +} + // [[arrow::export]] std::shared_ptr dataset___expr__field_ref(std::string name) { - return ds::field_ref(std::move(name)); + return Share(ds::field_ref(std::move(name))); } // [[arrow::export]] std::shared_ptr dataset___expr__equal( const std::shared_ptr& lhs, const std::shared_ptr& rhs) { - return ds::equal(lhs, rhs); + return Share(ds::equal(*lhs, *rhs)); } // [[arrow::export]] std::shared_ptr dataset___expr__not_equal( const std::shared_ptr& lhs, const std::shared_ptr& rhs) { - return ds::not_equal(lhs, rhs); + return Share(ds::not_equal(*lhs, *rhs)); } // [[arrow::export]] std::shared_ptr dataset___expr__greater( const std::shared_ptr& lhs, const std::shared_ptr& rhs) { - return ds::greater(lhs, rhs); + return Share(ds::greater(*lhs, *rhs)); } // [[arrow::export]] std::shared_ptr dataset___expr__greater_equal( const std::shared_ptr& lhs, const std::shared_ptr& rhs) { - return ds::greater_equal(lhs, rhs); + return Share(ds::greater_equal(*lhs, *rhs)); } // [[arrow::export]] std::shared_ptr dataset___expr__less( const std::shared_ptr& lhs, const std::shared_ptr& rhs) { - return ds::less(lhs, rhs); + return Share(ds::less(*lhs, *rhs)); } // [[arrow::export]] std::shared_ptr dataset___expr__less_equal( const std::shared_ptr& lhs, const std::shared_ptr& rhs) { - return ds::less_equal(lhs, rhs); + return Share(ds::less_equal(*lhs, *rhs)); } // [[arrow::export]] std::shared_ptr dataset___expr__in( const std::shared_ptr& lhs, const std::shared_ptr& rhs) { - return lhs->In(rhs).Copy(); + return Share(ds::call("is_in", {*lhs}, arrow::compute::SetLookupOptions{rhs})); } // [[arrow::export]] std::shared_ptr dataset___expr__and( const std::shared_ptr& lhs, const std::shared_ptr& rhs) { - return ds::and_(lhs, rhs); + return Share(ds::and_(*lhs, *rhs)); } // [[arrow::export]] std::shared_ptr dataset___expr__or( const std::shared_ptr& lhs, const std::shared_ptr& rhs) { - return ds::or_(lhs, rhs); + return Share(ds::or_(*lhs, *rhs)); } // [[arrow::export]] std::shared_ptr dataset___expr__not( const std::shared_ptr& lhs) { - return ds::not_(lhs); + return Share(ds::not_(*lhs)); } // [[arrow::export]] std::shared_ptr dataset___expr__is_valid( const std::shared_ptr& lhs) { - return lhs->IsValid().Copy(); + return Share(ds::call("is_valid", {*lhs})); } // [[arrow::export]] std::shared_ptr dataset___expr__scalar( const std::shared_ptr& x) { - return ds::scalar(x); + return Share(ds::literal(x)); } // [[arrow::export]] From 27ebb463a9ce0772ce93bb479ecca1d3db0136cd Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Fri, 11 Dec 2020 13:30:46 -0500 Subject: [PATCH 08/31] add more scalar cast implementations --- .../compute/kernels/scalar_cast_internal.h | 8 ++ .../compute/kernels/scalar_cast_temporal.cc | 75 +++++++++++++------ .../compute/kernels/scalar_set_lookup.cc | 29 ++----- .../arrow/compute/kernels/util_internal.cc | 23 ++++++ cpp/src/arrow/compute/kernels/util_internal.h | 10 +++ cpp/src/arrow/dataset/expression.cc | 47 ++++++------ cpp/src/arrow/dataset/expression_internal.h | 10 ++- cpp/src/arrow/dataset/expression_test.cc | 56 +++++++++----- cpp/src/arrow/dataset/test_util.h | 2 + 9 files changed, 169 insertions(+), 91 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_internal.h b/cpp/src/arrow/compute/kernels/scalar_cast_internal.h index 333601cf216..15769ce5f8f 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_internal.h +++ b/cpp/src/arrow/compute/kernels/scalar_cast_internal.h @@ -21,6 +21,7 @@ #include "arrow/compute/cast.h" // IWYU pragma: export #include "arrow/compute/cast_internal.h" // IWYU pragma: export #include "arrow/compute/kernels/common.h" +#include "arrow/compute/kernels/util_internal.h" namespace arrow { @@ -62,6 +63,13 @@ void AddSimpleCast(InputType in_ty, OutputType out_ty, CastFunction* func) { CastFunctor::Exec)); } +template +void AddSimpleArrayOnlyCast(InputType in_ty, OutputType out_ty, CastFunction* func) { + DCHECK_OK(func->AddKernel( + InType::type_id, {in_ty}, out_ty, + TrivialScalarUnaryAsArraysExec(CastFunctor::Exec))); +} + void ZeroCopyCastExec(KernelContext* ctx, const ExecBatch& batch, Datum* out); void AddZeroCopyCast(Type::type in_type_id, InputType in_type, OutputType out_type, diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_temporal.cc b/cpp/src/arrow/compute/kernels/scalar_cast_temporal.cc index 10b38d75095..7cff41f267e 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_temporal.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_temporal.cc @@ -126,7 +126,6 @@ struct CastFunctor< enable_if_t<(is_timestamp_type::value && is_timestamp_type::value) || (is_duration_type::value && is_duration_type::value)>> { static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - // TODO: Make this work on scalar inputs DCHECK_EQ(batch[0].kind(), Datum::ARRAY); const ArrayData& input = *batch[0].array(); @@ -147,7 +146,6 @@ struct CastFunctor< template <> struct CastFunctor { static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - // TODO: Make this work on scalar inputs DCHECK_EQ(batch[0].kind(), Datum::ARRAY); const ArrayData& input = *batch[0].array(); @@ -170,7 +168,6 @@ struct CastFunctor { template <> struct CastFunctor { static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - // TODO: Make this work on scalar inputs DCHECK_EQ(batch[0].kind(), Datum::ARRAY); const CastOptions& options = checked_cast(*ctx->state()).options; @@ -224,7 +221,6 @@ struct CastFunctor::value && is_time_type:: using out_t = typename O::c_type; static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - // TODO: Make this work on scalar inputs DCHECK_EQ(batch[0].kind(), Datum::ARRAY); const ArrayData& input = *batch[0].array(); @@ -245,7 +241,6 @@ struct CastFunctor::value && is_time_type:: template <> struct CastFunctor { static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - // TODO: Make this work on scalar inputs DCHECK_EQ(batch[0].kind(), Datum::ARRAY); ShiftTime(ctx, util::MULTIPLY, kMillisecondsInDay, @@ -256,7 +251,6 @@ struct CastFunctor { template <> struct CastFunctor { static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - // TODO: Make this work on scalar inputs DCHECK_EQ(batch[0].kind(), Datum::ARRAY); ShiftTime(ctx, util::DIVIDE, kMillisecondsInDay, *batch[0].array(), @@ -264,6 +258,38 @@ struct CastFunctor { } }; +// ---------------------------------------------------------------------- +// date32, date64 to timestamp + +template <> +struct CastFunctor { + static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + DCHECK_EQ(batch[0].kind(), Datum::ARRAY); + + const auto& out_type = checked_cast(*out->type()); + // get conversion SECOND -> unit + auto conversion = util::GetTimestampConversion(TimeUnit::SECOND, out_type.unit()); + DCHECK_EQ(conversion.first, util::MULTIPLY); + + // multiply to achieve days -> unit + conversion.second *= kMillisecondsInDay / 1000; + ShiftTime(ctx, util::MULTIPLY, conversion.second, *batch[0].array(), + out->mutable_array()); + } +}; + +template <> +struct CastFunctor { + static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + DCHECK_EQ(batch[0].kind(), Datum::ARRAY); + + const auto& out_type = checked_cast(*out->type()); + auto conversion = util::GetTimestampConversion(TimeUnit::MILLI, out_type.unit()); + ShiftTime(ctx, util::MULTIPLY, conversion.second, *batch[0].array(), + out->mutable_array()); + } +}; + // ---------------------------------------------------------------------- // String to Timestamp @@ -307,11 +333,11 @@ std::shared_ptr GetDate32Cast() { AddZeroCopyCast(Type::INT32, int32(), date32(), func.get()); // date64 -> date32 - AddSimpleCast(date64(), date32(), func.get()); + AddSimpleArrayOnlyCast(date64(), date32(), func.get()); // timestamp -> date32 - AddSimpleCast(InputType(Type::TIMESTAMP), date32(), - func.get()); + AddSimpleArrayOnlyCast(InputType(Type::TIMESTAMP), date32(), + func.get()); return func; } @@ -324,11 +350,11 @@ std::shared_ptr GetDate64Cast() { AddZeroCopyCast(Type::INT64, int64(), date64(), func.get()); // date32 -> date64 - AddSimpleCast(date32(), date64(), func.get()); + AddSimpleArrayOnlyCast(date32(), date64(), func.get()); // timestamp -> date64 - AddSimpleCast(InputType(Type::TIMESTAMP), date64(), - func.get()); + AddSimpleArrayOnlyCast(InputType(Type::TIMESTAMP), date64(), + func.get()); return func; } @@ -358,8 +384,8 @@ std::shared_ptr GetTime32Cast() { AddZeroCopyCast(Type::INT32, /*in_type=*/int32(), kOutputTargetType, func.get()); // time64 -> time32 - AddSimpleCast(InputType(Type::TIME64), kOutputTargetType, - func.get()); + AddSimpleArrayOnlyCast(InputType(Type::TIME64), + kOutputTargetType, func.get()); // time32 -> time32 AddCrossUnitCast(func.get()); @@ -375,8 +401,8 @@ std::shared_ptr GetTime64Cast() { AddZeroCopyCast(Type::INT64, /*in_type=*/int64(), kOutputTargetType, func.get()); // time32 -> time64 - AddSimpleCast(InputType(Type::TIME32), kOutputTargetType, - func.get()); + AddSimpleArrayOnlyCast(InputType(Type::TIME32), + kOutputTargetType, func.get()); // Between durations AddCrossUnitCast(func.get()); @@ -392,17 +418,18 @@ std::shared_ptr GetTimestampCast() { AddZeroCopyCast(Type::INT64, /*in_type=*/int64(), kOutputTargetType, func.get()); // From date types - // TODO: ARROW-8876, these casts are not implemented - // AddSimpleCast(InputType(Type::DATE32), - // kOutputTargetType, func.get()); - // AddSimpleCast(InputType(Type::DATE64), - // kOutputTargetType, func.get()); + // TODO: ARROW-8876, these casts are not directly tested + AddSimpleArrayOnlyCast(InputType(Type::DATE32), + kOutputTargetType, func.get()); + AddSimpleArrayOnlyCast(InputType(Type::DATE64), + kOutputTargetType, func.get()); // string -> timestamp - AddSimpleCast(utf8(), kOutputTargetType, func.get()); + AddSimpleArrayOnlyCast(utf8(), kOutputTargetType, + func.get()); // large_string -> timestamp - AddSimpleCast(large_utf8(), kOutputTargetType, - func.get()); + AddSimpleArrayOnlyCast(large_utf8(), kOutputTargetType, + func.get()); // From one timestamp to another AddCrossUnitCast(func.get()); diff --git a/cpp/src/arrow/compute/kernels/scalar_set_lookup.cc b/cpp/src/arrow/compute/kernels/scalar_set_lookup.cc index d43083c7538..722f3173a34 100644 --- a/cpp/src/arrow/compute/kernels/scalar_set_lookup.cc +++ b/cpp/src/arrow/compute/kernels/scalar_set_lookup.cc @@ -19,10 +19,10 @@ #include "arrow/array/builder_primitive.h" #include "arrow/compute/api_scalar.h" #include "arrow/compute/kernels/common.h" +#include "arrow/compute/kernels/util_internal.h" #include "arrow/util/bit_util.h" #include "arrow/util/bitmap_writer.h" #include "arrow/util/hashing.h" -#include "arrow/util/optional.h" #include "arrow/visitor_inline.h" namespace arrow { @@ -255,25 +255,8 @@ struct IndexInVisitor { } }; -void ExecArrayOrScalar(KernelContext* ctx, const Datum& in, Datum* out, - std::function array_impl) { - if (in.is_array()) { - KERNEL_RETURN_IF_ERROR(ctx, array_impl(*in.array())); - return; - } - - std::shared_ptr in_array; - std::shared_ptr out_scalar; - KERNEL_RETURN_IF_ERROR(ctx, MakeArrayFromScalar(*in.scalar(), 1).Value(&in_array)); - KERNEL_RETURN_IF_ERROR(ctx, array_impl(*in_array->data())); - KERNEL_RETURN_IF_ERROR(ctx, out->make_array()->GetScalar(0).Value(&out_scalar)); - *out = std::move(out_scalar); -} - void ExecIndexIn(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - ExecArrayOrScalar(ctx, batch[0], out, [&](const ArrayData& in) { - return IndexInVisitor(ctx, in, out).Execute(); - }); + KERNEL_RETURN_IF_ERROR(ctx, IndexInVisitor(ctx, *batch[0].array(), out).Execute()); } // ---------------------------------------------------------------------- @@ -360,9 +343,7 @@ struct IsInVisitor { }; void ExecIsIn(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - ExecArrayOrScalar(ctx, batch[0], out, [&](const ArrayData& in) { - return IsInVisitor(ctx, in, out).Execute(); - }); + KERNEL_RETURN_IF_ERROR(ctx, IsInVisitor(ctx, *batch[0].array(), out).Execute()); } // Unary set lookup kernels available for the following input types @@ -451,7 +432,7 @@ void RegisterScalarSetLookup(FunctionRegistry* registry) { { ScalarKernel isin_base; isin_base.init = InitSetLookup; - isin_base.exec = ExecIsIn; + isin_base.exec = TrivialScalarUnaryAsArraysExec(ExecIsIn); auto is_in = std::make_shared("is_in", Arity::Unary(), &is_in_doc); AddBasicSetLookupKernels(isin_base, /*output_type=*/boolean(), is_in.get()); @@ -468,7 +449,7 @@ void RegisterScalarSetLookup(FunctionRegistry* registry) { { ScalarKernel index_in_base; index_in_base.init = InitSetLookup; - index_in_base.exec = ExecIndexIn; + index_in_base.exec = TrivialScalarUnaryAsArraysExec(ExecIndexIn); index_in_base.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE; index_in_base.mem_allocation = MemAllocation::NO_PREALLOCATE; auto index_in = diff --git a/cpp/src/arrow/compute/kernels/util_internal.cc b/cpp/src/arrow/compute/kernels/util_internal.cc index 32c6317a104..21b6b706a50 100644 --- a/cpp/src/arrow/compute/kernels/util_internal.cc +++ b/cpp/src/arrow/compute/kernels/util_internal.cc @@ -57,6 +57,29 @@ PrimitiveArg GetPrimitiveArg(const ArrayData& arr) { return arg; } +ArrayKernelExec TrivialScalarUnaryAsArraysExec(ArrayKernelExec exec) { + return [exec](KernelContext* ctx, const ExecBatch& batch, Datum* out) { + if (out->is_array()) { + return exec(ctx, batch, out); + } + + if (!batch[0].scalar()->is_valid) { + out->scalar()->is_valid = false; + return; + } + + Datum array_in, array_out; + KERNEL_RETURN_IF_ERROR( + ctx, MakeArrayFromScalar(*batch[0].scalar(), 1).As().Value(&array_in)); + KERNEL_RETURN_IF_ERROR( + ctx, MakeArrayFromScalar(*out->scalar(), 1).As().Value(&array_out)); + + exec(ctx, ExecBatch{{std::move(array_in)}, 1}, &array_out); + KERNEL_RETURN_IF_ERROR(ctx, + array_out.make_array()->GetScalar(0).As().Value(out)); + }; +} + } // namespace internal } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/util_internal.h b/cpp/src/arrow/compute/kernels/util_internal.h index 7ab59965752..4aad3804366 100644 --- a/cpp/src/arrow/compute/kernels/util_internal.h +++ b/cpp/src/arrow/compute/kernels/util_internal.h @@ -18,8 +18,12 @@ #pragma once #include +#include +#include "arrow/array/util.h" #include "arrow/buffer.h" +#include "arrow/compute/kernels/codegen_internal.h" +#include "arrow/compute/type_fwd.h" namespace arrow { namespace compute { @@ -50,6 +54,12 @@ int GetBitWidth(const DataType& type); // rather than duplicating compiled code to do all these in each kernel. PrimitiveArg GetPrimitiveArg(const ArrayData& arr); +// Augment a unary ArrayKernelExec which supports only array-like inputs with support for +// scalar inputs. Scalars will be transformed to 1-long arrays which are passed to the +// original exec. This could be far more efficient, but instead of optimizing this it'd be +// better to support scalar inputs "upstream" in original exec. +ArrayKernelExec TrivialScalarUnaryAsArraysExec(ArrayKernelExec exec); + } // namespace internal } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/dataset/expression.cc b/cpp/src/arrow/dataset/expression.cc index 6ff53d4190c..cc142cbb2d4 100644 --- a/cpp/src/arrow/dataset/expression.cc +++ b/cpp/src/arrow/dataset/expression.cc @@ -335,7 +335,7 @@ Status MaybeInsertCast(std::shared_ptr to_type, Expression* expr) { auto call_with_cast = *CallNotNull(with_cast); call_with_cast.arguments[0] = std::move(*expr); *expr = Expression(std::make_shared(std::move(call_with_cast)), - ValueDescr{std::move(to_type), expr->descr().shape}); + ValueDescr{std::move(to_type), expr->descr().shape}); return Status::OK(); } @@ -344,7 +344,7 @@ Status InsertImplicitCasts(Expression::Call* call) { DCHECK(std::all_of(call->arguments.begin(), call->arguments.end(), [](const Expression& argument) { return argument.IsBound(); })); - if (Comparison::Get(call->function)) { + if (IsSameTypesBinary(call->function)) { for (auto&& argument : call->arguments) { if (auto value_type = GetDictionaryValueType(argument.descr().type)) { RETURN_NOT_OK(MaybeInsertCast(std::move(value_type), &argument)); @@ -390,7 +390,7 @@ Status InsertImplicitCasts(Expression::Call* call) { } Result Expression::Bind(ValueDescr in, - compute::ExecContext* exec_context) const { + compute::ExecContext* exec_context) const { if (exec_context == nullptr) { compute::ExecContext exec_context; return Bind(std::move(in), &exec_context); @@ -430,7 +430,7 @@ Result Expression::Bind(ValueDescr in, } Result Expression::Bind(const Schema& in_schema, - compute::ExecContext* exec_context) const { + compute::ExecContext* exec_context) const { return Bind(ValueDescr::Array(struct_(in_schema.fields())), exec_context); } @@ -493,9 +493,9 @@ std::array, 2> ArgumentsAndFlippedArguments(const Expression::Call& call) { DCHECK_EQ(call.arguments.size(), 2); return {std::pair{call.arguments[0], - call.arguments[1]}, + call.arguments[1]}, std::pair{call.arguments[1], - call.arguments[0]}}; + call.arguments[0]}}; } template Canonicalize(Expression expr, compute::ExecContext* exec_cont // fold the chain back up const auto& descr = expr.descr(); - auto folded = FoldLeft( - chain.fringe.begin(), chain.fringe.end(), - [call, &descr, &AlreadyCanonicalized](Expression l, Expression r) { - auto ret = *call; - ret.arguments = {std::move(l), std::move(r)}; - Expression expr(std::make_shared(std::move(ret)), - descr); - AlreadyCanonicalized.Add({expr}); - return expr; - }); + auto folded = + FoldLeft(chain.fringe.begin(), chain.fringe.end(), + [call, &descr, &AlreadyCanonicalized](Expression l, Expression r) { + auto ret = *call; + ret.arguments = {std::move(l), std::move(r)}; + Expression expr( + std::make_shared(std::move(ret)), descr); + AlreadyCanonicalized.Add({expr}); + return expr; + }); return std::move(*folded); } @@ -792,9 +792,8 @@ Result Canonicalize(Expression expr, compute::ExecContext* exec_cont ARROW_ASSIGN_OR_RAISE(flipped_call.kernel, function->DispatchExact(descrs)); std::swap(flipped_call.arguments[0], flipped_call.arguments[1]); - return Expression( - std::make_shared(std::move(flipped_call)), - expr.descr()); + return Expression(std::make_shared(std::move(flipped_call)), + expr.descr()); } } @@ -804,7 +803,7 @@ Result Canonicalize(Expression expr, compute::ExecContext* exec_cont } Result DirectComparisonSimplification(Expression expr, - const Expression::Call& guarantee) { + const Expression::Call& guarantee) { return Modify( std::move(expr), [](Expression expr) { return expr; }, [&guarantee](Expression expr, ...) -> Result { @@ -862,7 +861,7 @@ Result DirectComparisonSimplification(Expression expr, } Result SimplifyWithGuarantee(Expression expr, - const Expression& guaranteed_true_predicate) { + const Expression& guaranteed_true_predicate) { auto conjunction_members = GuaranteeConjunctionMembers(guaranteed_true_predicate); std::unordered_map known_values; @@ -1072,7 +1071,7 @@ Expression and_(Expression lhs, Expression rhs) { Expression and_(const std::vector& operands) { auto folded = FoldLeft(operands.begin(), - operands.end(), and_); + operands.end(), and_); if (folded) { return std::move(*folded); } @@ -1084,8 +1083,8 @@ Expression or_(Expression lhs, Expression rhs) { } Expression or_(const std::vector& operands) { - auto folded = FoldLeft(operands.begin(), - operands.end(), or_); + auto folded = + FoldLeft(operands.begin(), operands.end(), or_); if (folded) { return std::move(*folded); } diff --git a/cpp/src/arrow/dataset/expression_internal.h b/cpp/src/arrow/dataset/expression_internal.h index 8c26e037f16..90f28cd23ef 100644 --- a/cpp/src/arrow/dataset/expression_internal.h +++ b/cpp/src/arrow/dataset/expression_internal.h @@ -216,6 +216,14 @@ inline bool IsSetLookup(const std::string& function) { return function == "is_in" || function == "index_in"; } +inline bool IsSameTypesBinary(const std::string& function) { + if (Comparison::Get(function)) return true; + + static std::unordered_set set{"add", "subtract", "multiply", "divide"}; + + return set.find(function) != set.end(); +} + inline const compute::SetLookupOptions* GetSetLookupOptions( const Expression::Call& call) { if (!IsSetLookup(call.function)) return nullptr; @@ -402,7 +410,7 @@ inline Result> GetFunction( template Result Modify(Expression expr, const PreVisit& pre, - const PostVisitCall& post_call) { + const PostVisitCall& post_call) { ARROW_ASSIGN_OR_RAISE(expr, Result(pre(std::move(expr)))); auto call = expr.call(); diff --git a/cpp/src/arrow/dataset/expression_test.cc b/cpp/src/arrow/dataset/expression_test.cc index d9738798259..91e9dcdc4f1 100644 --- a/cpp/src/arrow/dataset/expression_test.cc +++ b/cpp/src/arrow/dataset/expression_test.cc @@ -219,30 +219,36 @@ TEST(Expression, BindLiteral) { } } -void ExpectBindsTo(Expression expr, Expression expected, +void ExpectBindsTo(Expression expr, util::optional expected, Expression* bound_out = nullptr) { + if (!expected) { + expected = expr; + } + ASSERT_OK_AND_ASSIGN(auto bound, expr.Bind(*kBoringSchema)); EXPECT_TRUE(bound.IsBound()); - ASSERT_OK_AND_ASSIGN(expected, expected.Bind(*kBoringSchema)); - EXPECT_EQ(bound, expected) << " unbound: " << expr.ToString(); + ASSERT_OK_AND_ASSIGN(expected, expected->Bind(*kBoringSchema)); + EXPECT_EQ(bound, *expected) << " unbound: " << expr.ToString(); if (bound_out) { *bound_out = bound; } } +const auto no_change = util::nullopt; + TEST(Expression, BindFieldRef) { // an unbound field_ref does not have the output ValueDescr set auto expr = field_ref("alpha"); EXPECT_EQ(expr.descr(), ValueDescr{}); EXPECT_FALSE(expr.IsBound()); - ExpectBindsTo(field_ref("i32"), field_ref("i32"), &expr); + ExpectBindsTo(field_ref("i32"), no_change, &expr); EXPECT_EQ(expr.descr(), ValueDescr::Array(int32())); // if the field is not found, a null scalar will be emitted - ExpectBindsTo(field_ref("no such field"), field_ref("no such field"), &expr); + ExpectBindsTo(field_ref("no such field"), no_change, &expr); EXPECT_EQ(expr.descr(), ValueDescr::Scalar(null())); // referencing a field by name is not supported if that name is not unique @@ -258,17 +264,19 @@ TEST(Expression, BindFieldRef) { } TEST(Expression, BindCall) { - auto expr = call("add", {field_ref("a"), field_ref("b")}); + auto expr = call("add", {field_ref("i32"), field_ref("i32_req")}); EXPECT_FALSE(expr.IsBound()); - ASSERT_OK_AND_ASSIGN(expr, - expr.Bind(Schema({field("a", int32()), field("b", int32())}))); + ExpectBindsTo(expr, no_change, &expr); EXPECT_EQ(expr.descr(), ValueDescr::Array(int32())); - EXPECT_TRUE(expr.IsBound()); - expr = call("add", {field_ref("a"), literal(3.5)}); - ASSERT_RAISES(NotImplemented, - expr.Bind(Schema({field("a", int32()), field("b", int32())}))); + // literal(3) may be safely cast to float32, so binding this expr casts that literal: + ExpectBindsTo(call("add", {field_ref("f32"), literal(3)}), + call("add", {field_ref("f32"), literal(3.0F)})); + + // literal(3.5) may not be safely cast to int32, so binding this expr fails: + ASSERT_RAISES(Invalid, + call("add", {field_ref("i32"), literal(3.5)}).Bind(*kBoringSchema)); } TEST(Expression, BindWithImplicitCasts) { @@ -284,11 +292,9 @@ TEST(Expression, BindWithImplicitCasts) { } // scalars are directly cast when possible: - ExpectBindsTo( - equal(field_ref("ts_ns"), literal("1990-10-23 10:23:33")), - equal(field_ref("ts_ns"), - literal( - *MakeScalar("1990-10-23 10:23:33")->CastTo(timestamp(TimeUnit::NANO))))); + auto ts_scalar = MakeScalar("1990-10-23")->CastTo(timestamp(TimeUnit::NANO)); + ExpectBindsTo(equal(field_ref("ts_ns"), literal("1990-10-23")), + equal(field_ref("ts_ns"), literal(*ts_scalar))); // cast value_set to argument type auto Opts = [](std::shared_ptr type) { @@ -509,7 +515,10 @@ TEST(Expression, FoldConstants) { literal(2), })); - compute::SetLookupOptions in_123(ArrayFromJSON(int32(), "[1,2,3]"), true); + compute::SetLookupOptions in_123(ArrayFromJSON(int32(), "[1,2,3]")); + + ExpectFoldsTo(call("is_in", {literal(2)}, in_123), literal(true)); + ExpectFoldsTo( call("is_in", {call("add", {field_ref("i32"), call("multiply", {literal(2), literal(3)})})}, @@ -932,6 +941,17 @@ TEST(Expression, Filter) { {"i32": 0, "f32": null, "in": 1}, {"i32": 0, "f32": 1.0, "in": 1} ])"); + + ExpectFilter( + greater(call("multiply", {field_ref("f32"), field_ref("f64")}), literal(0)), R"([ + {"f64": 0.3, "f32": 0.1, "in": 1}, + {"f64": -0.1, "f32": 0.3, "in": 0}, + {"f64": 0.1, "f32": 0.2, "in": 1}, + {"f64": 0.0, "f32": -0.1, "in": 0}, + {"f64": 1.0, "f32": 0.1, "in": 1}, + {"f64": -2.0, "f32": null, "in": null}, + {"f64": 3.0, "f32": 1.0, "in": 1} + ])"); } TEST(Expression, SerializationRoundTrips) { diff --git a/cpp/src/arrow/dataset/test_util.h b/cpp/src/arrow/dataset/test_util.h index c5904b45e06..1c7c471d3ca 100644 --- a/cpp/src/arrow/dataset/test_util.h +++ b/cpp/src/arrow/dataset/test_util.h @@ -53,8 +53,10 @@ const std::shared_ptr kBoringSchema = schema({ field("i32", int32()), field("i32_req", int32(), /*nullable=*/false), field("i64", int64()), + field("date64", date64()), field("f32", float32()), field("f32_req", float32(), /*nullable=*/false), + field("f64", float64()), field("bool", boolean()), field("str", utf8()), field("dict_str", dictionary(int32(), utf8())), From 9245f24d31f17d00e467a445c5311c8a0c9510e6 Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Mon, 14 Dec 2020 11:13:26 -0500 Subject: [PATCH 09/31] use dataset___expr__call to create Call expressions --- r/R/arrowExports.R | 48 +--------- r/R/expression.R | 34 ++++--- r/src/arrowExports.cpp | 197 ++++------------------------------------- r/src/compute.cpp | 11 ++- r/src/expression.cpp | 92 ++++--------------- 5 files changed, 63 insertions(+), 319 deletions(-) diff --git a/r/R/arrowExports.R b/r/R/arrowExports.R index 5e519080cee..5bd9c0bcf7f 100644 --- a/r/R/arrowExports.R +++ b/r/R/arrowExports.R @@ -736,52 +736,12 @@ FixedSizeListType__list_size <- function(type){ .Call(`_arrow_FixedSizeListType__list_size` , type) } -dataset___expr__field_ref <- function(name){ - .Call(`_arrow_dataset___expr__field_ref` , name) -} - -dataset___expr__equal <- function(lhs, rhs){ - .Call(`_arrow_dataset___expr__equal` , lhs, rhs) -} - -dataset___expr__not_equal <- function(lhs, rhs){ - .Call(`_arrow_dataset___expr__not_equal` , lhs, rhs) -} - -dataset___expr__greater <- function(lhs, rhs){ - .Call(`_arrow_dataset___expr__greater` , lhs, rhs) -} - -dataset___expr__greater_equal <- function(lhs, rhs){ - .Call(`_arrow_dataset___expr__greater_equal` , lhs, rhs) -} - -dataset___expr__less <- function(lhs, rhs){ - .Call(`_arrow_dataset___expr__less` , lhs, rhs) +dataset___expr__call <- function(func_name, argument_list, options){ + .Call(`_arrow_dataset___expr__call` , func_name, argument_list, options) } -dataset___expr__less_equal <- function(lhs, rhs){ - .Call(`_arrow_dataset___expr__less_equal` , lhs, rhs) -} - -dataset___expr__in <- function(lhs, rhs){ - .Call(`_arrow_dataset___expr__in` , lhs, rhs) -} - -dataset___expr__and <- function(lhs, rhs){ - .Call(`_arrow_dataset___expr__and` , lhs, rhs) -} - -dataset___expr__or <- function(lhs, rhs){ - .Call(`_arrow_dataset___expr__or` , lhs, rhs) -} - -dataset___expr__not <- function(lhs){ - .Call(`_arrow_dataset___expr__not` , lhs) -} - -dataset___expr__is_valid <- function(lhs){ - .Call(`_arrow_dataset___expr__is_valid` , lhs) +dataset___expr__field_ref <- function(name){ + .Call(`_arrow_dataset___expr__field_ref` , name) } dataset___expr__scalar <- function(x){ diff --git a/r/R/expression.R b/r/R/expression.R index d5623fb7786..8c7e4bca59f 100644 --- a/r/R/expression.R +++ b/r/R/expression.R @@ -182,36 +182,44 @@ Expression$field_ref <- function(name) { Expression$scalar <- function(x) { dataset___expr__scalar(Scalar$create(x)) } +Expression$call <- function(name, arguments, options = list()) { + dataset___expr__call(name, arguments, options) +} Expression$compare <- function(OP, e1, e2) { comp_func <- comparison_function_map[[OP]] if (is.null(comp_func)) { stop(OP, " is not a supported comparison function", call. = FALSE) } - comp_func(e1, e2) + Expression$call(comp_func, list(e1, e2)) } comparison_function_map <- list( - "==" = dataset___expr__equal, - "!=" = dataset___expr__not_equal, - ">" = dataset___expr__greater, - ">=" = dataset___expr__greater_equal, - "<" = dataset___expr__less, - "<=" = dataset___expr__less_equal + "==" = "equal", + "!=" = "not_equal", + ">" = "greater", + ">=" = "greater_equal", + "<" = "less", + "<=" = "less_equal" ) Expression$in_ <- function(x, set) { - dataset___expr__in(x, Array$create(set)) + Expression$call("is_in", list(x), + options = list(value_set = Array$create(set), + skip_nulls = TRUE)) } Expression$and <- function(e1, e2) { - dataset___expr__and(e1, e2) + Expression$call("and_kleene", list(e1, e2)) } Expression$or <- function(e1, e2) { - dataset___expr__or(e1, e2) + Expression$call("or_kleene", list(e1, e2)) } Expression$not <- function(e1) { - dataset___expr__not(e1) + Expression$call("invert", list(e1)) } Expression$is_valid <- function(e1) { - dataset___expr__is_valid(e1) + Expression$call("is_valid", list(e1)) +} +Expression$is_null <- function(e1) { + Expression$call("is_null", list(e1)) } #' @export @@ -253,4 +261,4 @@ make_expression <- function(operator, e1, e2) { } #' @export -is.na.Expression <- function(x) !Expression$is_valid(x) +is.na.Expression <- function(x) Expression$is_null(x) diff --git a/r/src/arrowExports.cpp b/r/src/arrowExports.cpp index 975ff72f7f1..26bd57e1e28 100644 --- a/r/src/arrowExports.cpp +++ b/r/src/arrowExports.cpp @@ -2846,190 +2846,33 @@ extern "C" SEXP _arrow_FixedSizeListType__list_size(SEXP type_sexp){ // expression.cpp #if defined(ARROW_R_WITH_ARROW) -std::shared_ptr dataset___expr__field_ref(std::string name); -extern "C" SEXP _arrow_dataset___expr__field_ref(SEXP name_sexp){ +std::shared_ptr dataset___expr__call(std::string func_name, cpp11::list argument_list, cpp11::list options); +extern "C" SEXP _arrow_dataset___expr__call(SEXP func_name_sexp, SEXP argument_list_sexp, SEXP options_sexp){ BEGIN_CPP11 - arrow::r::Input::type name(name_sexp); - return cpp11::as_sexp(dataset___expr__field_ref(name)); -END_CPP11 -} -#else -extern "C" SEXP _arrow_dataset___expr__field_ref(SEXP name_sexp){ - Rf_error("Cannot call dataset___expr__field_ref(). Please use arrow::install_arrow() to install required runtime libraries. "); -} -#endif - -// expression.cpp -#if defined(ARROW_R_WITH_ARROW) -std::shared_ptr dataset___expr__equal(const std::shared_ptr& lhs, const std::shared_ptr& rhs); -extern "C" SEXP _arrow_dataset___expr__equal(SEXP lhs_sexp, SEXP rhs_sexp){ -BEGIN_CPP11 - arrow::r::Input&>::type lhs(lhs_sexp); - arrow::r::Input&>::type rhs(rhs_sexp); - return cpp11::as_sexp(dataset___expr__equal(lhs, rhs)); -END_CPP11 -} -#else -extern "C" SEXP _arrow_dataset___expr__equal(SEXP lhs_sexp, SEXP rhs_sexp){ - Rf_error("Cannot call dataset___expr__equal(). Please use arrow::install_arrow() to install required runtime libraries. "); -} -#endif - -// expression.cpp -#if defined(ARROW_R_WITH_ARROW) -std::shared_ptr dataset___expr__not_equal(const std::shared_ptr& lhs, const std::shared_ptr& rhs); -extern "C" SEXP _arrow_dataset___expr__not_equal(SEXP lhs_sexp, SEXP rhs_sexp){ -BEGIN_CPP11 - arrow::r::Input&>::type lhs(lhs_sexp); - arrow::r::Input&>::type rhs(rhs_sexp); - return cpp11::as_sexp(dataset___expr__not_equal(lhs, rhs)); -END_CPP11 -} -#else -extern "C" SEXP _arrow_dataset___expr__not_equal(SEXP lhs_sexp, SEXP rhs_sexp){ - Rf_error("Cannot call dataset___expr__not_equal(). Please use arrow::install_arrow() to install required runtime libraries. "); -} -#endif - -// expression.cpp -#if defined(ARROW_R_WITH_ARROW) -std::shared_ptr dataset___expr__greater(const std::shared_ptr& lhs, const std::shared_ptr& rhs); -extern "C" SEXP _arrow_dataset___expr__greater(SEXP lhs_sexp, SEXP rhs_sexp){ -BEGIN_CPP11 - arrow::r::Input&>::type lhs(lhs_sexp); - arrow::r::Input&>::type rhs(rhs_sexp); - return cpp11::as_sexp(dataset___expr__greater(lhs, rhs)); -END_CPP11 -} -#else -extern "C" SEXP _arrow_dataset___expr__greater(SEXP lhs_sexp, SEXP rhs_sexp){ - Rf_error("Cannot call dataset___expr__greater(). Please use arrow::install_arrow() to install required runtime libraries. "); -} -#endif - -// expression.cpp -#if defined(ARROW_R_WITH_ARROW) -std::shared_ptr dataset___expr__greater_equal(const std::shared_ptr& lhs, const std::shared_ptr& rhs); -extern "C" SEXP _arrow_dataset___expr__greater_equal(SEXP lhs_sexp, SEXP rhs_sexp){ -BEGIN_CPP11 - arrow::r::Input&>::type lhs(lhs_sexp); - arrow::r::Input&>::type rhs(rhs_sexp); - return cpp11::as_sexp(dataset___expr__greater_equal(lhs, rhs)); -END_CPP11 -} -#else -extern "C" SEXP _arrow_dataset___expr__greater_equal(SEXP lhs_sexp, SEXP rhs_sexp){ - Rf_error("Cannot call dataset___expr__greater_equal(). Please use arrow::install_arrow() to install required runtime libraries. "); -} -#endif - -// expression.cpp -#if defined(ARROW_R_WITH_ARROW) -std::shared_ptr dataset___expr__less(const std::shared_ptr& lhs, const std::shared_ptr& rhs); -extern "C" SEXP _arrow_dataset___expr__less(SEXP lhs_sexp, SEXP rhs_sexp){ -BEGIN_CPP11 - arrow::r::Input&>::type lhs(lhs_sexp); - arrow::r::Input&>::type rhs(rhs_sexp); - return cpp11::as_sexp(dataset___expr__less(lhs, rhs)); -END_CPP11 -} -#else -extern "C" SEXP _arrow_dataset___expr__less(SEXP lhs_sexp, SEXP rhs_sexp){ - Rf_error("Cannot call dataset___expr__less(). Please use arrow::install_arrow() to install required runtime libraries. "); -} -#endif - -// expression.cpp -#if defined(ARROW_R_WITH_ARROW) -std::shared_ptr dataset___expr__less_equal(const std::shared_ptr& lhs, const std::shared_ptr& rhs); -extern "C" SEXP _arrow_dataset___expr__less_equal(SEXP lhs_sexp, SEXP rhs_sexp){ -BEGIN_CPP11 - arrow::r::Input&>::type lhs(lhs_sexp); - arrow::r::Input&>::type rhs(rhs_sexp); - return cpp11::as_sexp(dataset___expr__less_equal(lhs, rhs)); -END_CPP11 -} -#else -extern "C" SEXP _arrow_dataset___expr__less_equal(SEXP lhs_sexp, SEXP rhs_sexp){ - Rf_error("Cannot call dataset___expr__less_equal(). Please use arrow::install_arrow() to install required runtime libraries. "); -} -#endif - -// expression.cpp -#if defined(ARROW_R_WITH_ARROW) -std::shared_ptr dataset___expr__in(const std::shared_ptr& lhs, const std::shared_ptr& rhs); -extern "C" SEXP _arrow_dataset___expr__in(SEXP lhs_sexp, SEXP rhs_sexp){ -BEGIN_CPP11 - arrow::r::Input&>::type lhs(lhs_sexp); - arrow::r::Input&>::type rhs(rhs_sexp); - return cpp11::as_sexp(dataset___expr__in(lhs, rhs)); -END_CPP11 -} -#else -extern "C" SEXP _arrow_dataset___expr__in(SEXP lhs_sexp, SEXP rhs_sexp){ - Rf_error("Cannot call dataset___expr__in(). Please use arrow::install_arrow() to install required runtime libraries. "); -} -#endif - -// expression.cpp -#if defined(ARROW_R_WITH_ARROW) -std::shared_ptr dataset___expr__and(const std::shared_ptr& lhs, const std::shared_ptr& rhs); -extern "C" SEXP _arrow_dataset___expr__and(SEXP lhs_sexp, SEXP rhs_sexp){ -BEGIN_CPP11 - arrow::r::Input&>::type lhs(lhs_sexp); - arrow::r::Input&>::type rhs(rhs_sexp); - return cpp11::as_sexp(dataset___expr__and(lhs, rhs)); -END_CPP11 -} -#else -extern "C" SEXP _arrow_dataset___expr__and(SEXP lhs_sexp, SEXP rhs_sexp){ - Rf_error("Cannot call dataset___expr__and(). Please use arrow::install_arrow() to install required runtime libraries. "); -} -#endif - -// expression.cpp -#if defined(ARROW_R_WITH_ARROW) -std::shared_ptr dataset___expr__or(const std::shared_ptr& lhs, const std::shared_ptr& rhs); -extern "C" SEXP _arrow_dataset___expr__or(SEXP lhs_sexp, SEXP rhs_sexp){ -BEGIN_CPP11 - arrow::r::Input&>::type lhs(lhs_sexp); - arrow::r::Input&>::type rhs(rhs_sexp); - return cpp11::as_sexp(dataset___expr__or(lhs, rhs)); -END_CPP11 -} -#else -extern "C" SEXP _arrow_dataset___expr__or(SEXP lhs_sexp, SEXP rhs_sexp){ - Rf_error("Cannot call dataset___expr__or(). Please use arrow::install_arrow() to install required runtime libraries. "); -} -#endif - -// expression.cpp -#if defined(ARROW_R_WITH_ARROW) -std::shared_ptr dataset___expr__not(const std::shared_ptr& lhs); -extern "C" SEXP _arrow_dataset___expr__not(SEXP lhs_sexp){ -BEGIN_CPP11 - arrow::r::Input&>::type lhs(lhs_sexp); - return cpp11::as_sexp(dataset___expr__not(lhs)); + arrow::r::Input::type func_name(func_name_sexp); + arrow::r::Input::type argument_list(argument_list_sexp); + arrow::r::Input::type options(options_sexp); + return cpp11::as_sexp(dataset___expr__call(func_name, argument_list, options)); END_CPP11 } #else -extern "C" SEXP _arrow_dataset___expr__not(SEXP lhs_sexp){ - Rf_error("Cannot call dataset___expr__not(). Please use arrow::install_arrow() to install required runtime libraries. "); +extern "C" SEXP _arrow_dataset___expr__call(SEXP func_name_sexp, SEXP argument_list_sexp, SEXP options_sexp){ + Rf_error("Cannot call dataset___expr__call(). Please use arrow::install_arrow() to install required runtime libraries. "); } #endif // expression.cpp #if defined(ARROW_R_WITH_ARROW) -std::shared_ptr dataset___expr__is_valid(const std::shared_ptr& lhs); -extern "C" SEXP _arrow_dataset___expr__is_valid(SEXP lhs_sexp){ +std::shared_ptr dataset___expr__field_ref(std::string name); +extern "C" SEXP _arrow_dataset___expr__field_ref(SEXP name_sexp){ BEGIN_CPP11 - arrow::r::Input&>::type lhs(lhs_sexp); - return cpp11::as_sexp(dataset___expr__is_valid(lhs)); + arrow::r::Input::type name(name_sexp); + return cpp11::as_sexp(dataset___expr__field_ref(name)); END_CPP11 } #else -extern "C" SEXP _arrow_dataset___expr__is_valid(SEXP lhs_sexp){ - Rf_error("Cannot call dataset___expr__is_valid(). Please use arrow::install_arrow() to install required runtime libraries. "); +extern "C" SEXP _arrow_dataset___expr__field_ref(SEXP name_sexp){ + Rf_error("Cannot call dataset___expr__field_ref(). Please use arrow::install_arrow() to install required runtime libraries. "); } #endif @@ -6590,18 +6433,8 @@ static const R_CallMethodDef CallEntries[] = { { "_arrow_FixedSizeListType__value_field", (DL_FUNC) &_arrow_FixedSizeListType__value_field, 1}, { "_arrow_FixedSizeListType__value_type", (DL_FUNC) &_arrow_FixedSizeListType__value_type, 1}, { "_arrow_FixedSizeListType__list_size", (DL_FUNC) &_arrow_FixedSizeListType__list_size, 1}, + { "_arrow_dataset___expr__call", (DL_FUNC) &_arrow_dataset___expr__call, 3}, { "_arrow_dataset___expr__field_ref", (DL_FUNC) &_arrow_dataset___expr__field_ref, 1}, - { "_arrow_dataset___expr__equal", (DL_FUNC) &_arrow_dataset___expr__equal, 2}, - { "_arrow_dataset___expr__not_equal", (DL_FUNC) &_arrow_dataset___expr__not_equal, 2}, - { "_arrow_dataset___expr__greater", (DL_FUNC) &_arrow_dataset___expr__greater, 2}, - { "_arrow_dataset___expr__greater_equal", (DL_FUNC) &_arrow_dataset___expr__greater_equal, 2}, - { "_arrow_dataset___expr__less", (DL_FUNC) &_arrow_dataset___expr__less, 2}, - { "_arrow_dataset___expr__less_equal", (DL_FUNC) &_arrow_dataset___expr__less_equal, 2}, - { "_arrow_dataset___expr__in", (DL_FUNC) &_arrow_dataset___expr__in, 2}, - { "_arrow_dataset___expr__and", (DL_FUNC) &_arrow_dataset___expr__and, 2}, - { "_arrow_dataset___expr__or", (DL_FUNC) &_arrow_dataset___expr__or, 2}, - { "_arrow_dataset___expr__not", (DL_FUNC) &_arrow_dataset___expr__not, 1}, - { "_arrow_dataset___expr__is_valid", (DL_FUNC) &_arrow_dataset___expr__is_valid, 1}, { "_arrow_dataset___expr__scalar", (DL_FUNC) &_arrow_dataset___expr__scalar, 1}, { "_arrow_dataset___expr__ToString", (DL_FUNC) &_arrow_dataset___expr__ToString, 1}, { "_arrow_ipc___WriteFeather__Table", (DL_FUNC) &_arrow_ipc___WriteFeather__Table, 6}, diff --git a/r/src/compute.cpp b/r/src/compute.cpp index 12b23dad798..8a75a251d8e 100644 --- a/r/src/compute.cpp +++ b/r/src/compute.cpp @@ -126,7 +126,6 @@ arrow::Datum as_cpp(SEXP x) { // This assumes that R objects have already been converted to Arrow objects; // that seems right but should we do the wrapping here too/instead? cpp11::stop("to_datum: Not implemented for type %s", Rf_type2char(TYPEOF(x))); - return arrow::Datum(); } } // namespace cpp11 @@ -151,9 +150,7 @@ SEXP from_datum(arrow::Datum datum) { break; } - auto str = datum.ToString(); - cpp11::stop("from_datum: Not implemented for Datum %s", str.c_str()); - return R_NilValue; + cpp11::stop("from_datum: Not implemented for Datum %s", datum.ToString().c_str()); } std::shared_ptr make_compute_options( @@ -182,6 +179,12 @@ std::shared_ptr make_compute_options( return out; } + if (func_name == "is_in" || func_name == "index_in") { + using Options = arrow::compute::SetLookupOptions; + return std::make_shared(cpp11::as_cpp(options["value_set"]), + cpp11::as_cpp(options["skip_nulls"])); + } + return nullptr; } diff --git a/r/src/expression.cpp b/r/src/expression.cpp index b1116d27a2f..ddb1e72c309 100644 --- a/r/src/expression.cpp +++ b/r/src/expression.cpp @@ -23,94 +23,34 @@ #include namespace ds = ::arrow::dataset; -std::shared_ptr Share(ds::Expression expr) { - return std::make_shared(std::move(expr)); -} - -// [[arrow::export]] -std::shared_ptr dataset___expr__field_ref(std::string name) { - return Share(ds::field_ref(std::move(name))); -} - -// [[arrow::export]] -std::shared_ptr dataset___expr__equal( - const std::shared_ptr& lhs, - const std::shared_ptr& rhs) { - return Share(ds::equal(*lhs, *rhs)); -} +std::shared_ptr make_compute_options( + std::string func_name, cpp11::list options); // [[arrow::export]] -std::shared_ptr dataset___expr__not_equal( - const std::shared_ptr& lhs, - const std::shared_ptr& rhs) { - return Share(ds::not_equal(*lhs, *rhs)); -} - -// [[arrow::export]] -std::shared_ptr dataset___expr__greater( - const std::shared_ptr& lhs, - const std::shared_ptr& rhs) { - return Share(ds::greater(*lhs, *rhs)); -} +std::shared_ptr dataset___expr__call(std::string func_name, + cpp11::list argument_list, + cpp11::list options) { + std::vector arguments; + for (SEXP argument : argument_list) { + auto argument_ptr = cpp11::as_cpp>(argument); + arguments.push_back(*argument_ptr); + } -// [[arrow::export]] -std::shared_ptr dataset___expr__greater_equal( - const std::shared_ptr& lhs, - const std::shared_ptr& rhs) { - return Share(ds::greater_equal(*lhs, *rhs)); -} + auto options_ptr = make_compute_options(func_name, options); -// [[arrow::export]] -std::shared_ptr dataset___expr__less( - const std::shared_ptr& lhs, - const std::shared_ptr& rhs) { - return Share(ds::less(*lhs, *rhs)); + return std::make_shared( + ds::call(std::move(func_name), std::move(arguments), std::move(options_ptr))); } // [[arrow::export]] -std::shared_ptr dataset___expr__less_equal( - const std::shared_ptr& lhs, - const std::shared_ptr& rhs) { - return Share(ds::less_equal(*lhs, *rhs)); -} - -// [[arrow::export]] -std::shared_ptr dataset___expr__in( - const std::shared_ptr& lhs, - const std::shared_ptr& rhs) { - return Share(ds::call("is_in", {*lhs}, arrow::compute::SetLookupOptions{rhs})); -} - -// [[arrow::export]] -std::shared_ptr dataset___expr__and( - const std::shared_ptr& lhs, - const std::shared_ptr& rhs) { - return Share(ds::and_(*lhs, *rhs)); -} - -// [[arrow::export]] -std::shared_ptr dataset___expr__or( - const std::shared_ptr& lhs, - const std::shared_ptr& rhs) { - return Share(ds::or_(*lhs, *rhs)); -} - -// [[arrow::export]] -std::shared_ptr dataset___expr__not( - const std::shared_ptr& lhs) { - return Share(ds::not_(*lhs)); -} - -// [[arrow::export]] -std::shared_ptr dataset___expr__is_valid( - const std::shared_ptr& lhs) { - return Share(ds::call("is_valid", {*lhs})); +std::shared_ptr dataset___expr__field_ref(std::string name) { + return std::make_shared(ds::field_ref(std::move(name))); } // [[arrow::export]] std::shared_ptr dataset___expr__scalar( const std::shared_ptr& x) { - return Share(ds::literal(x)); + return std::make_shared(ds::literal(std::move(x))); } // [[arrow::export]] From 7fc23ddf002ea0562b90aaee012b3c02403508c4 Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Mon, 14 Dec 2020 11:16:04 -0500 Subject: [PATCH 10/31] lint fixes --- cpp/src/arrow/dataset/dataset.h | 9 +- cpp/src/arrow/dataset/expression_internal.h | 4 +- cpp/src/arrow/dataset/expression_test.cc | 1 - cpp/src/arrow/dataset/file_base.h | 3 +- cpp/src/arrow/dataset/file_parquet_test.cc | 3 +- cpp/src/arrow/dataset/filter.cc | 2 + cpp/src/arrow/dataset/filter.h | 1 - python/pyarrow/_dataset.pyx | 89 +++++++++----------- python/pyarrow/includes/libarrow_dataset.pxd | 89 +++++--------------- 9 files changed, 69 insertions(+), 132 deletions(-) diff --git a/cpp/src/arrow/dataset/dataset.h b/cpp/src/arrow/dataset/dataset.h index 4ad6ecbe74a..c92381d78c5 100644 --- a/cpp/src/arrow/dataset/dataset.h +++ b/cpp/src/arrow/dataset/dataset.h @@ -94,8 +94,7 @@ class ARROW_DS_EXPORT InMemoryFragment : public Fragment { public: InMemoryFragment(std::shared_ptr schema, RecordBatchVector record_batches, Expression = literal(true)); - explicit InMemoryFragment(RecordBatchVector record_batches, - Expression = literal(true)); + explicit InMemoryFragment(RecordBatchVector record_batches, Expression = literal(true)); Result Scan(std::shared_ptr options, std::shared_ptr context) override; @@ -180,8 +179,7 @@ class ARROW_DS_EXPORT InMemoryDataset : public Dataset { std::shared_ptr schema) const override; protected: - Result GetFragmentsImpl( - Expression predicate) override; + Result GetFragmentsImpl(Expression predicate) override; std::shared_ptr get_batches_; }; @@ -205,8 +203,7 @@ class ARROW_DS_EXPORT UnionDataset : public Dataset { std::shared_ptr schema) const override; protected: - Result GetFragmentsImpl( - Expression predicate) override; + Result GetFragmentsImpl(Expression predicate) override; explicit UnionDataset(std::shared_ptr schema, DatasetVector children) : Dataset(std::move(schema)), children_(std::move(children)) {} diff --git a/cpp/src/arrow/dataset/expression_internal.h b/cpp/src/arrow/dataset/expression_internal.h index 90f28cd23ef..306f219859a 100644 --- a/cpp/src/arrow/dataset/expression_internal.h +++ b/cpp/src/arrow/dataset/expression_internal.h @@ -179,7 +179,7 @@ struct Comparison { return GREATER_EQUAL; case GREATER_EQUAL: return LESS_EQUAL; - }; + } DCHECK(false); return NA; } @@ -201,7 +201,7 @@ struct Comparison { return "less_equal"; case GREATER_EQUAL: return "greater_equal"; - }; + } DCHECK(false); return "na"; } diff --git a/cpp/src/arrow/dataset/expression_test.cc b/cpp/src/arrow/dataset/expression_test.cc index 91e9dcdc4f1..e5301677ed7 100644 --- a/cpp/src/arrow/dataset/expression_test.cc +++ b/cpp/src/arrow/dataset/expression_test.cc @@ -465,7 +465,6 @@ struct { EXPECT_TRUE(Identical(folded, expr)); } } - } ExpectFoldsTo; TEST(Expression, FoldConstants) { diff --git a/cpp/src/arrow/dataset/file_base.h b/cpp/src/arrow/dataset/file_base.h index 7b67e6f9bf5..bb2aa86ba9b 100644 --- a/cpp/src/arrow/dataset/file_base.h +++ b/cpp/src/arrow/dataset/file_base.h @@ -233,8 +233,7 @@ class ARROW_DS_EXPORT FileSystemDataset : public Dataset { std::string ToString() const; protected: - Result GetFragmentsImpl( - Expression predicate) override; + Result GetFragmentsImpl(Expression predicate) override; FileSystemDataset(std::shared_ptr schema, Expression root_partition, std::shared_ptr format, diff --git a/cpp/src/arrow/dataset/file_parquet_test.cc b/cpp/src/arrow/dataset/file_parquet_test.cc index e9182ddf696..67d1fb17120 100644 --- a/cpp/src/arrow/dataset/file_parquet_test.cc +++ b/cpp/src/arrow/dataset/file_parquet_test.cc @@ -190,8 +190,7 @@ class TestParquetFileFormat : public ArrowParquetWriterMixin { } void CountRowGroupsInFragment(const std::shared_ptr& fragment, - std::vector expected_row_groups, - Expression filter) { + std::vector expected_row_groups, Expression filter) { schema_ = opts_->schema(); ASSERT_OK_AND_ASSIGN(auto bound, filter.Bind(*schema_)); auto parquet_fragment = checked_pointer_cast(fragment); diff --git a/cpp/src/arrow/dataset/filter.cc b/cpp/src/arrow/dataset/filter.cc index 9852f8c0808..2357896dd7a 100644 --- a/cpp/src/arrow/dataset/filter.cc +++ b/cpp/src/arrow/dataset/filter.cc @@ -15,4 +15,6 @@ // specific language governing permissions and limitations // under the License. +#include "arrow/dataset/filter.h" + // FIXME remove diff --git a/cpp/src/arrow/dataset/filter.h b/cpp/src/arrow/dataset/filter.h index 6dd39e25e1e..9852f8c0808 100644 --- a/cpp/src/arrow/dataset/filter.h +++ b/cpp/src/arrow/dataset/filter.h @@ -16,4 +16,3 @@ // under the License. // FIXME remove - diff --git a/python/pyarrow/_dataset.pyx b/python/pyarrow/_dataset.pyx index bbe0485fa69..e235d50c74a 100644 --- a/python/pyarrow/_dataset.pyx +++ b/python/pyarrow/_dataset.pyx @@ -86,24 +86,22 @@ cdef CFileSource _make_file_source(object file, FileSystem filesystem=None): cdef class Expression(_Weakrefable): cdef: - shared_ptr[CExpression] wrapped - CExpression* expr + CExpression expr def __init__(self): _forbid_instantiation(self.__class__) - cdef void init(self, const shared_ptr[CExpression]& sp): - self.wrapped = sp - self.expr = sp.get() + cdef void init(self, const CExpression& sp): + self.expr = sp @staticmethod - cdef wrap(const shared_ptr[CExpression]& sp): + cdef wrap(const CExpression& sp): cdef Expression self = Expression.__new__(Expression) self.init(sp) return self - cdef inline shared_ptr[CExpression] unwrap(self): - return self.wrapped + cdef inline CExpression unwrap(self): + return self.expr def equals(self, Expression other): return self.expr.Equals(other.unwrap()) @@ -119,7 +117,7 @@ cdef class Expression(_Weakrefable): @staticmethod def _deserialize(Buffer buffer not None): c_buffer = pyarrow_unwrap_buffer(buffer) - c_expr = GetResultValue(CExpression.Deserialize(deref(c_buffer))) + #c_expr = GetResultValue(CExpression.Deserialize(deref(c_buffer))) return Expression.wrap(move(c_expr)) def __reduce__(self): @@ -157,16 +155,16 @@ cdef class Expression(_Weakrefable): return Expression.wrap(CMakeNotExpression(self.unwrap())) @staticmethod - cdef shared_ptr[CExpression] _expr_or_scalar(object expr) except *: + cdef CExpression _expr_or_scalar(object expr) except *: if isinstance(expr, Expression): return ( expr).unwrap() return ( Expression._scalar(expr)).unwrap() def __richcmp__(self, other, int op): cdef: - shared_ptr[CExpression] c_expr - shared_ptr[CExpression] c_left - shared_ptr[CExpression] c_right + CExpression c_expr + CExpression c_left + CExpression c_right c_left = self.unwrap() c_right = Expression._expr_or_scalar(other) @@ -212,10 +210,17 @@ cdef class Expression(_Weakrefable): def isin(self, values): """Checks whether the expression is contained in values""" + cdef: + vector[CExpression] arguments + arguments.push_back(self.expr) + if not isinstance(values, pa.Array): values = pa.array(values) c_values = pyarrow_unwrap_array(values) - return Expression.wrap(self.expr.In(c_values).Copy()) + return Expression.wrap(CMakeCallExpression( + tobytes("is_in"), + move(arguments), + move(SetLookupOptions(values).set_lookup_options))) @staticmethod def _field(str name not None): @@ -231,11 +236,7 @@ cdef class Expression(_Weakrefable): else: scalar = pa.scalar(value) - return Expression.wrap( - shared_ptr[CExpression]( - new CScalarExpression(move(scalar.unwrap())) - ) - ) + return Expression.wrap(CMakeScalarExpression(scalar.unwrap())) _deserialize = Expression._deserialize @@ -288,12 +289,7 @@ cdef class Dataset(_Weakrefable): An Expression which evaluates to true for all data viewed by this Dataset. """ - cdef shared_ptr[CExpression] expr - expr = self.dataset.partition_expression() - if expr.get() == nullptr: - return None - else: - return Expression.wrap(expr) + return Expression.wrap(self.dataset.partition_expression()) def replace_schema(self, Schema schema not None): """ @@ -322,13 +318,13 @@ cdef class Dataset(_Weakrefable): fragments : iterator of Fragment """ cdef: - shared_ptr[CExpression] c_filter + CExpression c_filter CFragmentIterator c_iterator if filter is None: c_fragments = self.dataset.GetFragments() else: - c_filter = _insert_implicit_casts(filter, self.schema) + c_filter = _bind(filter, self.schema) c_fragments = self.dataset.GetFragments(c_filter) for maybe_fragment in c_fragments: @@ -593,19 +589,14 @@ cdef class FileSystemDataset(Dataset): return FileFormat.wrap(self.filesystem_dataset.format()) -cdef shared_ptr[CExpression] _insert_implicit_casts(Expression filter, - Schema schema) except *: +cdef CExpression _bind(Expression filter, Schema schema) except *: assert schema is not None if filter is None: return _true.unwrap() - return GetResultValue( - CInsertImplicitCasts( - deref(filter.unwrap().get()), - deref(pyarrow_unwrap_schema(schema).get()) - ) - ) + return GetResultValue(filter.unwrap().Bind( + deref(pyarrow_unwrap_schema(schema).get()))) cdef class FileWriteOptions(_Weakrefable): @@ -1035,11 +1026,11 @@ cdef class ParquetFileFragment(FileFragment): """ cdef: vector[shared_ptr[CFragment]] c_fragments - shared_ptr[CExpression] c_filter + CExpression c_filter shared_ptr[CFragment] c_fragment schema = schema or self.physical_schema - c_filter = _insert_implicit_casts(filter, schema) + c_filter = _bind(filter, schema) with nogil: c_fragments = move(GetResultValue( self.parquet_file_fragment.SplitByRowGroup(move(c_filter)))) @@ -1072,7 +1063,7 @@ cdef class ParquetFileFragment(FileFragment): ParquetFileFragment """ cdef: - shared_ptr[CExpression] c_filter + CExpression c_filter vector[int] c_row_group_ids shared_ptr[CFragment] c_fragment @@ -1083,7 +1074,7 @@ cdef class ParquetFileFragment(FileFragment): if filter is not None: schema = schema or self.physical_schema - c_filter = _insert_implicit_casts(filter, schema) + c_filter = _bind(filter, schema) with nogil: c_fragment = move(GetResultValue( self.parquet_file_fragment.SubsetWithFilter( @@ -1408,7 +1399,7 @@ cdef class Partitioning(_Weakrefable): return self.wrapped def parse(self, path): - cdef CResult[shared_ptr[CExpression]] result + cdef CResult[CExpression] result result = self.partitioning.Parse(tobytes(path)) return Expression.wrap(GetResultValue(result)) @@ -1643,11 +1634,7 @@ cdef class DatasetFactory(_Weakrefable): @property def root_partition(self): - cdef shared_ptr[CExpression] expr = self.factory.root_partition() - if expr.get() == nullptr: - return None - else: - return Expression.wrap(expr) + return Expression.wrap(self.factory.root_partition()) @root_partition.setter def root_partition(self, Expression expr): @@ -2121,14 +2108,14 @@ cdef void _populate_builder(const shared_ptr[CScannerBuilder]& ptr, int batch_size=_DEFAULT_BATCH_SIZE) except *: cdef: CScannerBuilder *builder - builder = ptr.get() - if columns is not None: - check_status(builder.Project([tobytes(c) for c in columns])) - check_status(builder.Filter(_insert_implicit_casts( + check_status(builder.Filter(_bind( filter, pyarrow_wrap_schema(builder.schema())))) + if columns is not None: + check_status(builder.Project([tobytes(c) for c in columns])) + check_status(builder.BatchSize(batch_size)) @@ -2296,12 +2283,12 @@ def _get_partition_keys(Expression partition_expression): is converted to {'part': 'a', 'year': 2016} """ cdef: - shared_ptr[CExpression] expr = partition_expression.unwrap() + CExpression expr = partition_expression.unwrap() pair[c_string, shared_ptr[CScalar]] name_val return { frombytes(name_val.first): pyarrow_wrap_scalar(name_val.second).as_py() - for name_val in GetResultValue(CGetPartitionKeys(deref(expr.get()))) + for name_val in GetResultValue(CGetPartitionKeys(expr)) } diff --git a/python/pyarrow/includes/libarrow_dataset.pxd b/python/pyarrow/includes/libarrow_dataset.pxd index a81042920d5..53e48da1703 100644 --- a/python/pyarrow/includes/libarrow_dataset.pxd +++ b/python/pyarrow/includes/libarrow_dataset.pxd @@ -36,63 +36,19 @@ cdef extern from "arrow/dataset/api.h" namespace "arrow::dataset" nogil: cdef cppclass CExpression "arrow::dataset::Expression": c_bool Equals(const CExpression& other) const - c_bool Equals(const shared_ptr[CExpression]& other) const - CResult[shared_ptr[CDataType]] Validate(const CSchema& schema) const - shared_ptr[CExpression] Assume(const CExpression& given) const - shared_ptr[CExpression] Assume( - const shared_ptr[CExpression]& given) const c_string ToString() const - shared_ptr[CExpression] Copy() const + CResult[CExpression] Bind(const CSchema&) - const CExpression& In(shared_ptr[CArray]) const - const CExpression& IsValid() const - const CExpression& CastTo(shared_ptr[CDataType], CCastOptions) const - const CExpression& CastLike(shared_ptr[CExpression], - CCastOptions) const + cdef CExpression CMakeScalarExpression \ + "arrow::dataset::literal"(shared_ptr[CScalar] value) - @staticmethod - CResult[shared_ptr[CExpression]] Deserialize(const CBuffer& buffer) - CResult[shared_ptr[CBuffer]] Serialize() const - - ctypedef vector[shared_ptr[CExpression]] CExpressionVector \ - "arrow::dataset::ExpressionVector" - - cdef cppclass CScalarExpression \ - "arrow::dataset::ScalarExpression"(CExpression): - CScalarExpression(const shared_ptr[CScalar]& value) - - cdef shared_ptr[CExpression] CMakeFieldExpression \ + cdef CExpression CMakeFieldExpression \ "arrow::dataset::field_ref"(c_string name) - cdef shared_ptr[CExpression] CMakeNotExpression \ - "arrow::dataset::not_"(shared_ptr[CExpression] operand) - cdef shared_ptr[CExpression] CMakeAndExpression \ - "arrow::dataset::and_"(shared_ptr[CExpression], - shared_ptr[CExpression]) - cdef shared_ptr[CExpression] CMakeOrExpression \ - "arrow::dataset::or_"(shared_ptr[CExpression], - shared_ptr[CExpression]) - cdef shared_ptr[CExpression] CMakeEqualExpression \ - "arrow::dataset::equal"(shared_ptr[CExpression], - shared_ptr[CExpression]) - cdef shared_ptr[CExpression] CMakeNotEqualExpression \ - "arrow::dataset::not_equal"(shared_ptr[CExpression], - shared_ptr[CExpression]) - cdef shared_ptr[CExpression] CMakeGreaterExpression \ - "arrow::dataset::greater"(shared_ptr[CExpression], - shared_ptr[CExpression]) - cdef shared_ptr[CExpression] CMakeGreaterEqualExpression \ - "arrow::dataset::greater_equal"(shared_ptr[CExpression], - shared_ptr[CExpression]) - cdef shared_ptr[CExpression] CMakeLessExpression \ - "arrow::dataset::less"(shared_ptr[CExpression], - shared_ptr[CExpression]) - cdef shared_ptr[CExpression] CMakeLessEqualExpression \ - "arrow::dataset::less_equal"(shared_ptr[CExpression], - shared_ptr[CExpression]) - - cdef CResult[shared_ptr[CExpression]] CInsertImplicitCasts \ - "arrow::dataset::InsertImplicitCasts"( - const CExpression &, const CSchema&) + + cdef CExpression CMakeCallExpression \ + "arrow::dataset::call"(c_string function, + vector[CExpression] arguments, + shared_ptr[CFunctionOptions] options) cdef cppclass CRecordBatchProjector "arrow::dataset::RecordBatchProjector": pass @@ -119,7 +75,7 @@ cdef extern from "arrow/dataset/api.h" namespace "arrow::dataset" nogil: shared_ptr[CScanOptions] options, shared_ptr[CScanContext] context) c_bool splittable() const c_string type_name() const - const shared_ptr[CExpression]& partition_expression() const + const CExpression& partition_expression() const ctypedef vector[shared_ptr[CFragment]] CFragmentVector \ "arrow::dataset::FragmentVector" @@ -130,7 +86,7 @@ cdef extern from "arrow/dataset/api.h" namespace "arrow::dataset" nogil: cdef cppclass CInMemoryFragment "arrow::dataset::InMemoryFragment"( CFragment): CInMemoryFragment(vector[shared_ptr[CRecordBatch]] record_batches, - shared_ptr[CExpression] partition_expression) + CExpression partition_expression) cdef cppclass CScanner "arrow::dataset::Scanner": CScanner(shared_ptr[CDataset], shared_ptr[CScanOptions], @@ -148,8 +104,7 @@ cdef extern from "arrow/dataset/api.h" namespace "arrow::dataset" nogil: CScannerBuilder(shared_ptr[CSchema], shared_ptr[CFragment], shared_ptr[CScanContext] scan_context) CStatus Project(const vector[c_string]& columns) - CStatus Filter(const CExpression& filter) - CStatus Filter(shared_ptr[CExpression] filter) + CStatus Filter(CExpression filter) CStatus UseThreads(c_bool use_threads) CStatus BatchSize(int64_t batch_size) CResult[shared_ptr[CScanner]] Finish() @@ -161,8 +116,8 @@ cdef extern from "arrow/dataset/api.h" namespace "arrow::dataset" nogil: cdef cppclass CDataset "arrow::dataset::Dataset": const shared_ptr[CSchema] & schema() CFragmentIterator GetFragments() - CFragmentIterator GetFragments(shared_ptr[CExpression] predicate) - const shared_ptr[CExpression] & partition_expression() + CFragmentIterator GetFragments(CExpression predicate) + const CExpression & partition_expression() c_string type_name() CResult[shared_ptr[CDataset]] ReplaceSchema(shared_ptr[CSchema]) @@ -193,8 +148,8 @@ cdef extern from "arrow/dataset/api.h" namespace "arrow::dataset" nogil: CResult[shared_ptr[CDataset]] FinishWithSchema "Finish"( const shared_ptr[CSchema]& schema) CResult[shared_ptr[CDataset]] Finish() - const shared_ptr[CExpression]& root_partition() - CStatus SetRootPartition(shared_ptr[CExpression] partition) + const CExpression& root_partition() + CStatus SetRootPartition(CExpression partition) cdef cppclass CUnionDatasetFactory "arrow::dataset::UnionDatasetFactory": @staticmethod @@ -221,7 +176,7 @@ cdef extern from "arrow/dataset/api.h" namespace "arrow::dataset" nogil: CResult[shared_ptr[CSchema]] Inspect(const CFileSource&) const CResult[shared_ptr[CFileFragment]] MakeFragment( CFileSource source, - shared_ptr[CExpression] partition_expression, + CExpression partition_expression, shared_ptr[CSchema] physical_schema) shared_ptr[CFileWriteOptions] DefaultWriteOptions() @@ -240,9 +195,9 @@ cdef extern from "arrow/dataset/api.h" namespace "arrow::dataset" nogil: const vector[int]& row_groups() const shared_ptr[CFileMetaData] metadata() const CResult[vector[shared_ptr[CFragment]]] SplitByRowGroup( - shared_ptr[CExpression] predicate) + CExpression predicate) CResult[shared_ptr[CFragment]] SubsetWithFilter "Subset"( - shared_ptr[CExpression] predicate) + CExpression predicate) CResult[shared_ptr[CFragment]] SubsetWithIds "Subset"( vector[int] row_group_ids) CStatus EnsureCompleteMetadata() @@ -260,7 +215,7 @@ cdef extern from "arrow/dataset/api.h" namespace "arrow::dataset" nogil: @staticmethod CResult[shared_ptr[CDataset]] Make( shared_ptr[CSchema] schema, - shared_ptr[CExpression] source_partition, + CExpression source_partition, shared_ptr[CFileFormat] format, shared_ptr[CFileSystem] filesystem, vector[shared_ptr[CFileFragment]] fragments) @@ -287,7 +242,7 @@ cdef extern from "arrow/dataset/api.h" namespace "arrow::dataset" nogil: CParquetFileFormatReaderOptions reader_options CResult[shared_ptr[CFileFragment]] MakeFragment( CFileSource source, - shared_ptr[CExpression] partition_expression, + CExpression partition_expression, shared_ptr[CSchema] physical_schema, vector[int] row_groups) @@ -305,7 +260,7 @@ cdef extern from "arrow/dataset/api.h" namespace "arrow::dataset" nogil: cdef cppclass CPartitioning "arrow::dataset::Partitioning": c_string type_name() const - CResult[shared_ptr[CExpression]] Parse(const c_string & path) const + CResult[CExpression] Parse(const c_string & path) const const shared_ptr[CSchema] & schema() cdef cppclass CPartitioningFactoryOptions \ From d9ec4bfae9470b659d7c419ee6a9e8034d0926a4 Mon Sep 17 00:00:00 2001 From: Neal Richardson Date: Tue, 15 Dec 2020 13:25:24 -0800 Subject: [PATCH 11/31] Refactor dataset expression code to follow array expression pattern --- r/R/dplyr.R | 2 +- r/R/expression.R | 99 ++++++++++++++---------------------------------- 2 files changed, 29 insertions(+), 72 deletions(-) diff --git a/r/R/dplyr.R b/r/R/dplyr.R index f3dd078c952..b4b6fc4dab1 100644 --- a/r/R/dplyr.R +++ b/r/R/dplyr.R @@ -248,7 +248,7 @@ filter_mask <- function(.data) { if (query_on_dataset(.data)) { comp_func <- function(operator) { force(operator) - function(e1, e2) make_expression(operator, e1, e2) + function(e1, e2) build_dataset_expression(operator, e1, e2) } var_binder <- function(x) Expression$field_ref(x) } else { diff --git a/r/R/expression.R b/r/R/expression.R index 8c7e4bca59f..f601730f905 100644 --- a/r/R/expression.R +++ b/r/R/expression.R @@ -173,6 +173,31 @@ Expression <- R6Class("Expression", inherit = ArrowObject, ToString = function() dataset___expr__ToString(self) ) ) +Expression$create <- function(name, arguments, options = list()) { + dataset___expr__call(name, arguments, options) +} + +build_dataset_expression <- function(.Generic, e1, e2, ...) { + if (.Generic %in% names(.unary_function_map)) { + expr <- Expression$create(.unary_function_map[[.Generic]], list(e1)) + } else if (.Generic == "%in%") { + # Special-case %in%, which is different from the Array function name + expr <- Expression$create("is_in", list(e1), + options = list( + value_set = Array$create(e2), + skip_nulls = TRUE) + ) + } else { + if (!inherits(e1, "Expression")) { + e1 <- Expression$scalar(e1) + } + if (!inherits(e2, "Expression")) { + e2 <- Expression$scalar(e2) + } + expr <- Expression$create(.binary_function_map[[.Generic]], list(e1, e2), ...) + } + expr +} Expression$field_ref <- function(name) { assert_is(name, "character") @@ -182,83 +207,15 @@ Expression$field_ref <- function(name) { Expression$scalar <- function(x) { dataset___expr__scalar(Scalar$create(x)) } -Expression$call <- function(name, arguments, options = list()) { - dataset___expr__call(name, arguments, options) -} -Expression$compare <- function(OP, e1, e2) { - comp_func <- comparison_function_map[[OP]] - if (is.null(comp_func)) { - stop(OP, " is not a supported comparison function", call. = FALSE) - } - Expression$call(comp_func, list(e1, e2)) -} - -comparison_function_map <- list( - "==" = "equal", - "!=" = "not_equal", - ">" = "greater", - ">=" = "greater_equal", - "<" = "less", - "<=" = "less_equal" -) -Expression$in_ <- function(x, set) { - Expression$call("is_in", list(x), - options = list(value_set = Array$create(set), - skip_nulls = TRUE)) -} -Expression$and <- function(e1, e2) { - Expression$call("and_kleene", list(e1, e2)) -} -Expression$or <- function(e1, e2) { - Expression$call("or_kleene", list(e1, e2)) -} -Expression$not <- function(e1) { - Expression$call("invert", list(e1)) -} -Expression$is_valid <- function(e1) { - Expression$call("is_valid", list(e1)) -} -Expression$is_null <- function(e1) { - Expression$call("is_null", list(e1)) -} #' @export Ops.Expression <- function(e1, e2) { if (.Generic == "!") { - return(Expression$not(e1)) - } - make_expression(.Generic, e1, e2) -} - -make_expression <- function(operator, e1, e2) { - if (operator == "%in%") { - # In doesn't take Scalar, it takes Array - return(Expression$in_(e1, e2)) - } - - # Handle unary functions before touching e2 - if (operator == "is.na") { - return(is.na(e1)) - } - if (operator == "!") { - return(Expression$not(e1)) - } - - # Check for non-expressions and convert to Expressions - if (!inherits(e1, "Expression")) { - e1 <- Expression$scalar(e1) - } - if (!inherits(e2, "Expression")) { - e2 <- Expression$scalar(e2) - } - if (operator == "&") { - Expression$and(e1, e2) - } else if (operator == "|") { - Expression$or(e1, e2) + build_dataset_expression(.Generic, e1) } else { - Expression$compare(operator, e1, e2) + build_dataset_expression(.Generic, e1, e2) } } #' @export -is.na.Expression <- function(x) Expression$is_null(x) +is.na.Expression <- function(x) Expression$create("is_null", list(x)) From 35b9e1fadd82bc2163b3579ebefd9c20437795fe Mon Sep 17 00:00:00 2001 From: Neal Richardson Date: Tue, 15 Dec 2020 13:54:33 -0800 Subject: [PATCH 12/31] A little more --- r/R/expression.R | 49 +++++++++++++++++++-------------------------- r/man/Expression.Rd | 14 ++----------- 2 files changed, 23 insertions(+), 40 deletions(-) diff --git a/r/R/expression.R b/r/R/expression.R index f601730f905..f9e09c2fadd 100644 --- a/r/R/expression.R +++ b/r/R/expression.R @@ -153,18 +153,8 @@ print.array_expression <- function(x, ...) { #' `Expression$field_ref(name)` is used to construct an `Expression` which #' evaluates to the named column in the `Dataset` against which it is evaluated. #' -#' `Expression$compare(OP, e1, e2)` takes two `Expression` operands, constructing -#' an `Expression` which will evaluate these operands then compare them with the -#' relation specified by OP (e.g. "==", "!=", ">", etc.) For example, to filter -#' down to rows where the column named "alpha" is less than 5: -#' `Expression$compare("<", Expression$field_ref("alpha"), Expression$scalar(5))` -#' -#' `Expression$and(e1, e2)`, `Expression$or(e1, e2)`, and `Expression$not(e1)` -#' construct an `Expression` combining their arguments with Boolean operators. -#' -#' `Expression$is_valid(x)` is essentially (an inversion of) `is.na()` for `Expression`s. -#' -#' `Expression$in_(x, set)` evaluates x and returns whether or not it is a member of the set. +#' `Expression$create(function_name, ..., options)` builds a function-call +#' `Expression` containing one or more `Expression`s. #' @name Expression #' @rdname Expression #' @export @@ -173,20 +163,32 @@ Expression <- R6Class("Expression", inherit = ArrowObject, ToString = function() dataset___expr__ToString(self) ) ) -Expression$create <- function(name, arguments, options = list()) { - dataset___expr__call(name, arguments, options) +Expression$create <- function(function_name, + ..., + args = list(...), + options = empty_named_list()) { + assert_that(is.string(function_name)) + dataset___expr__call(function_name, args, options) +} +Expression$field_ref <- function(name) { + assert_that(is.string(name)) + dataset___expr__field_ref(name) +} +Expression$scalar <- function(x) { + dataset___expr__scalar(Scalar$create(x)) } build_dataset_expression <- function(.Generic, e1, e2, ...) { if (.Generic %in% names(.unary_function_map)) { - expr <- Expression$create(.unary_function_map[[.Generic]], list(e1)) + expr <- Expression$create(.unary_function_map[[.Generic]], e1) } else if (.Generic == "%in%") { # Special-case %in%, which is different from the Array function name - expr <- Expression$create("is_in", list(e1), + expr <- Expression$create("is_in", e1, options = list( value_set = Array$create(e2), - skip_nulls = TRUE) + skip_nulls = TRUE ) + ) } else { if (!inherits(e1, "Expression")) { e1 <- Expression$scalar(e1) @@ -194,20 +196,11 @@ build_dataset_expression <- function(.Generic, e1, e2, ...) { if (!inherits(e2, "Expression")) { e2 <- Expression$scalar(e2) } - expr <- Expression$create(.binary_function_map[[.Generic]], list(e1, e2), ...) + expr <- Expression$create(.binary_function_map[[.Generic]], e1, e2, ...) } expr } -Expression$field_ref <- function(name) { - assert_is(name, "character") - assert_that(length(name) == 1) - dataset___expr__field_ref(name) -} -Expression$scalar <- function(x) { - dataset___expr__scalar(Scalar$create(x)) -} - #' @export Ops.Expression <- function(e1, e2) { if (.Generic == "!") { @@ -218,4 +211,4 @@ Ops.Expression <- function(e1, e2) { } #' @export -is.na.Expression <- function(x) Expression$create("is_null", list(x)) +is.na.Expression <- function(x) Expression$create("is_null", x) diff --git a/r/man/Expression.Rd b/r/man/Expression.Rd index 14570eba892..58a6a44c0c0 100644 --- a/r/man/Expression.Rd +++ b/r/man/Expression.Rd @@ -13,16 +13,6 @@ the provided scalar (length-1) R value. \code{Expression$field_ref(name)} is used to construct an \code{Expression} which evaluates to the named column in the \code{Dataset} against which it is evaluated. -\code{Expression$compare(OP, e1, e2)} takes two \code{Expression} operands, constructing -an \code{Expression} which will evaluate these operands then compare them with the -relation specified by OP (e.g. "==", "!=", ">", etc.) For example, to filter -down to rows where the column named "alpha" is less than 5: -\code{Expression$compare("<", Expression$field_ref("alpha"), Expression$scalar(5))} - -\code{Expression$and(e1, e2)}, \code{Expression$or(e1, e2)}, and \code{Expression$not(e1)} -construct an \code{Expression} combining their arguments with Boolean operators. - -\code{Expression$is_valid(x)} is essentially (an inversion of) \code{is.na()} for \code{Expression}s. - -\code{Expression$in_(x, set)} evaluates x and returns whether or not it is a member of the set. +\code{Expression$create(function_name, ..., options)} builds a function-call +\code{Expression} containing one or more \code{Expression}s. } From ec6ee9f51209bae4012ca1da19dd0e70dd168788 Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Tue, 15 Dec 2020 16:57:45 -0500 Subject: [PATCH 13/31] revamp Expression::ToString, make cast-from-string error more informative --- .../compute/kernels/scalar_cast_numeric.cc | 4 +- .../compute/kernels/scalar_cast_temporal.cc | 3 +- .../arrow/compute/kernels/scalar_string.cc | 4 +- cpp/src/arrow/dataset/expression.cc | 93 ++++++++++++------- cpp/src/arrow/dataset/expression_internal.h | 22 +++++ cpp/src/arrow/dataset/expression_test.cc | 37 ++++++-- r/tests/testthat/test-dataset.R | 4 +- r/tests/testthat/test-expression.R | 2 +- 8 files changed, 120 insertions(+), 49 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc b/cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc index 62665d4ea44..6e550fb12c0 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc @@ -282,7 +282,9 @@ struct ParseString { OutValue Call(KernelContext* ctx, Arg0Value val) const { OutValue result = OutValue(0); if (ARROW_PREDICT_FALSE(!ParseValue(val.data(), val.size(), &result))) { - ctx->SetStatus(Status::Invalid("Failed to parse string: ", val)); + ctx->SetStatus(Status::Invalid("Failed to parse string: '", val, + "' as a scalar of type ", + TypeTraits::type_singleton()->ToString())); } return result; } diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_temporal.cc b/cpp/src/arrow/compute/kernels/scalar_cast_temporal.cc index 7cff41f267e..f1702b7fd51 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_temporal.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_temporal.cc @@ -298,7 +298,8 @@ struct ParseTimestamp { OutValue Call(KernelContext* ctx, Arg0Value val) const { OutValue result = 0; if (ARROW_PREDICT_FALSE(!ParseValue(type, val.data(), val.size(), &result))) { - ctx->SetStatus(Status::Invalid("Failed to parse string: ", val)); + ctx->SetStatus(Status::Invalid("Failed to parse string: '", val, + "' as a scalar of type ", type.ToString())); } return result; } diff --git a/cpp/src/arrow/compute/kernels/scalar_string.cc b/cpp/src/arrow/compute/kernels/scalar_string.cc index eff60b80481..2f4a0d45ed2 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string.cc @@ -1207,7 +1207,9 @@ struct ParseStrptime { int64_t Call(KernelContext* ctx, util::string_view val) const { int64_t result = 0; if (!(*parser)(val.data(), val.size(), unit, &result)) { - ctx->SetStatus(Status::Invalid("Failed to parse string ", val)); + ctx->SetStatus(Status::Invalid("Failed to parse string: '", val, + "' as a scalar of type ", + TimestampType(unit).ToString())); } return result; } diff --git a/cpp/src/arrow/dataset/expression.cc b/cpp/src/arrow/dataset/expression.cc index cc142cbb2d4..f4600196e78 100644 --- a/cpp/src/arrow/dataset/expression.cc +++ b/cpp/src/arrow/dataset/expression.cc @@ -52,6 +52,21 @@ const FieldRef* Expression::field_ref() const { std::string Expression::ToString() const { if (auto lit = literal()) { if (lit->is_scalar()) { + switch (lit->type()->id()) { + case Type::STRING: + case Type::LARGE_STRING: + return '"' + + Escape(util::string_view(*lit->scalar_as().value)) + + '"'; + + case Type::BINARY: + case Type::FIXED_SIZE_BINARY: + case Type::LARGE_BINARY: + return '"' + lit->scalar_as().value->ToHexString() + '"'; + + default: + break; + } return lit->scalar()->ToString(); } return lit->ToString(); @@ -59,7 +74,7 @@ std::string Expression::ToString() const { if (auto ref = field_ref()) { if (auto name = ref->name()) { - return "FieldRef(" + *name + ")"; + return *name; } if (auto path = ref->field_path()) { return path->ToString(); @@ -68,54 +83,62 @@ std::string Expression::ToString() const { } auto call = CallNotNull(*this); + if (auto cmp = Comparison::Get(call->function)) { + return "(" + call->arguments[0].ToString() + " " + Comparison::GetOp(*cmp) + " " + + call->arguments[1].ToString() + ")"; + } + + if (auto options = GetStructOptions(*call)) { + std::string out = "{"; + auto argument = call->arguments.begin(); + for (const auto& field_name : options->field_names) { + out += field_name + "=" + argument++->ToString() + ", "; + } + out.resize(out.size() - 1); + out.back() = '}'; + return out; + } - // FIXME represent FunctionOptions std::string out = call->function + "("; for (const auto& arg : call->arguments) { - out += arg.ToString() + ","; + out += arg.ToString() + ", "; } - if (!call->options) { + if (call->options == nullptr) { + out.resize(out.size() - 1); out.back() = ')'; return out; } - if (call->options) { - out += " {"; - - if (auto options = GetSetLookupOptions(*call)) { - DCHECK_EQ(options->value_set.kind(), Datum::ARRAY); - out += "value_set:" + options->value_set.make_array()->ToString(); - if (options->skip_nulls) { - out += ",skip_nulls"; - } - } - - if (auto options = GetCastOptions(*call)) { - out += "to_type:" + options->to_type->ToString(); - if (options->allow_int_overflow) out += ",allow_int_overflow"; - if (options->allow_time_truncate) out += ",allow_time_truncate"; - if (options->allow_time_overflow) out += ",allow_time_overflow"; - if (options->allow_decimal_truncate) out += ",allow_decimal_truncate"; - if (options->allow_float_truncate) out += ",allow_float_truncate"; - if (options->allow_invalid_utf8) out += ",allow_invalid_utf8"; - } - - if (auto options = GetStructOptions(*call)) { - for (const auto& field_name : options->field_names) { - out += field_name + ","; - } - out.resize(out.size() - 1); + if (auto options = GetSetLookupOptions(*call)) { + DCHECK_EQ(options->value_set.kind(), Datum::ARRAY); + out += "value_set=" + options->value_set.make_array()->ToString(); + if (options->skip_nulls) { + out += ", skip_nulls"; } + return out + ")"; + } - if (auto options = GetStrptimeOptions(*call)) { - out += "format:" + options->format; - out += ",unit:" + internal::ToString(options->unit); + if (auto options = GetCastOptions(*call)) { + if (options->to_type == nullptr) { + return out + "to_type=)"; } + out += "to_type=" + options->to_type->ToString(); + if (options->allow_int_overflow) out += ", allow_int_overflow"; + if (options->allow_time_truncate) out += ", allow_time_truncate"; + if (options->allow_time_overflow) out += ", allow_time_overflow"; + if (options->allow_decimal_truncate) out += ", allow_decimal_truncate"; + if (options->allow_float_truncate) out += ", allow_float_truncate"; + if (options->allow_invalid_utf8) out += ", allow_invalid_utf8"; + return out + ")"; + } - out += "})"; + if (auto options = GetStrptimeOptions(*call)) { + return out + "format=" + options->format + + ", unit=" + internal::ToString(options->unit) + ")"; } - return out; + + return out + "{NON-REPRESENTABLE OPTIONS})"; } void PrintTo(const Expression& expr, std::ostream* os) { diff --git a/cpp/src/arrow/dataset/expression_internal.h b/cpp/src/arrow/dataset/expression_internal.h index 306f219859a..c91bbb4b62f 100644 --- a/cpp/src/arrow/dataset/expression_internal.h +++ b/cpp/src/arrow/dataset/expression_internal.h @@ -205,6 +205,28 @@ struct Comparison { DCHECK(false); return "na"; } + + static std::string GetOp(type op) { + switch (op) { + case NA: + DCHECK(false) << "unreachable"; + break; + case EQUAL: + return "=="; + case LESS: + return "<"; + case GREATER: + return ">"; + case NOT_EQUAL: + return "!="; + case LESS_EQUAL: + return "<="; + case GREATER_EQUAL: + return ">="; + } + DCHECK(false); + return ""; + } }; inline const compute::CastOptions* GetCastOptions(const Expression::Call& call) { diff --git a/cpp/src/arrow/dataset/expression_test.cc b/cpp/src/arrow/dataset/expression_test.cc index e5301677ed7..738fab72237 100644 --- a/cpp/src/arrow/dataset/expression_test.cc +++ b/cpp/src/arrow/dataset/expression_test.cc @@ -48,23 +48,45 @@ Expression cast(Expression argument, std::shared_ptr to_type) { } TEST(Expression, ToString) { - EXPECT_EQ(field_ref("alpha").ToString(), "FieldRef(alpha)"); + EXPECT_EQ(field_ref("alpha").ToString(), "alpha"); EXPECT_EQ(literal(3).ToString(), "3"); + EXPECT_EQ(literal("a").ToString(), "\"a\""); + EXPECT_EQ(literal("a\nb").ToString(), "\"a\\nb\""); + EXPECT_EQ(literal(std::make_shared()).ToString(), "null"); + EXPECT_EQ(literal(std::make_shared(Buffer::FromString("az"))).ToString(), + "\"617A\""); auto ts = *MakeScalar("1990-10-23 10:23:33")->CastTo(timestamp(TimeUnit::NANO)); EXPECT_EQ(literal(ts).ToString(), "656677413000000000"); - EXPECT_EQ(call("add", {literal(3), field_ref("beta")}).ToString(), - "add(3,FieldRef(beta))"); + EXPECT_EQ(call("add", {literal(3), field_ref("beta")}).ToString(), "add(3, beta)"); auto in_12 = call("index_in", {field_ref("beta")}, compute::SetLookupOptions{ArrayFromJSON(int32(), "[1,2]")}); - EXPECT_EQ(in_12.ToString(), "index_in(FieldRef(beta), {value_set:[\n 1,\n 2\n]})"); + EXPECT_EQ(in_12.ToString(), "index_in(beta, value_set=[\n 1,\n 2\n])"); - EXPECT_EQ(cast(field_ref("a"), int32()).ToString(), - "cast(FieldRef(a), {to_type:int32})"); + EXPECT_EQ(cast(field_ref("a"), int32()).ToString(), "cast(a, to_type=int32)"); + EXPECT_EQ(cast(field_ref("a"), nullptr).ToString(), + "cast(a, to_type=)"); + + struct WidgetifyOptions : compute::FunctionOptions { + bool really; + }; + + // NB: corrupted for nullary functions but we don't have any of those + EXPECT_EQ(call("widgetify", {}).ToString(), "widgetif)"); + EXPECT_EQ( + call("widgetify", {literal(1)}, std::make_shared()).ToString(), + "widgetify(1, {NON-REPRESENTABLE OPTIONS})"); + + EXPECT_EQ(equal(field_ref("a"), literal(1)).ToString(), "(a == 1)"); + EXPECT_EQ(less(field_ref("a"), literal(2)).ToString(), "(a < 2)"); + EXPECT_EQ(greater(field_ref("a"), literal(3)).ToString(), "(a > 3)"); + EXPECT_EQ(not_equal(field_ref("a"), literal("a")).ToString(), "(a != \"a\")"); + EXPECT_EQ(less_equal(field_ref("a"), literal("b")).ToString(), "(a <= \"b\")"); + EXPECT_EQ(greater_equal(field_ref("a"), literal("c")).ToString(), "(a >= \"c\")"); EXPECT_EQ(project( { @@ -80,8 +102,7 @@ TEST(Expression, ToString) { "b", }) .ToString(), - "struct(FieldRef(a),FieldRef(a),3," + in_12.ToString() + - ", {a,renamed_a,three,b})"); + "{a=a, renamed_a=a, three=3, b=" + in_12.ToString() + "}"); } TEST(Expression, Equality) { diff --git a/r/tests/testthat/test-dataset.R b/r/tests/testthat/test-dataset.R index 73d654eb5a1..a2858115622 100644 --- a/r/tests/testthat/test-dataset.R +++ b/r/tests/testthat/test-dataset.R @@ -499,7 +499,7 @@ test_that("filter scalar validation doesn't crash (ARROW-7772)", { ds %>% filter(int == "fff", part == 1) %>% collect(), - "error parsing 'fff' as scalar of type int32" + "Failed to parse string: 'fff' as a scalar of type int32" ) }) @@ -654,7 +654,7 @@ test_that("Dataset and query print methods", { "lgl: bool", "integer: int32", "", - "* Filter: (int == 6:double)", + "* Filter: (int == 6)", "* Grouped by lgl", "See $.data for the source Arrow object", sep = "\n" diff --git a/r/tests/testthat/test-expression.R b/r/tests/testthat/test-expression.R index 1bf08595758..0c5ef4c12da 100644 --- a/r/tests/testthat/test-expression.R +++ b/r/tests/testthat/test-expression.R @@ -60,7 +60,7 @@ test_that("C++ expressions", { expect_is(!(f < 4), "Expression") expect_output( print(f > 4), - 'Expression\n(f > 4:double)', + 'Expression\n(f > 4)', fixed = TRUE ) # Interprets that as a list type From 8bb9b90467cafa38d01069c676a1e9db487afa76 Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Tue, 15 Dec 2020 17:13:17 -0500 Subject: [PATCH 14/31] print and_kleene, or_kleene as binary ops --- cpp/src/arrow/dataset/expression.cc | 16 ++++++++++++++-- cpp/src/arrow/dataset/expression_test.cc | 4 ++++ 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/cpp/src/arrow/dataset/expression.cc b/cpp/src/arrow/dataset/expression.cc index f4600196e78..693f4b67d58 100644 --- a/cpp/src/arrow/dataset/expression.cc +++ b/cpp/src/arrow/dataset/expression.cc @@ -83,9 +83,21 @@ std::string Expression::ToString() const { } auto call = CallNotNull(*this); - if (auto cmp = Comparison::Get(call->function)) { - return "(" + call->arguments[0].ToString() + " " + Comparison::GetOp(*cmp) + " " + + auto binary = [&](std::string op) { + return "(" + call->arguments[0].ToString() + " " + op + " " + call->arguments[1].ToString() + ")"; + }; + + if (auto cmp = Comparison::Get(call->function)) { + return binary(Comparison::GetOp(*cmp)); + } + + if (call->function == "and_kleene") { + return binary("and"); + } + + if (call->function == "or_kleene") { + return binary("or"); } if (auto options = GetStructOptions(*call)) { diff --git a/cpp/src/arrow/dataset/expression_test.cc b/cpp/src/arrow/dataset/expression_test.cc index 738fab72237..709e19791a8 100644 --- a/cpp/src/arrow/dataset/expression_test.cc +++ b/cpp/src/arrow/dataset/expression_test.cc @@ -67,6 +67,10 @@ TEST(Expression, ToString) { EXPECT_EQ(in_12.ToString(), "index_in(beta, value_set=[\n 1,\n 2\n])"); + EXPECT_EQ(and_(field_ref("a"), field_ref("b")).ToString(), "(a and b)"); + EXPECT_EQ(or_(field_ref("a"), field_ref("b")).ToString(), "(a or b)"); + EXPECT_EQ(not_(field_ref("a")).ToString(), "invert(a)"); + EXPECT_EQ(cast(field_ref("a"), int32()).ToString(), "cast(a, to_type=int32)"); EXPECT_EQ(cast(field_ref("a"), nullptr).ToString(), "cast(a, to_type=)"); From 7c88e351cb61f082720dcf707cb3ac80cef067e7 Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Wed, 16 Dec 2020 14:22:42 -0500 Subject: [PATCH 15/31] clean up Expression class, extract Parameter --- cpp/src/arrow/dataset/expression.cc | 159 ++++++++++++-------- cpp/src/arrow/dataset/expression.h | 62 ++++---- cpp/src/arrow/dataset/expression_internal.h | 29 ++-- cpp/src/arrow/dataset/expression_test.cc | 36 ++--- cpp/src/arrow/dataset/partition_test.cc | 2 +- 5 files changed, 155 insertions(+), 133 deletions(-) diff --git a/cpp/src/arrow/dataset/expression.cc b/cpp/src/arrow/dataset/expression.cc index 693f4b67d58..4c5b693ffb9 100644 --- a/cpp/src/arrow/dataset/expression.cc +++ b/cpp/src/arrow/dataset/expression.cc @@ -39,14 +39,54 @@ using internal::checked_pointer_cast; namespace dataset { -const Expression::Call* Expression::call() const { - return util::get_if(impl_.get()); +Expression::Expression(Call call) : impl_(std::make_shared(std::move(call))) {} + +Expression::Expression(Datum literal) + : impl_(std::make_shared(std::move(literal))) {} + +Expression::Expression(Parameter parameter) + : impl_(std::make_shared(std::move(parameter))) {} + +Expression literal(Datum lit) { return Expression(std::move(lit)); } + +Expression field_ref(FieldRef ref) { + return Expression(Expression::Parameter{std::move(ref), {}}); +} + +Expression call(std::string function, std::vector arguments, + std::shared_ptr options) { + Expression::Call call; + call.function_name = std::move(function); + call.arguments = std::move(arguments); + call.options = std::move(options); + return Expression(std::move(call)); } const Datum* Expression::literal() const { return util::get_if(impl_.get()); } const FieldRef* Expression::field_ref() const { - return util::get_if(impl_.get()); + if (auto parameter = util::get_if(impl_.get())) { + return ¶meter->ref; + } + return nullptr; +} + +const Expression::Call* Expression::call() const { + return util::get_if(impl_.get()); +} + +ValueDescr Expression::descr() const { + if (impl_ == nullptr) return {}; + + if (auto lit = literal()) { + return lit->descr(); + } + + if (auto parameter = util::get_if(impl_.get())) { + return parameter->descr; + } + + return CallNotNull(*this)->descr; } std::string Expression::ToString() const { @@ -88,16 +128,14 @@ std::string Expression::ToString() const { call->arguments[1].ToString() + ")"; }; - if (auto cmp = Comparison::Get(call->function)) { + if (auto cmp = Comparison::Get(call->function_name)) { return binary(Comparison::GetOp(*cmp)); } - if (call->function == "and_kleene") { - return binary("and"); - } - - if (call->function == "or_kleene") { - return binary("or"); + constexpr util::string_view kleene = "_kleene"; + if (util::string_view{call->function_name}.ends_with(kleene)) { + auto op = call->function_name.substr(0, call->function_name.size() - kleene.size()); + return binary(std::move(op)); } if (auto options = GetStructOptions(*call)) { @@ -111,7 +149,7 @@ std::string Expression::ToString() const { return out; } - std::string out = call->function + "("; + std::string out = call->function_name + "("; for (const auto& arg : call->arguments) { out += arg.ToString() + ", "; } @@ -178,7 +216,8 @@ bool Expression::Equals(const Expression& other) const { auto call = CallNotNull(*this); auto other_call = CallNotNull(other); - if (call->function != other_call->function || call->kernel != other_call->kernel) { + if (call->function_name != other_call->function_name || + call->kernel != other_call->kernel) { return false; } @@ -223,7 +262,7 @@ bool Expression::Equals(const Expression& other) const { } ARROW_LOG(WARNING) << "comparing unknown FunctionOptions for function " - << call->function; + << call->function_name; return false; } @@ -241,7 +280,7 @@ size_t Expression::hash() const { auto call = CallNotNull(*this); - size_t out = std::hash{}(call->function); + size_t out = std::hash{}(call->function_name); for (const auto& arg : call->arguments) { out ^= arg.hash(); } @@ -249,7 +288,7 @@ size_t Expression::hash() const { } bool Expression::IsBound() const { - if (descr_.type == nullptr) return false; + if (descr().type == nullptr) return false; if (auto lit = literal()) return true; @@ -278,20 +317,20 @@ bool Expression::IsScalarExpression() const { if (!arg.IsScalarExpression()) return false; } - if (call->kernel == nullptr) { - // this expression is not bound; make a best guess based on - // the default function registry - if (auto function = compute::GetFunctionRegistry() - ->GetFunction(call->function) - .ValueOr(nullptr)) { - return function->kind() == compute::Function::SCALAR; - } + if (call->function) { + return call->function->kind() == compute::Function::SCALAR; + } - // unknown function or other error; conservatively return false - return false; + // this expression is not bound; make a best guess based on + // the default function registry + if (auto function = compute::GetFunctionRegistry() + ->GetFunction(call->function_name) + .ValueOr(nullptr)) { + return function->kind() == compute::Function::SCALAR; } - return call->function_kind == compute::Function::SCALAR; + // unknown function or other error; conservatively return false + return false; } bool Expression::IsNullLiteral() const { @@ -305,7 +344,7 @@ bool Expression::IsNullLiteral() const { } bool Expression::IsSatisfiable() const { - if (descr_.type && descr_.type->id() == Type::NA) { + if (descr().type && descr().type->id() == Type::NA) { return false; } @@ -369,9 +408,9 @@ Status MaybeInsertCast(std::shared_ptr to_type, Expression* expr) { auto call_with_cast = *CallNotNull(with_cast); call_with_cast.arguments[0] = std::move(*expr); - *expr = Expression(std::make_shared(std::move(call_with_cast)), - ValueDescr{std::move(to_type), expr->descr().shape}); + call_with_cast.descr = ValueDescr{std::move(to_type), expr->descr().shape}; + *expr = Expression(std::move(call_with_cast)); return Status::OK(); } @@ -379,7 +418,7 @@ Status InsertImplicitCasts(Expression::Call* call) { DCHECK(std::all_of(call->arguments.begin(), call->arguments.end(), [](const Expression& argument) { return argument.IsBound(); })); - if (IsSameTypesBinary(call->function)) { + if (IsSameTypesBinary(call->function_name)) { for (auto&& argument : call->arguments) { if (auto value_type = GetDictionaryValueType(argument.descr().type)) { RETURN_NOT_OK(MaybeInsertCast(std::move(value_type), &argument)); @@ -389,9 +428,10 @@ Status InsertImplicitCasts(Expression::Call* call) { if (call->arguments[0].descr().shape == ValueDescr::SCALAR) { // argument 0 is scalar so casting is cheap return MaybeInsertCast(call->arguments[1].descr().type, &call->arguments[0]); - } else { - return MaybeInsertCast(call->arguments[0].descr().type, &call->arguments[1]); } + + // cast argument 1 unconditionally + return MaybeInsertCast(call->arguments[0].descr().type, &call->arguments[1]); } if (auto options = GetSetLookupOptions(*call)) { @@ -435,15 +475,13 @@ Result Expression::Bind(ValueDescr in, if (auto ref = field_ref()) { ARROW_ASSIGN_OR_RAISE(auto field, ref->GetOneOrNone(*in.type)); - auto out = *this; - out.descr_ = field ? ValueDescr{field->type(), in.shape} : ValueDescr::Scalar(null()); - return out; + auto descr = field ? ValueDescr{field->type(), in.shape} : ValueDescr::Scalar(null()); + return Expression{Parameter{*ref, std::move(descr)}}; } auto bound_call = *CallNotNull(*this); - ARROW_ASSIGN_OR_RAISE(auto function, GetFunction(bound_call, exec_context)); - bound_call.function_kind = function->kind(); + ARROW_ASSIGN_OR_RAISE(bound_call.function, GetFunction(bound_call, exec_context)); for (auto&& argument : bound_call.arguments) { ARROW_ASSIGN_OR_RAISE(argument, argument.Bind(in, exec_context)); @@ -451,17 +489,18 @@ Result Expression::Bind(ValueDescr in, RETURN_NOT_OK(InsertImplicitCasts(&bound_call)); auto descrs = GetDescriptors(bound_call.arguments); - ARROW_ASSIGN_OR_RAISE(bound_call.kernel, function->DispatchExact(descrs)); + ARROW_ASSIGN_OR_RAISE(bound_call.kernel, bound_call.function->DispatchExact(descrs)); compute::KernelContext kernel_context(exec_context); ARROW_ASSIGN_OR_RAISE(bound_call.kernel_state, InitKernelState(bound_call, exec_context)); kernel_context.SetState(bound_call.kernel_state.get()); - ARROW_ASSIGN_OR_RAISE(auto descr, bound_call.kernel->signature->out_type().Resolve( - &kernel_context, descrs)); + ARROW_ASSIGN_OR_RAISE( + bound_call.descr, + bound_call.kernel->signature->out_type().Resolve(&kernel_context, descrs)); - return Expression(std::make_shared(std::move(bound_call)), std::move(descr)); + return Expression(std::move(bound_call)); } Result Expression::Bind(const Schema& in_schema, @@ -547,7 +586,7 @@ util::optional FoldLeft(It begin, It end, const BinOp& bin_op) { util::optional GetNullHandling( const Expression::Call& call) { - if (call.function_kind == compute::Function::SCALAR) { + if (call.function && call.function->kind() == compute::Function::SCALAR) { return static_cast(call.kernel)->null_handling; } return util::nullopt; @@ -619,7 +658,7 @@ Result FoldConstants(Expression expr) { } } - if (call->function == "and_kleene") { + if (call->function_name == "and_kleene") { for (auto args : ArgumentsAndFlippedArguments(*call)) { // true and x == x if (args.first == literal(true)) return args.second; @@ -633,7 +672,7 @@ Result FoldConstants(Expression expr) { return expr; } - if (call->function == "or_kleene") { + if (call->function_name == "or_kleene") { for (auto args : ArgumentsAndFlippedArguments(*call)) { // false or x == x if (args.first == literal(false)) return args.second; @@ -654,7 +693,7 @@ Result FoldConstants(Expression expr) { inline std::vector GuaranteeConjunctionMembers( const Expression& guaranteed_true_predicate) { auto guarantee = guaranteed_true_predicate.call(); - if (!guarantee || guarantee->function != "and_kleene") { + if (!guarantee || guarantee->function_name != "and_kleene") { return {guaranteed_true_predicate}; } return FlattenedAssociativeChain(guaranteed_true_predicate).fringe; @@ -672,7 +711,7 @@ Status ExtractKnownFieldValuesImpl( auto call = expr.call(); if (!call) return true; - if (call->function == "equal") { + if (call->function_name == "equal") { auto ref = call->arguments[0].field_ref(); auto lit = call->arguments[1].literal(); return !(ref && lit); @@ -741,7 +780,7 @@ inline bool IsBinaryAssociativeCommutative(const Expression::Call& call) { "and", "or", "and_kleene", "or_kleene", "xor", "multiply", "add", "multiply_checked", "add_checked"}; - auto it = binary_associative_commutative.find(call.function); + auto it = binary_associative_commutative.find(call.function_name); return it != binary_associative_commutative.end(); } @@ -798,37 +837,35 @@ Result Canonicalize(Expression expr, compute::ExecContext* exec_cont std::stable_sort(chain.fringe.begin(), chain.fringe.end(), CanonicalOrdering); // fold the chain back up - const auto& descr = expr.descr(); auto folded = FoldLeft(chain.fringe.begin(), chain.fringe.end(), - [call, &descr, &AlreadyCanonicalized](Expression l, Expression r) { - auto ret = *call; - ret.arguments = {std::move(l), std::move(r)}; - Expression expr( - std::make_shared(std::move(ret)), descr); + [call, &AlreadyCanonicalized](Expression l, Expression r) { + auto canonicalized_call = *call; + canonicalized_call.arguments = {std::move(l), std::move(r)}; + Expression expr(std::move(canonicalized_call)); AlreadyCanonicalized.Add({expr}); return expr; }); return std::move(*folded); } - if (auto cmp = Comparison::Get(call->function)) { + if (auto cmp = Comparison::Get(call->function_name)) { if (call->arguments[0].literal() && !call->arguments[1].literal()) { // ensure that literals are on comparisons' RHS auto flipped_call = *call; - flipped_call.function = Comparison::GetName(Comparison::GetFlipped(*cmp)); + flipped_call.function_name = + Comparison::GetName(Comparison::GetFlipped(*cmp)); // look up the flipped kernel // TODO extract a helper for use here and in Bind ARROW_ASSIGN_OR_RAISE( auto function, - exec_context->func_registry()->GetFunction(flipped_call.function)); + exec_context->func_registry()->GetFunction(flipped_call.function_name)); auto descrs = GetDescriptors(flipped_call.arguments); ARROW_ASSIGN_OR_RAISE(flipped_call.kernel, function->DispatchExact(descrs)); std::swap(flipped_call.arguments[0], flipped_call.arguments[1]); - return Expression(std::make_shared(std::move(flipped_call)), - expr.descr()); + return Expression(std::move(flipped_call)); } } @@ -847,7 +884,7 @@ Result DirectComparisonSimplification(Expression expr, // Ensure both calls are comparisons with equal LHS and scalar RHS auto cmp = Comparison::Get(expr); - auto cmp_guarantee = Comparison::Get(guarantee.function); + auto cmp_guarantee = Comparison::Get(guarantee.function_name); if (!cmp || !cmp_guarantee) return expr; if (call->arguments[0] != guarantee.arguments[0]) return expr; @@ -961,7 +998,7 @@ Result> Serialize(const Expression& expr) { } auto call = CallNotNull(expr); - metadata_->Append("call", call->function); + metadata_->Append("call", call->function_name); for (const auto& argument : call->arguments) { RETURN_NOT_OK(Visit(argument)); @@ -973,7 +1010,7 @@ Result> Serialize(const Expression& expr) { metadata_->Append("options", std::move(value)); } - metadata_->Append("end", call->function); + metadata_->Append("end", call->function_name); return Status::OK(); } diff --git a/cpp/src/arrow/dataset/expression.h b/cpp/src/arrow/dataset/expression.h index 0ab98327619..6137666abff 100644 --- a/cpp/src/arrow/dataset/expression.h +++ b/cpp/src/arrow/dataset/expression.h @@ -48,15 +48,15 @@ namespace dataset { class ARROW_DS_EXPORT Expression { public: struct Call { - std::string function; + std::string function_name; std::vector arguments; std::shared_ptr options; // post-Bind properties: const compute::Kernel* kernel = NULLPTR; - compute::Function::Kind - function_kind; // XXX give Kernel a non-owning pointer to its Function + std::shared_ptr function; std::shared_ptr kernel_state; + ValueDescr descr; }; std::string ToString() const; @@ -102,20 +102,23 @@ class ARROW_DS_EXPORT Expression { const Datum* literal() const; const FieldRef* field_ref() const; - const ValueDescr& descr() const { return descr_; } - - using Impl = util::Variant; + ValueDescr descr() const; + // XXX someday + // NullGeneralization::type nullable() const; - explicit Expression(std::shared_ptr impl, ValueDescr descr = {}) - : impl_(std::move(impl)), descr_(std::move(descr)) {} + struct Parameter { + FieldRef ref; + ValueDescr descr; + }; Expression() = default; + explicit Expression(Call call); + explicit Expression(Datum literal); + explicit Expression(Parameter parameter); private: + using Impl = util::Variant; std::shared_ptr impl_; - ValueDescr descr_; - // XXX someday - // NullGeneralization::type evaluates_to_null_; ARROW_EXPORT friend bool Identical(const Expression& l, const Expression& r); @@ -127,15 +130,21 @@ inline bool operator!=(const Expression& l, const Expression& r) { return !l.Equ // Factories -inline Expression call(std::string function, std::vector arguments, - std::shared_ptr options = NULLPTR) { - Expression::Call call; - call.function = std::move(function); - call.arguments = std::move(arguments); - call.options = std::move(options); - return Expression(std::make_shared(std::move(call))); +ARROW_DS_EXPORT +Expression literal(Datum lit); + +template +Expression literal(Arg&& arg) { + return literal(Datum(std::forward(arg))); } +ARROW_DS_EXPORT +Expression field_ref(FieldRef ref); + +ARROW_DS_EXPORT +Expression call(std::string function, std::vector arguments, + std::shared_ptr options = NULLPTR); + template ::value>::type> Expression call(std::string function, std::vector arguments, @@ -144,23 +153,6 @@ Expression call(std::string function, std::vector arguments, std::make_shared(std::move(options))); } -template -Expression field_ref(Args&&... args) { - return Expression( - std::make_shared(FieldRef(std::forward(args)...))); -} - -template -Expression literal(Arg&& arg) { - Datum lit(std::forward(arg)); - ValueDescr descr = lit.descr(); - return Expression(std::make_shared(std::move(lit)), std::move(descr)); -} - -inline Expression literal(std::shared_ptr scalar) { - return literal(Datum(std::move(scalar))); -} - ARROW_DS_EXPORT std::vector FieldsInExpression(const Expression&); diff --git a/cpp/src/arrow/dataset/expression_internal.h b/cpp/src/arrow/dataset/expression_internal.h index c91bbb4b62f..eba972dc18b 100644 --- a/cpp/src/arrow/dataset/expression_internal.h +++ b/cpp/src/arrow/dataset/expression_internal.h @@ -135,7 +135,7 @@ struct Comparison { static const type* Get(const Expression& expr) { if (auto call = expr.call()) { - return Comparison::Get(call->function); + return Comparison::Get(call->function_name); } return nullptr; } @@ -230,7 +230,7 @@ struct Comparison { }; inline const compute::CastOptions* GetCastOptions(const Expression::Call& call) { - if (call.function != "cast") return nullptr; + if (call.function_name != "cast") return nullptr; return checked_cast(call.options.get()); } @@ -248,17 +248,17 @@ inline bool IsSameTypesBinary(const std::string& function) { inline const compute::SetLookupOptions* GetSetLookupOptions( const Expression::Call& call) { - if (!IsSetLookup(call.function)) return nullptr; + if (!IsSetLookup(call.function_name)) return nullptr; return checked_cast(call.options.get()); } inline const compute::StructOptions* GetStructOptions(const Expression::Call& call) { - if (call.function != "struct") return nullptr; + if (call.function_name != "struct") return nullptr; return checked_cast(call.options.get()); } inline const compute::StrptimeOptions* GetStrptimeOptions(const Expression::Call& call) { - if (call.function != "strptime") return nullptr; + if (call.function_name != "strptime") return nullptr; return checked_cast(call.options.get()); } @@ -321,7 +321,7 @@ inline Result> FunctionOptionsToStructScalar( {"value_set", "skip_nulls"}); } - if (call.function == "cast") { + if (call.function_name == "cast") { auto options = checked_cast(call.options.get()); return Finish( { @@ -344,7 +344,7 @@ inline Result> FunctionOptionsToStructScalar( }); } - return Status::NotImplemented("conversion of options for ", call.function); + return Status::NotImplemented("conversion of options for ", call.function_name); } inline Status FunctionOptionsFromStructScalar(const StructScalar* repr, @@ -354,7 +354,7 @@ inline Status FunctionOptionsFromStructScalar(const StructScalar* repr, return Status::OK(); } - if (IsSetLookup(call->function)) { + if (IsSetLookup(call->function_name)) { ARROW_ASSIGN_OR_RAISE(auto value_set, repr->field("value_set")); ARROW_ASSIGN_OR_RAISE(auto skip_nulls, repr->field("skip_nulls")); call->options = std::make_shared( @@ -363,7 +363,7 @@ inline Status FunctionOptionsFromStructScalar(const StructScalar* repr, return Status::OK(); } - if (call->function == "cast") { + if (call->function_name == "cast") { auto options = std::make_shared(); ARROW_ASSIGN_OR_RAISE(auto to_type_holder, repr->field("to_type_holder")); options->to_type = to_type_holder->type; @@ -384,7 +384,7 @@ inline Status FunctionOptionsFromStructScalar(const StructScalar* repr, return Status::OK(); } - return Status::NotImplemented("conversion of options for ", call->function); + return Status::NotImplemented("conversion of options for ", call->function_name); } struct FlattenedAssociativeChain { @@ -399,7 +399,7 @@ struct FlattenedAssociativeChain { while (it != fringe.end()) { auto sub_call = it->call(); - if (!sub_call || sub_call->function != call->function) { + if (!sub_call || sub_call->function_name != call->function_name) { ++it; continue; } @@ -422,8 +422,8 @@ struct FlattenedAssociativeChain { inline Result> GetFunction( const Expression::Call& call, compute::ExecContext* exec_context) { - if (call.function != "cast") { - return exec_context->func_registry()->GetFunction(call.function); + if (call.function_name != "cast") { + return exec_context->func_registry()->GetFunction(call.function_name); } // XXX this special case is strange; why not make "cast" a ScalarFunction? const auto& to_type = checked_cast(*call.options).to_type; @@ -453,8 +453,7 @@ Result Modify(Expression expr, const PreVisit& pre, if (at_least_one_modified) { // reconstruct the call expression with the modified arguments - auto modified_expr = Expression( - std::make_shared(std::move(modified_call)), expr.descr()); + auto modified_expr = Expression(std::move(modified_call)); return post_call(std::move(modified_expr), &expr); } diff --git a/cpp/src/arrow/dataset/expression_test.cc b/cpp/src/arrow/dataset/expression_test.cc index 709e19791a8..3b286e378d0 100644 --- a/cpp/src/arrow/dataset/expression_test.cc +++ b/cpp/src/arrow/dataset/expression_test.cc @@ -282,8 +282,9 @@ TEST(Expression, BindFieldRef) { {field("alpha", int32()), field("alpha", float32())}))); // referencing nested fields is supported - ASSERT_OK_AND_ASSIGN(expr, field_ref("a", "b").Bind( - Schema({field("a", struct_({field("b", int32())}))}))); + ASSERT_OK_AND_ASSIGN(expr, + field_ref(FieldRef("a", "b")) + .Bind(Schema({field("a", struct_({field("b", int32())}))}))); EXPECT_TRUE(expr.IsBound()); EXPECT_EQ(expr.descr(), ValueDescr::Array(int32())); } @@ -476,6 +477,13 @@ TEST(Expression, ExecuteDictionaryTransparent) { ])")); } +void ExpectIdenticalIfUnchanged(Expression modified, Expression original) { + if (modified == original) { + // no change -> must be identical + EXPECT_TRUE(Identical(modified, original)) << " " << original.ToString(); + } +} + struct { void operator()(Expression expr, Expression expected) { ASSERT_OK_AND_ASSIGN(expr, expr.Bind(*kBoringSchema)); @@ -484,11 +492,7 @@ struct { ASSERT_OK_AND_ASSIGN(auto folded, FoldConstants(expr)); EXPECT_EQ(folded, expected); - - if (folded == expr) { - // no change -> must be identical - EXPECT_TRUE(Identical(folded, expr)); - } + ExpectIdenticalIfUnchanged(folded, expr); } } ExpectFoldsTo; @@ -629,11 +633,7 @@ TEST(Expression, ReplaceFieldsWithKnownValues) { ReplaceFieldsWithKnownValues(known_values, expr)); EXPECT_EQ(replaced, expected); - - if (replaced == expr) { - // no change -> must be identical - EXPECT_TRUE(Identical(replaced, expr)); - } + ExpectIdenticalIfUnchanged(replaced, expr); }; std::unordered_map i32_is_3{{"i32", Datum(3)}}; @@ -677,18 +677,14 @@ struct { ASSERT_OK_AND_ASSIGN(auto actual, Canonicalize(bound)); EXPECT_EQ(actual, expected); - - if (actual == expr) { - // no change -> must be identical - EXPECT_TRUE(Identical(actual, expr)); - } + ExpectIdenticalIfUnchanged(actual, bound); } } ExpectCanonicalizesTo; TEST(Expression, CanonicalizeTrivial) { ExpectCanonicalizesTo(literal(1), literal(1)); - ExpectCanonicalizesTo(field_ref("b"), field_ref("b")); + ExpectCanonicalizesTo(field_ref("i32"), field_ref("i32")); ExpectCanonicalizesTo(equal(field_ref("i32"), field_ref("i32_req")), equal(field_ref("i32"), field_ref("i32_req"))); @@ -752,9 +748,7 @@ struct Simplify { << " guarantee: " << guarantee.ToString() << "\n" << (simplified == bound ? " (no change)\n" : ""); - if (simplified == bound) { - EXPECT_TRUE(Identical(simplified, bound)); - } + ExpectIdenticalIfUnchanged(simplified, bound); } void ExpectUnchanged() { Expect(expr); } void Expect(bool constant) { Expect(literal(constant)); } diff --git a/cpp/src/arrow/dataset/partition_test.cc b/cpp/src/arrow/dataset/partition_test.cc index 840e2f0303f..8610ff1e891 100644 --- a/cpp/src/arrow/dataset/partition_test.cc +++ b/cpp/src/arrow/dataset/partition_test.cc @@ -429,7 +429,7 @@ TEST_F(TestPartitioning, Set) { set.push_back(checked_cast(*s).value); } - subexpressions.push_back(call("is_in", {field_ref(matches[1])}, + subexpressions.push_back(call("is_in", {field_ref(std::string(matches[1]))}, compute::SetLookupOptions{ints(set)})); } return and_(std::move(subexpressions)); From dfad163b79f5a89cc31705bad13c308d23801ce6 Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Wed, 16 Dec 2020 14:33:02 -0500 Subject: [PATCH 16/31] revert variant noexcept changes --- cpp/src/arrow/datum.h | 4 ++-- cpp/src/arrow/util/variant.h | 32 ++++++++++++------------------ cpp/src/arrow/util/variant_test.cc | 13 ------------ 3 files changed, 15 insertions(+), 34 deletions(-) diff --git a/cpp/src/arrow/datum.h b/cpp/src/arrow/datum.h index 3d1f545e5f7..8479f7ad366 100644 --- a/cpp/src/arrow/datum.h +++ b/cpp/src/arrow/datum.h @@ -120,8 +120,8 @@ struct ARROW_EXPORT Datum { Datum(const Datum& other) = default; Datum& operator=(const Datum& other) = default; - Datum(Datum&& other) noexcept = default; - Datum& operator=(Datum&& other) noexcept = default; + Datum(Datum&& other) = default; + Datum& operator=(Datum&& other) = default; Datum(std::shared_ptr value) // NOLINT implicit conversion : value(std::move(value)) {} diff --git a/cpp/src/arrow/util/variant.h b/cpp/src/arrow/util/variant.h index 04ac5e2bb45..89f39ab8917 100644 --- a/cpp/src/arrow/util/variant.h +++ b/cpp/src/arrow/util/variant.h @@ -96,7 +96,7 @@ template struct all : conditional_t, std::false_type> {}; struct delete_copy_constructor { - template + template struct type { type() = default; type(const type& other) = delete; @@ -105,13 +105,11 @@ struct delete_copy_constructor { }; struct explicit_copy_constructor { - template + template struct type { type() = default; - type(const type& other) noexcept(NoexceptCopyable) { - static_cast(other).copy_to(this); - } - type& operator=(const type& other) noexcept(NoexceptCopyable) { + type(const type& other) { static_cast(other).copy_to(this); } + type& operator=(const type& other) { static_cast(this)->destroy(); static_cast(other).copy_to(this); return *this; @@ -119,18 +117,11 @@ struct explicit_copy_constructor { }; }; -template -using CopyConstructionImpl = - typename conditional_t::value && - std::is_copy_assignable::value)...>::value, - explicit_copy_constructor, delete_copy_constructor>:: - template type::value...>::value>; - template struct VariantStorage { VariantStorage() = default; - VariantStorage(const VariantStorage&) noexcept {} - VariantStorage& operator=(const VariantStorage&) noexcept { return *this; } + VariantStorage(const VariantStorage&) {} + VariantStorage& operator=(const VariantStorage&) { return *this; } VariantStorage(VariantStorage&&) noexcept {} VariantStorage& operator=(VariantStorage&&) noexcept { return *this; } ~VariantStorage() { @@ -201,8 +192,7 @@ struct VariantImpl, H, T...> : VariantImpl, T...> { // Templated to avoid instantiation in case H is not copy constructible template - void copy_to(Void* generic_target) const - noexcept(noexcept(H(std::declval()))) { + void copy_to(Void* generic_target) const { const auto target = static_cast(generic_target); try { if (this->index_ == kIndex) { @@ -253,7 +243,11 @@ struct VariantImpl, H, T...> : VariantImpl, T...> { template class Variant : detail::VariantImpl, T...>, - detail::CopyConstructionImpl, T...> { + detail::conditional_t< + detail::all<(std::is_copy_constructible::value && + std::is_copy_assignable::value)...>::value, + detail::explicit_copy_constructor, + detail::delete_copy_constructor>::template type> { template static constexpr uint8_t index_of() { return Impl::index_of(detail::type_constant{}); @@ -344,7 +338,7 @@ class Variant : detail::VariantImpl, T...>, this->index_ = 0; } - template + template friend struct detail::explicit_copy_constructor::type; template diff --git a/cpp/src/arrow/util/variant_test.cc b/cpp/src/arrow/util/variant_test.cc index 5d83e00185f..9e36f2eb9cf 100644 --- a/cpp/src/arrow/util/variant_test.cc +++ b/cpp/src/arrow/util/variant_test.cc @@ -125,19 +125,6 @@ TEST(Variant, CopyConstruction) { EXPECT_NO_THROW(AssertCopyConstruction(CopyAssignThrows{})); } -TEST(Variant, Noexcept) { - struct CopyThrows { - CopyThrows() = default; - CopyThrows(const CopyThrows&) { throw 42; } - CopyThrows& operator=(const CopyThrows&) { throw 42; } - }; - static_assert(!std::is_nothrow_copy_constructible::value, ""); - static_assert(std::is_nothrow_copy_constructible::value, ""); - static_assert(std::is_nothrow_copy_constructible>::value, ""); - static_assert( - !std::is_nothrow_copy_constructible>::value, ""); -} - TEST(Variant, Emplace) { using variant_type = Variant, int>; variant_type v; From b32f2540c0d4860be5e02cb6361fa7f094f09322 Mon Sep 17 00:00:00 2001 From: Neal Richardson Date: Wed, 16 Dec 2020 16:09:33 -0800 Subject: [PATCH 17/31] Start adding support for arithmetic functions in R --- r/R/expression.R | 25 ++++++++++++ r/tests/testthat/test-compute-arith.R | 57 +++++++++++++++++++++++++++ 2 files changed, 82 insertions(+) create mode 100644 r/tests/testthat/test-compute-arith.R diff --git a/r/R/expression.R b/r/R/expression.R index f9e09c2fadd..046b4f11ee6 100644 --- a/r/R/expression.R +++ b/r/R/expression.R @@ -59,6 +59,20 @@ build_array_expression <- function(.Generic, e1, e2, ...) { } else { e1 <- .wrap_arrow(e1, .Generic, e2$type) e2 <- .wrap_arrow(e2, .Generic, e1$type) + + # In Arrow, "divide" is one function, which does integer division on + # integer inputs and floating-point division on floats + if (.Generic == "/") { + # TODO: cast needs to be an expression + # TODO: omg so many ways it's wrong to assume these types + e1 <- e1$cast(float64()) + e2 <- e2$cast(float64()) + } else if (.Generic == "%/%") { + e1 <- e1$cast(int32()) + e2 <- e2$cast(int32()) + } else if (.Generic == "%%") { + # e1 - e1 %/% e2 + } expr <- array_expression(.binary_function_map[[.Generic]], e1, e2, ...) } expr @@ -91,9 +105,20 @@ build_array_expression <- function(.Generic, e1, e2, ...) { "<=" = "less_equal", "&" = "and_kleene", "|" = "or_kleene", + "+" = "add", + "-" = "subtract", + "*" = "multiply", + "/" = "divide", + "%/%" = "divide", "%in%" = "is_in_meta_binary" ) + +# ‘"^"’, +# ‘"%%"’, +# ‘"%/%"’ + + .array_function_map <- c(.unary_function_map, .binary_function_map) eval_array_expression <- function(x) { diff --git a/r/tests/testthat/test-compute-arith.R b/r/tests/testthat/test-compute-arith.R new file mode 100644 index 00000000000..633cc57af1b --- /dev/null +++ b/r/tests/testthat/test-compute-arith.R @@ -0,0 +1,57 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# TODO: +# * Use _checked variations? See what R does +# * More tests for edge cases, esp. with division; add test helpers here? +# * Is there a better "autocasting" solution? See what rules C++ Datasets do +# * test-dplyr tests +# * then, dataset tests, special casing for division + +test_that("Addition", { + a <- Array$create(c(1:4, NA_integer_)) + expect_type_equal(a, int32()) + expect_type_equal(a + 4, int32()) + expect_equal(a + 4, Array$create(c(5:8, NA_integer_))) + expect_identical(as.vector(a + 4), c(5:8, NA_integer_)) + expect_equal(a + 4L, Array$create(c(5:8, NA_integer_))) + expect_vector(a + 4L, c(5:8, NA_integer_)) + expect_equal(a + NA_integer_, Array$create(rep(NA_integer_, 5))) + skip("autocasting should happen in compute kernels; R workaround fails on this") + expect_type_equal(a + 4.1, float64()) + expect_equal(a + 4.1, Array$create(c(5.1, 6.1, 7.1, 8.1, NA_real_))) +}) + +test_that("Subtraction", { + a <- Array$create(c(1:4, NA_integer_)) + expect_equal(a - 3, Array$create(c(-2:1, NA_integer_))) +}) + +test_that("Multiplication", { + a <- Array$create(c(1:4, NA_integer_)) + expect_equal(a * 2, Array$create(c(1:4 * 2L, NA_integer_))) +}) + +test_that("Division", { + a <- Array$create(c(1:4, NA_integer_)) + expect_equal(a / 2, Array$create(c(1:4 / 2, NA_real_))) + expect_equal(a %/% 2, Array$create(c(0L, 1L, 1L, 2L, NA_integer_))) + + b <- a$cast(float64()) + expect_equal(b / 2, Array$create(c(1:4 / 2, NA_real_))) + expect_equal(b %/% 2, Array$create(c(0L, 1L, 1L, 2L, NA_integer_))) +}) From 6b3e796fcfc9b5c1e1f2584191bffbe504afcb92 Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Wed, 16 Dec 2020 21:26:11 -0500 Subject: [PATCH 18/31] get python binding building --- python/pyarrow/_dataset.pyx | 149 ++++++++----------- python/pyarrow/compute.py | 2 +- python/pyarrow/includes/libarrow.pxd | 4 + python/pyarrow/includes/libarrow_dataset.pxd | 17 ++- python/pyarrow/tests/test_dataset.py | 13 -- 5 files changed, 77 insertions(+), 108 deletions(-) diff --git a/python/pyarrow/_dataset.pyx b/python/pyarrow/_dataset.pyx index e235d50c74a..7b96e87633b 100644 --- a/python/pyarrow/_dataset.pyx +++ b/python/pyarrow/_dataset.pyx @@ -116,111 +116,79 @@ cdef class Expression(_Weakrefable): @staticmethod def _deserialize(Buffer buffer not None): - c_buffer = pyarrow_unwrap_buffer(buffer) - #c_expr = GetResultValue(CExpression.Deserialize(deref(c_buffer))) - return Expression.wrap(move(c_expr)) + return Expression.wrap(GetResultValue(CDeserializeExpression( + deref(pyarrow_unwrap_buffer(buffer))))) def __reduce__(self): - buffer = pyarrow_wrap_buffer(GetResultValue(self.expr.Serialize())) + buffer = pyarrow_wrap_buffer(GetResultValue( + CSerializeExpression(self.expr))) return Expression._deserialize, (buffer,) - def validate(self, Schema schema not None): - """Validate this expression for execution against a schema. - - This will check that all reference fields are present (fields not in - the schema will be replaced with null) and all subexpressions are - executable. Returns the type to which this expression will evaluate. - - Parameters - ---------- - schema : Schema - Schema to execute the expression on. + @staticmethod + cdef Expression _expr_or_scalar(object expr): + if isinstance(expr, Expression): + return ( expr) + return ( Expression._scalar(expr)) - Returns - ------- - type : DataType - """ + @staticmethod + cdef Expression _call(str function_name, list arguments, + shared_ptr[CFunctionOptions] options=( + nullptr)): cdef: - shared_ptr[CSchema] sp_schema - CResult[shared_ptr[CDataType]] result - sp_schema = pyarrow_unwrap_schema(schema) - result = self.expr.Validate(deref(sp_schema)) - return pyarrow_wrap_data_type(GetResultValue(result)) + vector[CExpression] c_arguments - def assume(self, Expression given): - """Simplify to an equivalent Expression given assumed constraints.""" - return Expression.wrap(self.expr.Assume(given.unwrap())) + for argument in arguments: + c_arguments.push_back(( argument).expr) - def __invert__(self): - return Expression.wrap(CMakeNotExpression(self.unwrap())) - - @staticmethod - cdef CExpression _expr_or_scalar(object expr) except *: - if isinstance(expr, Expression): - return ( expr).unwrap() - return ( Expression._scalar(expr)).unwrap() + return Expression.wrap(CMakeCallExpression(tobytes(function_name), + move(c_arguments), options)) def __richcmp__(self, other, int op): - cdef: - CExpression c_expr - CExpression c_left - CExpression c_right - - c_left = self.unwrap() - c_right = Expression._expr_or_scalar(other) - - if op == Py_EQ: - c_expr = CMakeEqualExpression(move(c_left), move(c_right)) - elif op == Py_NE: - c_expr = CMakeNotEqualExpression(move(c_left), move(c_right)) - elif op == Py_GT: - c_expr = CMakeGreaterExpression(move(c_left), move(c_right)) - elif op == Py_GE: - c_expr = CMakeGreaterEqualExpression(move(c_left), move(c_right)) - elif op == Py_LT: - c_expr = CMakeLessExpression(move(c_left), move(c_right)) - elif op == Py_LE: - c_expr = CMakeLessEqualExpression(move(c_left), move(c_right)) - - return Expression.wrap(c_expr) + other = Expression._expr_or_scalar(other) + return Expression._call({ + Py_EQ: "equal", + Py_NE: "not_equal", + Py_GT: "greater", + Py_GE: "greater_equal", + Py_LT: "less", + Py_LE: "less_equal", + }[op], [self, other]) + + def __invert__(self): + return Expression._call("invert", [self]) def __and__(Expression self, other): - c_other = Expression._expr_or_scalar(other) - return Expression.wrap(CMakeAndExpression(self.wrapped, - move(c_other))) + other = Expression._expr_or_scalar(other) + return Expression._call("and_kleene", [self, other]) def __or__(Expression self, other): - c_other = Expression._expr_or_scalar(other) - return Expression.wrap(CMakeOrExpression(self.wrapped, - move(c_other))) + other = Expression._expr_or_scalar(other) + return Expression._call("or_kleene", [self, other]) def is_valid(self): """Checks whether the expression is not-null (valid)""" - return Expression.wrap(self.expr.IsValid().Copy()) + return Expression._call("is_valid", [self]) def cast(self, type, bint safe=True): """Explicitly change the expression's data type""" - cdef CCastOptions options - if safe: - options = CCastOptions.Safe() - else: - options = CCastOptions.Unsafe() - c_type = pyarrow_unwrap_data_type(ensure_type(type)) - return Expression.wrap(self.expr.CastTo(c_type, options).Copy()) + cdef shared_ptr[CCastOptions] c_options + c_options.reset(new CCastOptions(safe)) + c_options.get().to_type = pyarrow_unwrap_data_type(ensure_type(type)) + return Expression._call("cast", [self], + c_options) def isin(self, values): """Checks whether the expression is contained in values""" cdef: - vector[CExpression] arguments - arguments.push_back(self.expr) + shared_ptr[CFunctionOptions] c_options + CDatum c_values if not isinstance(values, pa.Array): values = pa.array(values) - c_values = pyarrow_unwrap_array(values) - return Expression.wrap(CMakeCallExpression( - tobytes("is_in"), - move(arguments), - move(SetLookupOptions(values).set_lookup_options))) + + c_values = CDatum(pyarrow_unwrap_array(values)) + c_options.reset(new CSetLookupOptions(c_values, True)) + return Expression._call("is_in", [self], c_options) @staticmethod def _field(str name not None): @@ -322,10 +290,11 @@ cdef class Dataset(_Weakrefable): CFragmentIterator c_iterator if filter is None: - c_fragments = self.dataset.GetFragments() + c_fragments = move(GetResultValue(self.dataset.GetFragments())) else: c_filter = _bind(filter, self.schema) - c_fragments = self.dataset.GetFragments(c_filter) + c_fragments = move(GetResultValue( + self.dataset.GetFragments(c_filter))) for maybe_fragment in c_fragments: yield Fragment.wrap(GetResultValue(move(maybe_fragment))) @@ -596,7 +565,7 @@ cdef CExpression _bind(Expression filter, Schema schema) except *: return _true.unwrap() return GetResultValue(filter.unwrap().Bind( - deref(pyarrow_unwrap_schema(schema).get()))) + deref(pyarrow_unwrap_schema(schema).get()))) cdef class FileWriteOptions(_Weakrefable): @@ -2264,7 +2233,8 @@ cdef class Scanner(_Weakrefable): def get_fragments(self): """Returns an iterator over the fragments in this scan. """ - cdef CFragmentIterator c_fragments = self.scanner.GetFragments() + cdef CFragmentIterator c_fragments = move(GetResultValue( + self.scanner.GetFragments())) for maybe_fragment in c_fragments: yield Fragment.wrap(GetResultValue(move(maybe_fragment))) @@ -2284,12 +2254,15 @@ def _get_partition_keys(Expression partition_expression): """ cdef: CExpression expr = partition_expression.unwrap() - pair[c_string, shared_ptr[CScalar]] name_val - - return { - frombytes(name_val.first): pyarrow_wrap_scalar(name_val.second).as_py() - for name_val in GetResultValue(CGetPartitionKeys(expr)) - } + pair[CFieldRef, CDatum] name_val + + out = {} + for ref_val in GetResultValue(CExtractKnownFieldValues(expr)): + assert ref_val.first.name() != nullptr + assert ref_val.second.kind() == DatumType_SCALAR + val = pyarrow_wrap_scalar(name_val.second.scalar()).as_py() + out[deref(ref_val.first.name())] = val + return out def _filesystemdataset_write( diff --git a/python/pyarrow/compute.py b/python/pyarrow/compute.py index ddfd8057db2..ddd77174d13 100644 --- a/python/pyarrow/compute.py +++ b/python/pyarrow/compute.py @@ -62,7 +62,7 @@ def _get_arg_names(func): arg_names = ["left", "right"] else: raise NotImplementedError( - "unsupported arity: {}".format(func.arity)) + f"unsupported arity: {func.arity} (function: {func.name})") return arg_names diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index 45a39061919..aaee9fb5351 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -405,6 +405,10 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil: CFieldRef() CFieldRef(c_string name) CFieldRef(int index) + const c_string* name() const + + cdef cppclass CFieldRefHash" arrow::FieldRef::Hash": + pass cdef cppclass CStructType" arrow::StructType"(CDataType): CStructType(const vector[shared_ptr[CField]]& fields) diff --git a/python/pyarrow/includes/libarrow_dataset.pxd b/python/pyarrow/includes/libarrow_dataset.pxd index 53e48da1703..98e6c20bf23 100644 --- a/python/pyarrow/includes/libarrow_dataset.pxd +++ b/python/pyarrow/includes/libarrow_dataset.pxd @@ -50,6 +50,11 @@ cdef extern from "arrow/dataset/api.h" namespace "arrow::dataset" nogil: vector[CExpression] arguments, shared_ptr[CFunctionOptions] options) + cdef CResult[shared_ptr[CBuffer]] CSerializeExpression \ + "arrow::dataset::Serialize"(const CExpression&) + cdef CResult[CExpression] CDeserializeExpression \ + "arrow::dataset::Deserialize"(const CBuffer&) + cdef cppclass CRecordBatchProjector "arrow::dataset::RecordBatchProjector": pass @@ -95,7 +100,7 @@ cdef extern from "arrow/dataset/api.h" namespace "arrow::dataset" nogil: shared_ptr[CScanContext]) CResult[CScanTaskIterator] Scan() CResult[shared_ptr[CTable]] ToTable() - CFragmentIterator GetFragments() + CResult[CFragmentIterator] GetFragments() const shared_ptr[CScanOptions]& options() cdef cppclass CScannerBuilder "arrow::dataset::ScannerBuilder": @@ -115,8 +120,8 @@ cdef extern from "arrow/dataset/api.h" namespace "arrow::dataset" nogil: cdef cppclass CDataset "arrow::dataset::Dataset": const shared_ptr[CSchema] & schema() - CFragmentIterator GetFragments() - CFragmentIterator GetFragments(CExpression predicate) + CResult[CFragmentIterator] GetFragments() + CResult[CFragmentIterator] GetFragments(CExpression predicate) const CExpression & partition_expression() c_string type_name() @@ -301,9 +306,9 @@ cdef extern from "arrow/dataset/api.h" namespace "arrow::dataset" nogil: const CExpression& partition_expression, CRecordBatchProjector* projector) - cdef CResult[unordered_map[c_string, shared_ptr[CScalar]]] \ - CGetPartitionKeys "arrow::dataset::KeyValuePartitioning::GetKeys"( - const CExpression& partition_expression) + cdef CResult[unordered_map[CFieldRef, CDatum, CFieldRefHash]] \ + CExtractKnownFieldValues "arrow::dataset::ExtractKnownFieldValues"( + const CExpression& partition_expression) cdef cppclass CFileSystemFactoryOptions \ "arrow::dataset::FileSystemFactoryOptions": diff --git a/python/pyarrow/tests/test_dataset.py b/python/pyarrow/tests/test_dataset.py index 3bda5692128..a783c3b4273 100644 --- a/python/pyarrow/tests/test_dataset.py +++ b/python/pyarrow/tests/test_dataset.py @@ -408,19 +408,6 @@ def test_expression_serialization(): f = ds.scalar({'a': 1}) g = ds.scalar(pa.scalar(1)) - condition = ds.field('i64') > 5 - schema = pa.schema([ - pa.field('i64', pa.int64()), - pa.field('f64', pa.float64()) - ]) - assert condition.validate(schema) == pa.bool_() - - assert condition.assume(ds.field('i64') == 5).equals( - ds.scalar(False)) - - assert condition.assume(ds.field('i64') == 7).equals( - ds.scalar(True)) - all_exprs = [a, b, c, d, e, f, g, a == b, a > b, a & b, a | b, ~c, d.is_valid(), a.cast(pa.int32(), safe=False), a.cast(pa.int32(), safe=False), a.isin([1, 2, 3]), From 849308a744d22de2e26bc41a9ea46a15bd4ce74f Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Thu, 17 Dec 2020 13:00:57 -0500 Subject: [PATCH 19/31] fix doc generation for struct function --- cpp/src/arrow/compute/cast.cc | 2 +- cpp/src/arrow/compute/function.cc | 11 ++++++++--- python/pyarrow/compute.py | 10 ++++++---- python/pyarrow/tests/test_compute.py | 11 ++++++++--- 4 files changed, 23 insertions(+), 11 deletions(-) diff --git a/cpp/src/arrow/compute/cast.cc b/cpp/src/arrow/compute/cast.cc index b55dfb38c15..5c332aedf73 100644 --- a/cpp/src/arrow/compute/cast.cc +++ b/cpp/src/arrow/compute/cast.cc @@ -121,7 +121,7 @@ class CastMetaFunction : public MetaFunction { const FunctionDoc struct_doc{"Wrap Arrays into a StructArray", ("Names of the StructArray's fields are\n" "specified through StructOptions."), - {}, + {"*args"}, "StructOptions"}; Result StructResolve(KernelContext* ctx, diff --git a/cpp/src/arrow/compute/function.cc b/cpp/src/arrow/compute/function.cc index 2d3e06e2fb2..28951b7dae1 100644 --- a/cpp/src/arrow/compute/function.cc +++ b/cpp/src/arrow/compute/function.cc @@ -150,10 +150,15 @@ Result Function::Execute(const std::vector& args, Status Function::Validate() const { if (!doc_->summary.empty()) { // Documentation given, check its contents - if (static_cast(doc_->arg_names.size()) != arity_.num_args) { - return Status::Invalid("In function '", name_, - "': ", "number of argument names != function arity"); + int arg_count = static_cast(doc_->arg_names.size()); + if (arg_count == arity_.num_args) { + return Status::OK(); } + if (arity_.is_varargs && arg_count == arity_.num_args + 1) { + return Status::OK(); + } + return Status::Invalid("In function '", name_, + "': ", "number of argument names != function arity"); } return Status::OK(); } diff --git a/python/pyarrow/compute.py b/python/pyarrow/compute.py index ddd77174d13..8650a38345b 100644 --- a/python/pyarrow/compute.py +++ b/python/pyarrow/compute.py @@ -116,7 +116,7 @@ def _decorate_compute_function(wrapper, exposed_name, func, option_class): doc_pieces.append("""\ options : pyarrow.compute.{0}, optional Parameters altering compute function semantics - **kwargs: optional + **kwargs : optional Parameters for {0} constructor. Either `options` or `**kwargs` can be passed, but not both at the same time. """.format(option_class.__name__)) @@ -160,14 +160,14 @@ def _handle_options(name, option_class, options, kwargs): _wrapper_template = dedent("""\ def make_wrapper(func, option_class): - def {func_name}({args_sig}, *, memory_pool=None): + def {func_name}({args_sig}{kwonly}, memory_pool=None): return func.call([{args_sig}], None, memory_pool) return {func_name} """) _wrapper_options_template = dedent("""\ def make_wrapper(func, option_class): - def {func_name}({args_sig}, *, options=None, memory_pool=None, + def {func_name}({args_sig}{kwonly}, options=None, memory_pool=None, **kwargs): options = _handle_options({func_name!r}, option_class, options, kwargs) @@ -180,6 +180,7 @@ def _wrap_function(name, func): option_class = _get_options_class(func) arg_names = _get_arg_names(func) args_sig = ', '.join(arg_names) + kwonly = '' if arg_names[-1].startswith('*') else ', *' # Generate templated wrapper, so that the signature matches # the documented argument names. @@ -188,7 +189,8 @@ def _wrap_function(name, func): template = _wrapper_options_template else: template = _wrapper_template - exec(template.format(func_name=name, args_sig=args_sig), globals(), ns) + exec(template.format(func_name=name, args_sig=args_sig, kwonly=kwonly), + globals(), ns) wrapper = ns['make_wrapper'](func, option_class) return _decorate_compute_function(wrapper, name, func, option_class) diff --git a/python/pyarrow/tests/test_compute.py b/python/pyarrow/tests/test_compute.py index 1afa0ef931c..b74dfcdec7b 100644 --- a/python/pyarrow/tests/test_compute.py +++ b/python/pyarrow/tests/test_compute.py @@ -83,7 +83,11 @@ def test_exported_functions(): functions = exported_functions assert len(functions) >= 10 for func in functions: - args = [object()] * func.__arrow_compute_function__['arity'] + arity = func.__arrow_compute_function__['arity'] + if arity is Ellipsis: + args = [object()] * 3 + else: + args = [object()] * arity with pytest.raises(TypeError, match="Got unexpected argument type " " for compute function"): @@ -172,7 +176,8 @@ def test_function_attributes(): kernels = func.kernels assert func.num_kernels == len(kernels) assert all(isinstance(ker, pc.Kernel) for ker in kernels) - assert func.arity >= 1 # no varargs functions for now + if func.arity is not Ellipsis: + assert func.arity >= 1 repr(func) for ker in kernels: repr(ker) @@ -402,7 +407,7 @@ def test_generated_docstrings(): If not passed, will allocate memory from the default memory pool. options : pyarrow.compute.MinMaxOptions, optional Parameters altering compute function semantics - **kwargs: optional + **kwargs : optional Parameters for MinMaxOptions constructor. Either `options` or `**kwargs` can be passed, but not both at the same time. """) From 9b8b2ab03c50a10a140ca27ea99bc0622156b241 Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Thu, 17 Dec 2020 21:37:55 -0500 Subject: [PATCH 20/31] review comments --- .../compute/kernels/scalar_cast_internal.cc | 17 +++---- .../compute/kernels/scalar_cast_temporal.cc | 6 ++- .../arrow/compute/kernels/util_internal.cc | 13 +++-- cpp/src/arrow/compute/type_fwd.h | 6 +++ cpp/src/arrow/dataset/expression.cc | 8 +++ cpp/src/arrow/dataset/expression.h | 9 ++-- cpp/src/arrow/dataset/expression_internal.h | 29 ++++++----- cpp/src/arrow/dataset/expression_test.cc | 4 +- cpp/src/arrow/dataset/partition.cc | 1 + cpp/src/arrow/dataset/partition_test.cc | 1 + cpp/src/arrow/dataset/scanner_test.cc | 36 +++++-------- cpp/src/arrow/datum.cc | 18 +++++-- cpp/src/arrow/datum.h | 5 ++ cpp/src/arrow/type.cc | 51 ------------------- cpp/src/arrow/type.h | 5 -- 15 files changed, 87 insertions(+), 122 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_internal.cc b/cpp/src/arrow/compute/kernels/scalar_cast_internal.cc index 7abf1c618da..f8dde20e3aa 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_internal.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_internal.cc @@ -149,16 +149,10 @@ void CastNumberToNumberUnsafe(Type::type in_type, Type::type out_type, const Dat // ---------------------------------------------------------------------- void UnpackDictionary(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - auto Finish = [&](Result result) { - if (!result.ok()) { - ctx->SetStatus(result.status()); - return; - } - *out = *result; - }; - if (out->is_scalar()) { - return Finish(batch[0].scalar_as().GetEncodedValue()); + KERNEL_ASSIGN_OR_RAISE(*out, ctx, + batch[0].scalar_as().GetEncodedValue()); + return; } DictionaryArray dict_arr(batch[0].array()); @@ -172,8 +166,9 @@ void UnpackDictionary(KernelContext* ctx, const ExecBatch& batch, Datum* out) { return; } - return Finish(Take(Datum(dict_arr.dictionary()), Datum(dict_arr.indices()), - /*options=*/TakeOptions::Defaults(), ctx->exec_context())); + KERNEL_ASSIGN_OR_RAISE(*out, ctx, + Take(Datum(dict_arr.dictionary()), Datum(dict_arr.indices()), + TakeOptions::Defaults(), ctx->exec_context())); } void OutputAllNull(KernelContext* ctx, const ExecBatch& batch, Datum* out) { diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_temporal.cc b/cpp/src/arrow/compute/kernels/scalar_cast_temporal.cc index f1702b7fd51..99c1401d1b8 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_temporal.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_temporal.cc @@ -284,9 +284,11 @@ struct CastFunctor { DCHECK_EQ(batch[0].kind(), Datum::ARRAY); const auto& out_type = checked_cast(*out->type()); + + // date64 is ms since epoch auto conversion = util::GetTimestampConversion(TimeUnit::MILLI, out_type.unit()); - ShiftTime(ctx, util::MULTIPLY, conversion.second, *batch[0].array(), - out->mutable_array()); + ShiftTime(ctx, conversion.first, conversion.second, + *batch[0].array(), out->mutable_array()); } }; diff --git a/cpp/src/arrow/compute/kernels/util_internal.cc b/cpp/src/arrow/compute/kernels/util_internal.cc index 21b6b706a50..3d21f5b1494 100644 --- a/cpp/src/arrow/compute/kernels/util_internal.cc +++ b/cpp/src/arrow/compute/kernels/util_internal.cc @@ -68,15 +68,14 @@ ArrayKernelExec TrivialScalarUnaryAsArraysExec(ArrayKernelExec exec) { return; } - Datum array_in, array_out; - KERNEL_RETURN_IF_ERROR( - ctx, MakeArrayFromScalar(*batch[0].scalar(), 1).As().Value(&array_in)); - KERNEL_RETURN_IF_ERROR( - ctx, MakeArrayFromScalar(*out->scalar(), 1).As().Value(&array_out)); + KERNEL_ASSIGN_OR_RAISE(Datum array_in, ctx, + MakeArrayFromScalar(*batch[0].scalar(), 1)); + + KERNEL_ASSIGN_OR_RAISE(Datum array_out, ctx, MakeArrayFromScalar(*out->scalar(), 1)); exec(ctx, ExecBatch{{std::move(array_in)}, 1}, &array_out); - KERNEL_RETURN_IF_ERROR(ctx, - array_out.make_array()->GetScalar(0).As().Value(out)); + + KERNEL_ASSIGN_OR_RAISE(*out, ctx, array_out.make_array()->GetScalar(0)); }; } diff --git a/cpp/src/arrow/compute/type_fwd.h b/cpp/src/arrow/compute/type_fwd.h index 4b05ceccb5a..9888e610aa7 100644 --- a/cpp/src/arrow/compute/type_fwd.h +++ b/cpp/src/arrow/compute/type_fwd.h @@ -20,9 +20,13 @@ namespace arrow { struct Datum; +struct ValueDescr; namespace compute { +class Function; +struct FunctionOptions; + struct CastOptions; class ExecContext; @@ -33,5 +37,7 @@ struct ScalarKernel; struct ScalarAggregateKernel; struct VectorKernel; +struct KernelState; + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/dataset/expression.cc b/cpp/src/arrow/dataset/expression.cc index 4c5b693ffb9..637010cfed2 100644 --- a/cpp/src/arrow/dataset/expression.cc +++ b/cpp/src/arrow/dataset/expression.cc @@ -26,6 +26,7 @@ #include "arrow/io/memory.h" #include "arrow/ipc/reader.h" #include "arrow/ipc/writer.h" +#include "arrow/util/atomic_shared_ptr.h" #include "arrow/util/key_value_metadata.h" #include "arrow/util/logging.h" #include "arrow/util/optional.h" @@ -279,11 +280,18 @@ size_t Expression::hash() const { } auto call = CallNotNull(*this); + if (call->hash != nullptr) { + return call->hash->load(); + } size_t out = std::hash{}(call->function_name); for (const auto& arg : call->arguments) { out ^= arg.hash(); } + + std::shared_ptr> expected = nullptr; + internal::atomic_compare_exchange_strong(&const_cast(call)->hash, &expected, + std::make_shared>(out)); return out; } diff --git a/cpp/src/arrow/dataset/expression.h b/cpp/src/arrow/dataset/expression.h index 6137666abff..f625ca694e1 100644 --- a/cpp/src/arrow/dataset/expression.h +++ b/cpp/src/arrow/dataset/expression.h @@ -19,21 +19,17 @@ #pragma once -#include +#include #include #include #include #include #include -#include "arrow/chunked_array.h" -#include "arrow/compute/api_scalar.h" -#include "arrow/compute/cast.h" +#include "arrow/compute/type_fwd.h" #include "arrow/dataset/type_fwd.h" #include "arrow/dataset/visibility.h" #include "arrow/datum.h" -#include "arrow/result.h" -#include "arrow/scalar.h" #include "arrow/type_fwd.h" #include "arrow/util/variant.h" @@ -51,6 +47,7 @@ class ARROW_DS_EXPORT Expression { std::string function_name; std::vector arguments; std::shared_ptr options; + std::shared_ptr> hash; // post-Bind properties: const compute::Kernel* kernel = NULLPTR; diff --git a/cpp/src/arrow/dataset/expression_internal.h b/cpp/src/arrow/dataset/expression_internal.h index eba972dc18b..77e555aa7a6 100644 --- a/cpp/src/arrow/dataset/expression_internal.h +++ b/cpp/src/arrow/dataset/expression_internal.h @@ -21,7 +21,8 @@ #include #include -#include "arrow/compute/api_vector.h" +#include "arrow/compute/api_scalar.h" +#include "arrow/compute/cast.h" #include "arrow/compute/registry.h" #include "arrow/record_batch.h" #include "arrow/table.h" @@ -92,11 +93,11 @@ inline Result GetDatumField(const FieldRef& ref, const Datum& input) { FieldPath path; if (auto type = input.type()) { - ARROW_ASSIGN_OR_RAISE(path, ref.FindOneOrNone(*input.type())); - } else if (input.kind() == Datum::RECORD_BATCH) { - ARROW_ASSIGN_OR_RAISE(path, ref.FindOneOrNone(*input.record_batch()->schema())); - } else if (input.kind() == Datum::TABLE) { - ARROW_ASSIGN_OR_RAISE(path, ref.FindOneOrNone(*input.table()->schema())); + ARROW_ASSIGN_OR_RAISE(path, ref.FindOneOrNone(*type)); + } else if (auto schema = input.schema()) { + ARROW_ASSIGN_OR_RAISE(path, ref.FindOneOrNone(*schema)); + } else { + return Status::NotImplemented("retrieving fields from datum ", input.ToString()); } if (path) { @@ -123,14 +124,14 @@ struct Comparison { }; static const type* Get(const std::string& function) { - static std::unordered_map flipped_comparisons{ + static std::unordered_map map{ {"equal", EQUAL}, {"not_equal", NOT_EQUAL}, {"less", LESS}, {"less_equal", LESS_EQUAL}, {"greater", GREATER}, {"greater_equal", GREATER_EQUAL}, }; - auto it = flipped_comparisons.find(function); - return it != flipped_comparisons.end() ? &it->second : nullptr; + auto it = map.find(function); + return it != map.end() ? &it->second : nullptr; } static const type* Get(const Expression& expr) { @@ -262,13 +263,12 @@ inline const compute::StrptimeOptions* GetStrptimeOptions(const Expression::Call return checked_cast(call.options.get()); } -inline const std::shared_ptr& GetDictionaryValueType( +inline std::shared_ptr GetDictionaryValueType( const std::shared_ptr& type) { if (type && type->id() == Type::DICTIONARY) { return checked_cast(*type).value_type(); } - static std::shared_ptr null; - return null; + return nullptr; } inline Status EnsureNotDictionary(ValueDescr* descr) { @@ -410,7 +410,10 @@ struct FlattenedAssociativeChain { exprs.push_back(std::move(*it)); it = fringe.erase(it); - it = fringe.insert(it, sub_call->arguments.begin(), sub_call->arguments.end()); + + auto index = it - fringe.begin(); + fringe.insert(it, sub_call->arguments.begin(), sub_call->arguments.end()); + it = fringe.begin() + index; // NB: no increment so we hit sub_call's first argument next iteration } diff --git a/cpp/src/arrow/dataset/expression_test.cc b/cpp/src/arrow/dataset/expression_test.cc index 3b286e378d0..175e045918d 100644 --- a/cpp/src/arrow/dataset/expression_test.cc +++ b/cpp/src/arrow/dataset/expression_test.cc @@ -558,7 +558,7 @@ TEST(Expression, FoldConstantsBoolean) { // test and_kleene/or_kleene-specific optimizations auto one = literal(1); auto two = literal(2); - auto whatever = call("equal", {call("add", {one, field_ref("i32")}), two}); + auto whatever = equal(call("add", {one, field_ref("i32")}), two); auto true_ = literal(true); auto false_ = literal(false); @@ -696,7 +696,7 @@ TEST(Expression, CanonicalizeAnd) { auto null_ = literal(std::make_shared()); auto b = field_ref("bool"); - auto c = call("equal", {literal(1), literal(2)}); + auto c = equal(literal(1), literal(2)); // no change possible: ExpectCanonicalizesTo(and_(b, c), and_(b, c)); diff --git a/cpp/src/arrow/dataset/partition.cc b/cpp/src/arrow/dataset/partition.cc index ba4f0c46b8f..2822c4a15b6 100644 --- a/cpp/src/arrow/dataset/partition.cc +++ b/cpp/src/arrow/dataset/partition.cc @@ -27,6 +27,7 @@ #include "arrow/array/builder_dict.h" #include "arrow/compute/api_scalar.h" #include "arrow/compute/api_vector.h" +#include "arrow/compute/cast.h" #include "arrow/dataset/dataset_internal.h" #include "arrow/filesystem/path_util.h" #include "arrow/scalar.h" diff --git a/cpp/src/arrow/dataset/partition_test.cc b/cpp/src/arrow/dataset/partition_test.cc index 8610ff1e891..2260eb219da 100644 --- a/cpp/src/arrow/dataset/partition_test.cc +++ b/cpp/src/arrow/dataset/partition_test.cc @@ -26,6 +26,7 @@ #include #include +#include "arrow/compute/api_scalar.h" #include "arrow/dataset/scanner_internal.h" #include "arrow/dataset/test_util.h" #include "arrow/filesystem/path_util.h" diff --git a/cpp/src/arrow/dataset/scanner_test.cc b/cpp/src/arrow/dataset/scanner_test.cc index 7261bdb8037..f8f959e3a28 100644 --- a/cpp/src/arrow/dataset/scanner_test.cc +++ b/cpp/src/arrow/dataset/scanner_test.cc @@ -184,25 +184,17 @@ TEST_F(TestScannerBuilder, TestFilter) { ScannerBuilder builder(dataset_, ctx_); ASSERT_OK(builder.Filter(literal(true))); - ASSERT_OK(builder.Filter(call("equal", {field_ref("i64"), literal(10)}))); - ASSERT_OK(builder.Filter( - call("or_kleene", { - call("equal", {field_ref("i64"), literal(10)}), - call("equal", {field_ref("b"), literal(true)}), - }))); - - ASSERT_OK(builder.Filter(call("equal", {field_ref("i64"), literal(10)}))); - - ASSERT_RAISES( - Invalid, builder.Filter(call("equal", {field_ref("not_a_column"), literal(true)}))); - - ASSERT_RAISES( - Invalid, - builder.Filter( - call("or_kleene", { - call("equal", {field_ref("i64"), literal(10)}), - call("equal", {field_ref("not_a_column"), literal(true)}), - }))); + ASSERT_OK(builder.Filter(equal(field_ref("i64"), literal(10)))); + ASSERT_OK(builder.Filter(or_(equal(field_ref("i64"), literal(10)), + equal(field_ref("b"), literal(true))))); + + ASSERT_OK(builder.Filter(equal(field_ref("i64"), literal(10)))); + + ASSERT_RAISES(Invalid, builder.Filter(equal(field_ref("not_a_column"), literal(true)))); + + ASSERT_RAISES(Invalid, + builder.Filter(or_(equal(field_ref("i64"), literal(10)), + equal(field_ref("not_a_column"), literal(true))))); } using testing::ElementsAre; @@ -215,7 +207,7 @@ TEST(ScanOptions, TestMaterializedFields) { auto opts = ScanOptions::Make(schema({})); EXPECT_THAT(opts->MaterializedFields(), IsEmpty()); - opts->filter2 = call("equal", {field_ref("i32"), literal(10)}); + opts->filter2 = equal(field_ref("i32"), literal(10)); EXPECT_THAT(opts->MaterializedFields(), ElementsAre("i32")); opts = ScanOptions::Make(schema({i32, i64})); @@ -224,10 +216,10 @@ TEST(ScanOptions, TestMaterializedFields) { opts = opts->ReplaceSchema(schema({i32})); EXPECT_THAT(opts->MaterializedFields(), ElementsAre("i32")); - opts->filter2 = call("equal", {field_ref("i32"), literal(10)}); + opts->filter2 = equal(field_ref("i32"), literal(10)); EXPECT_THAT(opts->MaterializedFields(), ElementsAre("i32", "i32")); - opts->filter2 = call("equal", {field_ref("i64"), literal(10)}); + opts->filter2 = equal(field_ref("i64"), literal(10)); EXPECT_THAT(opts->MaterializedFields(), ElementsAre("i32", "i64")); } diff --git a/cpp/src/arrow/datum.cc b/cpp/src/arrow/datum.cc index 0f2fa8ebe90..786110996dc 100644 --- a/cpp/src/arrow/datum.cc +++ b/cpp/src/arrow/datum.cc @@ -89,12 +89,24 @@ std::shared_ptr Datum::make_array() const { std::shared_ptr Datum::type() const { if (this->kind() == Datum::ARRAY) { return util::get>(this->value)->type; - } else if (this->kind() == Datum::CHUNKED_ARRAY) { + } + if (this->kind() == Datum::CHUNKED_ARRAY) { return util::get>(this->value)->type(); - } else if (this->kind() == Datum::SCALAR) { + } + if (this->kind() == Datum::SCALAR) { return util::get>(this->value)->type; } - return NULLPTR; + return nullptr; +} + +std::shared_ptr Datum::schema() const { + if (this->kind() == Datum::RECORD_BATCH) { + return util::get>(this->value)->schema(); + } + if (this->kind() == Datum::TABLE) { + return util::get>(this->value)->schema(); + } + return nullptr; } int64_t Datum::length() const { diff --git a/cpp/src/arrow/datum.h b/cpp/src/arrow/datum.h index 8479f7ad366..fb783ea5261 100644 --- a/cpp/src/arrow/datum.h +++ b/cpp/src/arrow/datum.h @@ -252,6 +252,11 @@ struct ARROW_EXPORT Datum { /// \return nullptr if no type std::shared_ptr type() const; + /// \brief The schema of the variant, if any + /// + /// \return nullptr if no schema + std::shared_ptr schema() const; + /// \brief The value length of the variant, if any /// /// \return kUnknownLength if no type diff --git a/cpp/src/arrow/type.cc b/cpp/src/arrow/type.cc index 20dc0f14ca0..ad889b3eb24 100644 --- a/cpp/src/arrow/type.cc +++ b/cpp/src/arrow/type.cc @@ -30,12 +30,10 @@ #include #include "arrow/array.h" -#include "arrow/chunked_array.h" #include "arrow/compare.h" #include "arrow/record_batch.h" #include "arrow/result.h" #include "arrow/status.h" -#include "arrow/table.h" #include "arrow/util/checked_cast.h" #include "arrow/util/hash_util.h" #include "arrow/util/hashing.h" @@ -900,8 +898,6 @@ std::string FieldPath::ToString() const { struct FieldPathGetImpl { static const DataType& GetType(const ArrayData& data) { return *data.type; } - static const DataType& GetType(const ChunkedArray& array) { return *array.type(); } - static void Summarize(const FieldVector& fields, std::stringstream* ss) { *ss << "{ "; for (const auto& field : fields) { @@ -996,32 +992,6 @@ struct FieldPathGetImpl { path, &child_data, [](const std::shared_ptr& data) { return &data->child_data; }); } - - static Result> Get( - const FieldPath* path, const ChunkedArrayVector& columns_arg) { - ChunkedArrayVector columns = columns_arg; - - return FieldPathGetImpl::Get( - path, &columns, [&](const std::shared_ptr& a) { - columns.clear(); - - for (int i = 0; i < a->type()->num_fields(); ++i) { - ArrayVector child_chunks; - - for (const auto& chunk : a->chunks()) { - auto child_chunk = MakeArray(chunk->data()->child_data[i]); - child_chunks.push_back(std::move(child_chunk)); - } - - auto child_column = std::make_shared( - std::move(child_chunks), a->type()->field(i)->type()); - - columns.emplace_back(std::move(child_column)); - } - - return &columns; - }); - } }; Result> FieldPath::Get(const Schema& schema) const { @@ -1045,10 +1015,6 @@ Result> FieldPath::Get(const RecordBatch& batch) const { return MakeArray(std::move(data)); } -Result> FieldPath::Get(const Table& table) const { - return FieldPathGetImpl::Get(this, table.columns()); -} - Result> FieldPath::Get(const Array& array) const { ARROW_ASSIGN_OR_RAISE(auto data, Get(*array.data())); return MakeArray(std::move(data)); @@ -1058,15 +1024,6 @@ Result> FieldPath::Get(const ArrayData& data) const { return FieldPathGetImpl::Get(this, data.child_data); } -Result> FieldPath::Get(const ChunkedArray& array) const { - FieldPath prefixed_with_0 = *this; - prefixed_with_0.indices_.insert(prefixed_with_0.indices_.begin(), 0); - - ChunkedArrayVector vec; - vec.emplace_back(const_cast(&array), [](...) {}); - return FieldPathGetImpl::Get(&prefixed_with_0, vec); -} - FieldRef::FieldRef(FieldPath indices) : impl_(std::move(indices)) { DCHECK_GT(util::get(impl_).indices().size(), 0); } @@ -1320,18 +1277,10 @@ std::vector FieldRef::FindAll(const Array& array) const { return FindAll(*array.type()); } -std::vector FieldRef::FindAll(const ChunkedArray& array) const { - return FindAll(*array.type()); -} - std::vector FieldRef::FindAll(const RecordBatch& batch) const { return FindAll(*batch.schema()); } -std::vector FieldRef::FindAll(const Table& table) const { - return FindAll(*table.schema()); -} - void PrintTo(const FieldRef& ref, std::ostream* os) { *os << ref.ToString(); } // ---------------------------------------------------------------------- diff --git a/cpp/src/arrow/type.h b/cpp/src/arrow/type.h index cadb819b9f2..f0fa04f40be 100644 --- a/cpp/src/arrow/type.h +++ b/cpp/src/arrow/type.h @@ -1424,15 +1424,10 @@ class ARROW_EXPORT FieldPath { /// \brief Retrieve the referenced column from a RecordBatch or Table Result> Get(const RecordBatch& batch) const; - Result> Get(const Table& table) const; /// \brief Retrieve the referenced child from an Array, ArrayData, or ChunkedArray Result> Get(const Array& array) const; Result> Get(const ArrayData& data) const; - Result> Get(const ChunkedArray& array) const; - - /// \brief Retrieve the reference child from a Datum - Result Get(const Datum& datum) const; private: std::vector indices_; From 5206104be0ab74e00101a1689bdc5320b2935c8b Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Fri, 18 Dec 2020 10:33:49 -0500 Subject: [PATCH 21/31] repair linkage on Expression friends --- cpp/src/arrow/dataset/expression.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/src/arrow/dataset/expression.h b/cpp/src/arrow/dataset/expression.h index f625ca694e1..4b414d4d142 100644 --- a/cpp/src/arrow/dataset/expression.h +++ b/cpp/src/arrow/dataset/expression.h @@ -117,9 +117,9 @@ class ARROW_DS_EXPORT Expression { using Impl = util::Variant; std::shared_ptr impl_; - ARROW_EXPORT friend bool Identical(const Expression& l, const Expression& r); + ARROW_DS_EXPORT friend bool Identical(const Expression& l, const Expression& r); - ARROW_EXPORT friend void PrintTo(const Expression&, std::ostream*); + ARROW_DS_EXPORT friend void PrintTo(const Expression&, std::ostream*); }; inline bool operator==(const Expression& l, const Expression& r) { return l.Equals(r); } From 7ea0664d769d18b39f19966ba6109956fe93efaf Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Fri, 18 Dec 2020 11:32:45 -0500 Subject: [PATCH 22/31] centos-7 doesn't recognize this initializer list --- cpp/src/arrow/dataset/expression.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/cpp/src/arrow/dataset/expression.cc b/cpp/src/arrow/dataset/expression.cc index 637010cfed2..848055ac5aa 100644 --- a/cpp/src/arrow/dataset/expression.cc +++ b/cpp/src/arrow/dataset/expression.cc @@ -389,9 +389,10 @@ Result> InitKernelState( if (!call.kernel->init) return nullptr; compute::KernelContext kernel_context(exec_context); - auto kernel_state = call.kernel->init( - &kernel_context, {call.kernel, GetDescriptors(call.arguments), call.options.get()}); + compute::KernelInitArgs kernel_init_args{call.kernel, GetDescriptors(call.arguments), + call.options.get()}; + auto kernel_state = call.kernel->init(&kernel_context, kernel_init_args); RETURN_NOT_OK(kernel_context.status()); return std::move(kernel_state); } From 1a7d5727b3321893fc226c064307496b678ced2a Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Fri, 18 Dec 2020 11:33:15 -0500 Subject: [PATCH 23/31] get test_{dataset,parquet}.py passing --- cpp/src/arrow/compute/kernels/scalar_cast_string.cc | 12 ++++++++---- python/pyarrow/_dataset.pyx | 6 +++--- python/pyarrow/includes/libarrow.pxd | 13 +++++++------ python/pyarrow/tests/test_dataset.py | 4 +++- python/pyarrow/tests/test_parquet.py | 2 +- 5 files changed, 22 insertions(+), 15 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_string.cc b/cpp/src/arrow/compute/kernels/scalar_cast_string.cc index 9d3812e5967..ca9c33a8f1b 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_string.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_string.cc @@ -49,6 +49,7 @@ struct NumericToStringCastFunctor { using FormatterType = StringFormatter; static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + DCHECK(out->is_array()); const ArrayData& input = *batch[0].array(); ArrayData* output = out->mutable_array(); ctx->SetStatus(Convert(ctx, input, output)); @@ -160,6 +161,7 @@ struct BinaryToBinaryCastFunctor { using output_offset_type = typename O::offset_type; static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + DCHECK(out->is_array()); const CastOptions& options = checked_cast(*ctx->state()).options; const ArrayData& input = *batch[0].array(); @@ -194,7 +196,8 @@ void AddNumberToStringCasts(CastFunction* func) { auto out_ty = TypeTraits::type_singleton(); DCHECK_OK(func->AddKernel(Type::BOOL, {boolean()}, out_ty, - NumericToStringCastFunctor::Exec, + TrivialScalarUnaryAsArraysExec( + NumericToStringCastFunctor::Exec), NullHandling::COMPUTED_NO_PREALLOCATE)); for (const std::shared_ptr& in_ty : NumericTypes()) { @@ -210,9 +213,10 @@ void AddBinaryToBinaryCast(CastFunction* func) { auto in_ty = TypeTraits::type_singleton(); auto out_ty = TypeTraits::type_singleton(); - DCHECK_OK(func->AddKernel(OutType::type_id, {in_ty}, out_ty, - BinaryToBinaryCastFunctor::Exec, - NullHandling::COMPUTED_NO_PREALLOCATE)); + DCHECK_OK(func->AddKernel( + OutType::type_id, {in_ty}, out_ty, + TrivialScalarUnaryAsArraysExec(BinaryToBinaryCastFunctor::Exec), + NullHandling::COMPUTED_NO_PREALLOCATE)); } } // namespace diff --git a/python/pyarrow/_dataset.pyx b/python/pyarrow/_dataset.pyx index 7b96e87633b..710e3c9c2c4 100644 --- a/python/pyarrow/_dataset.pyx +++ b/python/pyarrow/_dataset.pyx @@ -2254,14 +2254,14 @@ def _get_partition_keys(Expression partition_expression): """ cdef: CExpression expr = partition_expression.unwrap() - pair[CFieldRef, CDatum] name_val + pair[CFieldRef, CDatum] ref_val out = {} for ref_val in GetResultValue(CExtractKnownFieldValues(expr)): assert ref_val.first.name() != nullptr assert ref_val.second.kind() == DatumType_SCALAR - val = pyarrow_wrap_scalar(name_val.second.scalar()).as_py() - out[deref(ref_val.first.name())] = val + val = pyarrow_wrap_scalar(ref_val.second.scalar()) + out[frombytes(deref(ref_val.first.name()))] = val.as_py() return out diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index aaee9fb5351..fcdf8ed4179 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -1830,13 +1830,14 @@ cdef extern from "arrow/compute/api.h" namespace "arrow::compute" nogil: CDatum(const shared_ptr[CRecordBatch]& value) CDatum(const shared_ptr[CTable]& value) - DatumType kind() + DatumType kind() const + c_string ToString() const - shared_ptr[CArrayData] array() - shared_ptr[CChunkedArray] chunked_array() - shared_ptr[CRecordBatch] record_batch() - shared_ptr[CTable] table() - shared_ptr[CScalar] scalar() + const shared_ptr[CArrayData]& array() const + const shared_ptr[CChunkedArray]& chunked_array() const + const shared_ptr[CRecordBatch]& record_batch() const + const shared_ptr[CTable]& table() const + const shared_ptr[CScalar]& scalar() const cdef cppclass CSetLookupOptions \ "arrow::compute::SetLookupOptions"(CFunctionOptions): diff --git a/python/pyarrow/tests/test_dataset.py b/python/pyarrow/tests/test_dataset.py index a783c3b4273..0ab9d95398d 100644 --- a/python/pyarrow/tests/test_dataset.py +++ b/python/pyarrow/tests/test_dataset.py @@ -768,7 +768,9 @@ def assert_yields_projected(fragment, row_slice, # Fragments don't contain the partition's columns if not provided to the # `to_table(schema=...)` method. - with pytest.raises(ValueError, match="Field named 'part' not found"): + pattern = (r'No match for FieldRef.Name\(part\) in ' + + fragment.physical_schema.to_string(False, False, False)) + with pytest.raises(ValueError, match=pattern): new_fragment = parquet_format.make_fragment( fragment.path, fragment.filesystem, partition_expression=fragment.partition_expression) diff --git a/python/pyarrow/tests/test_parquet.py b/python/pyarrow/tests/test_parquet.py index 9422594caff..6fa001cc758 100644 --- a/python/pyarrow/tests/test_parquet.py +++ b/python/pyarrow/tests/test_parquet.py @@ -2144,7 +2144,7 @@ def test_filters_invalid_column(tempdir, use_legacy_dataset): _generate_partition_directories(fs, base_path, partition_spec, df) - msg = "Field named 'non_existent_column' not found" + msg = r"No match for FieldRef.Name\(non_existent_column\)" with pytest.raises(ValueError, match=msg): pq.ParquetDataset(base_path, filesystem=fs, filters=[('non_existent_column', '<', 3), ], From 5de9db0e9abb6e6fc23be90ee0713727de82985c Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Fri, 18 Dec 2020 13:05:20 -0500 Subject: [PATCH 24/31] move Identical() impl to .cc --- cpp/src/arrow/dataset/expression.cc | 2 ++ cpp/src/arrow/dataset/expression_internal.h | 2 -- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/src/arrow/dataset/expression.cc b/cpp/src/arrow/dataset/expression.cc index 848055ac5aa..7f08788de29 100644 --- a/cpp/src/arrow/dataset/expression.cc +++ b/cpp/src/arrow/dataset/expression.cc @@ -267,6 +267,8 @@ bool Expression::Equals(const Expression& other) const { return false; } +bool Identical(const Expression& l, const Expression& r) { return l.impl_ == r.impl_; } + size_t Expression::hash() const { if (auto lit = literal()) { if (lit->is_scalar()) { diff --git a/cpp/src/arrow/dataset/expression_internal.h b/cpp/src/arrow/dataset/expression_internal.h index 77e555aa7a6..c3fd49dc347 100644 --- a/cpp/src/arrow/dataset/expression_internal.h +++ b/cpp/src/arrow/dataset/expression_internal.h @@ -34,8 +34,6 @@ using internal::checked_cast; namespace dataset { -bool Identical(const Expression& l, const Expression& r) { return l.impl_ == r.impl_; } - const Expression::Call* CallNotNull(const Expression& expr) { auto call = expr.call(); DCHECK_NE(call, nullptr); From 97182554c2ba376a2d8b4d586242f7f1add124a7 Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Fri, 18 Dec 2020 13:52:07 -0500 Subject: [PATCH 25/31] export ExecuteScalarExpression --- cpp/src/arrow/dataset/expression.h | 1 + cpp/src/arrow/dataset/expression_test.cc | 3 +++ 2 files changed, 4 insertions(+) diff --git a/cpp/src/arrow/dataset/expression.h b/cpp/src/arrow/dataset/expression.h index 4b414d4d142..e597ffb7fcd 100644 --- a/cpp/src/arrow/dataset/expression.h +++ b/cpp/src/arrow/dataset/expression.h @@ -193,6 +193,7 @@ Result SimplifyWithGuarantee(Expression, /// Execute a scalar expression against the provided state and input Datum. This /// expression must be bound. +ARROW_DS_EXPORT Result ExecuteScalarExpression(const Expression&, const Datum& input, compute::ExecContext* = NULLPTR); diff --git a/cpp/src/arrow/dataset/expression_test.cc b/cpp/src/arrow/dataset/expression_test.cc index 175e045918d..e245b0d7093 100644 --- a/cpp/src/arrow/dataset/expression_test.cc +++ b/cpp/src/arrow/dataset/expression_test.cc @@ -1006,6 +1006,9 @@ TEST(Expression, SerializationRoundTrips) { ExpectRoundTrips(not_(field_ref("alpha"))); + ExpectRoundTrips(call("is_in", {literal(1)}, + compute::SetLookupOptions{ArrayFromJSON(int32(), "[1, 2, 3]")})); + ExpectRoundTrips( call("is_in", {call("cast", {field_ref("version")}, compute::CastOptions::Safe(float64()))}, From 2197c881d99471620d9c1398ccc0f5287531b07c Mon Sep 17 00:00:00 2001 From: Jonathan Keane Date: Wed, 23 Dec 2020 11:23:10 -0600 Subject: [PATCH 26/31] Adjust some of the casting, datasets not (yet) working --- r/R/expression.R | 36 +++++++++++++++------- r/src/compute.cpp | 12 ++++++++ r/tests/testthat/test-compute-arith.R | 20 +++++++++++-- r/tests/testthat/test-dataset.R | 43 +++++++++++++++++++++++++++ r/tests/testthat/test-dplyr.R | 42 ++++++++++++++++++++++++++ 5 files changed, 140 insertions(+), 13 deletions(-) diff --git a/r/R/expression.R b/r/R/expression.R index 046b4f11ee6..817941a1caa 100644 --- a/r/R/expression.R +++ b/r/R/expression.R @@ -63,15 +63,15 @@ build_array_expression <- function(.Generic, e1, e2, ...) { # In Arrow, "divide" is one function, which does integer division on # integer inputs and floating-point division on floats if (.Generic == "/") { - # TODO: cast needs to be an expression # TODO: omg so many ways it's wrong to assume these types - e1 <- e1$cast(float64()) - e2 <- e2$cast(float64()) + e1 <- array_expression("cast", e1, options = list(to_type = float64())) + e2 <- array_expression("cast", e2, options = list(to_type = float64())) } else if (.Generic == "%/%") { - e1 <- e1$cast(int32()) - e2 <- e2$cast(int32()) + e1 <- array_expression("cast", e1, options = list(to_type = float64())) + e2 <- array_expression("cast", e2, options = list(to_type = float64())) + return(array_expression("cast", array_expression(.binary_function_map[[.Generic]], e1, e2, ...), options = list(to_type = int32()))) } else if (.Generic == "%%") { - # e1 - e1 %/% e2 + # e1 - e2 * ( e1 %/% e2 ) } expr <- array_expression(.binary_function_map[[.Generic]], e1, e2, ...) } @@ -105,11 +105,11 @@ build_array_expression <- function(.Generic, e1, e2, ...) { "<=" = "less_equal", "&" = "and_kleene", "|" = "or_kleene", - "+" = "add", - "-" = "subtract", - "*" = "multiply", - "/" = "divide", - "%/%" = "divide", + "+" = "add_checked", + "-" = "subtract_checked", + "*" = "multiply_checked", + "/" = "divide_checked", + "%/%" = "divide_checked", "%in%" = "is_in_meta_binary" ) @@ -221,6 +221,20 @@ build_dataset_expression <- function(.Generic, e1, e2, ...) { if (!inherits(e2, "Expression")) { e2 <- Expression$scalar(e2) } + + # In Arrow, "divide" is one function, which does integer division on + # integer inputs and floating-point division on floats + if (.Generic == "/") { + # TODO: omg so many ways it's wrong to assume these types + e1 <- array_expression("cast", e1, options = list(to_type = float64())) + e2 <- array_expression("cast", e2, options = list(to_type = float64())) + } else if (.Generic == "%/%") { + e1 <- array_expression("cast", e1, options = list(to_type = int32())) + e2 <- array_expression("cast", e2, options = list(to_type = int32())) + } else if (.Generic == "%%") { + # e1 - e2 * ( e1 %/% e2 ) + } + expr <- Expression$create(.binary_function_map[[.Generic]], e1, e2, ...) } expr diff --git a/r/src/compute.cpp b/r/src/compute.cpp index 8a75a251d8e..c44d1674153 100644 --- a/r/src/compute.cpp +++ b/r/src/compute.cpp @@ -185,6 +185,18 @@ std::shared_ptr make_compute_options( cpp11::as_cpp(options["skip_nulls"])); } + // hacky attempt to pass through to_type + if (func_name == "cast") { + using Options = arrow::compute::CastOptions; + auto out = std::make_shared(false); + SEXP to_type = options["to_type"]; + if (!Rf_isNull(to_type) && cpp11::as_cpp>(to_type)) { + out->to_type = cpp11::as_cpp>(to_type); + } + return out; + + } + return nullptr; } diff --git a/r/tests/testthat/test-compute-arith.R b/r/tests/testthat/test-compute-arith.R index 633cc57af1b..27d78dd090d 100644 --- a/r/tests/testthat/test-compute-arith.R +++ b/r/tests/testthat/test-compute-arith.R @@ -16,10 +16,10 @@ # under the License. # TODO: -# * Use _checked variations? See what R does # * More tests for edge cases, esp. with division; add test helpers here? # * Is there a better "autocasting" solution? See what rules C++ Datasets do -# * test-dplyr tests +# * test-dplyr tests (Added one addition, and one summarize, but check to see if +# we can make summarize route through arrow need more?) # * then, dataset tests, special casing for division test_that("Addition", { @@ -31,6 +31,12 @@ test_that("Addition", { expect_equal(a + 4L, Array$create(c(5:8, NA_integer_))) expect_vector(a + 4L, c(5:8, NA_integer_)) expect_equal(a + NA_integer_, Array$create(rep(NA_integer_, 5))) + + # overflow errors — this is slightly different from R's `NA` coercion when + # overflowing, but better than the alternative of silently restarting + casted <- a$cast(int8()) + expect_error(casted + 257) + skip("autocasting should happen in compute kernels; R workaround fails on this") expect_type_equal(a + 4.1, float64()) expect_equal(a + 4.1, Array$create(c(5.1, 6.1, 7.1, 8.1, NA_real_))) @@ -50,8 +56,18 @@ test_that("Division", { a <- Array$create(c(1:4, NA_integer_)) expect_equal(a / 2, Array$create(c(1:4 / 2, NA_real_))) expect_equal(a %/% 2, Array$create(c(0L, 1L, 1L, 2L, NA_integer_))) + expect_equal(a / 2 / 2, Array$create(c(1:4 / 2 / 2, NA_real_))) + expect_equal(a %/% 2 %/% 2, Array$create(c(0L, 0L, 0L, 1L, NA_integer_))) b <- a$cast(float64()) expect_equal(b / 2, Array$create(c(1:4 / 2, NA_real_))) expect_equal(b %/% 2, Array$create(c(0L, 1L, 1L, 2L, NA_integer_))) + + # the behavior of %/% matches R's (i.e. the integer of the quotient, not + # simply dividing two integers) + expect_equal(b / 2.2, Array$create(c(1:4 / 2.2, NA_real_))) + # c(1:4) %/% 2.2 != c(1:4) %/% as.integer(2.2) + # c(1:4) %/% 2.2 == c(0L, 0L, 1L, 1L) + # c(1:4) %/% as.integer(2.2) == c(0L, 1L, 1L, 2L) + expect_equal(b %/% 2.2, Array$create(c(0L, 0L, 1L, 1L, NA_integer_))) }) diff --git a/r/tests/testthat/test-dataset.R b/r/tests/testthat/test-dataset.R index a2858115622..2228dbc681c 100644 --- a/r/tests/testthat/test-dataset.R +++ b/r/tests/testthat/test-dataset.R @@ -494,6 +494,49 @@ test_that("filter() on date32 columns", { ) }) +test_that("filter() with expressions", { + ds <- open_dataset(dataset_dir, partitioning = schema(part = uint8())) + expect_is(ds$format, "ParquetFileFormat") + expect_is(ds$filesystem, "LocalFileSystem") + expect_is(ds, "Dataset") + expect_equivalent( + ds %>% + select(chr, dbl) %>% + filter(dbl * 2 > 14 & dbl - 50 < 3L) %>% + collect() %>% + arrange(dbl), + rbind( + df1[8:10, c("chr", "dbl")], + df2[1:2, c("chr", "dbl")] + ) + ) + + # check division's special casing. + expect_equivalent( + ds %>% + select(chr, dbl) %>% + filter(dbl / 2 > 3.5 & dbl < 53) %>% + collect() %>% + arrange(dbl), + rbind( + df1[8:10, c("chr", "dbl")], + df2[1:2, c("chr", "dbl")] + ) + ) + + expect_equivalent( + ds %>% + select(chr, dbl) %>% + filter(dbl / 2L > 3.5 & dbl < 53) %>% + collect() %>% + arrange(dbl), + rbind( + df1[8:10, c("chr", "dbl")], + df2[1:2, c("chr", "dbl")] + ) + ) +}) + test_that("filter scalar validation doesn't crash (ARROW-7772)", { expect_error( ds %>% diff --git a/r/tests/testthat/test-dplyr.R b/r/tests/testthat/test-dplyr.R index 2145adcc0ee..f088d833f24 100644 --- a/r/tests/testthat/test-dplyr.R +++ b/r/tests/testthat/test-dplyr.R @@ -133,6 +133,40 @@ test_that("filtering with expression", { ) }) +test_that("filtering with arithmetic", { + expect_dplyr_equal( + input %>% + filter(dbl + 1 > 3) %>% + select(string = chr, int, dbl) %>% + collect(), + tbl + ) + + expect_dplyr_equal( + input %>% + filter(dbl / 2 > 3) %>% + select(string = chr, int, dbl) %>% + collect(), + tbl + ) + + expect_dplyr_equal( + input %>% + filter(dbl / 2L > 3) %>% + select(string = chr, int, dbl) %>% + collect(), + tbl + ) + + expect_dplyr_equal( + input %>% + filter(dbl %/% 2 > 3) %>% + select(string = chr, int, dbl) %>% + collect(), + tbl + ) +}) + test_that("More complex select/filter", { expect_dplyr_equal( input %>% @@ -240,6 +274,14 @@ test_that("summarize", { summarize(min_int = min(int)), tbl ) + + expect_dplyr_equal( + input %>% + select(int, chr) %>% + filter(int > 5) %>% + summarize(min_int = min(int) / 2), + tbl + ) }) test_that("mutate", { From 0ac1858f5c3386ee9bb245467612591da94710d4 Mon Sep 17 00:00:00 2001 From: Jonathan Keane Date: Wed, 23 Dec 2020 15:43:03 -0600 Subject: [PATCH 27/31] expresion-based casting for division --- r/R/expression.R | 57 +++++++++++++++++++++------ r/src/compute.cpp | 23 +++++++++-- r/tests/testthat/test-compute-arith.R | 4 ++ r/tests/testthat/test-dataset.R | 12 ------ r/tests/testthat/test-dplyr.R | 4 +- r/tests/testthat/test-expression.R | 19 +++++++++ 6 files changed, 91 insertions(+), 28 deletions(-) diff --git a/r/R/expression.R b/r/R/expression.R index 817941a1caa..d5ed4c845e9 100644 --- a/r/R/expression.R +++ b/r/R/expression.R @@ -69,9 +69,26 @@ build_array_expression <- function(.Generic, e1, e2, ...) { } else if (.Generic == "%/%") { e1 <- array_expression("cast", e1, options = list(to_type = float64())) e2 <- array_expression("cast", e2, options = list(to_type = float64())) - return(array_expression("cast", array_expression(.binary_function_map[[.Generic]], e1, e2, ...), options = list(to_type = int32()))) + return(array_expression("cast", array_expression(.binary_function_map[[.Generic]], e1, e2, ...), options = list(to_type = int32(), allow_float_truncate = TRUE))) } else if (.Generic == "%%") { - # e1 - e2 * ( e1 %/% e2 ) + # {e1 - e2 * ( e1 %/% e2 )} + # TODO: there has to be a way to use the form ^^^ instead of this. + out <- array_expression( + "subtract_checked", e1, array_expression( + "multiply_checked", e2, array_expression( + # this outer cast is to ensure that the result of this and the + # result of multiply are the same + "cast", + array_expression( + "cast", + array_expression(.binary_function_map[[.Generic]], e1, e2, ...), + options = list(to_type = int32(), allow_float_truncate = TRUE) + ), + options = list(to_type = e2$type, allow_float_truncate = TRUE) + ) + ) + ) + return(out) } expr <- array_expression(.binary_function_map[[.Generic]], e1, e2, ...) } @@ -110,13 +127,13 @@ build_array_expression <- function(.Generic, e1, e2, ...) { "*" = "multiply_checked", "/" = "divide_checked", "%/%" = "divide_checked", - "%in%" = "is_in_meta_binary" + "%in%" = "is_in_meta_binary", + "%%" = "divide_checked" ) -# ‘"^"’, -# ‘"%%"’, -# ‘"%/%"’ +# ‘"^"’ + .array_function_map <- c(.unary_function_map, .binary_function_map) @@ -226,13 +243,31 @@ build_dataset_expression <- function(.Generic, e1, e2, ...) { # integer inputs and floating-point division on floats if (.Generic == "/") { # TODO: omg so many ways it's wrong to assume these types - e1 <- array_expression("cast", e1, options = list(to_type = float64())) - e2 <- array_expression("cast", e2, options = list(to_type = float64())) + e1 <- Expression$create("cast", e1, options = list(to_type = float64())) + e2 <- Expression$create("cast", e2, options = list(to_type = float64())) } else if (.Generic == "%/%") { - e1 <- array_expression("cast", e1, options = list(to_type = int32())) - e2 <- array_expression("cast", e2, options = list(to_type = int32())) + e1 <- Expression$create("cast", e1, options = list(to_type = float64())) + e2 <- Expression$create("cast", e2, options = list(to_type = float64())) + return(Expression$create("cast", Expression$create(.binary_function_map[[.Generic]], e1, e2, ...), options = list(to_type = int32(), allow_float_truncate = TRUE))) } else if (.Generic == "%%") { - # e1 - e2 * ( e1 %/% e2 ) + # {e1 - e2 * ( e1 %/% e2 )} + # TODO: there has to be a way to use the form ^^^ instead of this. + out <- Expression$create( + "subtract_checked", e1, Expression$create( + "multiply_checked", e2, Expression$create( + # this outer cast is to ensure that the result of this and the + # result of multiply are the same + "cast", + Expression$create( + "cast", + Expression$create(.binary_function_map[[.Generic]], e1, e2, ...), + options = list(to_type = int32(), allow_float_truncate = TRUE) + ), + options = list(to_type = e2$type, allow_float_truncate = TRUE) + ) + ) + ) + return(out) } expr <- Expression$create(.binary_function_map[[.Generic]], e1, e2, ...) diff --git a/r/src/compute.cpp b/r/src/compute.cpp index c44d1674153..4497f5b59a3 100644 --- a/r/src/compute.cpp +++ b/r/src/compute.cpp @@ -185,16 +185,31 @@ std::shared_ptr make_compute_options( cpp11::as_cpp(options["skip_nulls"])); } - // hacky attempt to pass through to_type + // hacky attempt to pass through to_type and other options if (func_name == "cast") { using Options = arrow::compute::CastOptions; - auto out = std::make_shared(false); + auto out = std::make_shared(true); SEXP to_type = options["to_type"]; if (!Rf_isNull(to_type) && cpp11::as_cpp>(to_type)) { out->to_type = cpp11::as_cpp>(to_type); - } - return out; + } + SEXP allow_float_truncate = options["allow_float_truncate"]; + if (!Rf_isNull(allow_float_truncate) && cpp11::as_cpp(allow_float_truncate)) { + out->allow_float_truncate = cpp11::as_cpp(allow_float_truncate); + } + + SEXP allow_time_truncate = options["allow_time_truncate"]; + if (!Rf_isNull(allow_time_truncate) && cpp11::as_cpp(allow_time_truncate)) { + out->allow_time_truncate = cpp11::as_cpp(allow_time_truncate); + } + + SEXP allow_int_overflow = options["allow_int_overflow"]; + if (!Rf_isNull(allow_int_overflow) && cpp11::as_cpp(allow_int_overflow)) { + out->allow_int_overflow = cpp11::as_cpp(allow_int_overflow); + } + + return out; } return nullptr; diff --git a/r/tests/testthat/test-compute-arith.R b/r/tests/testthat/test-compute-arith.R index 27d78dd090d..829eaae5796 100644 --- a/r/tests/testthat/test-compute-arith.R +++ b/r/tests/testthat/test-compute-arith.R @@ -70,4 +70,8 @@ test_that("Division", { # c(1:4) %/% 2.2 == c(0L, 0L, 1L, 1L) # c(1:4) %/% as.integer(2.2) == c(0L, 1L, 1L, 2L) expect_equal(b %/% 2.2, Array$create(c(0L, 0L, 1L, 1L, NA_integer_))) + + expect_equal(a %% 2, Array$create(c(1L, 0L, 1L, 0L, NA_integer_))) + + expect_equal(b %% 2, Array$create(c(1:4 %% 2, NA_real_))) }) diff --git a/r/tests/testthat/test-dataset.R b/r/tests/testthat/test-dataset.R index 2228dbc681c..156fa286f81 100644 --- a/r/tests/testthat/test-dataset.R +++ b/r/tests/testthat/test-dataset.R @@ -523,18 +523,6 @@ test_that("filter() with expressions", { df2[1:2, c("chr", "dbl")] ) ) - - expect_equivalent( - ds %>% - select(chr, dbl) %>% - filter(dbl / 2L > 3.5 & dbl < 53) %>% - collect() %>% - arrange(dbl), - rbind( - df1[8:10, c("chr", "dbl")], - df2[1:2, c("chr", "dbl")] - ) - ) }) test_that("filter scalar validation doesn't crash (ARROW-7772)", { diff --git a/r/tests/testthat/test-dplyr.R b/r/tests/testthat/test-dplyr.R index f088d833f24..87163a51410 100644 --- a/r/tests/testthat/test-dplyr.R +++ b/r/tests/testthat/test-dplyr.R @@ -163,7 +163,9 @@ test_that("filtering with arithmetic", { filter(dbl %/% 2 > 3) %>% select(string = chr, int, dbl) %>% collect(), - tbl + tbl, + # TODO: why are record batched versions problematic? + skip_record_batch = "record batches aren't (auto?) casting correctly" ) }) diff --git a/r/tests/testthat/test-expression.R b/r/tests/testthat/test-expression.R index 0c5ef4c12da..1251cf0e7c5 100644 --- a/r/tests/testthat/test-expression.R +++ b/r/tests/testthat/test-expression.R @@ -66,3 +66,22 @@ test_that("C++ expressions", { # Interprets that as a list type expect_is(f == c(1L, 2L), "Expression") }) + +test_that("Can create an expression", { + a <- Array$create(as.numeric(1:5)) + expr <- array_expression("cast", a, options = list(to_type = int32())) + expect_is(expr, "array_expression") + expect_equal(eval_array_expression(expr), Array$create(1:5)) + + b <- Array$create(0.5:4.5) + bad_expr <- array_expression("cast", b, options = list(to_type = int32())) + expect_is(bad_expr, "array_expression") + expect_error( + eval_array_expression(bad_expr), + "Invalid: Float value .* was truncated converting" + ) + expr <- array_expression("cast", b, options = list(to_type = int32(), allow_float_truncate = TRUE)) + expect_is(expr, "array_expression") + expect_equal(eval_array_expression(expr), Array$create(0:4)) +}) + From c0e7725aba0f94eabf6172fcfd76890e8a53dabe Mon Sep 17 00:00:00 2001 From: Jonathan Keane Date: Wed, 30 Dec 2020 12:27:18 -0600 Subject: [PATCH 28/31] PR comments --- r/R/expression.R | 63 ++++++++++++--------------- r/tests/testthat/test-compute-arith.R | 24 ++++++---- r/tests/testthat/test-dataset.R | 61 ++++++++++++++++++++++++++ r/tests/testthat/test-dplyr.R | 50 ++++++++++++++++++--- 4 files changed, 150 insertions(+), 48 deletions(-) diff --git a/r/R/expression.R b/r/R/expression.R index d5ed4c845e9..673225ea15a 100644 --- a/r/R/expression.R +++ b/r/R/expression.R @@ -64,15 +64,15 @@ build_array_expression <- function(.Generic, e1, e2, ...) { # integer inputs and floating-point division on floats if (.Generic == "/") { # TODO: omg so many ways it's wrong to assume these types - e1 <- array_expression("cast", e1, options = list(to_type = float64())) - e2 <- array_expression("cast", e2, options = list(to_type = float64())) + e1 <- e1$cast(float64()) + e2 <- e2$cast(float64()) } else if (.Generic == "%/%") { - e1 <- array_expression("cast", e1, options = list(to_type = float64())) - e2 <- array_expression("cast", e2, options = list(to_type = float64())) return(array_expression("cast", array_expression(.binary_function_map[[.Generic]], e1, e2, ...), options = list(to_type = int32(), allow_float_truncate = TRUE))) } else if (.Generic == "%%") { # {e1 - e2 * ( e1 %/% e2 )} # TODO: there has to be a way to use the form ^^^ instead of this. + # with return(e1 - e2 * (e1 %/% e2)) we get: + # "cannot add bindings to a locked environment" out <- array_expression( "subtract_checked", e1, array_expression( "multiply_checked", e2, array_expression( @@ -90,6 +90,13 @@ build_array_expression <- function(.Generic, e1, e2, ...) { ) return(out) } + + # hack to use subtract instead of subtract_checked for timestamps + if (inherits(e1$type, "Timestamp") && inherits(e2$type, "Timestamp") && .Generic == "-"){ + # don't use the checked variant for timestamp + return(array_expression("subtract", e1, e2, ...)) + } + expr <- array_expression(.binary_function_map[[.Generic]], e1, e2, ...) } expr @@ -127,15 +134,13 @@ build_array_expression <- function(.Generic, e1, e2, ...) { "*" = "multiply_checked", "/" = "divide_checked", "%/%" = "divide_checked", - "%in%" = "is_in_meta_binary", - "%%" = "divide_checked" + # we don't actually use divide_checked with `%%`, rather it is rewritten to + # use %/% above. + "%%" = "divide_checked", + # TODO: "^" (ARROW-11070) + "%in%" = "is_in_meta_binary" ) - -# ‘"^"’ - - - .array_function_map <- c(.unary_function_map, .binary_function_map) eval_array_expression <- function(x) { @@ -202,7 +207,10 @@ print.array_expression <- function(x, ...) { #' @export Expression <- R6Class("Expression", inherit = ArrowObject, public = list( - ToString = function() dataset___expr__ToString(self) + ToString = function() dataset___expr__ToString(self), + cast = function(to_type, ...) { + Expression$create("cast", self, options = list(to_type = to_type, ...)) + } ) ) Expression$create <- function(function_name, @@ -243,31 +251,16 @@ build_dataset_expression <- function(.Generic, e1, e2, ...) { # integer inputs and floating-point division on floats if (.Generic == "/") { # TODO: omg so many ways it's wrong to assume these types - e1 <- Expression$create("cast", e1, options = list(to_type = float64())) - e2 <- Expression$create("cast", e2, options = list(to_type = float64())) + e1 <- e1$cast(float64()) + e2 <- e2$cast(float64()) } else if (.Generic == "%/%") { - e1 <- Expression$create("cast", e1, options = list(to_type = float64())) - e2 <- Expression$create("cast", e2, options = list(to_type = float64())) - return(Expression$create("cast", Expression$create(.binary_function_map[[.Generic]], e1, e2, ...), options = list(to_type = int32(), allow_float_truncate = TRUE))) + # In R, integer division works like floor(float division) + out <- build_dataset_expression("/", e1, e2) + return(out$cast(int32(), allow_float_truncate = TRUE)) } else if (.Generic == "%%") { - # {e1 - e2 * ( e1 %/% e2 )} - # TODO: there has to be a way to use the form ^^^ instead of this. - out <- Expression$create( - "subtract_checked", e1, Expression$create( - "multiply_checked", e2, Expression$create( - # this outer cast is to ensure that the result of this and the - # result of multiply are the same - "cast", - Expression$create( - "cast", - Expression$create(.binary_function_map[[.Generic]], e1, e2, ...), - options = list(to_type = int32(), allow_float_truncate = TRUE) - ), - options = list(to_type = e2$type, allow_float_truncate = TRUE) - ) - ) - ) - return(out) + # TODO: need to do something with types to ensure that e2 is compatible + # with e1 %/% e2 and e1. + return(e1 - e2 * ( e1 %/% e2 )) } expr <- Expression$create(.binary_function_map[[.Generic]], e1, e2, ...) diff --git a/r/tests/testthat/test-compute-arith.R b/r/tests/testthat/test-compute-arith.R index 829eaae5796..7ac5bd9016c 100644 --- a/r/tests/testthat/test-compute-arith.R +++ b/r/tests/testthat/test-compute-arith.R @@ -15,13 +15,6 @@ # specific language governing permissions and limitations # under the License. -# TODO: -# * More tests for edge cases, esp. with division; add test helpers here? -# * Is there a better "autocasting" solution? See what rules C++ Datasets do -# * test-dplyr tests (Added one addition, and one summarize, but check to see if -# we can make summarize route through arrow need more?) -# * then, dataset tests, special casing for division - test_that("Addition", { a <- Array$create(c(1:4, NA_integer_)) expect_type_equal(a, int32()) @@ -37,7 +30,7 @@ test_that("Addition", { casted <- a$cast(int8()) expect_error(casted + 257) - skip("autocasting should happen in compute kernels; R workaround fails on this") + skip("autocasting should happen in compute kernels; R workaround fails on this ARROW-11078") expect_type_equal(a + 4.1, float64()) expect_equal(a + 4.1, Array$create(c(5.1, 6.1, 7.1, 8.1, NA_real_))) }) @@ -75,3 +68,18 @@ test_that("Division", { expect_equal(b %% 2, Array$create(c(1:4 %% 2, NA_real_))) }) + +test_that("Dates casting", { + a <- Array$create(c(Sys.Date() + 1:4, NA_integer_)) + + skip("autocasting should happen in compute kernels; R workaround fails on this ARROW-11078") + expect_equal(a + 2, Array$create(c((Sys.Date() + 1:4 ) + 2), NA_integer_)) +}) + +test_that("Datetimes", { + a <- Array$create(c(Sys.time() + 1:4, NA_integer_)) + b <- Scalar$create(Sys.time()) + result <- a - b + expect_is(result$type, "DataType") + expect_identical(result$type$ToString(), "duration[us]") +}) diff --git a/r/tests/testthat/test-dataset.R b/r/tests/testthat/test-dataset.R index 156fa286f81..9fc86d32616 100644 --- a/r/tests/testthat/test-dataset.R +++ b/r/tests/testthat/test-dataset.R @@ -523,6 +523,67 @@ test_that("filter() with expressions", { df2[1:2, c("chr", "dbl")] ) ) + + expect_equivalent( + ds %>% + select(chr, dbl, int) %>% + filter(int %/% 2L > 3 & dbl < 53) %>% + collect() %>% + arrange(dbl), + rbind( + df1[8:10, c("chr", "dbl", "int")], + df2[1:2, c("chr", "dbl", "int")] + ) + ) + + expect_equivalent( + ds %>% + select(chr, dbl, int) %>% + filter(int %/% 2 > 3 & dbl < 53) %>% + collect() %>% + arrange(dbl), + rbind( + df1[8:10, c("chr", "dbl", "int")], + df2[1:2, c("chr", "dbl", "int")] + ) + ) + + expect_equivalent( + ds %>% + select(chr, dbl, int) %>% + filter(int %% 2L > 0 & dbl < 53) %>% + collect() %>% + arrange(dbl), + rbind( + df1[c(1, 3, 5, 7, 9), c("chr", "dbl", "int")], + df2[1, c("chr", "dbl", "int")] + ) + ) + + skip("autocasting should happen in compute kernels; R workaround fails on this ARROW-11078") + expect_equivalent( + ds %>% + select(chr, dbl, int) %>% + filter(int %% 2 > 0 & dbl < 53) %>% + collect() %>% + arrange(dbl), + rbind( + df1[c(1, 3, 5, 7, 9), c("chr", "dbl", "int")], + df2[1, c("chr", "dbl", "int")] + ) + ) + + expect_equivalent( + ds %>% + select(chr, dbl, int) %>% + filter(dbl + int > 15 & dbl < 53L) %>% + collect() %>% + arrange(dbl), + rbind( + df1[8:10, c("chr", "dbl", "int")], + df2[1:2, c("chr", "dbl", "int")] + ) + ) }) test_that("filter scalar validation doesn't crash (ARROW-7772)", { diff --git a/r/tests/testthat/test-dplyr.R b/r/tests/testthat/test-dplyr.R index 87163a51410..63459253f32 100644 --- a/r/tests/testthat/test-dplyr.R +++ b/r/tests/testthat/test-dplyr.R @@ -27,6 +27,8 @@ expect_dplyr_equal <- function(expr, # A dplyr pipeline with `input` as its star expr <- rlang::enquo(expr) expected <- rlang::eval_tidy(expr, rlang::new_data_mask(rlang::env(input = tbl))) + skip_msg <- NULL + if (is.null(skip_record_batch)) { via_batch <- rlang::eval_tidy( expr, @@ -34,7 +36,7 @@ expect_dplyr_equal <- function(expr, # A dplyr pipeline with `input` as its star ) expect_equivalent(via_batch, expected, ...) } else { - skip(skip_record_batch) + skip_msg <- c(skip_msg, skip_record_batch) } if (is.null(skip_table)) { @@ -44,7 +46,11 @@ expect_dplyr_equal <- function(expr, # A dplyr pipeline with `input` as its star ) expect_equivalent(via_table, expected, ...) } else { - skip(skip_table) + skip_msg <- c(skip_msg, skip_table) + } + + if (!is.null(skip_msg)) { + skip(paste(skip_msg, collpase = "\n")) } } @@ -158,14 +164,48 @@ test_that("filtering with arithmetic", { tbl ) + expect_dplyr_equal( + input %>% + filter(int / 2 > 3) %>% + select(string = chr, int, dbl) %>% + collect(), + tbl + ) + + expect_dplyr_equal( + input %>% + filter(int / 2L > 3) %>% + select(string = chr, int, dbl) %>% + collect(), + tbl + ) + + skip("autocasting should happen in compute kernels; R workaround fails on this ARROW-11078") expect_dplyr_equal( input %>% filter(dbl %/% 2 > 3) %>% select(string = chr, int, dbl) %>% collect(), - tbl, - # TODO: why are record batched versions problematic? - skip_record_batch = "record batches aren't (auto?) casting correctly" + tbl + ) +}) + +test_that("filtering with expression + autocasting", { + skip("autocasting should happen in compute kernels; R workaround fails on this ARROW-11078") + expect_dplyr_equal( + input %>% + filter(dbl + 1 > 3L) %>% # test autocasting with comparison to 3L + select(string = chr, int, dbl) %>% + collect(), + tbl + ) + + expect_dplyr_equal( + input %>% + filter(int + 1 > 3) %>% + select(string = chr, int, dbl) %>% + collect(), + tbl ) }) From 843ff2a39d8a4e1c92247fb672567c0b85b4f45a Mon Sep 17 00:00:00 2001 From: Neal Richardson Date: Wed, 30 Dec 2020 13:47:00 -0800 Subject: [PATCH 29/31] Move array/scalar autocasting to eval time --- r/R/arrowExports.R | 4 -- r/R/expression.R | 76 ++++++++++++++++----------- r/R/scalar.R | 10 +++- r/src/arrowExports.cpp | 17 ------ r/src/scalar.cpp | 6 --- r/tests/testthat/test-compute-arith.R | 3 +- r/tests/testthat/test-dataset.R | 17 +++++- r/tests/testthat/test-dplyr.R | 4 +- r/tests/testthat/test-expression.R | 3 +- 9 files changed, 73 insertions(+), 67 deletions(-) diff --git a/r/R/arrowExports.R b/r/R/arrowExports.R index 5bd9c0bcf7f..7407e7c23de 100644 --- a/r/R/arrowExports.R +++ b/r/R/arrowExports.R @@ -1408,10 +1408,6 @@ Scalar__ToString <- function(s){ .Call(`_arrow_Scalar__ToString` , s) } -Scalar__CastTo <- function(s, t){ - .Call(`_arrow_Scalar__CastTo` , s, t) -} - StructScalar__field <- function(s, i){ .Call(`_arrow_StructScalar__field` , s, i) } diff --git a/r/R/expression.R b/r/R/expression.R index 673225ea15a..bd944a27a59 100644 --- a/r/R/expression.R +++ b/r/R/expression.R @@ -57,38 +57,28 @@ build_array_expression <- function(.Generic, e1, e2, ...) { if (.Generic %in% names(.unary_function_map)) { expr <- array_expression(.unary_function_map[[.Generic]], e1) } else { - e1 <- .wrap_arrow(e1, .Generic, e2$type) - e2 <- .wrap_arrow(e2, .Generic, e1$type) + e1 <- .wrap_arrow(e1, .Generic) + e2 <- .wrap_arrow(e2, .Generic) # In Arrow, "divide" is one function, which does integer division on # integer inputs and floating-point division on floats if (.Generic == "/") { # TODO: omg so many ways it's wrong to assume these types - e1 <- e1$cast(float64()) - e2 <- e2$cast(float64()) + e1 <- cast_array_expression(e1, float64()) + e2 <- cast_array_expression(e2, float64()) } else if (.Generic == "%/%") { - return(array_expression("cast", array_expression(.binary_function_map[[.Generic]], e1, e2, ...), options = list(to_type = int32(), allow_float_truncate = TRUE))) + # In R, integer division works like floor(float division) + out <- build_array_expression("/", e1, e2) + return(cast_array_expression(out, int32(), allow_float_truncate = TRUE)) } else if (.Generic == "%%") { # {e1 - e2 * ( e1 %/% e2 )} - # TODO: there has to be a way to use the form ^^^ instead of this. - # with return(e1 - e2 * (e1 %/% e2)) we get: - # "cannot add bindings to a locked environment" - out <- array_expression( - "subtract_checked", e1, array_expression( - "multiply_checked", e2, array_expression( - # this outer cast is to ensure that the result of this and the - # result of multiply are the same - "cast", - array_expression( - "cast", - array_expression(.binary_function_map[[.Generic]], e1, e2, ...), - options = list(to_type = int32(), allow_float_truncate = TRUE) - ), - options = list(to_type = e2$type, allow_float_truncate = TRUE) - ) - ) - ) - return(out) + # ^^^ form doesn't work because Ops.Array evaluates eagerly, + # but we can build that up + quotient <- build_array_expression("%/%", e1, e2) + # this cast is to ensure that the result of this and e1 are the same + # (autocasting only applies to scalars) + base <- cast_array_expression(quotient * e2, e1$type) + return(build_array_expression("-", e1, base)) } # hack to use subtract instead of subtract_checked for timestamps @@ -102,14 +92,24 @@ build_array_expression <- function(.Generic, e1, e2, ...) { expr } -.wrap_arrow <- function(arg, fun, type) { +cast_array_expression <- function(x, to_type, safe = TRUE, ...) { + opts <- list( + to_type = to_type, + allow_int_overflow = !safe, + allow_time_truncate = !safe, + allow_float_truncate = !safe + ) + array_expression("cast", x, options = modifyList(opts, list(...))) +} + +.wrap_arrow <- function(arg, fun) { if (!inherits(arg, c("ArrowObject", "array_expression"))) { # TODO: Array$create if lengths are equal? # TODO: these kernels should autocast like the dataset ones do (e.g. int vs. float) if (fun == "%in%") { - arg <- Array$create(arg, type = type) + arg <- Array$create(arg) } else { - arg <- Scalar$create(arg, type = type) + arg <- Scalar$create(arg) } } arg @@ -151,6 +151,16 @@ eval_array_expression <- function(x) { a } }) + if (length(x$args) == 2L) { + # Insert implicit casts + if (inherits(x$args[[1]], "Scalar")) { + x$args[[1]] <- x$args[[1]]$cast(x$args[[2]]$type) + } else if (inherits(x$args[[2]], "Scalar")) { + x$args[[2]] <- x$args[[2]]$cast(x$args[[1]]$type) + } else if (x$fun == "is_in_meta_binary" && inherits(x$args[[2]], "Array")) { + x$args[[2]] <- x$args[[2]]$cast(x$args[[1]]$type) + } + } call_function(x$fun, args = x$args, options = x$options %||% empty_named_list()) } @@ -208,8 +218,14 @@ print.array_expression <- function(x, ...) { Expression <- R6Class("Expression", inherit = ArrowObject, public = list( ToString = function() dataset___expr__ToString(self), - cast = function(to_type, ...) { - Expression$create("cast", self, options = list(to_type = to_type, ...)) + cast = function(to_type, safe = TRUE, ...) { + opts <- list( + to_type = to_type, + allow_int_overflow = !safe, + allow_time_truncate = !safe, + allow_float_truncate = !safe + ) + Expression$create("cast", self, options = modifyList(opts, list(...))) } ) ) @@ -258,8 +274,6 @@ build_dataset_expression <- function(.Generic, e1, e2, ...) { out <- build_dataset_expression("/", e1, e2) return(out$cast(int32(), allow_float_truncate = TRUE)) } else if (.Generic == "%%") { - # TODO: need to do something with types to ensure that e2 is compatible - # with e1 %/% e2 and e1. return(e1 - e2 * ( e1 %/% e2 )) } diff --git a/r/R/scalar.R b/r/R/scalar.R index 12f29990e0a..774fe571145 100644 --- a/r/R/scalar.R +++ b/r/R/scalar.R @@ -32,8 +32,14 @@ Scalar <- R6Class("Scalar", # TODO: document the methods public = list( ToString = function() Scalar__ToString(self), - cast = function(target_type) { - Scalar__CastTo(self, as_type(target_type)) + cast = function(target_type, safe = TRUE, ...) { + opts <- list( + to_type = as_type(target_type), + allow_int_overflow = !safe, + allow_time_truncate = !safe, + allow_float_truncate = !safe + ) + call_function("cast", self, options = modifyList(opts, list(...))) }, as_vector = function() Scalar__as_vector(self) ), diff --git a/r/src/arrowExports.cpp b/r/src/arrowExports.cpp index 26bd57e1e28..0a73b8681c4 100644 --- a/r/src/arrowExports.cpp +++ b/r/src/arrowExports.cpp @@ -5528,22 +5528,6 @@ extern "C" SEXP _arrow_Scalar__ToString(SEXP s_sexp){ } #endif -// scalar.cpp -#if defined(ARROW_R_WITH_ARROW) -std::shared_ptr Scalar__CastTo(const std::shared_ptr& s, const std::shared_ptr& t); -extern "C" SEXP _arrow_Scalar__CastTo(SEXP s_sexp, SEXP t_sexp){ -BEGIN_CPP11 - arrow::r::Input&>::type s(s_sexp); - arrow::r::Input&>::type t(t_sexp); - return cpp11::as_sexp(Scalar__CastTo(s, t)); -END_CPP11 -} -#else -extern "C" SEXP _arrow_Scalar__CastTo(SEXP s_sexp, SEXP t_sexp){ - Rf_error("Cannot call Scalar__CastTo(). Please use arrow::install_arrow() to install required runtime libraries. "); -} -#endif - // scalar.cpp #if defined(ARROW_R_WITH_ARROW) std::shared_ptr StructScalar__field(const std::shared_ptr& s, int i); @@ -6601,7 +6585,6 @@ static const R_CallMethodDef CallEntries[] = { { "_arrow_ipc___RecordBatchStreamWriter__Open", (DL_FUNC) &_arrow_ipc___RecordBatchStreamWriter__Open, 4}, { "_arrow_Array__GetScalar", (DL_FUNC) &_arrow_Array__GetScalar, 2}, { "_arrow_Scalar__ToString", (DL_FUNC) &_arrow_Scalar__ToString, 1}, - { "_arrow_Scalar__CastTo", (DL_FUNC) &_arrow_Scalar__CastTo, 2}, { "_arrow_StructScalar__field", (DL_FUNC) &_arrow_StructScalar__field, 2}, { "_arrow_StructScalar__GetFieldByName", (DL_FUNC) &_arrow_StructScalar__GetFieldByName, 2}, { "_arrow_Scalar__as_vector", (DL_FUNC) &_arrow_Scalar__as_vector, 1}, diff --git a/r/src/scalar.cpp b/r/src/scalar.cpp index 2c2d291b5bf..c0cc396b02d 100644 --- a/r/src/scalar.cpp +++ b/r/src/scalar.cpp @@ -47,12 +47,6 @@ std::string Scalar__ToString(const std::shared_ptr& s) { return s->ToString(); } -// [[arrow::export]] -std::shared_ptr Scalar__CastTo(const std::shared_ptr& s, - const std::shared_ptr& t) { - return ValueOrStop(s->CastTo(t)); -} - // [[arrow::export]] std::shared_ptr StructScalar__field( const std::shared_ptr& s, int i) { diff --git a/r/tests/testthat/test-compute-arith.R b/r/tests/testthat/test-compute-arith.R index 7ac5bd9016c..393fdf45455 100644 --- a/r/tests/testthat/test-compute-arith.R +++ b/r/tests/testthat/test-compute-arith.R @@ -28,7 +28,8 @@ test_that("Addition", { # overflow errors — this is slightly different from R's `NA` coercion when # overflowing, but better than the alternative of silently restarting casted <- a$cast(int8()) - expect_error(casted + 257) + expect_error(casted + 127) + expect_error(casted + 200) skip("autocasting should happen in compute kernels; R workaround fails on this ARROW-11078") expect_type_equal(a + 4.1, float64()) diff --git a/r/tests/testthat/test-dataset.R b/r/tests/testthat/test-dataset.R index 9fc86d32616..c9929efd414 100644 --- a/r/tests/testthat/test-dataset.R +++ b/r/tests/testthat/test-dataset.R @@ -560,7 +560,20 @@ test_that("filter() with expressions", { ) ) - skip("autocasting should happen in compute kernels; R workaround fails on this ARROW-11078") + expect_equivalent( + ds %>% + select(chr, dbl, int) %>% + filter(int %% 2L > 0 & dbl < 53) %>% + collect() %>% + arrange(dbl), + rbind( + df1[c(1, 3, 5, 7, 9), c("chr", "dbl", "int")], + df2[1, c("chr", "dbl", "int")] + ) + ) + + skip("Implicit casts aren't being inserted everywhere they need to be") + # Error: NotImplemented: Function multiply_checked has no kernel matching input types (scalar[double], array[int32]) expect_equivalent( ds %>% select(chr, dbl, int) %>% @@ -573,6 +586,8 @@ test_that("filter() with expressions", { ) ) + skip("Implicit casts are only inserted for scalars") + # Error: NotImplemented: Function add_checked has no kernel matching input types (array[double], array[int32]) expect_equivalent( ds %>% select(chr, dbl, int) %>% diff --git a/r/tests/testthat/test-dplyr.R b/r/tests/testthat/test-dplyr.R index 63459253f32..19c0665c807 100644 --- a/r/tests/testthat/test-dplyr.R +++ b/r/tests/testthat/test-dplyr.R @@ -180,7 +180,6 @@ test_that("filtering with arithmetic", { tbl ) - skip("autocasting should happen in compute kernels; R workaround fails on this ARROW-11078") expect_dplyr_equal( input %>% filter(dbl %/% 2 > 3) %>% @@ -191,7 +190,6 @@ test_that("filtering with arithmetic", { }) test_that("filtering with expression + autocasting", { - skip("autocasting should happen in compute kernels; R workaround fails on this ARROW-11078") expect_dplyr_equal( input %>% filter(dbl + 1 > 3L) %>% # test autocasting with comparison to 3L @@ -243,7 +241,7 @@ test_that("Print method", { int: int32 chr: string -* Filter: and(and(greater(, 2), or(equal(, "d"), equal(, "f"))), less(, 5L)) +* Filter: and(and(greater(, 2), or(equal(, "d"), equal(, "f"))), less(, 5)) See $.data for the source Arrow object', fixed = TRUE ) diff --git a/r/tests/testthat/test-expression.R b/r/tests/testthat/test-expression.R index 1251cf0e7c5..3c100812ff1 100644 --- a/r/tests/testthat/test-expression.R +++ b/r/tests/testthat/test-expression.R @@ -29,7 +29,7 @@ test_that("array_expression print method", { expect_output( print(build_array_expression(">", Array$create(1:5), 4)), # Not ideal but it is informative - "greater(, 4L)", + "greater(, 4)", fixed = TRUE ) }) @@ -84,4 +84,3 @@ test_that("Can create an expression", { expect_is(expr, "array_expression") expect_equal(eval_array_expression(expr), Array$create(0:4)) }) - From 391061ec24a3ecdc8c5c65a11b0a80478ad7ab60 Mon Sep 17 00:00:00 2001 From: Neal Richardson Date: Wed, 30 Dec 2020 14:16:42 -0800 Subject: [PATCH 30/31] Add jira to skips --- r/tests/testthat/test-dataset.R | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/r/tests/testthat/test-dataset.R b/r/tests/testthat/test-dataset.R index c9929efd414..62b437e439c 100644 --- a/r/tests/testthat/test-dataset.R +++ b/r/tests/testthat/test-dataset.R @@ -572,7 +572,7 @@ test_that("filter() with expressions", { ) ) - skip("Implicit casts aren't being inserted everywhere they need to be") + skip("Implicit casts aren't being inserted everywhere they need to be (ARROW-11080)") # Error: NotImplemented: Function multiply_checked has no kernel matching input types (scalar[double], array[int32]) expect_equivalent( ds %>% @@ -586,7 +586,7 @@ test_that("filter() with expressions", { ) ) - skip("Implicit casts are only inserted for scalars") + skip("Implicit casts are only inserted for scalars (ARROW-11080)") # Error: NotImplemented: Function add_checked has no kernel matching input types (array[double], array[int32]) expect_equivalent( ds %>% From e3d3a6a94042af346359d9120346089fe3275379 Mon Sep 17 00:00:00 2001 From: Jonathan Keane Date: Thu, 31 Dec 2020 10:00:57 -0600 Subject: [PATCH 31/31] Remove datetime hackery --- r/R/expression.R | 6 ------ r/tests/testthat/test-compute-arith.R | 8 -------- 2 files changed, 14 deletions(-) diff --git a/r/R/expression.R b/r/R/expression.R index bd944a27a59..9a5e575183d 100644 --- a/r/R/expression.R +++ b/r/R/expression.R @@ -81,12 +81,6 @@ build_array_expression <- function(.Generic, e1, e2, ...) { return(build_array_expression("-", e1, base)) } - # hack to use subtract instead of subtract_checked for timestamps - if (inherits(e1$type, "Timestamp") && inherits(e2$type, "Timestamp") && .Generic == "-"){ - # don't use the checked variant for timestamp - return(array_expression("subtract", e1, e2, ...)) - } - expr <- array_expression(.binary_function_map[[.Generic]], e1, e2, ...) } expr diff --git a/r/tests/testthat/test-compute-arith.R b/r/tests/testthat/test-compute-arith.R index 393fdf45455..ffde12c4d9b 100644 --- a/r/tests/testthat/test-compute-arith.R +++ b/r/tests/testthat/test-compute-arith.R @@ -76,11 +76,3 @@ test_that("Dates casting", { skip("autocasting should happen in compute kernels; R workaround fails on this ARROW-11078") expect_equal(a + 2, Array$create(c((Sys.Date() + 1:4 ) + 2), NA_integer_)) }) - -test_that("Datetimes", { - a <- Array$create(c(Sys.time() + 1:4, NA_integer_)) - b <- Scalar$create(Sys.time()) - result <- a - b - expect_is(result$type, "DataType") - expect_identical(result$type$ToString(), "duration[us]") -})