From f8e258b33e6cf436c6dac61a647396d18d7fe945 Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Mon, 26 Oct 2020 12:46:07 -0400 Subject: [PATCH 01/38] ARROW-10322: [C++][Dataset] Minimize dataset::Expression --- cpp/src/arrow/array/array_struct_test.cc | 36 + cpp/src/arrow/array/util.cc | 17 +- cpp/src/arrow/chunked_array.h | 3 + cpp/src/arrow/compute/cast.cc | 80 +- cpp/src/arrow/compute/cast.h | 14 +- cpp/src/arrow/compute/exec.h | 2 +- cpp/src/arrow/compute/exec_internal.h | 4 +- cpp/src/arrow/compute/function.h | 6 +- .../arrow/compute/kernels/scalar_cast_test.cc | 43 + cpp/src/arrow/dataset/CMakeLists.txt | 2 + cpp/src/arrow/dataset/expression.cc | 931 ++++++++++++++++++ cpp/src/arrow/dataset/expression.h | 217 ++++ cpp/src/arrow/dataset/expression_internal.h | 110 +++ cpp/src/arrow/dataset/expression_test.cc | 729 ++++++++++++++ cpp/src/arrow/dataset/partition.cc | 11 - cpp/src/arrow/dataset/partition_test.cc | 4 - cpp/src/arrow/dataset/test_util.h | 1 + cpp/src/arrow/dataset/type_fwd.h | 3 + cpp/src/arrow/datum.cc | 15 + cpp/src/arrow/datum.h | 32 +- cpp/src/arrow/result.h | 22 + cpp/src/arrow/type.cc | 33 +- cpp/src/arrow/type.h | 25 +- cpp/src/arrow/type_fwd.h | 2 + 24 files changed, 2296 insertions(+), 46 deletions(-) create mode 100644 cpp/src/arrow/dataset/expression.cc create mode 100644 cpp/src/arrow/dataset/expression.h create mode 100644 cpp/src/arrow/dataset/expression_internal.h create mode 100644 cpp/src/arrow/dataset/expression_test.cc diff --git a/cpp/src/arrow/array/array_struct_test.cc b/cpp/src/arrow/array/array_struct_test.cc index f54b43465e9..aef0076d0d3 100644 --- a/cpp/src/arrow/array/array_struct_test.cc +++ b/cpp/src/arrow/array/array_struct_test.cc @@ -24,6 +24,7 @@ #include "arrow/array.h" #include "arrow/array/builder_nested.h" +#include "arrow/chunked_array.h" #include "arrow/status.h" #include "arrow/testing/gtest_common.h" #include "arrow/testing/gtest_util.h" @@ -582,4 +583,39 @@ TEST_F(TestStructBuilder, TestSlice) { ASSERT_EQ(list_field->null_count(), 1); } +TEST(TestFieldRef, GetChildren) { + auto struct_array = ArrayFromJSON(struct_({field("a", float64())}), R"([ + {"a": 6.125}, + {"a": 0.0}, + {"a": -1} + ])"); + + ASSERT_OK_AND_ASSIGN(auto a, FieldRef("a").GetOne(*struct_array)); + auto expected_a = ArrayFromJSON(float64(), "[6.125, 0.0, -1]"); + AssertArraysEqual(*a, *expected_a); + + auto ToChunked = [struct_array](int64_t midpoint) { + return ChunkedArray( + ArrayVector{ + struct_array->Slice(0, midpoint), + struct_array->Slice(midpoint), + }, + struct_array->type()); + }; + AssertChunkedEquivalent(ToChunked(1), ToChunked(2)); + + // more nested: + struct_array = + ArrayFromJSON(struct_({field("a", struct_({field("a", float64())}))}), R"([ + {"a": {"a": 6.125}}, + {"a": {"a": 0.0}}, + {"a": {"a": -1}} + ])"); + + ASSERT_OK_AND_ASSIGN(a, FieldRef("a", "a").GetOne(*struct_array)); + expected_a = ArrayFromJSON(float64(), "[6.125, 0.0, -1]"); + AssertArraysEqual(*a, *expected_a); + AssertChunkedEquivalent(ToChunked(1), ToChunked(2)); +} + } // namespace arrow diff --git a/cpp/src/arrow/array/util.cc b/cpp/src/arrow/array/util.cc index 70167954e5a..0d498931d42 100644 --- a/cpp/src/arrow/array/util.cc +++ b/cpp/src/arrow/array/util.cc @@ -297,8 +297,9 @@ class RepeatedArrayFactory { return out_; } - Status Visit(const DataType& type) { - return Status::NotImplemented("construction from scalar of type ", *scalar_.type); + Status Visit(const NullType& type) { + DCHECK(false); // already forwarded to MakeArrayOfNull + return Status::OK(); } Status Visit(const BooleanType&) { @@ -403,6 +404,18 @@ class RepeatedArrayFactory { return Status::OK(); } + Status Visit(const ExtensionType& type) { + return Status::NotImplemented("construction from scalar of type ", *scalar_.type); + } + + Status Visit(const DenseUnionType& type) { + return Status::NotImplemented("construction from scalar of type ", *scalar_.type); + } + + Status Visit(const SparseUnionType& type) { + return Status::NotImplemented("construction from scalar of type ", *scalar_.type); + } + template Status CreateOffsetsBuffer(OffsetType value_length, std::shared_ptr* out) { TypedBufferBuilder builder(pool_); diff --git a/cpp/src/arrow/chunked_array.h b/cpp/src/arrow/chunked_array.h index b1f8d6cd6f5..5c0dda91850 100644 --- a/cpp/src/arrow/chunked_array.h +++ b/cpp/src/arrow/chunked_array.h @@ -72,6 +72,9 @@ class ARROW_EXPORT ChunkedArray { /// data type. explicit ChunkedArray(ArrayVector chunks); + ChunkedArray(ChunkedArray&&) = default; + ChunkedArray& operator=(ChunkedArray&&) = default; + /// \brief Construct a chunked array from a single Array explicit ChunkedArray(std::shared_ptr chunk) : ChunkedArray(ArrayVector{std::move(chunk)}) {} diff --git a/cpp/src/arrow/compute/cast.cc b/cpp/src/arrow/compute/cast.cc index 29a80f73241..b55dfb38c15 100644 --- a/cpp/src/arrow/compute/cast.cc +++ b/cpp/src/arrow/compute/cast.cc @@ -118,8 +118,86 @@ class CastMetaFunction : public MetaFunction { } // namespace +const FunctionDoc struct_doc{"Wrap Arrays into a StructArray", + ("Names of the StructArray's fields are\n" + "specified through StructOptions."), + {}, + "StructOptions"}; + +Result StructResolve(KernelContext* ctx, + const std::vector& descrs) { + const auto& names = OptionsWrapper::Get(ctx).field_names; + if (names.size() != descrs.size()) { + return Status::Invalid("Struct() was passed ", names.size(), " field ", "names but ", + descrs.size(), " arguments"); + } + + size_t i = 0; + FieldVector fields(descrs.size()); + + ValueDescr::Shape shape = ValueDescr::SCALAR; + for (const ValueDescr& descr : descrs) { + if (descr.shape != ValueDescr::SCALAR) { + shape = ValueDescr::ARRAY; + } else { + switch (descr.type->id()) { + case Type::EXTENSION: + case Type::DENSE_UNION: + case Type::SPARSE_UNION: + return Status::NotImplemented("Broadcasting scalars of type ", *descr.type); + default: + break; + } + } + + fields[i] = field(names[i], descr.type); + ++i; + } + + return ValueDescr{struct_(std::move(fields)), shape}; +} + +void StructExec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + KERNEL_ASSIGN_OR_RAISE(auto descr, ctx, StructResolve(ctx, batch.GetDescriptors())); + + if (descr.shape == ValueDescr::SCALAR) { + ScalarVector scalars(batch.num_values()); + for (int i = 0; i < batch.num_values(); ++i) { + scalars[i] = batch[i].scalar(); + } + + *out = + Datum(std::make_shared(std::move(scalars), std::move(descr.type))); + return; + } + + ArrayVector arrays(batch.num_values()); + for (int i = 0; i < batch.num_values(); ++i) { + if (batch[i].is_array()) { + arrays[i] = batch[i].make_array(); + continue; + } + + KERNEL_ASSIGN_OR_RAISE( + arrays[i], ctx, + MakeArrayFromScalar(*batch[i].scalar(), batch.length, ctx->memory_pool())); + } + + *out = std::make_shared(descr.type, batch.length, std::move(arrays)); +} + void RegisterScalarCast(FunctionRegistry* registry) { DCHECK_OK(registry->AddFunction(std::make_shared())); + + auto struct_function = + std::make_shared("struct", Arity::VarArgs(), &struct_doc); + ScalarKernel kernel{KernelSignature::Make({InputType{}}, OutputType{StructResolve}, + /*is_varargs=*/true), + StructExec, OptionsWrapper::Init}; + kernel.null_handling = NullHandling::OUTPUT_NOT_NULL; + kernel.mem_allocation = MemAllocation::NO_PREALLOCATE; + DCHECK_OK(struct_function->AddKernel(std::move(kernel))); + DCHECK_OK(registry->AddFunction(std::move(struct_function))); } } // namespace internal @@ -135,7 +213,7 @@ CastFunction::CastFunction(std::string name, Type::type out_type) impl_->out_type = out_type; } -CastFunction::~CastFunction() {} +CastFunction::~CastFunction() = default; Type::type CastFunction::out_type_id() const { return impl_->out_type; } diff --git a/cpp/src/arrow/compute/cast.h b/cpp/src/arrow/compute/cast.h index 43392ce99bf..6808d1d86f3 100644 --- a/cpp/src/arrow/compute/cast.h +++ b/cpp/src/arrow/compute/cast.h @@ -83,7 +83,7 @@ struct ARROW_EXPORT CastOptions : public FunctionOptions { class CastFunction : public ScalarFunction { public: CastFunction(std::string name, Type::type out_type); - ~CastFunction(); + ~CastFunction() override; Type::type out_type_id() const; @@ -157,5 +157,17 @@ Result Cast(const Datum& value, std::shared_ptr to_type, const CastOptions& options = CastOptions::Safe(), ExecContext* ctx = NULLPTR); +/// \addtogroup compute-concrete-options +/// @{ + +struct ARROW_EXPORT StructOptions : public FunctionOptions { + explicit StructOptions(std::vector n) : field_names(std::move(n)) {} + + /// Names for wrapped columns + std::vector field_names; +}; + +/// @} + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/exec.h b/cpp/src/arrow/compute/exec.h index 142e149bccc..f491489ed8a 100644 --- a/cpp/src/arrow/compute/exec.h +++ b/cpp/src/arrow/compute/exec.h @@ -165,7 +165,7 @@ class ARROW_EXPORT SelectionVector { /// than is desirable for this class. Microbenchmarks would help determine for /// sure. See ARROW-8928. struct ExecBatch { - ExecBatch() {} + ExecBatch() = default; ExecBatch(std::vector values, int64_t length) : values(std::move(values)), length(length) {} diff --git a/cpp/src/arrow/compute/exec_internal.h b/cpp/src/arrow/compute/exec_internal.h index 8bad135e40d..a74e5c8d8fa 100644 --- a/cpp/src/arrow/compute/exec_internal.h +++ b/cpp/src/arrow/compute/exec_internal.h @@ -84,14 +84,14 @@ class ARROW_EXPORT ExecListener { class DatumAccumulator : public ExecListener { public: - DatumAccumulator() {} + DatumAccumulator() = default; Status OnResult(Datum value) override { values_.emplace_back(value); return Status::OK(); } - std::vector values() const { return values_; } + std::vector values() { return std::move(values_); } private: std::vector values_; diff --git a/cpp/src/arrow/compute/function.h b/cpp/src/arrow/compute/function.h index a71dbe40292..e8e732027c9 100644 --- a/cpp/src/arrow/compute/function.h +++ b/cpp/src/arrow/compute/function.h @@ -41,7 +41,9 @@ namespace compute { /// \brief Base class for specifying options configuring a function's behavior, /// such as error handling. -struct ARROW_EXPORT FunctionOptions {}; +struct ARROW_EXPORT FunctionOptions { + virtual ~FunctionOptions() = default; +}; /// \brief Contains the number of required arguments for the function. /// @@ -96,7 +98,7 @@ struct ARROW_EXPORT FunctionDoc { /// \brief Name of the options class, if any. std::string options_class; - FunctionDoc() {} + FunctionDoc() = default; FunctionDoc(std::string summary, std::string description, std::vector arg_names, std::string options_class = "") diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_test.cc b/cpp/src/arrow/compute/kernels/scalar_cast_test.cc index ec612e497f4..566a89f4b47 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_test.cc @@ -1892,5 +1892,48 @@ TEST_F(TestCast, ExtensionTypeToIntDowncast) { ASSERT_OK(UnregisterExtensionType("smallint")); } +class TestStruct : public TestBase { + public: + Result Struct(std::vector args) { + StructOptions opts{field_names}; + return CallFunction("struct", args, &opts); + } + + std::vector field_names; +}; + +TEST_F(TestStruct, Scalar) { + std::shared_ptr expected(new StructScalar{{}, struct_({})}); + ASSERT_OK_AND_EQ(Datum(expected), Struct({})); + + auto i32 = MakeScalar(1); + auto f64 = MakeScalar(2.5); + auto str = MakeScalar("yo"); + + expected.reset(new StructScalar{ + {i32, f64, str}, + struct_({field("i", i32->type), field("f", f64->type), field("s", str->type)})}); + field_names = {"i", "f", "s"}; + ASSERT_OK_AND_EQ(Datum(expected), Struct({i32, f64, str})); + + // Three field names but one input value + ASSERT_RAISES(Invalid, Struct({str})); +} + +TEST_F(TestStruct, Array) { + field_names = {"i", "s"}; + auto i32 = ArrayFromJSON(int32(), "[42, 13, 7]"); + auto str = ArrayFromJSON(utf8(), R"(["aa", "aa", "aa"])"); + ASSERT_OK_AND_ASSIGN(Datum expected, StructArray::Make({i32, str}, field_names)); + + ASSERT_OK_AND_EQ(expected, Struct({i32, str})); + + // Scalars are broadcast to the length of the arrays + ASSERT_OK_AND_EQ(expected, Struct({i32, MakeScalar("aa")})); + + // Array length mismatch + ASSERT_RAISES(Invalid, Struct({i32->Slice(1), str})); +} + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/dataset/CMakeLists.txt b/cpp/src/arrow/dataset/CMakeLists.txt index b693c48049c..ae6bb3656dc 100644 --- a/cpp/src/arrow/dataset/CMakeLists.txt +++ b/cpp/src/arrow/dataset/CMakeLists.txt @@ -22,6 +22,7 @@ arrow_install_all_headers("arrow/dataset") set(ARROW_DATASET_SRCS dataset.cc discovery.cc + expression.cc file_base.cc file_ipc.cc filter.cc @@ -106,6 +107,7 @@ endfunction() add_arrow_dataset_test(dataset_test) add_arrow_dataset_test(discovery_test) +add_arrow_dataset_test(expression_test) add_arrow_dataset_test(file_ipc_test) add_arrow_dataset_test(file_test) add_arrow_dataset_test(filter_test) diff --git a/cpp/src/arrow/dataset/expression.cc b/cpp/src/arrow/dataset/expression.cc new file mode 100644 index 00000000000..e0a44a34ab8 --- /dev/null +++ b/cpp/src/arrow/dataset/expression.cc @@ -0,0 +1,931 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/dataset/expression.h" + +#include +#include + +#include "arrow/chunked_array.h" +#include "arrow/compute/exec_internal.h" +#include "arrow/compute/registry.h" +#include "arrow/dataset/expression_internal.h" +#include "arrow/util/logging.h" +#include "arrow/util/optional.h" +#include "arrow/util/string.h" + +namespace arrow { + +using internal::checked_cast; +using internal::checked_pointer_cast; + +namespace dataset { + +const Expression2::Call* Expression2::call() const { + return util::get_if(impl_.get()); +} + +const Datum* Expression2::literal() const { return util::get_if(impl_.get()); } + +const FieldRef* Expression2::field_ref() const { + return util::get_if(impl_.get()); +} + +std::string Expression2::ToString() const { + if (auto lit = literal()) { + if (lit->is_scalar()) { + return lit->scalar()->ToString(); + } + return lit->ToString(); + } + + if (auto ref = field_ref()) { + if (auto name = ref->name()) { + return "FieldRef(" + *name + ")"; + } + if (auto path = ref->field_path()) { + return path->ToString(); + } + return ref->ToString(); + } + + auto call = CallNotNull(*this); + + // FIXME represent FunctionOptions + std::string out = call->function + "("; + for (const auto& arg : call->arguments) { + out += arg.ToString() + ","; + } + out.back() = ')'; + return out; +} + +void PrintTo(const Expression2& expr, std::ostream* os) { *os << expr.ToString(); } + +bool Expression2::Equals(const Expression2& other) const { + if (Identical(*this, other)) return true; + + if (impl_->index() != other.impl_->index()) { + return false; + } + + if (auto lit = literal()) { + return lit->Equals(*other.literal()); + } + + if (auto ref = field_ref()) { + return ref->Equals(*other.field_ref()); + } + + auto call = CallNotNull(*this); + auto other_call = CallNotNull(other); + + if (call->function != other_call->function || call->kernel != other_call->kernel) { + return false; + } + + // FIXME compare FunctionOptions for equality + for (size_t i = 0; i < call->arguments.size(); ++i) { + if (!call->arguments[i].Equals(other_call->arguments[i])) { + return false; + } + } + + return true; +} + +size_t Expression2::hash() const { + if (auto lit = literal()) { + if (lit->is_scalar()) { + return Scalar::Hash::hash(*lit->scalar()); + } + return 0; + } + + if (auto ref = field_ref()) { + return ref->hash(); + } + + auto call = CallNotNull(*this); + + size_t out = std::hash{}(call->function); + for (const auto& arg : call->arguments) { + out ^= arg.hash(); + } + return out; +} + +bool Expression2::IsBound() const { + if (descr_.type == nullptr) return false; + + if (auto lit = literal()) return true; + + if (auto ref = field_ref()) return true; + + auto call = CallNotNull(*this); + + for (const Expression2& arg : call->arguments) { + if (!arg.IsBound()) return false; + } + + return call->kernel != nullptr; +} + +bool Expression2::IsScalarExpression() const { + if (auto lit = literal()) { + return lit->is_scalar(); + } + + // FIXME handle case where a list's item field is referenced + if (auto ref = field_ref()) return true; + + auto call = CallNotNull(*this); + + for (const Expression2& arg : call->arguments) { + if (!arg.IsScalarExpression()) return false; + } + + if (call->kernel == nullptr) { + // this expression is not bound; make a best guess based on + // the default function registry + if (auto function = compute::GetFunctionRegistry() + ->GetFunction(call->function) + .ValueOr(nullptr)) { + return function->kind() == compute::Function::SCALAR; + } + + // unknown function or other error; conservatively return false + return false; + } + + return call->function_kind == compute::Function::SCALAR; +} + +bool Expression2::IsNullLiteral() const { + if (auto lit = literal()) { + if (lit->null_count() == lit->length()) { + return true; + } + } + return false; +} + +bool Expression2::IsSatisfiable() const { + if (auto lit = literal()) { + if (lit->null_count() == lit->length()) { + return false; + } + if (lit->is_scalar() && lit->type()->id() == Type::BOOL) { + return lit->scalar_as().value; + } + } + return true; +} + +bool KernelStateIsImmutable(const std::string& function) { + // XXX maybe just add Kernel::state_is_immutable or so? + + // known functions with non-null but nevertheless immutable KernelState + static std::unordered_set immutable_state = { + "is_in", "index_in", "cast", "struct", "strptime", + }; + + return immutable_state.find(function) != immutable_state.end(); +} + +Result> InitKernelState( + const Expression2::Call& call, compute::ExecContext* exec_context) { + if (!call.kernel->init) return nullptr; + + compute::KernelContext kernel_context(exec_context); + auto kernel_state = call.kernel->init( + &kernel_context, {call.kernel, GetDescriptors(call.arguments), call.options.get()}); + + RETURN_NOT_OK(kernel_context.status()); + return std::move(kernel_state); +} + +Result> ExpressionState::Clone( + const std::shared_ptr& state, const Expression2& expr, + compute::ExecContext* exec_context) { + if (!expr.IsBound()) { + return Status::Invalid("Cannot clone State against an unbound expression."); + } + + if (state == nullptr) return nullptr; + + auto call = CallNotNull(expr); + auto call_state = checked_cast(state.get()); + + CallState clone; + clone.argument_states.resize(call_state->argument_states.size()); + + bool recursively_share = true; + for (size_t i = 0; i < call->arguments.size(); ++i) { + ARROW_ASSIGN_OR_RAISE( + clone.argument_states[i], + Clone(call_state->argument_states[i], call->arguments[i], exec_context)); + + if (clone.argument_states[i] != call_state->argument_states[i]) { + recursively_share = false; + } + } + + if (call_state->kernel_state == nullptr || KernelStateIsImmutable(call->function)) { + // The kernel's state is immutable so it's safe to just + // share a pointer between threads + if (recursively_share) { + return state; + } + clone.kernel_state = call_state->kernel_state; + } else { + // The kernel's state must be re-initialized. + ARROW_ASSIGN_OR_RAISE(clone.kernel_state, InitKernelState(*call, exec_context)); + } + + return std::make_shared(std::move(clone)); +} + +Result Expression2::Bind( + ValueDescr in, compute::ExecContext* exec_context) const { + if (exec_context == nullptr) { + compute::ExecContext exec_context; + return Bind(std::move(in), &exec_context); + } + + if (literal()) return StateAndBound{*this, nullptr}; + + if (auto ref = field_ref()) { + ARROW_ASSIGN_OR_RAISE(auto field, ref->GetOneOrNone(*in.type)); + auto bound = *this; + bound.descr_ = + field ? ValueDescr{field->type(), in.shape} : ValueDescr::Scalar(null()); + return StateAndBound{std::move(bound), nullptr}; + } + + auto bound_call = *CallNotNull(*this); + + std::shared_ptr function; + if (bound_call.function == "cast") { + // XXX this special case is strange; why not make "cast" a ScalarFunction? + const auto& to_type = + checked_cast(*bound_call.options).to_type; + ARROW_ASSIGN_OR_RAISE(function, compute::GetCastFunction(to_type)); + } else { + ARROW_ASSIGN_OR_RAISE( + function, exec_context->func_registry()->GetFunction(bound_call.function)); + } + bound_call.function_kind = function->kind(); + + std::vector> argument_states( + bound_call.arguments.size()); + for (size_t i = 0; i < argument_states.size(); ++i) { + ARROW_ASSIGN_OR_RAISE(std::tie(bound_call.arguments[i], argument_states[i]), + bound_call.arguments[i].Bind(in, exec_context)); + } + + auto descrs = GetDescriptors(bound_call.arguments); + ARROW_ASSIGN_OR_RAISE(bound_call.kernel, function->DispatchExact(descrs)); + + auto call_state = std::make_shared(); + ARROW_ASSIGN_OR_RAISE(call_state->kernel_state, + InitKernelState(bound_call, exec_context)); + call_state->argument_states = std::move(argument_states); + + compute::KernelContext kernel_context(exec_context); + kernel_context.SetState(call_state->kernel_state.get()); + + auto bound = Expression2(std::make_shared(std::move(bound_call))); + ARROW_ASSIGN_OR_RAISE(bound.descr_, bound_call.kernel->signature->out_type().Resolve( + &kernel_context, descrs)); + return StateAndBound{std::move(bound), std::move(call_state)}; +} + +Result Expression2::Bind( + const Schema& in_schema, compute::ExecContext* exec_context) const { + return Bind(ValueDescr::Array(struct_(in_schema.fields())), exec_context); +} + +Result ExecuteScalarExpression(const Expression2& expr, ExpressionState* state, + const Datum& input, + compute::ExecContext* exec_context) { + if (exec_context == nullptr) { + compute::ExecContext exec_context; + return ExecuteScalarExpression(expr, state, input, &exec_context); + } + + if (!expr.IsBound()) { + return Status::Invalid("Cannot Execute unbound expression."); + } + + if (!expr.IsScalarExpression()) { + return Status::Invalid( + "ExecuteScalarExpression cannot Execute non-scalar expression ", expr.ToString()); + } + + if (auto lit = expr.literal()) return *lit; + + if (auto ref = expr.field_ref()) { + ARROW_ASSIGN_OR_RAISE(Datum field, GetDatumField(*ref, input)); + + if (field.descr() != expr.descr()) { + // Refernced field was present but didn't have the expected type. + // Should we just error here? For now, pay dispatch cost and just cast. + ARROW_ASSIGN_OR_RAISE( + field, compute::Cast(field, expr.descr().type, compute::CastOptions::Safe(), + exec_context)); + } + + return field; + } + + auto call = CallNotNull(expr); + auto call_state = checked_cast(state); + + std::vector arguments(call->arguments.size()); + for (size_t i = 0; i < arguments.size(); ++i) { + auto argument_state = call_state->argument_states[i].get(); + ARROW_ASSIGN_OR_RAISE( + arguments[i], + ExecuteScalarExpression(call->arguments[i], argument_state, input, exec_context)); + } + + auto executor = compute::detail::KernelExecutor::MakeScalar(); + + compute::KernelContext kernel_context(exec_context); + kernel_context.SetState(call_state->kernel_state.get()); + + auto kernel = call->kernel; + auto inputs = GetDescriptors(call->arguments); + auto options = call->options.get(); + RETURN_NOT_OK(executor->Init(&kernel_context, {kernel, inputs, options})); + + auto listener = std::make_shared(); + RETURN_NOT_OK(executor->Execute(arguments, listener.get())); + return executor->WrapResults(arguments, listener->values()); +} + +std::array, 2> +ArgumentsAndFlippedArguments(const Expression2::Call& call) { + DCHECK_EQ(call.arguments.size(), 2); + return {std::pair{call.arguments[0], + call.arguments[1]}, + std::pair{call.arguments[1], + call.arguments[0]}}; +} + +util::optional GetNullHandling( + const Expression2::Call& call) { + if (call.function_kind == compute::Function::SCALAR) { + return static_cast(call.kernel)->null_handling; + } + return util::nullopt; +} + +bool DefinitelyNotNull(const Expression2& expr) { + DCHECK(expr.IsBound()); + + if (expr.literal()) { + return !expr.IsNullLiteral(); + } + + if (expr.field_ref()) return false; + + auto call = CallNotNull(expr); + if (auto null_handling = GetNullHandling(*call)) { + if (null_handling == compute::NullHandling::OUTPUT_NOT_NULL) { + return true; + } + if (null_handling == compute::NullHandling::INTERSECTION) { + return std::all_of(call->arguments.begin(), call->arguments.end(), + DefinitelyNotNull); + } + } + + return false; +} + +template +Result Modify(Expression2 expr, const PreVisit& pre, + const PostVisitCall& post_call) { + ARROW_ASSIGN_OR_RAISE(expr, Result(pre(std::move(expr)))); + + auto call = expr.call(); + if (!call) return expr; + + bool at_least_one_modified = false; + auto modified_call = *call; + auto modified_argument = modified_call.arguments.begin(); + + for (const auto& argument : call->arguments) { + ARROW_ASSIGN_OR_RAISE(*modified_argument, Modify(argument, pre, post_call)); + + if (!Identical(*modified_argument, argument)) { + at_least_one_modified = true; + } + ++modified_argument; + } + + if (at_least_one_modified) { + // reconstruct the call expression with the modified arguments + auto modified_expr = Expression2( + std::make_shared(std::move(modified_call)), expr.descr()); + return post_call(std::move(modified_expr), &expr); + } + + return post_call(std::move(expr), nullptr); +} + +Result FoldConstants(Expression2 expr, ExpressionState* state) { + DCHECK(expr.IsBound()); + + struct StateAndIndex { + CallState* state; + int index; + }; + std::vector stack; + + auto root_state = checked_cast(state); + return Modify( + std::move(expr), + [&stack, root_state](Expression2 expr) { + auto call = expr.call(); + + if (stack.empty()) { + stack = {{root_state, 0}}; + return expr; + } + + int i = stack.back().index++; + + if (!call) return expr; + auto next_state = stack.back().state->argument_states[i].get(); + stack.push_back({checked_cast(next_state), 0}); + + return expr; + }, + [&stack](Expression2 expr, ...) -> Result { + auto state = stack.back().state; + stack.pop_back(); + + auto call = CallNotNull(expr); + if (std::all_of(call->arguments.begin(), call->arguments.end(), + [](const Expression2& argument) { return argument.literal(); })) { + // all arguments are literal; we can evaluate this subexpression *now* + static const Datum ignored_input; + ARROW_ASSIGN_OR_RAISE(Datum constant, + ExecuteScalarExpression(expr, state, ignored_input)); + return literal(std::move(constant)); + } + + // XXX the following should probably be in a registry of passes instead of inline + + if (GetNullHandling(*call) == compute::NullHandling::INTERSECTION) { + // kernels which always produce intersected validity can be resolved to null + // *now* if any of their inputs is a null literal + for (const auto& argument : call->arguments) { + if (argument.IsNullLiteral()) return argument; + } + } + + if (call->function == "and_kleene") { + for (auto args : ArgumentsAndFlippedArguments(*call)) { + // true and x == x + if (args.first == literal(true)) return args.second; + + // false and x == false + if (args.first == literal(false)) return args.first; + } + return expr; + } + + if (call->function == "or_kleene") { + for (auto args : ArgumentsAndFlippedArguments(*call)) { + // false or x == x + if (args.first == literal(false)) return args.second; + + // true or x == true + if (args.first == literal(true)) return args.first; + } + return expr; + } + + return expr; + }); + + return expr; +} + +struct FlattenedAssociativeChain { + bool was_left_folded = true; + std::vector exprs, fringe; + + explicit FlattenedAssociativeChain(Expression2 expr) : exprs{std::move(expr)} { + auto call = CallNotNull(exprs.back()); + fringe = call->arguments; + + auto it = fringe.begin(); + + while (it != fringe.end()) { + auto sub_call = it->call(); + if (!sub_call || sub_call->function != call->function) { + ++it; + continue; + } + + if (it != fringe.begin()) { + was_left_folded = false; + } + + exprs.push_back(std::move(*it)); + it = fringe.erase(it); + it = fringe.insert(it, sub_call->arguments.begin(), sub_call->arguments.end()); + // NB: no increment so we hit sub_call's first argument next iteration + } + + DCHECK(std::all_of(exprs.begin(), exprs.end(), [](const Expression2& expr) { + return CallNotNull(expr)->options == nullptr; + })); + } +}; + +inline std::vector GuaranteeConjunctionMembers( + const Expression2& guaranteed_true_predicate) { + auto guarantee = guaranteed_true_predicate.call(); + if (!guarantee || guarantee->function != "and_kleene") { + return {guaranteed_true_predicate}; + } + return FlattenedAssociativeChain(guaranteed_true_predicate).fringe; +} + +// Conjunction members which are represented in known_values are erased from +// conjunction_members +Status ExtractKnownFieldValuesImpl( + std::vector* conjunction_members, + std::unordered_map* known_values) { + for (auto&& member : *conjunction_members) { + ARROW_ASSIGN_OR_RAISE(member, Canonicalize(std::move(member))); + } + + auto unconsumed_end = + std::partition(conjunction_members->begin(), conjunction_members->end(), + [](const Expression2& expr) { + // search for an equality conditions between a field and a literal + auto call = expr.call(); + if (!call) return true; + + if (call->function == "equal") { + auto ref = call->arguments[0].field_ref(); + auto lit = call->arguments[1].literal(); + return !(ref && lit); + } + + return true; + }); + + for (auto it = unconsumed_end; it != conjunction_members->end(); ++it) { + auto call = CallNotNull(*it); + + auto ref = call->arguments[0].field_ref(); + auto lit = call->arguments[1].literal(); + + auto it_success = known_values->emplace(*ref, *lit); + if (it_success.second) continue; + + // A value was already known for ref; check it + auto ref_lit = it_success.first; + if (*lit != ref_lit->second) { + return Status::Invalid("Conflicting guarantees: (", ref->ToString(), + " == ", lit->ToString(), ") vs (", ref->ToString(), + " == ", ref_lit->second.ToString()); + } + } + + conjunction_members->erase(unconsumed_end, conjunction_members->end()); + + return Status::OK(); +} + +Result> ExtractKnownFieldValues( + const Expression2& guaranteed_true_predicate) { + auto conjunction_members = GuaranteeConjunctionMembers(guaranteed_true_predicate); + std::unordered_map known_values; + RETURN_NOT_OK(ExtractKnownFieldValuesImpl(&conjunction_members, &known_values)); + return known_values; +} + +Result ReplaceFieldsWithKnownValues( + const std::unordered_map& known_values, + Expression2 expr) { + return Modify( + std::move(expr), + [&known_values](Expression2 expr) { + if (auto ref = expr.field_ref()) { + auto it = known_values.find(*ref); + if (it != known_values.end()) { + return literal(it->second); + } + } + return expr; + }, + [](Expression2 expr, ...) { return expr; }); +} + +template ::value_type> +util::optional FoldLeft(It begin, It end, const BinOp& bin_op) { + if (begin == end) return util::nullopt; + + Out folded = std::move(*begin++); + while (begin != end) { + folded = bin_op(std::move(folded), std::move(*begin++)); + } + return folded; +} + +inline bool IsBinaryAssociativeCommutative(const Expression2::Call& call) { + static std::unordered_set binary_associative_commutative{ + "and", "or", "and_kleene", "or_kleene", "xor", + "multiply", "add", "multiply_checked", "add_checked"}; + + auto it = binary_associative_commutative.find(call.function); + return it != binary_associative_commutative.end(); +} + +struct Comparison { + enum type { + NA = 0, + EQUAL = 1, + LESS = 2, + GREATER = 4, + NOT_EQUAL = LESS | GREATER, + LESS_EQUAL = LESS | EQUAL, + GREATER_EQUAL = GREATER | EQUAL, + }; + + static const type* Get(const std::string& function) { + static std::unordered_map flipped_comparisons{ + {"equal", EQUAL}, {"not_equal", NOT_EQUAL}, + {"less", LESS}, {"less_equal", LESS_EQUAL}, + {"greater", GREATER}, {"greater_equal", GREATER_EQUAL}, + }; + + auto it = flipped_comparisons.find(function); + return it != flipped_comparisons.end() ? &it->second : nullptr; + } + + static const type* Get(const Expression2& expr) { + if (auto call = expr.call()) { + return Comparison::Get(call->function); + } + return nullptr; + } + + static Result Execute(Datum l, Datum r) { + if (!l.is_scalar() || !r.is_scalar()) { + return Status::Invalid("Cannot Execute Comparison on non-scalars"); + } + std::vector arguments{std::move(l), std::move(r)}; + + ARROW_ASSIGN_OR_RAISE(auto equal, compute::CallFunction("equal", arguments)); + + if (!equal.scalar()->is_valid) return NA; + if (equal.scalar_as().value) return EQUAL; + + ARROW_ASSIGN_OR_RAISE(auto less, compute::CallFunction("less", arguments)); + + if (!less.scalar()->is_valid) return NA; + return less.scalar_as().value ? LESS : GREATER; + } + + static type GetFlipped(type op) { + switch (op) { + case NA: + return NA; + case EQUAL: + return EQUAL; + case LESS: + return GREATER; + case GREATER: + return LESS; + case NOT_EQUAL: + return NOT_EQUAL; + case LESS_EQUAL: + return GREATER_EQUAL; + case GREATER_EQUAL: + return LESS_EQUAL; + }; + DCHECK(false); + return NA; + } + + static std::string GetName(type op) { + switch (op) { + case NA: + DCHECK(false) << "unreachable"; + break; + case EQUAL: + return "equal"; + case LESS: + return "less"; + case GREATER: + return "greater"; + case NOT_EQUAL: + return "not_equal"; + case LESS_EQUAL: + return "less_equal"; + case GREATER_EQUAL: + return "greater_equal"; + }; + DCHECK(false); + return "na"; + } +}; + +Result Canonicalize(Expression2 expr) { + // If potentially reconstructing more deeply than a call's immediate arguments + // (for example, when reorganizing an associative chain), add expressions to this set to + // avoid unnecessary work + struct { + std::unordered_set set_; + + bool operator()(const Expression2& expr) const { + return set_.find(expr) != set_.end(); + } + + void Add(std::vector exprs) { + std::move(exprs.begin(), exprs.end(), std::inserter(set_, set_.end())); + } + } AlreadyCanonicalized; + + return Modify( + std::move(expr), + [&AlreadyCanonicalized](Expression2 expr) -> Result { + auto call = expr.call(); + if (!call) return expr; + + if (AlreadyCanonicalized(expr)) return expr; + + if (IsBinaryAssociativeCommutative(*call)) { + struct { + int Priority(const Expression2& operand) const { + // order literals first, starting with nulls + if (operand.IsNullLiteral()) return 0; + if (operand.literal()) return 1; + return 2; + } + bool operator()(const Expression2& l, const Expression2& r) const { + return Priority(l) < Priority(r); + } + } CanonicalOrdering; + + FlattenedAssociativeChain chain(expr); + if (chain.was_left_folded && + std::is_sorted(chain.fringe.begin(), chain.fringe.end(), + CanonicalOrdering)) { + AlreadyCanonicalized.Add(std::move(chain.exprs)); + return expr; + } + + std::stable_sort(chain.fringe.begin(), chain.fringe.end(), CanonicalOrdering); + + // fold the chain back up + const auto& descr = expr.descr(); + auto folded = FoldLeft( + chain.fringe.begin(), chain.fringe.end(), + [call, &descr, &AlreadyCanonicalized](Expression2 l, Expression2 r) { + auto ret = *call; + ret.arguments = {std::move(l), std::move(r)}; + Expression2 expr(std::make_shared(std::move(ret)), + descr); + AlreadyCanonicalized.Add({expr}); + return expr; + }); + return std::move(*folded); + } + + if (auto cmp = Comparison::Get(call->function)) { + if (call->arguments[0].literal() && !call->arguments[1].literal()) { + // ensure that literals are on comparisons' RHS + auto flipped_call = *call; + for (auto&& argument : flipped_call.arguments) { + ARROW_ASSIGN_OR_RAISE(argument, Canonicalize(std::move(argument))); + } + flipped_call.function = Comparison::GetName(Comparison::GetFlipped(*cmp)); + std::swap(flipped_call.arguments[0], flipped_call.arguments[1]); + return Expression2( + std::make_shared(std::move(flipped_call))); + } + } + + return expr; + }, + [](Expression2 expr, ...) { return expr; }); +} + +Result DirectComparisonSimplification(Expression2 expr, + const Expression2::Call& guarantee) { + return Modify( + std::move(expr), [](Expression2 expr) { return expr; }, + [&guarantee](Expression2 expr, ...) -> Result { + auto call = expr.call(); + if (!call) return expr; + + // Ensure both calls are comparisons with equal LHS and scalar RHS + auto cmp = Comparison::Get(expr); + auto cmp_guarantee = Comparison::Get(guarantee.function); + if (!cmp || !cmp_guarantee) return expr; + + if (call->arguments[0] != guarantee.arguments[0]) return expr; + + auto rhs = call->arguments[1].literal(); + auto guarantee_rhs = guarantee.arguments[1].literal(); + if (!rhs || !guarantee_rhs) return expr; + + if (!rhs->is_scalar() || !guarantee_rhs->is_scalar()) { + return expr; + } + + ARROW_ASSIGN_OR_RAISE(auto cmp_rhs_guarantee_rhs, + Comparison::Execute(*rhs, *guarantee_rhs)); + DCHECK_NE(cmp_rhs_guarantee_rhs, Comparison::NA); + + if (cmp_rhs_guarantee_rhs == Comparison::EQUAL) { + // RHS of filter is equal to RHS of guarantee + + if ((*cmp_guarantee & *cmp) == *cmp_guarantee) { + // guarantee is a subset of filter, so all data will be included + return literal(true); + } + + if ((*cmp_guarantee & *cmp) == 0) { + // guarantee disjoint with filter, so all data will be excluded + return literal(false); + } + + return expr; + } + + if (*cmp_guarantee & cmp_rhs_guarantee_rhs) { + // unusable guarantee + return expr; + } + + if (*cmp & Comparison::GetFlipped(cmp_rhs_guarantee_rhs)) { + // x > 1, x >= 1, x != 1 guaranteed by x >= 3 + return literal(true); + } else { + // x < 1, x <= 1, x == 1 unsatisfiable if x >= 3 + return literal(false); + } + }); +} + +Result SimplifyWithGuarantee(Expression2 expr, ExpressionState* state, + const Expression2& guaranteed_true_predicate) { + auto conjunction_members = GuaranteeConjunctionMembers(guaranteed_true_predicate); + + std::unordered_map known_values; + RETURN_NOT_OK(ExtractKnownFieldValuesImpl(&conjunction_members, &known_values)); + + ARROW_ASSIGN_OR_RAISE(expr, + ReplaceFieldsWithKnownValues(known_values, std::move(expr))); + + auto CanonicalizeAndFoldConstants = [&expr, state] { + ARROW_ASSIGN_OR_RAISE(expr, Canonicalize(std::move(expr))); + ARROW_ASSIGN_OR_RAISE(expr, FoldConstants(std::move(expr), state)); + return Status::OK(); + }; + RETURN_NOT_OK(CanonicalizeAndFoldConstants()); + + for (const auto& guarantee : conjunction_members) { + if (Comparison::Get(guarantee) && guarantee.call()->arguments[1].literal()) { + ARROW_ASSIGN_OR_RAISE( + auto simplified, DirectComparisonSimplification(expr, *CallNotNull(guarantee))); + + if (Identical(simplified, expr)) continue; + + expr = std::move(simplified); + RETURN_NOT_OK(CanonicalizeAndFoldConstants()); + } + } + + return expr; +} + +} // namespace dataset +} // namespace arrow diff --git a/cpp/src/arrow/dataset/expression.h b/cpp/src/arrow/dataset/expression.h new file mode 100644 index 00000000000..4da41b8b9c4 --- /dev/null +++ b/cpp/src/arrow/dataset/expression.h @@ -0,0 +1,217 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// This API is EXPERIMENTAL. + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "arrow/chunked_array.h" +#include "arrow/compute/api_scalar.h" +#include "arrow/compute/cast.h" +#include "arrow/dataset/type_fwd.h" +#include "arrow/dataset/visibility.h" +#include "arrow/datum.h" +#include "arrow/result.h" +#include "arrow/scalar.h" +#include "arrow/type_fwd.h" +#include "arrow/util/variant.h" + +namespace arrow { +namespace dataset { + +/// An unbound expression which maps a single Datum to another Datum. +/// An expression is one of +/// - A literal Datum. +/// - A reference to a single (potentially nested) field of the input Datum. +/// - A call to a compute function, with arguments specified by other Expressions. +/// - Optionally, a explicitly selected Kernel of the function. If provided, +/// execution will skip function lookup and kernel dispatch and use that Kernel +/// directly. +class ARROW_DS_EXPORT Expression2 { + public: + struct Call { + std::string function; + std::vector arguments; + std::shared_ptr options; + + // post-Bind properties: + const compute::Kernel* kernel = NULLPTR; + compute::Function::Kind function_kind; + }; + + std::string ToString() const; + bool Equals(const Expression2& other) const; + size_t hash() const; + struct Hash { + size_t operator()(const Expression2& expr) const { return expr.hash(); } + }; + + /// Bind this expression to the given input type, looking up Kernels and field types. + /// Some expression simplification may be performed and implicit casts will be inserted. + /// Any state necessary for execution will be initialized and returned. + using StateAndBound = std::pair>; + Result Bind(ValueDescr in, compute::ExecContext* = NULLPTR) const; + Result Bind(const Schema& in_schema, + compute::ExecContext* = NULLPTR) const; + + /// Return true if all an expression's field references have explicit ValueDescr and all + /// of its functions' kernels are looked up. + bool IsBound() const; + + /// Return true if this expression is composed only of Scalar literals, field + /// references, and calls to ScalarFunctions. + bool IsScalarExpression() const; + + /// Return true if this expression is literal and entirely null. + bool IsNullLiteral() const; + + /// Return true if this expression could evaluate to true. + bool IsSatisfiable() const; + + // XXX someday + // Result GetPipelines(); + + const Call* call() const; + const Datum* literal() const; + const FieldRef* field_ref() const; + + const ValueDescr& descr() const { return descr_; } + + using Impl = util::variant; + + explicit Expression2(std::shared_ptr impl, ValueDescr descr = {}) + : impl_(std::move(impl)), descr_(std::move(descr)) {} + + private: + std::shared_ptr impl_; + ValueDescr descr_; + // XXX someday + // NullGeneralization::type evaluates_to_null_; + + friend bool Identical(const Expression2& l, const Expression2& r); + + ARROW_EXPORT friend void PrintTo(const Expression2&, std::ostream*); +}; + +struct ExpressionState { + virtual ~ExpressionState() = default; + + // Produce another instance of ExpressionState which may be safely used from a + // different thread + static Result> Clone( + const std::shared_ptr&, const Expression2&, compute::ExecContext*); +}; + +inline bool operator==(const Expression2& l, const Expression2& r) { return l.Equals(r); } +inline bool operator!=(const Expression2& l, const Expression2& r) { + return !l.Equals(r); +} + +// Factories + +inline Expression2 call(std::string function, std::vector arguments, + std::shared_ptr options = NULLPTR) { + Expression2::Call call; + call.function = std::move(function); + call.arguments = std::move(arguments); + call.options = std::move(options); + return Expression2(std::make_shared(std::move(call))); +} + +template ::value>::type> +Expression2 call(std::string function, std::vector arguments, + Options options) { + return call(std::move(function), std::move(arguments), + std::make_shared(std::move(options))); +} + +inline Expression2 project(std::vector names, + std::vector values) { + return call("struct", std::move(values), compute::StructOptions{std::move(names)}); +} + +template +Expression2 field_ref(Args&&... args) { + return Expression2( + std::make_shared(FieldRef(std::forward(args)...))); +} + +template +Expression2 literal(Arg&& arg) { + Datum lit(std::forward(arg)); + ValueDescr descr = lit.descr(); + return Expression2(std::make_shared(std::move(lit)), + std::move(descr)); +} + +// Simplification passes + +/// Weak canonicalization which establishes guarantees for subsequent passes. Even +/// equivalent Expressions may result in different canonicalized expressions. +/// TODO this could be a strong canonicalization +ARROW_DS_EXPORT +Result Canonicalize(Expression2 expr); + +/// Simplify Expressions based on literal arguments (for example, add(null, x) will always +/// be null so replace the call with a null literal). Includes early evaluation of all +/// calls whose arguments are entirely literal. +ARROW_DS_EXPORT +Result FoldConstants(Expression2 expr, ExpressionState*); + +ARROW_DS_EXPORT +Result> ExtractKnownFieldValues( + const Expression2& guaranteed_true_predicate); + +ARROW_DS_EXPORT +Result ReplaceFieldsWithKnownValues( + const std::unordered_map& known_values, + Expression2 expr); + +/// Simplify an expression by replacing subexpressions based on a guarantee: +/// a boolean expression which is guaranteed to evaluate to `true`. For example, this is +/// used to remove redundant function calls from a filter expression or to replace a +/// reference to a constant-value field with a literal. +ARROW_DS_EXPORT +Result SimplifyWithGuarantee(Expression2 expr, ExpressionState*, + const Expression2& guaranteed_true_predicate); + +// Execution + +/// Execute a scalar expression against the provided state and input Datum. This +/// expression must be bound. +Result ExecuteScalarExpression(const Expression2&, ExpressionState*, + const Datum& input, + compute::ExecContext* exec_context = NULLPTR); + +// Serialization + +ARROW_DS_EXPORT +Result> Serialize(const Expression2&); + +ARROW_DS_EXPORT +Result Deserialize(const Buffer&); + +} // namespace dataset +} // namespace arrow diff --git a/cpp/src/arrow/dataset/expression_internal.h b/cpp/src/arrow/dataset/expression_internal.h new file mode 100644 index 00000000000..1998a565ec3 --- /dev/null +++ b/cpp/src/arrow/dataset/expression_internal.h @@ -0,0 +1,110 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/dataset/expression.h" + +#include +#include +#include + +#include "arrow/record_batch.h" +#include "arrow/table.h" +#include "arrow/util/logging.h" + +namespace arrow { + +using internal::checked_cast; + +namespace dataset { + +bool Identical(const Expression2& l, const Expression2& r) { return l.impl_ == r.impl_; } + +const Expression2::Call* CallNotNull(const Expression2& expr) { + auto call = expr.call(); + DCHECK_NE(call, nullptr); + return call; +} + +inline void GetAllFieldRefs(const Expression2& expr, + std::unordered_set* refs) { + if (auto lit = expr.literal()) return; + + if (auto ref = expr.field_ref()) { + refs->emplace(*ref); + return; + } + + for (const Expression2& arg : CallNotNull(expr)->arguments) { + GetAllFieldRefs(arg, refs); + } +} + +inline std::vector GetDescriptors(const std::vector& exprs) { + std::vector descrs(exprs.size()); + for (size_t i = 0; i < exprs.size(); ++i) { + DCHECK(exprs[i].IsBound()); + descrs[i] = exprs[i].descr(); + } + return descrs; +} + +struct CallState : ExpressionState { + std::vector> argument_states; + std::shared_ptr kernel_state; +}; + +struct FieldPathGetDatumImpl { + template ()))> + Result operator()(const std::shared_ptr& ptr) { + return path_.Get(*ptr).template As(); + } + + template + Result operator()(const T&) { + return Status::NotImplemented("FieldPath::Get() into Datum ", datum_.ToString()); + } + + const Datum& datum_; + const FieldPath& path_; +}; + +inline Result GetDatumField(const FieldRef& ref, const Datum& input) { + Datum field; + + FieldPath path; + if (auto type = input.type()) { + ARROW_ASSIGN_OR_RAISE(path, ref.FindOneOrNone(*input.type())); + } else if (input.kind() == Datum::RECORD_BATCH) { + ARROW_ASSIGN_OR_RAISE(path, ref.FindOneOrNone(*input.record_batch()->schema())); + } else if (input.kind() == Datum::TABLE) { + ARROW_ASSIGN_OR_RAISE(path, ref.FindOneOrNone(*input.table()->schema())); + } + + if (path) { + ARROW_ASSIGN_OR_RAISE(field, + util::visit(FieldPathGetDatumImpl{input, path}, input.value)); + } + + if (field == Datum{}) { + field = Datum(std::make_shared()); + } + + return field; +} + +} // namespace dataset +} // namespace arrow diff --git a/cpp/src/arrow/dataset/expression_test.cc b/cpp/src/arrow/dataset/expression_test.cc new file mode 100644 index 00000000000..7cc98a1e151 --- /dev/null +++ b/cpp/src/arrow/dataset/expression_test.cc @@ -0,0 +1,729 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/dataset/expression.h" + +#include +#include +#include +#include +#include + +#include +#include + +#include "arrow/compute/registry.h" +#include "arrow/dataset/expression_internal.h" +#include "arrow/testing/gtest_util.h" + +using testing::UnorderedElementsAreArray; + +namespace arrow { + +using internal::checked_cast; +using internal::checked_pointer_cast; + +namespace dataset { + +const Schema kBoringSchema{{ + field("i32", int32()), + field("f32", float32()), +}}; + +#define EXPECT_OK ARROW_EXPECT_OK + +TEST(Expression2, ToString) { + EXPECT_EQ(field_ref("alpha").ToString(), "FieldRef(alpha)"); + + EXPECT_EQ(literal(3).ToString(), "3"); + + EXPECT_EQ(call("add", {literal(3), field_ref("beta")}).ToString(), + "add(3,FieldRef(beta))"); + + EXPECT_EQ(call("add", {literal(3), call("index_in", {field_ref("beta")})}).ToString(), + "add(3,index_in(FieldRef(beta)))"); +} + +TEST(Expression2, Equality) { + EXPECT_EQ(literal(1), literal(1)); + + EXPECT_EQ(field_ref("a"), field_ref("a")); + + EXPECT_EQ(call("add", {literal(3), field_ref("a")}), + call("add", {literal(3), field_ref("a")})); + + EXPECT_EQ(call("add", {literal(3), call("index_in", {field_ref("beta")})}), + call("add", {literal(3), call("index_in", {field_ref("beta")})})); +} + +TEST(Expression2, Hash) { + std::unordered_set set; + + EXPECT_TRUE(set.emplace(field_ref("alpha")).second); + EXPECT_TRUE(set.emplace(field_ref("beta")).second); + EXPECT_FALSE(set.emplace(field_ref("beta")).second) << "already inserted"; + EXPECT_TRUE(set.emplace(literal(1)).second); + EXPECT_FALSE(set.emplace(literal(1)).second) << "already inserted"; + EXPECT_TRUE(set.emplace(literal(3)).second); + + // NB: no validation on construction; we couldn't execute + // add with zero arguments + EXPECT_TRUE(set.emplace(call("add", {})).second); + EXPECT_FALSE(set.emplace(call("add", {})).second) << "already inserted"; + + // NB: unbound expressions don't check for availability in any registry + EXPECT_TRUE(set.emplace(call("widgetify", {})).second); + + EXPECT_EQ(set.size(), 6); +} + +TEST(Expression2, BindLiteral) { + for (Datum dat : { + Datum(3), + Datum(3.5), + Datum(ArrayFromJSON(int32(), "[1,2,3]")), + }) { + // literals are always considered bound + auto expr = literal(dat); + EXPECT_EQ(expr.descr(), dat.descr()); + EXPECT_TRUE(expr.IsBound()); + } +} + +TEST(Expression2, BindFieldRef) { + // an unbound field_ref does not have the output ValueDescr set + { + auto expr = field_ref("alpha"); + EXPECT_EQ(expr.descr(), ValueDescr{}); + EXPECT_FALSE(expr.IsBound()); + } + + { + auto expr = field_ref("alpha"); + // binding a field_ref looks up that field's type in the input Schema + ASSERT_OK_AND_ASSIGN(std::tie(expr, std::ignore), + expr.Bind(Schema({field("alpha", int32())}))); + EXPECT_EQ(expr.descr(), ValueDescr::Array(int32())); + EXPECT_TRUE(expr.IsBound()); + } + + { + // if the field is not found, a null scalar will be emitted + auto expr = field_ref("alpha"); + ASSERT_OK_AND_ASSIGN(std::tie(expr, std::ignore), expr.Bind(Schema({}))); + EXPECT_EQ(expr.descr(), ValueDescr::Scalar(null())); + EXPECT_TRUE(expr.IsBound()); + } + + { + // referencing a field by name is not supported if that name is not unique + // in the input schema + auto expr = field_ref("alpha"); + ASSERT_RAISES( + Invalid, expr.Bind(Schema({field("alpha", int32()), field("alpha", float32())}))); + } + + { + // referencing nested fields is supported + auto expr = field_ref("a", "b"); + ASSERT_OK_AND_ASSIGN(std::tie(expr, std::ignore), + expr.Bind(Schema({field("a", struct_({field("b", int32())}))}))); + EXPECT_EQ(expr.descr(), ValueDescr::Array(int32())); + EXPECT_TRUE(expr.IsBound()); + } +} + +TEST(Expression2, BindCall) { + auto expr = call("add", {field_ref("a"), field_ref("b")}); + EXPECT_FALSE(expr.IsBound()); + + ASSERT_OK_AND_ASSIGN(std::tie(expr, std::ignore), + expr.Bind(Schema({field("a", int32()), field("b", int32())}))); + EXPECT_EQ(expr.descr(), ValueDescr::Array(int32())); + EXPECT_TRUE(expr.IsBound()); + + expr = call("add", {field_ref("a"), literal(3.5)}); + ASSERT_RAISES(NotImplemented, + expr.Bind(Schema({field("a", int32()), field("b", int32())}))); +} + +TEST(Expression2, BindNestedCall) { + auto expr = + call("add", {field_ref("a"), + call("subtract", {call("multiply", {field_ref("b"), field_ref("c")}), + field_ref("d")})}); + EXPECT_FALSE(expr.IsBound()); + + ASSERT_OK_AND_ASSIGN(std::tie(expr, std::ignore), + expr.Bind(Schema({field("a", int32()), field("b", int32()), + field("c", int32()), field("d", int32())}))); + EXPECT_EQ(expr.descr(), ValueDescr::Array(int32())); + EXPECT_TRUE(expr.IsBound()); +} + +TEST(Expression2, ExecuteFieldRef) { + auto AssertRefIs = [](FieldRef ref, Datum in, Datum expected) { + auto expr = field_ref(ref); + + ASSERT_OK_AND_ASSIGN(std::tie(expr, std::ignore), expr.Bind(in.descr())); + ASSERT_OK_AND_ASSIGN(Datum actual, + ExecuteScalarExpression(expr, /*state=*/nullptr, in)); + + AssertDatumsEqual(actual, expected, /*verbose=*/true); + }; + + AssertRefIs("a", ArrayFromJSON(struct_({field("a", float64())}), R"([ + {"a": 6.125}, + {"a": 0.0}, + {"a": -1} + ])"), + ArrayFromJSON(float64(), R"([6.125, 0.0, -1])")); + + // more nested: + AssertRefIs(FieldRef{"a", "a"}, + ArrayFromJSON(struct_({field("a", struct_({field("a", float64())}))}), R"([ + {"a": {"a": 6.125}}, + {"a": {"a": 0.0}}, + {"a": {"a": -1}} + ])"), + ArrayFromJSON(float64(), R"([6.125, 0.0, -1])")); + + // absent fields are resolved as a null scalar: + AssertRefIs(FieldRef{"b"}, ArrayFromJSON(struct_({field("a", float64())}), R"([ + {"a": 6.125}, + {"a": 0.0}, + {"a": -1} + ])"), + MakeNullScalar(float64())); +} + +Result NaiveExecuteScalarExpression(const Expression2& expr, + ExpressionState* state, const Datum& input) { + auto call = expr.call(); + if (call == nullptr) { + // already tested execution of field_ref, execution of literal is trivial + return ExecuteScalarExpression(expr, state, input); + } + + auto call_state = checked_cast(state); + + std::vector arguments(call->arguments.size()); + for (size_t i = 0; i < arguments.size(); ++i) { + auto argument_state = call_state->argument_states[i].get(); + ARROW_ASSIGN_OR_RAISE(arguments[i], NaiveExecuteScalarExpression( + call->arguments[i], argument_state, input)); + + EXPECT_EQ(call->arguments[i].descr(), arguments[i].descr()); + } + + ARROW_ASSIGN_OR_RAISE(auto function, + compute::GetFunctionRegistry()->GetFunction(call->function)); + ARROW_ASSIGN_OR_RAISE(auto expected_kernel, + function->DispatchExact(GetDescriptors(call->arguments))); + + EXPECT_EQ(call->kernel, expected_kernel); + + compute::ExecContext exec_context; + return function->Execute(arguments, call->options.get(), &exec_context); +} + +void AssertExecute(Expression2 expr, Datum in) { + std::shared_ptr state; + ASSERT_OK_AND_ASSIGN(std::tie(expr, state), expr.Bind(in.descr())); + + ASSERT_OK_AND_ASSIGN(Datum actual, ExecuteScalarExpression(expr, state.get(), in)); + + ASSERT_OK_AND_ASSIGN(Datum expected, + NaiveExecuteScalarExpression(expr, state.get(), in)); + + AssertDatumsEqual(actual, expected, /*verbose=*/true); +} + +TEST(Expression2, ExecuteCall) { + AssertExecute(call("add", {field_ref("a"), literal(3.5)}), + ArrayFromJSON(struct_({field("a", float64())}), R"([ + {"a": 6.125}, + {"a": 0.0}, + {"a": -1} + ])")); + + AssertExecute( + call("add", {field_ref("a"), call("subtract", {literal(3.5), field_ref("b")})}), + ArrayFromJSON(struct_({field("a", float64()), field("b", float64())}), R"([ + {"a": 6.125, "b": 3.375}, + {"a": 0.0, "b": 1}, + {"a": -1, "b": 4.75} + ])")); + + AssertExecute(call("strptime", {field_ref("a")}, + compute::StrptimeOptions("%m/%d/%Y", TimeUnit::MICRO)), + ArrayFromJSON(struct_({field("a", utf8())}), R"([ + {"a": "5/1/2020"}, + {"a": null}, + {"a": "12/11/1900"} + ])")); + + AssertExecute(project({"a + 3.5"}, {call("add", {field_ref("a"), literal(3.5)})}), + ArrayFromJSON(struct_({field("a", float64())}), R"([ + {"a": 6.125}, + {"a": 0.0}, + {"a": -1} + ])")); +} + +struct { + void operator()(Expression2 expr, Expression2 expected) { + std::shared_ptr state; + ASSERT_OK_AND_ASSIGN(std::tie(expr, state), expr.Bind(kBoringSchema)); + ASSERT_OK_AND_ASSIGN(std::tie(expected, std::ignore), expected.Bind(kBoringSchema)); + ASSERT_OK_AND_ASSIGN(auto actual, FoldConstants(expr, state.get())); + EXPECT_EQ(actual, expected); + if (actual == expr) { + // no change -> must be identical + EXPECT_TRUE(Identical(actual, expr)); + } + } +} ExpectFoldsTo; + +TEST(Expression2, FoldConstants) { + // literals are unchanged + ExpectFoldsTo(literal(3), literal(3)); + + // field_refs are unchanged + ExpectFoldsTo(field_ref("i32"), field_ref("i32")); + + // call against literals (3 + 2 == 5) + ExpectFoldsTo(call("add", {literal(3), literal(2)}), literal(5)); + + // call against literal and field_ref + ExpectFoldsTo(call("add", {literal(3), field_ref("i32")}), + call("add", {literal(3), field_ref("i32")})); + + // nested call against literals ((8 - (2 * 3)) + 2 == 4) + ExpectFoldsTo(call("add", + { + call("subtract", + { + literal(8), + call("multiply", {literal(2), literal(3)}), + }), + literal(2), + }), + literal(4)); + + // nested call against literals with one field_ref + // (i32 - (2 * 3)) + 2 == (i32 - 6) + 2 + // NB this could be improved further by using associativity of addition; another pass + ExpectFoldsTo(call("add", + { + call("subtract", + { + field_ref("i32"), + call("multiply", {literal(2), literal(3)}), + }), + literal(2), + }), + call("add", { + call("subtract", + { + field_ref("i32"), + literal(6), + }), + literal(2), + })); +} + +TEST(Expression2, FoldConstantsBoolean) { + // test and_kleene/or_kleene-specific optimizations + auto one = literal(1); + auto two = literal(2); + auto whatever = call("equal", {call("add", {one, field_ref("i32")}), two}); + + auto true_ = literal(true); + auto false_ = literal(false); + + ExpectFoldsTo(call("and_kleene", {false_, whatever}), false_); + ExpectFoldsTo(call("and_kleene", {true_, whatever}), whatever); + + ExpectFoldsTo(call("or_kleene", {true_, whatever}), true_); + ExpectFoldsTo(call("or_kleene", {false_, whatever}), whatever); +} + +TEST(Expression2, ExtractKnownFieldValues) { + struct { + void operator()(Expression2 guarantee, + std::unordered_map expected) { + ASSERT_OK_AND_ASSIGN(auto actual, ExtractKnownFieldValues(guarantee)); + EXPECT_THAT(actual, UnorderedElementsAreArray(expected)); + } + } ExpectKnown; + + ExpectKnown(call("equal", {field_ref("a"), literal(3)}), {{"a", Datum(3)}}); + + ExpectKnown(call("greater", {field_ref("a"), literal(3)}), {}); + + // FIXME known null should be expressed with is_null rather than equality + auto null_int32 = std::make_shared(); + ExpectKnown(call("equal", {field_ref("a"), literal(null_int32)}), + {{"a", Datum(null_int32)}}); + + ExpectKnown(call("and_kleene", {call("equal", {field_ref("a"), literal(3)}), + call("equal", {literal(1), field_ref("b")})}), + {{"a", Datum(3)}, {"b", Datum(1)}}); + + ExpectKnown(call("and_kleene", + {call("equal", {field_ref("a"), literal(3)}), + call("and_kleene", {call("equal", {field_ref("b"), literal(2)}), + call("equal", {literal(1), field_ref("c")})})}), + {{"a", Datum(3)}, {"b", Datum(2)}, {"c", Datum(1)}}); + + ExpectKnown(call("and_kleene", + {call("or_kleene", {call("equal", {field_ref("a"), literal(3)}), + call("equal", {field_ref("a"), literal(4)})}), + call("equal", {literal(1), field_ref("b")})}), + {{"b", Datum(1)}}); + + ExpectKnown( + call("and_kleene", + {call("equal", {field_ref("a"), literal(3)}), + call("and_kleene", {call("and_kleene", {field_ref("b"), field_ref("d")}), + call("equal", {literal(1), field_ref("c")})})}), + {{"a", Datum(3)}, {"c", Datum(1)}}); +} + +TEST(Expression2, ReplaceFieldsWithKnownValues) { + auto ExpectSimplifiesTo = + [](Expression2 expr, + std::unordered_map known_values, + Expression2 expected) { + ASSERT_OK_AND_ASSIGN(auto actual, + ReplaceFieldsWithKnownValues(known_values, expr)); + EXPECT_EQ(actual, expected); + + if (actual == expr) { + // no change -> must be identical + EXPECT_TRUE(Identical(actual, expr)); + } + }; + + std::unordered_map a_is_3{{"a", Datum(3)}}; + + ExpectSimplifiesTo(literal(1), a_is_3, literal(1)); + + ExpectSimplifiesTo(field_ref("a"), a_is_3, literal(3)); + + ExpectSimplifiesTo(field_ref("b"), a_is_3, field_ref("b")); + + ExpectSimplifiesTo(call("equal", {field_ref("a"), literal(1)}), a_is_3, + call("equal", {literal(3), literal(1)})); + + ExpectSimplifiesTo(call("add", + { + call("subtract", + { + field_ref("a"), + call("multiply", {literal(2), literal(3)}), + }), + literal(2), + }), + a_is_3, + call("add", { + call("subtract", + { + literal(3), + call("multiply", {literal(2), literal(3)}), + }), + literal(2), + })); +} + +struct { + void operator()(Expression2 expr, Expression2 expected) const { + ASSERT_OK_AND_ASSIGN(auto actual, Canonicalize(expr)); + EXPECT_EQ(actual, expected); + + if (expr.IsBound()) { + EXPECT_TRUE(actual.IsBound()); + } + + if (actual == expr) { + // no change -> must be identical + EXPECT_TRUE(Identical(actual, expr)); + } + } +} ExpectCanonicalizesTo; + +TEST(Expression2, CanonicalizeTrivial) { + ExpectCanonicalizesTo(literal(1), literal(1)); + + ExpectCanonicalizesTo(field_ref("b"), field_ref("b")); + + ExpectCanonicalizesTo(call("equal", {field_ref("a"), field_ref("b")}), + call("equal", {field_ref("a"), field_ref("b")})); +} + +TEST(Expression2, CanonicalizeAnd) { + // some aliases for brevity: + auto true_ = literal(true); + auto null_ = literal(std::make_shared()); + + auto a = field_ref("a"); + auto c = call("equal", {literal(1), literal(2)}); + + auto and_ = [](Expression2 l, Expression2 r) { return call("and_kleene", {l, r}); }; + + // no change possible: + ExpectCanonicalizesTo(and_(a, c), and_(a, c)); + + // literals are placed innermost + ExpectCanonicalizesTo(and_(a, true_), and_(true_, a)); + ExpectCanonicalizesTo(and_(true_, a), and_(true_, a)); + + ExpectCanonicalizesTo(and_(a, and_(true_, c)), and_(and_(true_, a), c)); + ExpectCanonicalizesTo(and_(a, and_(and_(true_, true_), c)), + and_(and_(and_(true_, true_), a), c)); + ExpectCanonicalizesTo(and_(a, and_(and_(true_, null_), c)), + and_(and_(and_(null_, true_), a), c)); + ExpectCanonicalizesTo(and_(a, and_(and_(true_, null_), and_(c, null_))), + and_(and_(and_(and_(null_, null_), true_), a), c)); + + // catches and_kleene even when it's a subexpression + ExpectCanonicalizesTo(call("is_valid", {and_(a, true_)}), + call("is_valid", {and_(true_, a)})); +} + +TEST(Expression2, CanonicalizeComparison) { + // some aliases for brevity: + auto equal = [](Expression2 l, Expression2 r) { return call("equal", {l, r}); }; + auto less = [](Expression2 l, Expression2 r) { return call("less", {l, r}); }; + auto greater = [](Expression2 l, Expression2 r) { return call("greater", {l, r}); }; + + ExpectCanonicalizesTo(equal(literal(1), field_ref("a")), + equal(field_ref("a"), literal(1))); + + ExpectCanonicalizesTo(equal(field_ref("a"), literal(1)), + equal(field_ref("a"), literal(1))); + + ExpectCanonicalizesTo(less(literal(1), field_ref("a")), + greater(field_ref("a"), literal(1))); + + ExpectCanonicalizesTo(less(field_ref("a"), literal(1)), + less(field_ref("a"), literal(1))); +} + +struct Simplify { + Expression2 filter; + + struct Expectable { + Expression2 filter, guarantee; + + void Expect(Expression2 expected) { + std::shared_ptr state; + ASSERT_OK_AND_ASSIGN(std::tie(filter, state), filter.Bind(kBoringSchema)); + + ASSERT_OK_AND_ASSIGN(std::tie(expected, std::ignore), expected.Bind(kBoringSchema)); + ASSERT_OK_AND_ASSIGN(auto simplified, + SimplifyWithGuarantee(filter, state.get(), guarantee)); + EXPECT_EQ(simplified, expected) << " original: " << filter.ToString() << "\n" + << " guarantee: " << guarantee.ToString() << "\n" + << (simplified == filter ? " (no change)\n" : ""); + + if (simplified == filter) { + EXPECT_TRUE(Identical(simplified, filter)); + } + } + void ExpectUnchanged() { Expect(filter); } + void Expect(bool constant) { Expect(literal(constant)); } + }; + + Expectable WithGuarantee(Expression2 guarantee) { return {filter, guarantee}; } +}; + +TEST(Expression2, SingleComparisonGuarantees) { + // some aliases for brevity: + auto equal = [](Expression2 l, Expression2 r) { return call("equal", {l, r}); }; + auto less = [](Expression2 l, Expression2 r) { return call("less", {l, r}); }; + auto greater = [](Expression2 l, Expression2 r) { return call("greater", {l, r}); }; + auto not_equal = [](Expression2 l, Expression2 r) { return call("not_equal", {l, r}); }; + auto less_equal = [](Expression2 l, Expression2 r) { + return call("less_equal", {l, r}); + }; + auto greater_equal = [](Expression2 l, Expression2 r) { + return call("greater_equal", {l, r}); + }; + auto i32 = field_ref("i32"); + + // i32 is guaranteed equal to 3, so the projection can just materialize that constant + // and need not incur IO + Simplify{project({"i32 + 1"}, {call("add", {i32, literal(1)})})} + .WithGuarantee(equal(i32, literal(3))) + .Expect(literal( + std::make_shared(ScalarVector{std::make_shared(4)}, + struct_({field("i32 + 1", int32())})))); + + // i32 is guaranteed equal to 5 everywhere, so filtering i32==5 is redundant and the + // filter can be simplified to true (== select everything) + Simplify{ + equal(i32, literal(5)), + } + .WithGuarantee(equal(i32, literal(5))) + .Expect(true); + + Simplify{ + equal(i32, literal(5)), + } + .WithGuarantee(equal(i32, literal(5))) + .Expect(true); + + Simplify{ + less_equal(i32, literal(5)), + } + .WithGuarantee(equal(i32, literal(5))) + .Expect(true); + + Simplify{ + less(i32, literal(5)), + } + .WithGuarantee(equal(i32, literal(3))) + .Expect(true); + + Simplify{ + greater_equal(i32, literal(5)), + } + .WithGuarantee(greater(i32, literal(5))) + .Expect(true); + + // i32 is guaranteed less than 3 everywhere, so filtering i32==5 is redundant and the + // filter can be simplified to false (== select nothing) + Simplify{ + equal(i32, literal(5)), + } + .WithGuarantee(less(i32, literal(3))) + .Expect(false); + + Simplify{ + less(i32, literal(5)), + } + .WithGuarantee(equal(i32, literal(5))) + .Expect(false); + + Simplify{ + less_equal(i32, literal(3)), + } + .WithGuarantee(equal(i32, literal(5))) + .Expect(false); + + // no simplification possible: + Simplify{ + not_equal(i32, literal(3)), + } + .WithGuarantee(less(i32, literal(5))) + .ExpectUnchanged(); + + // exhaustive coverage of all single comparison simplifications + for (std::string filter_op : + {"equal", "not_equal", "less", "less_equal", "greater", "greater_equal"}) { + for (auto filter_rhs : {literal(5), literal(3), literal(7)}) { + auto filter = call(filter_op, {i32, filter_rhs}); + for (std::string guarantee_op : + {"equal", "less", "less_equal", "greater", "greater_equal"}) { + auto guarantee = call(guarantee_op, {i32, literal(5)}); + + // generate data which satisfies the guarantee + static std::unordered_map satisfying_i32{ + {"equal", "[5]"}, + {"less", "[4, 3, 2, 1]"}, + {"less_equal", "[5, 4, 3, 2, 1]"}, + {"greater", "[6, 7, 8, 9]"}, + {"greater_equal", "[5, 6, 7, 8, 9]"}, + }; + + ASSERT_OK_AND_ASSIGN( + Datum input, + StructArray::Make({ArrayFromJSON(int32(), satisfying_i32[guarantee_op])}, + {"i32"})); + + std::shared_ptr state; + ASSERT_OK_AND_ASSIGN(std::tie(filter, state), filter.Bind(kBoringSchema)); + ASSERT_OK_AND_ASSIGN(Datum evaluated, + ExecuteScalarExpression(filter, state.get(), input)); + + // ensure that the simplified filter is as simplified as it could be + // (this is always possible for single comparisons) + bool all = true, none = true; + for (int64_t i = 0; i < input.length(); ++i) { + if (evaluated.array_as()->Value(i)) { + none = false; + } else { + all = false; + } + } + Simplify{filter}.WithGuarantee(guarantee).Expect( + all ? literal(true) : none ? literal(false) : filter); + } + } + } +} + +TEST(Expression2, SimplifyWithGuarantee) { + // drop both members of a conjunctive filter + Simplify{call("and_kleene", + { + call("equal", {field_ref("i32"), literal(2)}), + call("equal", {field_ref("f32"), literal(3.5F)}), + })} + .WithGuarantee(call("and_kleene", + { + call("greater_equal", {field_ref("i32"), literal(0)}), + call("less_equal", {field_ref("i32"), literal(1)}), + })) + .Expect(false); + + // drop one member of a conjunctive filter + Simplify{call("and_kleene", + { + call("equal", {field_ref("i32"), literal(0)}), + call("equal", {field_ref("f32"), literal(3.5F)}), + })} + .WithGuarantee(call("equal", {field_ref("i32"), literal(0)})) + .Expect(call("equal", {field_ref("f32"), literal(3.5F)})); + + // drop both members of a disjunctive filter + Simplify{call("or_kleene", + { + call("equal", {field_ref("i32"), literal(0)}), + call("equal", {field_ref("f32"), literal(3.5F)}), + })} + .WithGuarantee(call("equal", {field_ref("i32"), literal(0)})) + .Expect(true); + + // drop one member of a disjunctive filter + Simplify{call("or_kleene", + { + call("equal", {field_ref("i32"), literal(0)}), + call("equal", {field_ref("i32"), literal(3)}), + })} + .WithGuarantee(call("and_kleene", + { + call("greater_equal", {field_ref("i32"), literal(0)}), + call("less_equal", {field_ref("i32"), literal(1)}), + })) + .Expect(call("equal", {field_ref("i32"), literal(0)})); +} + +} // namespace dataset +} // namespace arrow diff --git a/cpp/src/arrow/dataset/partition.cc b/cpp/src/arrow/dataset/partition.cc index 159e0ac0331..a39a8adaae0 100644 --- a/cpp/src/arrow/dataset/partition.cc +++ b/cpp/src/arrow/dataset/partition.cc @@ -18,31 +18,20 @@ #include "arrow/dataset/partition.h" #include -#include #include -#include -#include #include #include #include "arrow/array/array_base.h" #include "arrow/array/array_nested.h" -#include "arrow/array/builder_binary.h" #include "arrow/array/builder_dict.h" #include "arrow/compute/api_scalar.h" #include "arrow/dataset/dataset_internal.h" -#include "arrow/dataset/file_base.h" #include "arrow/dataset/filter.h" -#include "arrow/dataset/scanner.h" -#include "arrow/dataset/scanner_internal.h" -#include "arrow/filesystem/filesystem.h" #include "arrow/filesystem/path_util.h" #include "arrow/scalar.h" -#include "arrow/util/iterator.h" #include "arrow/util/logging.h" #include "arrow/util/make_unique.h" -#include "arrow/util/range.h" -#include "arrow/util/sort.h" #include "arrow/util/string_view.h" namespace arrow { diff --git a/cpp/src/arrow/dataset/partition_test.cc b/cpp/src/arrow/dataset/partition_test.cc index f49103a585a..b47f5ba9926 100644 --- a/cpp/src/arrow/dataset/partition_test.cc +++ b/cpp/src/arrow/dataset/partition_test.cc @@ -21,20 +21,16 @@ #include #include -#include #include #include #include #include -#include "arrow/dataset/file_base.h" #include "arrow/dataset/scanner_internal.h" #include "arrow/dataset/test_util.h" -#include "arrow/filesystem/localfs.h" #include "arrow/filesystem/path_util.h" #include "arrow/status.h" #include "arrow/testing/gtest_util.h" -#include "arrow/util/io_util.h" namespace arrow { using internal::checked_pointer_cast; diff --git a/cpp/src/arrow/dataset/test_util.h b/cpp/src/arrow/dataset/test_util.h index f504305a996..9cfba41a624 100644 --- a/cpp/src/arrow/dataset/test_util.h +++ b/cpp/src/arrow/dataset/test_util.h @@ -22,6 +22,7 @@ #include #include #include +#include #include #include #include diff --git a/cpp/src/arrow/dataset/type_fwd.h b/cpp/src/arrow/dataset/type_fwd.h index 0ff77de0102..522def49818 100644 --- a/cpp/src/arrow/dataset/type_fwd.h +++ b/cpp/src/arrow/dataset/type_fwd.h @@ -76,6 +76,9 @@ class ExpressionEvaluator; ARROW_DS_EXPORT const std::shared_ptr& scalar(bool); +struct ExpressionState; +class Expression2; + class Partitioning; class PartitioningFactory; class PartitioningOrFactory; diff --git a/cpp/src/arrow/datum.cc b/cpp/src/arrow/datum.cc index 5feed556207..c1666adc4c5 100644 --- a/cpp/src/arrow/datum.cc +++ b/cpp/src/arrow/datum.cc @@ -68,6 +68,8 @@ Datum::Datum(int64_t value) : value(std::make_shared(value)) {} Datum::Datum(uint64_t value) : value(std::make_shared(value)) {} Datum::Datum(float value) : value(std::make_shared(value)) {} Datum::Datum(double value) : value(std::make_shared(value)) {} +Datum::Datum(std::string value) + : value(std::make_shared(std::move(value))) {} Datum::Datum(const ChunkedArray& value) : value(std::make_shared(value.chunks(), value.type())) {} @@ -238,4 +240,17 @@ ValueDescr::Shape GetBroadcastShape(const std::vector& args) { return ValueDescr::SCALAR; } +void PrintTo(const Datum& datum, std::ostream* os) { + switch (datum.kind()) { + case Datum::SCALAR: + *os << datum.scalar()->ToString(); + break; + case Datum::ARRAY: + *os << datum.make_array()->ToString(); + break; + default: + *os << datum.ToString(); + } +} + } // namespace arrow diff --git a/cpp/src/arrow/datum.h b/cpp/src/arrow/datum.h index 09dc870f687..5f80551f337 100644 --- a/cpp/src/arrow/datum.h +++ b/cpp/src/arrow/datum.h @@ -81,7 +81,9 @@ struct ARROW_EXPORT ValueDescr { } bool operator==(const ValueDescr& other) const { - return this->shape == other.shape && this->type->Equals(*other.type); + if (shape != other.shape) return false; + if (type == other.type) return true; + return type && type->Equals(other.type); } bool operator!=(const ValueDescr& other) const { return !(*this == other); } @@ -114,6 +116,11 @@ struct ARROW_EXPORT Datum { /// \brief Empty datum, to be populated elsewhere Datum() = default; + Datum(const Datum& other) noexcept = default; + Datum& operator=(const Datum& other) noexcept = default; + Datum(Datum&& other) noexcept = default; + Datum& operator=(Datum&& other) noexcept = default; + Datum(std::shared_ptr value) // NOLINT implicit conversion : value(std::move(value)) {} @@ -153,21 +160,7 @@ struct ARROW_EXPORT Datum { explicit Datum(uint64_t value); explicit Datum(float value); explicit Datum(double value); - - Datum(const Datum& other) noexcept { this->value = other.value; } - - Datum& operator=(const Datum& other) noexcept { - value = other.value; - return *this; - } - - // Define move constructor and move assignment, for better performance - Datum(Datum&& other) noexcept : value(std::move(other.value)) {} - - Datum& operator=(Datum&& other) noexcept { - value = std::move(other.value); - return *this; - } + explicit Datum(std::string value); Datum::Kind kind() const { switch (this->value.index()) { @@ -218,6 +211,11 @@ struct ARROW_EXPORT Datum { return util::get>(this->value); } + template + std::shared_ptr array_as() const { + return internal::checked_pointer_cast(this->make_array()); + } + template const ExactType& scalar_as() const { return internal::checked_cast(*this->scalar()); @@ -267,6 +265,8 @@ struct ARROW_EXPORT Datum { bool operator!=(const Datum& other) const { return !Equals(other); } std::string ToString() const; + + ARROW_EXPORT friend void PrintTo(const Datum&, std::ostream*); }; } // namespace arrow diff --git a/cpp/src/arrow/result.h b/cpp/src/arrow/result.h index eac08919286..09dfd59c8d2 100644 --- a/cpp/src/arrow/result.h +++ b/cpp/src/arrow/result.h @@ -401,6 +401,28 @@ class ARROW_MUST_USE_TYPE Result : public util::EqualityComparable> { return std::forward(m)(ValueUnsafe()); } + /// Cast the internally stored value to produce a new result or propagate the stored + /// error. + template ::value>::type> + Result As() && { + if (!ok()) { + return status(); + } + return U(MoveValueUnsafe()); + } + + /// Cast the internally stored value to produce a new result or propagate the stored + /// error. + template ::value>::type> + Result As() const& { + if (!ok()) { + return status(); + } + return U(ValueUnsafe()); + } + const T& ValueUnsafe() const& { return *internal::launder(reinterpret_cast(&data_)); } diff --git a/cpp/src/arrow/type.cc b/cpp/src/arrow/type.cc index 5e6fec82cb3..5604c5e2717 100644 --- a/cpp/src/arrow/type.cc +++ b/cpp/src/arrow/type.cc @@ -885,12 +885,15 @@ size_t FieldPath::hash() const { } std::string FieldPath::ToString() const { + if (this->indices().empty()) { + return "FieldPath(empty)"; + } + std::string repr = "FieldPath("; for (auto index : this->indices()) { repr += std::to_string(index) + " "; } - repr.resize(repr.size() - 1); - repr += ")"; + repr.back() = ')'; return repr; } @@ -1039,13 +1042,31 @@ Result> FieldPath::Get(const FieldVector& fields) const { Result> FieldPath::Get(const RecordBatch& batch) const { ARROW_ASSIGN_OR_RAISE(auto data, FieldPathGetImpl::Get(this, batch.column_data())); - return MakeArray(data); + return MakeArray(std::move(data)); } Result> FieldPath::Get(const Table& table) const { return FieldPathGetImpl::Get(this, table.columns()); } +Result> FieldPath::Get(const Array& array) const { + ARROW_ASSIGN_OR_RAISE(auto data, Get(*array.data())); + return MakeArray(std::move(data)); +} + +Result> FieldPath::Get(const ArrayData& data) const { + return FieldPathGetImpl::Get(this, data.child_data); +} + +Result> FieldPath::Get(const ChunkedArray& array) const { + FieldPath prefixed_with_0 = *this; + prefixed_with_0.indices_.insert(prefixed_with_0.indices_.begin(), 0); + + ChunkedArrayVector vec; + vec.emplace_back(const_cast(&array), [](...) {}); + return FieldPathGetImpl::Get(&prefixed_with_0, vec); +} + FieldRef::FieldRef(FieldPath indices) : impl_(std::move(indices)) { DCHECK_GT(util::get(impl_).indices().size(), 0); } @@ -1291,6 +1312,10 @@ std::vector FieldRef::FindAll(const FieldVector& fields) const { return util::visit(Visitor{fields}, impl_); } +std::vector FieldRef::FindAll(const ArrayData& array) const { + return FindAll(*array.type); +} + std::vector FieldRef::FindAll(const Array& array) const { return FindAll(*array.type()); } @@ -1333,7 +1358,7 @@ Schema::Schema(std::vector> fields, Schema::Schema(const Schema& schema) : detail::Fingerprintable(), impl_(new Impl(*schema.impl_)) {} -Schema::~Schema() {} +Schema::~Schema() = default; int Schema::num_fields() const { return static_cast(impl_->fields_.size()); } diff --git a/cpp/src/arrow/type.h b/cpp/src/arrow/type.h index ed504c3afe5..3594c248e10 100644 --- a/cpp/src/arrow/type.h +++ b/cpp/src/arrow/type.h @@ -1402,6 +1402,9 @@ class ARROW_EXPORT FieldPath { std::string ToString() const; size_t hash() const; + struct Hash { + size_t operator()(const FieldPath& path) const { return path.hash(); } + }; explicit operator bool() const { return !indices_.empty(); } bool operator!() const { return indices_.empty(); } @@ -1423,10 +1426,14 @@ class ARROW_EXPORT FieldPath { Result> Get(const RecordBatch& batch) const; Result> Get(const Table& table) const; - /// \brief Retrieve the referenced child Array from an Array or ChunkedArray + /// \brief Retrieve the referenced child from an Array, ArrayData, or ChunkedArray Result> Get(const Array& array) const; + Result> Get(const ArrayData& data) const; Result> Get(const ChunkedArray& array) const; + /// \brief Retrieve the reference child from a Datum + Result Get(const Datum& datum) const; + private: std::vector indices_; }; @@ -1517,6 +1524,12 @@ class ARROW_EXPORT FieldRef { std::string ToString() const; size_t hash() const; + struct Hash { + size_t operator()(const FieldRef& ref) const { return ref.hash(); } + }; + + explicit operator bool() const { return Equals(FieldPath{}); } + bool operator!() const { return !Equals(FieldPath{}); } bool IsFieldPath() const { return util::holds_alternative(impl_); } bool IsName() const { return util::holds_alternative(impl_); } @@ -1526,6 +1539,13 @@ class ARROW_EXPORT FieldRef { return true; } + const FieldPath* field_path() const { + return IsFieldPath() ? &util::get(impl_) : NULLPTR; + } + const std::string* name() const { + return IsName() ? &util::get(impl_) : NULLPTR; + } + /// \brief Retrieve FieldPath of every child field which matches this FieldRef. std::vector FindAll(const Schema& schema) const; std::vector FindAll(const Field& field) const; @@ -1533,6 +1553,7 @@ class ARROW_EXPORT FieldRef { std::vector FindAll(const FieldVector& fields) const; /// \brief Convenience function which applies FindAll to arg's type or schema. + std::vector FindAll(const ArrayData& array) const; std::vector FindAll(const Array& array) const; std::vector FindAll(const ChunkedArray& array) const; std::vector FindAll(const RecordBatch& batch) const; @@ -1609,7 +1630,7 @@ class ARROW_EXPORT FieldRef { if (match) { return match.Get(root).ValueOrDie(); } - return NULLPTR; + return GetType(NULLPTR); } private: diff --git a/cpp/src/arrow/type_fwd.h b/cpp/src/arrow/type_fwd.h index cf90953527e..f1000d1fe7f 100644 --- a/cpp/src/arrow/type_fwd.h +++ b/cpp/src/arrow/type_fwd.h @@ -71,6 +71,8 @@ class RecordBatch; class RecordBatchReader; class Table; +struct Datum; + using ChunkedArrayVector = std::vector>; using RecordBatchVector = std::vector>; using RecordBatchIterator = Iterator>; From c31069e88d6afbbfcfb81485bcd1c3b714a0d538 Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Tue, 1 Dec 2020 13:28:17 -0500 Subject: [PATCH 02/38] replace filtering with Expression2 --- .../arrow/dataset-parquet-scan-example.cc | 11 +- .../compute/kernels/scalar_cast_internal.cc | 4 +- cpp/src/arrow/dataset/dataset.cc | 37 +- cpp/src/arrow/dataset/dataset.h | 34 +- cpp/src/arrow/dataset/dataset_internal.h | 8 +- cpp/src/arrow/dataset/dataset_test.cc | 3 +- cpp/src/arrow/dataset/discovery.cc | 2 +- cpp/src/arrow/dataset/discovery.h | 8 +- cpp/src/arrow/dataset/discovery_test.cc | 10 +- cpp/src/arrow/dataset/expression.cc | 804 ++++++++++++------ cpp/src/arrow/dataset/expression.h | 100 ++- cpp/src/arrow/dataset/expression_internal.h | 318 ++++++- cpp/src/arrow/dataset/expression_test.cc | 593 ++++++++----- cpp/src/arrow/dataset/file_base.cc | 36 +- cpp/src/arrow/dataset/file_base.h | 17 +- cpp/src/arrow/dataset/file_csv.cc | 14 +- cpp/src/arrow/dataset/file_ipc_test.cc | 4 +- cpp/src/arrow/dataset/file_parquet.cc | 78 +- cpp/src/arrow/dataset/file_parquet.h | 14 +- cpp/src/arrow/dataset/file_parquet_test.cc | 57 +- cpp/src/arrow/dataset/file_test.cc | 122 ++- cpp/src/arrow/dataset/filter.cc | 41 +- cpp/src/arrow/dataset/filter.h | 55 -- cpp/src/arrow/dataset/filter_test.cc | 236 +---- cpp/src/arrow/dataset/partition.cc | 127 ++- cpp/src/arrow/dataset/partition.h | 33 +- cpp/src/arrow/dataset/partition_test.cc | 107 ++- cpp/src/arrow/dataset/scanner.cc | 48 +- cpp/src/arrow/dataset/scanner.h | 16 +- cpp/src/arrow/dataset/scanner_internal.h | 96 +-- cpp/src/arrow/dataset/scanner_test.cc | 45 +- cpp/src/arrow/dataset/test_util.h | 122 +-- cpp/src/arrow/dataset/type_fwd.h | 6 - cpp/src/arrow/datum.cc | 3 + cpp/src/arrow/datum.h | 3 + cpp/src/arrow/type.cc | 2 +- cpp/src/arrow/type.h | 2 +- cpp/src/arrow/util/variant.h | 16 + 38 files changed, 1962 insertions(+), 1270 deletions(-) diff --git a/cpp/examples/arrow/dataset-parquet-scan-example.cc b/cpp/examples/arrow/dataset-parquet-scan-example.cc index 3cdd298b8fd..5b933d3ca62 100644 --- a/cpp/examples/arrow/dataset-parquet-scan-example.cc +++ b/cpp/examples/arrow/dataset-parquet-scan-example.cc @@ -37,8 +37,6 @@ namespace fs = arrow::fs; namespace ds = arrow::dataset; -using ds::string_literals::operator"" _; - #define ABORT_ON_FAILURE(expr) \ do { \ arrow::Status status_ = (expr); \ @@ -62,7 +60,8 @@ struct Configuration { // Indicates the filter by which rows will be filtered. This optimization can // make use of partition information and/or file metadata if possible. - std::shared_ptr filter = ("total_amount"_ > 1000.0f).Copy(); + ds::Expression2 filter = + ds::greater(ds::field_ref("total_amount"), ds::literal(1000.0f)); ds::InspectOptions inspect_options{}; ds::FinishOptions finish_options{}; @@ -147,7 +146,7 @@ std::shared_ptr GetDatasetFromPath( std::shared_ptr GetScannerFromDataset(std::shared_ptr dataset, std::vector columns, - std::shared_ptr filter, + ds::Expression2 filter, bool use_threads) { auto scanner_builder = dataset->NewScan().ValueOrDie(); @@ -155,9 +154,7 @@ std::shared_ptr GetScannerFromDataset(std::shared_ptr ABORT_ON_FAILURE(scanner_builder->Project(columns)); } - if (filter != nullptr) { - ABORT_ON_FAILURE(scanner_builder->Filter(filter)); - } + ABORT_ON_FAILURE(scanner_builder->Filter(filter)); ABORT_ON_FAILURE(scanner_builder->UseThreads(use_threads)); diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_internal.cc b/cpp/src/arrow/compute/kernels/scalar_cast_internal.cc index cd33de672a6..67f0820402a 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_internal.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_internal.cc @@ -191,6 +191,8 @@ void CastFromExtension(KernelContext* ctx, const ExecBatch& batch, Datum* out) { } void CastFromNull(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + if (batch[0].is_scalar()) return; + ArrayData* output = out->mutable_array(); std::shared_ptr nulls; Status s = MakeArrayOfNull(output->type, batch.length).Value(&nulls); @@ -251,7 +253,7 @@ static bool CanCastFromDictionary(Type::type type_id) { void AddCommonCasts(Type::type out_type_id, OutputType out_ty, CastFunction* func) { // From null to this type - DCHECK_OK(func->AddKernel(Type::NA, {InputType::Array(null())}, out_ty, CastFromNull)); + DCHECK_OK(func->AddKernel(Type::NA, {null()}, out_ty, CastFromNull)); // From dictionary to this type if (CanCastFromDictionary(out_type_id)) { diff --git a/cpp/src/arrow/dataset/dataset.cc b/cpp/src/arrow/dataset/dataset.cc index 71755aaf566..56c901f959d 100644 --- a/cpp/src/arrow/dataset/dataset.cc +++ b/cpp/src/arrow/dataset/dataset.cc @@ -32,12 +32,10 @@ namespace arrow { namespace dataset { -Fragment::Fragment(std::shared_ptr partition_expression, +Fragment::Fragment(Expression2 partition_expression, std::shared_ptr physical_schema) : partition_expression_(std::move(partition_expression)), - physical_schema_(std::move(physical_schema)) { - DCHECK_NE(partition_expression_, nullptr); -} + physical_schema_(std::move(physical_schema)) {} Result> Fragment::ReadPhysicalSchema() { { @@ -61,14 +59,14 @@ Result> InMemoryFragment::ReadPhysicalSchemaImpl() { InMemoryFragment::InMemoryFragment(std::shared_ptr schema, RecordBatchVector record_batches, - std::shared_ptr partition_expression) + Expression2 partition_expression) : Fragment(std::move(partition_expression), std::move(schema)), record_batches_(std::move(record_batches)) { DCHECK_NE(physical_schema_, nullptr); } InMemoryFragment::InMemoryFragment(RecordBatchVector record_batches, - std::shared_ptr partition_expression) + Expression2 partition_expression) : InMemoryFragment(record_batches.empty() ? schema({}) : record_batches[0]->schema(), std::move(record_batches), std::move(partition_expression)) {} @@ -95,11 +93,9 @@ Result InMemoryFragment::Scan(std::shared_ptr opt return MakeMapIterator(fn, std::move(batches_it)); } -Dataset::Dataset(std::shared_ptr schema, - std::shared_ptr partition_expression) - : schema_(std::move(schema)), partition_expression_(std::move(partition_expression)) { - DCHECK_NE(partition_expression_, nullptr); -} +Dataset::Dataset(std::shared_ptr schema, Expression2 partition_expression) + : schema_(std::move(schema)), + partition_expression_(std::move(partition_expression)) {} Result> Dataset::NewScan( std::shared_ptr context) { @@ -110,10 +106,16 @@ Result> Dataset::NewScan() { return NewScan(std::make_shared()); } -FragmentIterator Dataset::GetFragments(std::shared_ptr predicate) { - predicate = predicate->Assume(*partition_expression_); - return predicate->IsSatisfiable() ? GetFragmentsImpl(std::move(predicate)) - : MakeEmptyIterator>(); +Result Dataset::GetFragments() { + ARROW_ASSIGN_OR_RAISE(auto predicate, literal(true).Bind(*schema_)); + return GetFragments(std::move(predicate)); +} + +Result Dataset::GetFragments(Expression2::BoundWithState predicate) { + ARROW_ASSIGN_OR_RAISE( + predicate, SimplifyWithGuarantee(std::move(predicate), partition_expression_)); + return predicate.first.IsSatisfiable() ? GetFragmentsImpl(std::move(predicate)) + : MakeEmptyIterator>(); } struct VectorRecordBatchGenerator : InMemoryDataset::RecordBatchGenerator { @@ -153,7 +155,7 @@ Result> InMemoryDataset::ReplaceSchema( return std::make_shared(std::move(schema), get_batches_); } -FragmentIterator InMemoryDataset::GetFragmentsImpl(std::shared_ptr) { +Result InMemoryDataset::GetFragmentsImpl(Expression2::BoundWithState) { auto schema = this->schema(); auto create_fragment = @@ -194,7 +196,8 @@ Result> UnionDataset::ReplaceSchema( new UnionDataset(std::move(schema), std::move(children))); } -FragmentIterator UnionDataset::GetFragmentsImpl(std::shared_ptr predicate) { +Result UnionDataset::GetFragmentsImpl( + Expression2::BoundWithState predicate) { return GetFragmentsFromDatasets(children_, predicate); } diff --git a/cpp/src/arrow/dataset/dataset.h b/cpp/src/arrow/dataset/dataset.h index be138dd277d..35c91b79fed 100644 --- a/cpp/src/arrow/dataset/dataset.h +++ b/cpp/src/arrow/dataset/dataset.h @@ -25,6 +25,7 @@ #include #include +#include "arrow/dataset/expression.h" #include "arrow/dataset/type_fwd.h" #include "arrow/dataset/visibility.h" #include "arrow/util/macros.h" @@ -71,21 +72,19 @@ class ARROW_DS_EXPORT Fragment { /// \brief An expression which evaluates to true for all data viewed by this /// Fragment. - const std::shared_ptr& partition_expression() const { - return partition_expression_; - } + const Expression2& partition_expression() const { return partition_expression_; } virtual ~Fragment() = default; protected: Fragment() = default; - explicit Fragment(std::shared_ptr partition_expression, + explicit Fragment(Expression2 partition_expression, std::shared_ptr physical_schema); virtual Result> ReadPhysicalSchemaImpl() = 0; util::Mutex physical_schema_mutex_; - std::shared_ptr partition_expression_ = scalar(true); + Expression2 partition_expression_ = literal(true); std::shared_ptr physical_schema_; }; @@ -94,9 +93,9 @@ class ARROW_DS_EXPORT Fragment { class ARROW_DS_EXPORT InMemoryFragment : public Fragment { public: InMemoryFragment(std::shared_ptr schema, RecordBatchVector record_batches, - std::shared_ptr = scalar(true)); + Expression2 = literal(true)); explicit InMemoryFragment(RecordBatchVector record_batches, - std::shared_ptr = scalar(true)); + Expression2 = literal(true)); Result Scan(std::shared_ptr options, std::shared_ptr context) override; @@ -123,15 +122,14 @@ class ARROW_DS_EXPORT Dataset : public std::enable_shared_from_this { Result> NewScan(); /// \brief GetFragments returns an iterator of Fragments given a predicate. - FragmentIterator GetFragments(std::shared_ptr predicate = scalar(true)); + Result GetFragments(Expression2::BoundWithState predicate); + Result GetFragments(); const std::shared_ptr& schema() const { return schema_; } /// \brief An expression which evaluates to true for all data viewed by this Dataset. /// May be null, which indicates no information is available. - const std::shared_ptr& partition_expression() const { - return partition_expression_; - } + const Expression2& partition_expression() const { return partition_expression_; } /// \brief The name identifying the kind of Dataset virtual std::string type_name() const = 0; @@ -148,13 +146,13 @@ class ARROW_DS_EXPORT Dataset : public std::enable_shared_from_this { protected: explicit Dataset(std::shared_ptr schema) : schema_(std::move(schema)) {} - Dataset(std::shared_ptr schema, - std::shared_ptr partition_expression); + Dataset(std::shared_ptr schema, Expression2 partition_expression); - virtual FragmentIterator GetFragmentsImpl(std::shared_ptr predicate) = 0; + virtual Result GetFragmentsImpl( + Expression2::BoundWithState predicate) = 0; std::shared_ptr schema_; - std::shared_ptr partition_expression_ = scalar(true); + Expression2 partition_expression_ = literal(true); }; /// \brief A Source which yields fragments wrapping a stream of record batches. @@ -183,7 +181,8 @@ class ARROW_DS_EXPORT InMemoryDataset : public Dataset { std::shared_ptr schema) const override; protected: - FragmentIterator GetFragmentsImpl(std::shared_ptr predicate) override; + Result GetFragmentsImpl( + Expression2::BoundWithState predicate) override; std::shared_ptr get_batches_; }; @@ -207,7 +206,8 @@ class ARROW_DS_EXPORT UnionDataset : public Dataset { std::shared_ptr schema) const override; protected: - FragmentIterator GetFragmentsImpl(std::shared_ptr predicate) override; + Result GetFragmentsImpl( + Expression2::BoundWithState predicate) override; explicit UnionDataset(std::shared_ptr schema, DatasetVector children) : Dataset(std::move(schema)), children_(std::move(children)) {} diff --git a/cpp/src/arrow/dataset/dataset_internal.h b/cpp/src/arrow/dataset/dataset_internal.h index 489339e7907..a6ee8a117bc 100644 --- a/cpp/src/arrow/dataset/dataset_internal.h +++ b/cpp/src/arrow/dataset/dataset_internal.h @@ -35,18 +35,18 @@ namespace dataset { /// \brief GetFragmentsFromDatasets transforms a vector into a /// flattened FragmentIterator. -inline FragmentIterator GetFragmentsFromDatasets(const DatasetVector& datasets, - std::shared_ptr predicate) { +inline Result GetFragmentsFromDatasets( + const DatasetVector& datasets, Expression2::BoundWithState predicate) { // Iterator auto datasets_it = MakeVectorIterator(datasets); // Dataset -> Iterator - auto fn = [predicate](std::shared_ptr dataset) -> FragmentIterator { + auto fn = [predicate](std::shared_ptr dataset) -> Result { return dataset->GetFragments(predicate); }; // Iterator> - auto fragments_it = MakeMapIterator(fn, std::move(datasets_it)); + auto fragments_it = MakeMaybeMapIterator(fn, std::move(datasets_it)); // Iterator return MakeFlattenIterator(std::move(fragments_it)); diff --git a/cpp/src/arrow/dataset/dataset_test.cc b/cpp/src/arrow/dataset/dataset_test.cc index 0b138caafe7..58a4a4a0b3f 100644 --- a/cpp/src/arrow/dataset/dataset_test.cc +++ b/cpp/src/arrow/dataset/dataset_test.cc @@ -505,7 +505,8 @@ TEST_F(TestEndToEnd, EndToEndSingleDataset) { // The following filter tests both predicate pushdown and post filtering // without partition information because `year` is a partition and `sales` is // not. - auto filter = ("year"_ == 2019 && "sales"_ > 100.0); + auto filter = and_(equal(field_ref("year"), literal(2019)), + greater(field_ref("sales"), literal(100.0))); ASSERT_OK(scanner_builder->Filter(filter)); ASSERT_OK_AND_ASSIGN(auto scanner, scanner_builder->Finish()); diff --git a/cpp/src/arrow/dataset/discovery.cc b/cpp/src/arrow/dataset/discovery.cc index b7af6dcbf0f..079cc6d4065 100644 --- a/cpp/src/arrow/dataset/discovery.cc +++ b/cpp/src/arrow/dataset/discovery.cc @@ -35,7 +35,7 @@ namespace arrow { namespace dataset { -DatasetFactory::DatasetFactory() : root_partition_(scalar(true)) {} +DatasetFactory::DatasetFactory() : root_partition_(literal(true)) {} Result> DatasetFactory::Inspect(InspectOptions options) { ARROW_ASSIGN_OR_RAISE(auto schemas, InspectSchemas(std::move(options))); diff --git a/cpp/src/arrow/dataset/discovery.h b/cpp/src/arrow/dataset/discovery.h index 5176f4f53a7..0b1c94f5adc 100644 --- a/cpp/src/arrow/dataset/discovery.h +++ b/cpp/src/arrow/dataset/discovery.h @@ -90,9 +90,9 @@ class ARROW_DS_EXPORT DatasetFactory { virtual Result> Finish(FinishOptions options) = 0; /// \brief Optional root partition for the resulting Dataset. - const std::shared_ptr& root_partition() const { return root_partition_; } - Status SetRootPartition(std::shared_ptr partition) { - root_partition_ = partition; + const Expression2& root_partition() const { return root_partition_; } + Status SetRootPartition(Expression2 partition) { + root_partition_ = std::move(partition); return Status::OK(); } @@ -101,7 +101,7 @@ class ARROW_DS_EXPORT DatasetFactory { protected: DatasetFactory(); - std::shared_ptr root_partition_; + Expression2 root_partition_; }; /// \brief DatasetFactory provides a way to inspect/discover a Dataset's diff --git a/cpp/src/arrow/dataset/discovery_test.cc b/cpp/src/arrow/dataset/discovery_test.cc index bc330edb53b..cd0ba2bee4d 100644 --- a/cpp/src/arrow/dataset/discovery_test.cc +++ b/cpp/src/arrow/dataset/discovery_test.cc @@ -139,7 +139,8 @@ class FileSystemDatasetFactoryTest : public DatasetFactoryTest { } options_ = ScanOptions::Make(schema); ASSERT_OK_AND_ASSIGN(dataset_, factory_->Finish(schema)); - AssertFragmentsAreFromPath(dataset_->GetFragments(), paths); + ASSERT_OK_AND_ASSIGN(auto fragment_it, dataset_->GetFragments()); + AssertFragmentsAreFromPath(std::move(fragment_it), paths); } protected: @@ -368,10 +369,13 @@ TEST_F(FileSystemDatasetFactoryTest, FilenameNotPartOfPartitions) { // column. In such case, the filename should not be used. MakeFactory({fs::File("one/file.parquet")}); + ASSERT_OK_AND_ASSIGN(auto expected, equal(field_ref("first"), literal("one")).Bind(*s)); + ASSERT_OK_AND_ASSIGN(auto dataset, factory_->Finish()); - for (const auto& maybe_fragment : dataset->GetFragments()) { + ASSERT_OK_AND_ASSIGN(auto fragment_it, dataset->GetFragments()); + for (const auto& maybe_fragment : fragment_it) { ASSERT_OK_AND_ASSIGN(auto fragment, maybe_fragment); - ASSERT_TRUE(fragment->partition_expression()->Equals(("first"_ == "one"))); + EXPECT_EQ(fragment->partition_expression(), expected.first); } } diff --git a/cpp/src/arrow/dataset/expression.cc b/cpp/src/arrow/dataset/expression.cc index e0a44a34ab8..c762cb89e0f 100644 --- a/cpp/src/arrow/dataset/expression.cc +++ b/cpp/src/arrow/dataset/expression.cc @@ -24,9 +24,15 @@ #include "arrow/compute/exec_internal.h" #include "arrow/compute/registry.h" #include "arrow/dataset/expression_internal.h" +#include "arrow/dataset/filter.h" +#include "arrow/io/memory.h" +#include "arrow/ipc/reader.h" +#include "arrow/ipc/writer.h" +#include "arrow/util/key_value_metadata.h" #include "arrow/util/logging.h" #include "arrow/util/optional.h" #include "arrow/util/string.h" +#include "arrow/util/value_parsing.h" namespace arrow { @@ -45,6 +51,144 @@ const FieldRef* Expression2::field_ref() const { return util::get_if(impl_.get()); } +Expression2::operator std::shared_ptr() const { + if (auto lit = literal()) { + DCHECK(lit->is_scalar()); + return std::make_shared(lit->scalar()); + } + + if (auto ref = field_ref()) { + DCHECK(ref->name()); + return std::make_shared(*ref->name()); + } + + auto call = CallNotNull(*this); + if (call->function == "invert") { + return std::make_shared(call->arguments[0]); + } + + if (call->function == "cast") { + const auto& options = checked_cast(*call->options); + return std::make_shared(call->arguments[0], options.to_type, options); + } + + if (call->function == "and_kleene") { + return std::make_shared(call->arguments[0], call->arguments[1]); + } + + if (call->function == "or_kleene") { + return std::make_shared(call->arguments[0], call->arguments[1]); + } + + if (auto cmp = Comparison::Get(call->function)) { + compute::CompareOperator op = [&] { + switch (*cmp) { + case Comparison::EQUAL: + return compute::EQUAL; + case Comparison::LESS: + return compute::LESS; + case Comparison::GREATER: + return compute::GREATER; + case Comparison::NOT_EQUAL: + return compute::NOT_EQUAL; + case Comparison::LESS_EQUAL: + return compute::LESS_EQUAL; + case Comparison::GREATER_EQUAL: + return compute::GREATER_EQUAL; + default: + break; + } + return static_cast(-1); + }(); + + return std::make_shared(op, call->arguments[0], + call->arguments[1]); + } + + if (call->function == "is_valid") { + return std::make_shared(call->arguments[0]); + } + + if (call->function == "is_in") { + auto set = checked_cast(*call->options) + .value_set.make_array(); + return std::make_shared(call->arguments[0], std::move(set)); + } + + DCHECK(false) << "untranslatable Expression2: " << ToString(); + return nullptr; +} + +Expression2::Expression2(const Expression& expr) { + switch (expr.type()) { + case ExpressionType::FIELD: + *this = + ::arrow::dataset::field_ref(checked_cast(expr).name()); + return; + + case ExpressionType::SCALAR: + *this = + ::arrow::dataset::literal(checked_cast(expr).value()); + return; + + case ExpressionType::NOT: + *this = ::arrow::dataset::call( + "invert", {checked_cast(expr).operand()}); + return; + + case ExpressionType::CAST: { + const auto& cast_expr = checked_cast(expr); + auto options = cast_expr.options(); + options.to_type = cast_expr.to_type(); + *this = ::arrow::dataset::call("cast", {cast_expr.operand()}, std::move(options)); + return; + } + + case ExpressionType::AND: { + const auto& and_expr = checked_cast(expr); + *this = ::arrow::dataset::call("and_kleene", + {and_expr.left_operand(), and_expr.right_operand()}); + return; + } + + case ExpressionType::OR: { + const auto& or_expr = checked_cast(expr); + *this = ::arrow::dataset::call("or_kleene", + {or_expr.left_operand(), or_expr.right_operand()}); + return; + } + + case ExpressionType::COMPARISON: { + const auto& cmp_expr = checked_cast(expr); + static std::array ops = { + "equal", "not_equal", "greater", "greater_equal", "less", "less_equal", + }; + *this = ::arrow::dataset::call(ops[cmp_expr.op()], + {cmp_expr.left_operand(), cmp_expr.right_operand()}); + return; + } + + case ExpressionType::IS_VALID: { + const auto& is_valid_expr = checked_cast(expr); + *this = ::arrow::dataset::call("is_valid", {is_valid_expr.operand()}); + return; + } + + case ExpressionType::IN: { + const auto& in_expr = checked_cast(expr); + *this = ::arrow::dataset::call( + "is_in", {in_expr.operand()}, + compute::SetLookupOptions{in_expr.set(), /*skip_nulls=*/true}); + return; + } + + default: + break; + } + + DCHECK(false) << "untranslatable Expression: " << expr.ToString(); +} + std::string Expression2::ToString() const { if (auto lit = literal()) { if (lit->is_scalar()) { @@ -74,7 +218,12 @@ std::string Expression2::ToString() const { return out; } -void PrintTo(const Expression2& expr, std::ostream* os) { *os << expr.ToString(); } +void PrintTo(const Expression2& expr, std::ostream* os) { + *os << expr.ToString(); + if (expr.IsBound()) { + *os << "[bound]"; + } +} bool Expression2::Equals(const Expression2& other) const { if (Identical(*this, other)) return true; @@ -181,30 +330,41 @@ bool Expression2::IsNullLiteral() const { return true; } } + return false; } bool Expression2::IsSatisfiable() const { + if (descr_.type && descr_.type->id() == Type::NA) { + return false; + } + if (auto lit = literal()) { if (lit->null_count() == lit->length()) { return false; } + if (lit->is_scalar() && lit->type()->id() == Type::BOOL) { return lit->scalar_as().value; } } + + if (auto ref = field_ref()) { + return true; + } + return true; } -bool KernelStateIsImmutable(const std::string& function) { +inline bool KernelStateIsImmutable(const std::string& function) { // XXX maybe just add Kernel::state_is_immutable or so? // known functions with non-null but nevertheless immutable KernelState - static std::unordered_set immutable_state = { + static std::unordered_set names = { "is_in", "index_in", "cast", "struct", "strptime", }; - return immutable_state.find(function) != immutable_state.end(); + return names.find(function) != names.end(); } Result> InitKernelState( @@ -219,62 +379,45 @@ Result> InitKernelState( return std::move(kernel_state); } -Result> ExpressionState::Clone( - const std::shared_ptr& state, const Expression2& expr, - compute::ExecContext* exec_context) { - if (!expr.IsBound()) { - return Status::Invalid("Cannot clone State against an unbound expression."); - } - - if (state == nullptr) return nullptr; - - auto call = CallNotNull(expr); - auto call_state = checked_cast(state.get()); +Result> CloneExpressionState( + const ExpressionState& state, compute::ExecContext* exec_context) { + auto clone = std::make_shared(); + clone->kernel_states.reserve(state.kernel_states.size()); - CallState clone; - clone.argument_states.resize(call_state->argument_states.size()); + for (const auto& sub_state : state.kernel_states) { + auto call = CallNotNull(sub_state.first); - bool recursively_share = true; - for (size_t i = 0; i < call->arguments.size(); ++i) { - ARROW_ASSIGN_OR_RAISE( - clone.argument_states[i], - Clone(call_state->argument_states[i], call->arguments[i], exec_context)); - - if (clone.argument_states[i] != call_state->argument_states[i]) { - recursively_share = false; + if (KernelStateIsImmutable(call->function)) { + // The kernel's state is immutable so it's safe to just + // share a pointer between threads + clone->kernel_states.insert(sub_state); + continue; } - } - if (call_state->kernel_state == nullptr || KernelStateIsImmutable(call->function)) { - // The kernel's state is immutable so it's safe to just - // share a pointer between threads - if (recursively_share) { - return state; - } - clone.kernel_state = call_state->kernel_state; - } else { // The kernel's state must be re-initialized. - ARROW_ASSIGN_OR_RAISE(clone.kernel_state, InitKernelState(*call, exec_context)); + ARROW_ASSIGN_OR_RAISE(auto kernel_state, InitKernelState(*call, exec_context)); + clone->kernel_states.emplace(sub_state.first, std::move(kernel_state)); } - return std::make_shared(std::move(clone)); + return clone; } -Result Expression2::Bind( +Result Expression2::Bind( ValueDescr in, compute::ExecContext* exec_context) const { if (exec_context == nullptr) { compute::ExecContext exec_context; return Bind(std::move(in), &exec_context); } - if (literal()) return StateAndBound{*this, nullptr}; + BoundWithState ret{*this, std::make_shared()}; + + if (literal()) return ret; if (auto ref = field_ref()) { ARROW_ASSIGN_OR_RAISE(auto field, ref->GetOneOrNone(*in.type)); - auto bound = *this; - bound.descr_ = + ret.first.descr_ = field ? ValueDescr{field->type(), in.shape} : ValueDescr::Scalar(null()); - return StateAndBound{std::move(bound), nullptr}; + return ret; } auto bound_call = *CallNotNull(*this); @@ -291,31 +434,41 @@ Result Expression2::Bind( } bound_call.function_kind = function->kind(); - std::vector> argument_states( - bound_call.arguments.size()); - for (size_t i = 0; i < argument_states.size(); ++i) { - ARROW_ASSIGN_OR_RAISE(std::tie(bound_call.arguments[i], argument_states[i]), + auto state = std::make_shared(); + for (size_t i = 0; i < bound_call.arguments.size(); ++i) { + std::shared_ptr argument_state; + ARROW_ASSIGN_OR_RAISE(std::tie(bound_call.arguments[i], argument_state), bound_call.arguments[i].Bind(in, exec_context)); + state->MoveFrom(argument_state.get()); + } + + if (RequriesDictionaryTransparency(bound_call)) { + RETURN_NOT_OK(EnsureNotDictionary(&bound_call)); } auto descrs = GetDescriptors(bound_call.arguments); - ARROW_ASSIGN_OR_RAISE(bound_call.kernel, function->DispatchExact(descrs)); + for (auto& descr : descrs) { + if (RequriesDictionaryTransparency(bound_call)) { + RETURN_NOT_OK(EnsureNotDictionary(&descr)); + } + } - auto call_state = std::make_shared(); - ARROW_ASSIGN_OR_RAISE(call_state->kernel_state, - InitKernelState(bound_call, exec_context)); - call_state->argument_states = std::move(argument_states); + ARROW_ASSIGN_OR_RAISE(bound_call.kernel, function->DispatchExact(descrs)); compute::KernelContext kernel_context(exec_context); - kernel_context.SetState(call_state->kernel_state.get()); + ARROW_ASSIGN_OR_RAISE(auto kernel_state, InitKernelState(bound_call, exec_context)); + kernel_context.SetState(kernel_state.get()); + + ARROW_ASSIGN_OR_RAISE(auto descr, bound_call.kernel->signature->out_type().Resolve( + &kernel_context, descrs)); - auto bound = Expression2(std::make_shared(std::move(bound_call))); - ARROW_ASSIGN_OR_RAISE(bound.descr_, bound_call.kernel->signature->out_type().Resolve( - &kernel_context, descrs)); - return StateAndBound{std::move(bound), std::move(call_state)}; + Expression2 bound(std::make_shared(std::move(bound_call)), std::move(descr)); + + state->kernel_states.emplace(bound, std::move(kernel_state)); + return BoundWithState{std::move(bound), std::move(state)}; } -Result Expression2::Bind( +Result Expression2::Bind( const Schema& in_schema, compute::ExecContext* exec_context) const { return Bind(ValueDescr::Array(struct_(in_schema.fields())), exec_context); } @@ -354,25 +507,26 @@ Result ExecuteScalarExpression(const Expression2& expr, ExpressionState* } auto call = CallNotNull(expr); - auto call_state = checked_cast(state); std::vector arguments(call->arguments.size()); for (size_t i = 0; i < arguments.size(); ++i) { - auto argument_state = call_state->argument_states[i].get(); - ARROW_ASSIGN_OR_RAISE( - arguments[i], - ExecuteScalarExpression(call->arguments[i], argument_state, input, exec_context)); + ARROW_ASSIGN_OR_RAISE(arguments[i], ExecuteScalarExpression(call->arguments[i], state, + input, exec_context)); + + if (RequriesDictionaryTransparency(*call)) { + RETURN_NOT_OK(EnsureNotDictionary(&arguments[i])); + } } auto executor = compute::detail::KernelExecutor::MakeScalar(); compute::KernelContext kernel_context(exec_context); - kernel_context.SetState(call_state->kernel_state.get()); + kernel_context.SetState(state->Get(expr)); auto kernel = call->kernel; - auto inputs = GetDescriptors(call->arguments); + auto descrs = GetDescriptors(arguments); auto options = call->options.get(); - RETURN_NOT_OK(executor->Init(&kernel_context, {kernel, inputs, options})); + RETURN_NOT_OK(executor->Init(&kernel_context, {kernel, descrs, options})); auto listener = std::make_shared(); RETURN_NOT_OK(executor->Execute(arguments, listener.get())); @@ -388,6 +542,18 @@ ArgumentsAndFlippedArguments(const Expression2::Call& call) { call.arguments[0]}}; } +template ::value_type> +util::optional FoldLeft(It begin, It end, const BinOp& bin_op) { + if (begin == end) return util::nullopt; + + Out folded = std::move(*begin++); + while (begin != end) { + folded = bin_op(std::move(folded), std::move(*begin++)); + } + return folded; +} + util::optional GetNullHandling( const Expression2::Call& call) { if (call.function_kind == compute::Function::SCALAR) { @@ -419,8 +585,23 @@ bool DefinitelyNotNull(const Expression2& expr) { return false; } +std::vector FieldsInExpression(const Expression2& expr) { + if (auto lit = expr.literal()) return {}; + + if (auto ref = expr.field_ref()) { + return {*ref}; + } + + std::vector fields; + for (const Expression2& arg : CallNotNull(expr)->arguments) { + auto argument_fields = FieldsInExpression(arg); + std::move(argument_fields.begin(), argument_fields.end(), std::back_inserter(fields)); + } + return fields; +} + template -Result Modify(Expression2 expr, const PreVisit& pre, +Result Modify(Expression2 expr, ExpressionState* state, const PreVisit& pre, const PostVisitCall& post_call) { ARROW_ASSIGN_OR_RAISE(expr, Result(pre(std::move(expr)))); @@ -432,7 +613,7 @@ Result Modify(Expression2 expr, const PreVisit& pre, auto modified_argument = modified_call.arguments.begin(); for (const auto& argument : call->arguments) { - ARROW_ASSIGN_OR_RAISE(*modified_argument, Modify(argument, pre, post_call)); + ARROW_ASSIGN_OR_RAISE(*modified_argument, Modify(argument, state, pre, post_call)); if (!Identical(*modified_argument, argument)) { at_least_one_modified = true; @@ -444,44 +625,37 @@ Result Modify(Expression2 expr, const PreVisit& pre, // reconstruct the call expression with the modified arguments auto modified_expr = Expression2( std::make_shared(std::move(modified_call)), expr.descr()); + + // if expr had associated kernel state, associate it with modified_expr + state->Replace(expr, modified_expr); return post_call(std::move(modified_expr), &expr); } return post_call(std::move(expr), nullptr); } -Result FoldConstants(Expression2 expr, ExpressionState* state) { - DCHECK(expr.IsBound()); - - struct StateAndIndex { - CallState* state; - int index; - }; - std::vector stack; - - auto root_state = checked_cast(state); - return Modify( - std::move(expr), - [&stack, root_state](Expression2 expr) { - auto call = expr.call(); +template +Result Modify(Expression2::BoundWithState bound, + const PreVisit& pre, + const PostVisitCall& post_call) { + DCHECK(bound.first.IsBound()); - if (stack.empty()) { - stack = {{root_state, 0}}; - return expr; - } + auto expr = std::move(bound.first); + auto state = bound.second.get(); - int i = stack.back().index++; + ARROW_ASSIGN_OR_RAISE(expr, Modify(std::move(expr), state, pre, post_call)); - if (!call) return expr; - auto next_state = stack.back().state->argument_states[i].get(); - stack.push_back({checked_cast(next_state), 0}); + bound.first = std::move(expr); - return expr; - }, - [&stack](Expression2 expr, ...) -> Result { - auto state = stack.back().state; - stack.pop_back(); + return bound; +} +Result FoldConstants(Expression2::BoundWithState bound) { + bound.second = std::make_shared(*bound.second); + auto state = bound.second.get(); + return Modify( + std::move(bound), [](Expression2 expr) { return expr; }, + [state](Expression2 expr, ...) -> Result { auto call = CallNotNull(expr); if (std::all_of(call->arguments.begin(), call->arguments.end(), [](const Expression2& argument) { return argument.literal(); })) { @@ -489,16 +663,22 @@ Result FoldConstants(Expression2 expr, ExpressionState* state) { static const Datum ignored_input; ARROW_ASSIGN_OR_RAISE(Datum constant, ExecuteScalarExpression(expr, state, ignored_input)); + + state->Drop(expr); return literal(std::move(constant)); } - // XXX the following should probably be in a registry of passes instead of inline + // XXX the following should probably be in a registry of passes instead + // of inline if (GetNullHandling(*call) == compute::NullHandling::INTERSECTION) { - // kernels which always produce intersected validity can be resolved to null - // *now* if any of their inputs is a null literal + // kernels which always produce intersected validity can be resolved + // to null *now* if any of their inputs is a null literal for (const auto& argument : call->arguments) { - if (argument.IsNullLiteral()) return argument; + if (argument.IsNullLiteral()) { + state->Drop(expr); + return argument; + } } } @@ -526,43 +706,8 @@ Result FoldConstants(Expression2 expr, ExpressionState* state) { return expr; }); - - return expr; } -struct FlattenedAssociativeChain { - bool was_left_folded = true; - std::vector exprs, fringe; - - explicit FlattenedAssociativeChain(Expression2 expr) : exprs{std::move(expr)} { - auto call = CallNotNull(exprs.back()); - fringe = call->arguments; - - auto it = fringe.begin(); - - while (it != fringe.end()) { - auto sub_call = it->call(); - if (!sub_call || sub_call->function != call->function) { - ++it; - continue; - } - - if (it != fringe.begin()) { - was_left_folded = false; - } - - exprs.push_back(std::move(*it)); - it = fringe.erase(it); - it = fringe.insert(it, sub_call->arguments.begin(), sub_call->arguments.end()); - // NB: no increment so we hit sub_call's first argument next iteration - } - - DCHECK(std::all_of(exprs.begin(), exprs.end(), [](const Expression2& expr) { - return CallNotNull(expr)->options == nullptr; - })); - } -}; - inline std::vector GuaranteeConjunctionMembers( const Expression2& guaranteed_true_predicate) { auto guarantee = guaranteed_true_predicate.call(); @@ -577,8 +722,12 @@ inline std::vector GuaranteeConjunctionMembers( Status ExtractKnownFieldValuesImpl( std::vector* conjunction_members, std::unordered_map* known_values) { - for (auto&& member : *conjunction_members) { - ARROW_ASSIGN_OR_RAISE(member, Canonicalize(std::move(member))); + { + auto empty_state = std::make_shared(); + for (auto&& member : *conjunction_members) { + ARROW_ASSIGN_OR_RAISE(std::tie(member, std::ignore), + Canonicalize({std::move(member), empty_state})); + } } auto unconsumed_end = @@ -622,17 +771,21 @@ Status ExtractKnownFieldValuesImpl( Result> ExtractKnownFieldValues( const Expression2& guaranteed_true_predicate) { + if (!guaranteed_true_predicate.IsBound()) { + return Status::Invalid("guaranteed_true_predicate was not bound"); + } auto conjunction_members = GuaranteeConjunctionMembers(guaranteed_true_predicate); std::unordered_map known_values; RETURN_NOT_OK(ExtractKnownFieldValuesImpl(&conjunction_members, &known_values)); return known_values; } -Result ReplaceFieldsWithKnownValues( +Result ReplaceFieldsWithKnownValues( const std::unordered_map& known_values, - Expression2 expr) { + Expression2::BoundWithState bound) { + bound.second = std::make_shared(*bound.second); return Modify( - std::move(expr), + std::move(bound), [&known_values](Expression2 expr) { if (auto ref = expr.field_ref()) { auto it = known_values.find(*ref); @@ -645,18 +798,6 @@ Result ReplaceFieldsWithKnownValues( [](Expression2 expr, ...) { return expr; }); } -template ::value_type> -util::optional FoldLeft(It begin, It end, const BinOp& bin_op) { - if (begin == end) return util::nullopt; - - Out folded = std::move(*begin++); - while (begin != end) { - folded = bin_op(std::move(folded), std::move(*begin++)); - } - return folded; -} - inline bool IsBinaryAssociativeCommutative(const Expression2::Call& call) { static std::unordered_set binary_associative_commutative{ "and", "or", "and_kleene", "or_kleene", "xor", @@ -666,97 +807,13 @@ inline bool IsBinaryAssociativeCommutative(const Expression2::Call& call) { return it != binary_associative_commutative.end(); } -struct Comparison { - enum type { - NA = 0, - EQUAL = 1, - LESS = 2, - GREATER = 4, - NOT_EQUAL = LESS | GREATER, - LESS_EQUAL = LESS | EQUAL, - GREATER_EQUAL = GREATER | EQUAL, - }; +Result Canonicalize(Expression2::BoundWithState bound, + compute::ExecContext* exec_context) { + if (exec_context == nullptr) { + compute::ExecContext exec_context; + return Canonicalize(std::move(bound), &exec_context); + } - static const type* Get(const std::string& function) { - static std::unordered_map flipped_comparisons{ - {"equal", EQUAL}, {"not_equal", NOT_EQUAL}, - {"less", LESS}, {"less_equal", LESS_EQUAL}, - {"greater", GREATER}, {"greater_equal", GREATER_EQUAL}, - }; - - auto it = flipped_comparisons.find(function); - return it != flipped_comparisons.end() ? &it->second : nullptr; - } - - static const type* Get(const Expression2& expr) { - if (auto call = expr.call()) { - return Comparison::Get(call->function); - } - return nullptr; - } - - static Result Execute(Datum l, Datum r) { - if (!l.is_scalar() || !r.is_scalar()) { - return Status::Invalid("Cannot Execute Comparison on non-scalars"); - } - std::vector arguments{std::move(l), std::move(r)}; - - ARROW_ASSIGN_OR_RAISE(auto equal, compute::CallFunction("equal", arguments)); - - if (!equal.scalar()->is_valid) return NA; - if (equal.scalar_as().value) return EQUAL; - - ARROW_ASSIGN_OR_RAISE(auto less, compute::CallFunction("less", arguments)); - - if (!less.scalar()->is_valid) return NA; - return less.scalar_as().value ? LESS : GREATER; - } - - static type GetFlipped(type op) { - switch (op) { - case NA: - return NA; - case EQUAL: - return EQUAL; - case LESS: - return GREATER; - case GREATER: - return LESS; - case NOT_EQUAL: - return NOT_EQUAL; - case LESS_EQUAL: - return GREATER_EQUAL; - case GREATER_EQUAL: - return LESS_EQUAL; - }; - DCHECK(false); - return NA; - } - - static std::string GetName(type op) { - switch (op) { - case NA: - DCHECK(false) << "unreachable"; - break; - case EQUAL: - return "equal"; - case LESS: - return "less"; - case GREATER: - return "greater"; - case NOT_EQUAL: - return "not_equal"; - case LESS_EQUAL: - return "less_equal"; - case GREATER_EQUAL: - return "greater_equal"; - }; - DCHECK(false); - return "na"; - } -}; - -Result Canonicalize(Expression2 expr) { // If potentially reconstructing more deeply than a call's immediate arguments // (for example, when reorganizing an associative chain), add expressions to this set to // avoid unnecessary work @@ -773,8 +830,8 @@ Result Canonicalize(Expression2 expr) { } AlreadyCanonicalized; return Modify( - std::move(expr), - [&AlreadyCanonicalized](Expression2 expr) -> Result { + std::move(bound), + [&AlreadyCanonicalized, exec_context](Expression2 expr) -> Result { auto call = expr.call(); if (!call) return expr; @@ -822,13 +879,25 @@ Result Canonicalize(Expression2 expr) { if (call->arguments[0].literal() && !call->arguments[1].literal()) { // ensure that literals are on comparisons' RHS auto flipped_call = *call; - for (auto&& argument : flipped_call.arguments) { - ARROW_ASSIGN_OR_RAISE(argument, Canonicalize(std::move(argument))); - } flipped_call.function = Comparison::GetName(Comparison::GetFlipped(*cmp)); + // look up the flipped kernel + // TODO extract a helper for use here and in Bind + ARROW_ASSIGN_OR_RAISE( + auto function, + exec_context->func_registry()->GetFunction(flipped_call.function)); + + auto descrs = GetDescriptors(flipped_call.arguments); + for (auto& descr : descrs) { + if (RequriesDictionaryTransparency(flipped_call)) { + RETURN_NOT_OK(EnsureNotDictionary(&descr)); + } + } + ARROW_ASSIGN_OR_RAISE(flipped_call.kernel, function->DispatchExact(descrs)); + std::swap(flipped_call.arguments[0], flipped_call.arguments[1]); return Expression2( - std::make_shared(std::move(flipped_call))); + std::make_shared(std::move(flipped_call)), + expr.descr()); } } @@ -837,10 +906,10 @@ Result Canonicalize(Expression2 expr) { [](Expression2 expr, ...) { return expr; }); } -Result DirectComparisonSimplification(Expression2 expr, - const Expression2::Call& guarantee) { +Result DirectComparisonSimplification( + Expression2::BoundWithState bound, const Expression2::Call& guarantee) { return Modify( - std::move(expr), [](Expression2 expr) { return expr; }, + std::move(bound), [](Expression2 expr) { return expr; }, [&guarantee](Expression2 expr, ...) -> Result { auto call = expr.call(); if (!call) return expr; @@ -895,36 +964,251 @@ Result DirectComparisonSimplification(Expression2 expr, }); } -Result SimplifyWithGuarantee(Expression2 expr, ExpressionState* state, - const Expression2& guaranteed_true_predicate) { +Result SimplifyWithGuarantee( + Expression2::BoundWithState bound, const Expression2& guaranteed_true_predicate) { + if (!guaranteed_true_predicate.IsBound()) { + return Status::Invalid("guaranteed_true_predicate was not bound"); + } + auto conjunction_members = GuaranteeConjunctionMembers(guaranteed_true_predicate); std::unordered_map known_values; RETURN_NOT_OK(ExtractKnownFieldValuesImpl(&conjunction_members, &known_values)); - ARROW_ASSIGN_OR_RAISE(expr, - ReplaceFieldsWithKnownValues(known_values, std::move(expr))); + ARROW_ASSIGN_OR_RAISE(bound, + ReplaceFieldsWithKnownValues(known_values, std::move(bound))); - auto CanonicalizeAndFoldConstants = [&expr, state] { - ARROW_ASSIGN_OR_RAISE(expr, Canonicalize(std::move(expr))); - ARROW_ASSIGN_OR_RAISE(expr, FoldConstants(std::move(expr), state)); + auto CanonicalizeAndFoldConstants = [&bound] { + ARROW_ASSIGN_OR_RAISE(bound, Canonicalize(std::move(bound))); + ARROW_ASSIGN_OR_RAISE(bound, FoldConstants(std::move(bound))); return Status::OK(); }; RETURN_NOT_OK(CanonicalizeAndFoldConstants()); for (const auto& guarantee : conjunction_members) { if (Comparison::Get(guarantee) && guarantee.call()->arguments[1].literal()) { - ARROW_ASSIGN_OR_RAISE( - auto simplified, DirectComparisonSimplification(expr, *CallNotNull(guarantee))); + ARROW_ASSIGN_OR_RAISE(auto simplified, DirectComparisonSimplification( + bound, *CallNotNull(guarantee))); - if (Identical(simplified, expr)) continue; + if (Identical(simplified.first, bound.first)) continue; - expr = std::move(simplified); + bound = std::move(simplified); RETURN_NOT_OK(CanonicalizeAndFoldConstants()); } } - return expr; + return bound; +} + +Result> Serialize(const Expression2& expr) { + struct { + std::shared_ptr metadata_ = std::make_shared(); + ArrayVector columns_; + + Result AddScalar(const Scalar& scalar) { + auto ret = columns_.size(); + ARROW_ASSIGN_OR_RAISE(auto array, MakeArrayFromScalar(scalar, 1)); + columns_.push_back(std::move(array)); + return std::to_string(ret); + } + + Status Visit(const Expression2& expr) { + if (auto lit = expr.literal()) { + if (!lit->is_scalar()) { + return Status::NotImplemented("Serialization of non-scalar literals"); + } + ARROW_ASSIGN_OR_RAISE(auto value, AddScalar(*lit->scalar())); + metadata_->Append("literal", std::move(value)); + return Status::OK(); + } + + if (auto ref = expr.field_ref()) { + if (!ref->name()) { + return Status::NotImplemented("Serialization of non-name field_refs"); + } + metadata_->Append("field_ref", *ref->name()); + return Status::OK(); + } + + auto call = CallNotNull(expr); + metadata_->Append("call", call->function); + + for (const auto& argument : call->arguments) { + RETURN_NOT_OK(Visit(argument)); + } + + if (call->options) { + ARROW_ASSIGN_OR_RAISE(auto options_scalar, FunctionOptionsToStructScalar(*call)); + ARROW_ASSIGN_OR_RAISE(auto value, AddScalar(*options_scalar)); + metadata_->Append("options", std::move(value)); + } + + metadata_->Append("end", call->function); + return Status::OK(); + } + + Result> operator()(const Expression2& expr) { + RETURN_NOT_OK(Visit(expr)); + FieldVector fields(columns_.size()); + for (size_t i = 0; i < fields.size(); ++i) { + fields[i] = field("", columns_[i]->type()); + } + return RecordBatch::Make(schema(std::move(fields), std::move(metadata_)), 1, + std::move(columns_)); + } + } ToRecordBatch; + + ARROW_ASSIGN_OR_RAISE(auto batch, ToRecordBatch(expr)); + ARROW_ASSIGN_OR_RAISE(auto stream, io::BufferOutputStream::Create()); + ARROW_ASSIGN_OR_RAISE(auto writer, ipc::MakeFileWriter(stream, batch->schema())); + RETURN_NOT_OK(writer->WriteRecordBatch(*batch)); + RETURN_NOT_OK(writer->Close()); + return stream->Finish(); +} + +Result Deserialize(const Buffer& buffer) { + io::BufferReader stream(buffer); + ARROW_ASSIGN_OR_RAISE(auto reader, ipc::RecordBatchFileReader::Open(&stream)); + ARROW_ASSIGN_OR_RAISE(auto batch, reader->ReadRecordBatch(0)); + if (batch->schema()->metadata() == nullptr) { + return Status::Invalid("serialized Expression2's batch repr had null metadata"); + } + if (batch->num_rows() != 1) { + return Status::Invalid( + "serialized Expression2's batch repr was not a single row - had ", + batch->num_rows()); + } + + struct FromRecordBatch { + const RecordBatch& batch_; + int index_; + + const KeyValueMetadata& metadata() { return *batch_.schema()->metadata(); } + + Result> GetScalar(const std::string& i) { + int32_t column_index; + if (!internal::ParseValue(i.data(), i.length(), &column_index)) { + return Status::Invalid("Couldn't parse column_index"); + } + if (column_index >= batch_.num_columns()) { + return Status::Invalid("column_index out of bounds"); + } + return batch_.column(column_index)->GetScalar(0); + } + + Result GetOne() { + if (index_ >= metadata().size()) { + return Status::Invalid("unterminated serialized Expression2"); + } + + const std::string& key = metadata().key(index_); + const std::string& value = metadata().value(index_); + ++index_; + + if (key == "literal") { + ARROW_ASSIGN_OR_RAISE(auto scalar, GetScalar(value)); + return literal(std::move(scalar)); + } + + if (key == "field_ref") { + return field_ref(value); + } + + if (key != "call") { + return Status::Invalid("Unrecognized serialized Expression2 key ", key); + } + + std::vector arguments; + while (metadata().key(index_) != "end") { + if (metadata().key(index_) == "options") { + ARROW_ASSIGN_OR_RAISE(auto options_scalar, GetScalar(metadata().value(index_))); + auto expr = call(value, std::move(arguments)); + RETURN_NOT_OK(FunctionOptionsFromStructScalar( + checked_cast(options_scalar.get()), + const_cast(expr.call()))); + index_ += 2; + return expr; + } + + ARROW_ASSIGN_OR_RAISE(auto argument, GetOne()); + arguments.push_back(std::move(argument)); + } + + ++index_; + return call(value, std::move(arguments)); + } + }; + + return FromRecordBatch{*batch, 0}.GetOne(); +} + +Expression2 project(std::vector values, std::vector names) { + return call("struct", std::move(values), compute::StructOptions{std::move(names)}); +} + +Expression2 equal(Expression2 lhs, Expression2 rhs) { + return call("equal", {std::move(lhs), std::move(rhs)}); +} + +Expression2 not_equal(Expression2 lhs, Expression2 rhs) { + return call("not_equal", {std::move(lhs), std::move(rhs)}); +} + +Expression2 less(Expression2 lhs, Expression2 rhs) { + return call("less", {std::move(lhs), std::move(rhs)}); +} + +Expression2 less_equal(Expression2 lhs, Expression2 rhs) { + return call("less_equal", {std::move(lhs), std::move(rhs)}); +} + +Expression2 greater(Expression2 lhs, Expression2 rhs) { + return call("greater", {std::move(lhs), std::move(rhs)}); +} + +Expression2 greater_equal(Expression2 lhs, Expression2 rhs) { + return call("greater_equal", {std::move(lhs), std::move(rhs)}); +} + +Expression2 and_(Expression2 lhs, Expression2 rhs) { + return call("and_kleene", {std::move(lhs), std::move(rhs)}); +} + +Expression2 and_(const std::vector& operands) { + auto folded = FoldLeft(operands.begin(), + operands.end(), and_); + if (folded) { + return std::move(*folded); + } + return literal(true); +} + +Expression2 or_(Expression2 lhs, Expression2 rhs) { + return call("or_kleene", {std::move(lhs), std::move(rhs)}); +} + +Expression2 or_(const std::vector& operands) { + auto folded = FoldLeft(operands.begin(), + operands.end(), or_); + if (folded) { + return std::move(*folded); + } + return literal(false); +} + +Expression2 not_(Expression2 operand) { return call("invert", {std::move(operand)}); } + +Expression2 operator&&(Expression2 lhs, Expression2 rhs) { + return and_(std::move(lhs), std::move(rhs)); +} + +Expression2 operator||(Expression2 lhs, Expression2 rhs) { + return or_(std::move(lhs), std::move(rhs)); +} + +Result InsertImplicitCasts(Expression2 expr, const Schema& s) { + std::shared_ptr e(expr); + return InsertImplicitCasts(*e, s); } } // namespace dataset diff --git a/cpp/src/arrow/dataset/expression.h b/cpp/src/arrow/dataset/expression.h index 4da41b8b9c4..f452e54b7c2 100644 --- a/cpp/src/arrow/dataset/expression.h +++ b/cpp/src/arrow/dataset/expression.h @@ -45,9 +45,6 @@ namespace dataset { /// - A literal Datum. /// - A reference to a single (potentially nested) field of the input Datum. /// - A call to a compute function, with arguments specified by other Expressions. -/// - Optionally, a explicitly selected Kernel of the function. If provided, -/// execution will skip function lookup and kernel dispatch and use that Kernel -/// directly. class ARROW_DS_EXPORT Expression2 { public: struct Call { @@ -70,10 +67,10 @@ class ARROW_DS_EXPORT Expression2 { /// Bind this expression to the given input type, looking up Kernels and field types. /// Some expression simplification may be performed and implicit casts will be inserted. /// Any state necessary for execution will be initialized and returned. - using StateAndBound = std::pair>; - Result Bind(ValueDescr in, compute::ExecContext* = NULLPTR) const; - Result Bind(const Schema& in_schema, - compute::ExecContext* = NULLPTR) const; + using BoundWithState = std::pair>; + Result Bind(ValueDescr in, compute::ExecContext* = NULLPTR) const; + Result Bind(const Schema& in_schema, + compute::ExecContext* = NULLPTR) const; /// Return true if all an expression's field references have explicit ValueDescr and all /// of its functions' kernels are looked up. @@ -96,33 +93,32 @@ class ARROW_DS_EXPORT Expression2 { const Datum* literal() const; const FieldRef* field_ref() const; + // FIXME remove these + operator std::shared_ptr() const; // NOLINT runtime/explicit + Expression2(const Expression& expr); // NOLINT runtime/explicit + Expression2(std::shared_ptr expr) // NOLINT runtime/explicit + : Expression2(*expr) {} + const ValueDescr& descr() const { return descr_; } - using Impl = util::variant; + using Impl = util::Variant; explicit Expression2(std::shared_ptr impl, ValueDescr descr = {}) : impl_(std::move(impl)), descr_(std::move(descr)) {} + Expression2() = default; + private: std::shared_ptr impl_; ValueDescr descr_; // XXX someday // NullGeneralization::type evaluates_to_null_; - friend bool Identical(const Expression2& l, const Expression2& r); + ARROW_EXPORT friend bool Identical(const Expression2& l, const Expression2& r); ARROW_EXPORT friend void PrintTo(const Expression2&, std::ostream*); }; -struct ExpressionState { - virtual ~ExpressionState() = default; - - // Produce another instance of ExpressionState which may be safely used from a - // different thread - static Result> Clone( - const std::shared_ptr&, const Expression2&, compute::ExecContext*); -}; - inline bool operator==(const Expression2& l, const Expression2& r) { return l.Equals(r); } inline bool operator!=(const Expression2& l, const Expression2& r) { return !l.Equals(r); @@ -147,11 +143,6 @@ Expression2 call(std::string function, std::vector arguments, std::make_shared(std::move(options))); } -inline Expression2 project(std::vector names, - std::vector values) { - return call("struct", std::move(values), compute::StructOptions{std::move(names)}); -} - template Expression2 field_ref(Args&&... args) { return Expression2( @@ -166,36 +157,47 @@ Expression2 literal(Arg&& arg) { std::move(descr)); } -// Simplification passes +ARROW_DS_EXPORT +std::vector FieldsInExpression(const Expression2&); + +ARROW_DS_EXPORT +Result> ExtractKnownFieldValues( + const Expression2& guaranteed_true_predicate); + +/// \defgroup expression-passes Functions for modification of Expression2s +/// +/// @{ +/// +/// These operate on a bound expression and its bound state simultaneously, +/// ensuring that Call Expression2s' KernelState can be utilized or reassociated. /// Weak canonicalization which establishes guarantees for subsequent passes. Even /// equivalent Expressions may result in different canonicalized expressions. /// TODO this could be a strong canonicalization ARROW_DS_EXPORT -Result Canonicalize(Expression2 expr); +Result Canonicalize(Expression2::BoundWithState, + compute::ExecContext* = NULLPTR); /// Simplify Expressions based on literal arguments (for example, add(null, x) will always /// be null so replace the call with a null literal). Includes early evaluation of all /// calls whose arguments are entirely literal. ARROW_DS_EXPORT -Result FoldConstants(Expression2 expr, ExpressionState*); - -ARROW_DS_EXPORT -Result> ExtractKnownFieldValues( - const Expression2& guaranteed_true_predicate); +Result FoldConstants(Expression2::BoundWithState); ARROW_DS_EXPORT -Result ReplaceFieldsWithKnownValues( +Result ReplaceFieldsWithKnownValues( const std::unordered_map& known_values, - Expression2 expr); + Expression2::BoundWithState); /// Simplify an expression by replacing subexpressions based on a guarantee: /// a boolean expression which is guaranteed to evaluate to `true`. For example, this is /// used to remove redundant function calls from a filter expression or to replace a /// reference to a constant-value field with a literal. ARROW_DS_EXPORT -Result SimplifyWithGuarantee(Expression2 expr, ExpressionState*, - const Expression2& guaranteed_true_predicate); +Result SimplifyWithGuarantee( + Expression2::BoundWithState, const Expression2& guaranteed_true_predicate); + +/// @} // Execution @@ -203,7 +205,7 @@ Result SimplifyWithGuarantee(Expression2 expr, ExpressionState*, /// expression must be bound. Result ExecuteScalarExpression(const Expression2&, ExpressionState*, const Datum& input, - compute::ExecContext* exec_context = NULLPTR); + compute::ExecContext* = NULLPTR); // Serialization @@ -213,5 +215,33 @@ Result> Serialize(const Expression2&); ARROW_DS_EXPORT Result Deserialize(const Buffer&); +// Convenience aliases for factories + +ARROW_DS_EXPORT Expression2 project(std::vector values, + std::vector names); + +ARROW_DS_EXPORT Expression2 equal(Expression2 lhs, Expression2 rhs); + +ARROW_DS_EXPORT Expression2 not_equal(Expression2 lhs, Expression2 rhs); + +ARROW_DS_EXPORT Expression2 less(Expression2 lhs, Expression2 rhs); + +ARROW_DS_EXPORT Expression2 less_equal(Expression2 lhs, Expression2 rhs); + +ARROW_DS_EXPORT Expression2 greater(Expression2 lhs, Expression2 rhs); + +ARROW_DS_EXPORT Expression2 greater_equal(Expression2 lhs, Expression2 rhs); + +ARROW_DS_EXPORT Expression2 and_(Expression2 lhs, Expression2 rhs); +ARROW_DS_EXPORT Expression2 and_(const std::vector&); +ARROW_DS_EXPORT Expression2 or_(Expression2 lhs, Expression2 rhs); +ARROW_DS_EXPORT Expression2 or_(const std::vector&); +ARROW_DS_EXPORT Expression2 not_(Expression2 operand); + +// FIXME remove these +ARROW_DS_EXPORT Expression2 operator&&(Expression2 lhs, Expression2 rhs); +ARROW_DS_EXPORT Expression2 operator||(Expression2 lhs, Expression2 rhs); +ARROW_DS_EXPORT Result InsertImplicitCasts(Expression2, const Schema&); + } // namespace dataset } // namespace arrow diff --git a/cpp/src/arrow/dataset/expression_internal.h b/cpp/src/arrow/dataset/expression_internal.h index 1998a565ec3..548d36711fb 100644 --- a/cpp/src/arrow/dataset/expression_internal.h +++ b/cpp/src/arrow/dataset/expression_internal.h @@ -21,6 +21,7 @@ #include #include +#include "arrow/compute/api_vector.h" #include "arrow/record_batch.h" #include "arrow/table.h" #include "arrow/util/logging.h" @@ -62,9 +63,44 @@ inline std::vector GetDescriptors(const std::vector& ex return descrs; } -struct CallState : ExpressionState { - std::vector> argument_states; - std::shared_ptr kernel_state; +inline std::vector GetDescriptors(const std::vector& values) { + std::vector descrs(values.size()); + for (size_t i = 0; i < values.size(); ++i) { + descrs[i] = values[i].descr(); + } + return descrs; +} + +struct ARROW_DS_EXPORT ExpressionState { + std::unordered_map, + Expression2::Hash> + kernel_states; + + compute::KernelState* Get(const Expression2& expr) const { + auto it = kernel_states.find(expr); + if (it == kernel_states.end()) return nullptr; + return it->second.get(); + } + + void Replace(const Expression2& expr, const Expression2& replacement) { + auto it = kernel_states.find(expr); + if (it == kernel_states.end()) return; + + auto kernel_state = std::move(it->second); + kernel_states.erase(it); + kernel_states.emplace(replacement, std::move(kernel_state)); + } + + void Drop(const Expression2& expr) { + auto it = kernel_states.find(expr); + if (it == kernel_states.end()) return; + kernel_states.erase(it); + } + + void MoveFrom(ExpressionState* other) { + std::move(other->kernel_states.begin(), other->kernel_states.end(), + std::inserter(kernel_states, kernel_states.end())); + } }; struct FieldPathGetDatumImpl { @@ -106,5 +142,281 @@ inline Result GetDatumField(const FieldRef& ref, const Datum& input) { return field; } +struct Comparison { + enum type { + NA = 0, + EQUAL = 1, + LESS = 2, + GREATER = 4, + NOT_EQUAL = LESS | GREATER, + LESS_EQUAL = LESS | EQUAL, + GREATER_EQUAL = GREATER | EQUAL, + }; + + static const type* Get(const std::string& function) { + static std::unordered_map flipped_comparisons{ + {"equal", EQUAL}, {"not_equal", NOT_EQUAL}, + {"less", LESS}, {"less_equal", LESS_EQUAL}, + {"greater", GREATER}, {"greater_equal", GREATER_EQUAL}, + }; + + auto it = flipped_comparisons.find(function); + return it != flipped_comparisons.end() ? &it->second : nullptr; + } + + static const type* Get(const Expression2& expr) { + if (auto call = expr.call()) { + return Comparison::Get(call->function); + } + return nullptr; + } + + static Result Execute(Datum l, Datum r) { + if (!l.is_scalar() || !r.is_scalar()) { + return Status::Invalid("Cannot Execute Comparison on non-scalars"); + } + std::vector arguments{std::move(l), std::move(r)}; + + ARROW_ASSIGN_OR_RAISE(auto equal, compute::CallFunction("equal", arguments)); + + if (!equal.scalar()->is_valid) return NA; + if (equal.scalar_as().value) return EQUAL; + + ARROW_ASSIGN_OR_RAISE(auto less, compute::CallFunction("less", arguments)); + + if (!less.scalar()->is_valid) return NA; + return less.scalar_as().value ? LESS : GREATER; + } + + static type GetFlipped(type op) { + switch (op) { + case NA: + return NA; + case EQUAL: + return EQUAL; + case LESS: + return GREATER; + case GREATER: + return LESS; + case NOT_EQUAL: + return NOT_EQUAL; + case LESS_EQUAL: + return GREATER_EQUAL; + case GREATER_EQUAL: + return LESS_EQUAL; + }; + DCHECK(false); + return NA; + } + + static std::string GetName(type op) { + switch (op) { + case NA: + DCHECK(false) << "unreachable"; + break; + case EQUAL: + return "equal"; + case LESS: + return "less"; + case GREATER: + return "greater"; + case NOT_EQUAL: + return "not_equal"; + case LESS_EQUAL: + return "less_equal"; + case GREATER_EQUAL: + return "greater_equal"; + }; + DCHECK(false); + return "na"; + } +}; + +inline bool IsSetLookup(const std::string& function) { + return function == "is_in" || function == "index_in"; +} + +inline const compute::SetLookupOptions* GetSetLookupOptions( + const Expression2::Call& call) { + if (!IsSetLookup(call.function)) return nullptr; + return checked_cast(call.options.get()); +} + +inline bool RequriesDictionaryTransparency(const Expression2::Call& call) { + // TODO move this functionality into compute:: + + // Functions which don't provide kernels for dictionary types. Dictionaries will be + // decoded for these functions. + if (Comparison::Get(call.function)) return true; + + if (IsSetLookup(call.function)) return true; + + return false; +} + +inline Status EnsureNotDictionary(ValueDescr* descr) { + const auto& type = descr->type; + if (type && type->id() == Type::DICTIONARY) { + descr->type = checked_cast(*type).value_type(); + } + return Status::OK(); +} + +inline Status EnsureNotDictionary(Datum* datum) { + if (datum->type()->id() != Type::DICTIONARY) { + return Status::OK(); + } + + if (datum->is_scalar()) { + ARROW_ASSIGN_OR_RAISE( + *datum, + checked_cast(*datum->scalar()).GetEncodedValue()); + return Status::OK(); + } + + DCHECK_EQ(datum->kind(), Datum::ARRAY); + ArrayData indices = *datum->array(); + indices.type = checked_cast(*datum->type()).index_type(); + auto values = std::move(indices.dictionary); + + ARROW_ASSIGN_OR_RAISE( + *datum, compute::Take(values, indices, compute::TakeOptions::NoBoundsCheck())); + return Status::OK(); +} + +inline Status EnsureNotDictionary(Expression2::Call* call) { + if (auto options = GetSetLookupOptions(*call)) { + auto new_options = *options; + RETURN_NOT_OK(EnsureNotDictionary(&new_options.value_set)); + call->options.reset(new compute::SetLookupOptions(std::move(new_options))); + } + return Status::OK(); +} + +inline Result> FunctionOptionsToStructScalar( + const Expression2::Call& call) { + if (call.options == nullptr) { + return nullptr; + } + + auto Finish = [](ScalarVector values, std::vector names) { + FieldVector fields(names.size()); + for (size_t i = 0; i < fields.size(); ++i) { + fields[i] = field(std::move(names[i]), values[i]->type); + } + return std::make_shared(std::move(values), struct_(std::move(fields))); + }; + + if (auto options = GetSetLookupOptions(call)) { + if (!options->value_set.is_array()) { + return Status::NotImplemented("chunked value_set"); + } + return Finish( + { + std::make_shared(options->value_set.make_array()), + MakeScalar(options->skip_nulls), + }, + {"value_set", "skip_nulls"}); + } + + if (call.function == "cast") { + auto options = checked_cast(call.options.get()); + return Finish( + { + MakeNullScalar(options->to_type), + MakeScalar(options->allow_int_overflow), + MakeScalar(options->allow_time_truncate), + MakeScalar(options->allow_time_overflow), + MakeScalar(options->allow_decimal_truncate), + MakeScalar(options->allow_float_truncate), + MakeScalar(options->allow_invalid_utf8), + }, + { + "to_type_holder", + "allow_int_overflow", + "allow_time_truncate", + "allow_time_overflow", + "allow_decimal_truncate", + "allow_float_truncate", + "allow_invalid_utf8", + }); + } + + return Status::NotImplemented("conversion of options for ", call.function); +} + +inline Status FunctionOptionsFromStructScalar(const StructScalar* repr, + Expression2::Call* call) { + if (repr == nullptr) { + call->options = nullptr; + return Status::OK(); + } + + if (IsSetLookup(call->function)) { + ARROW_ASSIGN_OR_RAISE(auto value_set, repr->field("value_set")); + ARROW_ASSIGN_OR_RAISE(auto skip_nulls, repr->field("skip_nulls")); + call->options = std::make_shared( + checked_cast(*value_set).value, + checked_cast(*skip_nulls).value); + return Status::OK(); + } + + if (call->function == "cast") { + auto options = std::make_shared(); + ARROW_ASSIGN_OR_RAISE(auto to_type_holder, repr->field("to_type_holder")); + options->to_type = to_type_holder->type; + + int i = 1; + for (bool* opt : { + &options->allow_int_overflow, + &options->allow_time_truncate, + &options->allow_time_overflow, + &options->allow_decimal_truncate, + &options->allow_float_truncate, + &options->allow_invalid_utf8, + }) { + *opt = checked_cast(*repr->value[i++]).value; + } + + call->options = std::move(options); + return Status::OK(); + } + + return Status::NotImplemented("conversion of options for ", call->function); +} + +struct FlattenedAssociativeChain { + bool was_left_folded = true; + std::vector exprs, fringe; + + explicit FlattenedAssociativeChain(Expression2 expr) : exprs{std::move(expr)} { + auto call = CallNotNull(exprs.back()); + fringe = call->arguments; + + auto it = fringe.begin(); + + while (it != fringe.end()) { + auto sub_call = it->call(); + if (!sub_call || sub_call->function != call->function) { + ++it; + continue; + } + + if (it != fringe.begin()) { + was_left_folded = false; + } + + exprs.push_back(std::move(*it)); + it = fringe.erase(it); + it = fringe.insert(it, sub_call->arguments.begin(), sub_call->arguments.end()); + // NB: no increment so we hit sub_call's first argument next iteration + } + + DCHECK(std::all_of(exprs.begin(), exprs.end(), [](const Expression2& expr) { + return CallNotNull(expr)->options == nullptr; + })); + } +}; + } // namespace dataset } // namespace arrow diff --git a/cpp/src/arrow/dataset/expression_test.cc b/cpp/src/arrow/dataset/expression_test.cc index 7cc98a1e151..6274147d208 100644 --- a/cpp/src/arrow/dataset/expression_test.cc +++ b/cpp/src/arrow/dataset/expression_test.cc @@ -28,6 +28,7 @@ #include "arrow/compute/registry.h" #include "arrow/dataset/expression_internal.h" +#include "arrow/dataset/test_util.h" #include "arrow/testing/gtest_util.h" using testing::UnorderedElementsAreArray; @@ -39,11 +40,6 @@ using internal::checked_pointer_cast; namespace dataset { -const Schema kBoringSchema{{ - field("i32", int32()), - field("f32", float32()), -}}; - #define EXPECT_OK ARROW_EXPECT_OK TEST(Expression2, ToString) { @@ -91,6 +87,78 @@ TEST(Expression2, Hash) { EXPECT_EQ(set.size(), 6); } +TEST(Expression2, IsScalarExpression) { + EXPECT_TRUE(literal(true).IsScalarExpression()); + + auto arr = ArrayFromJSON(int8(), "[]"); + EXPECT_FALSE(literal(arr).IsScalarExpression()); + + EXPECT_TRUE(field_ref("a").IsScalarExpression()); + + EXPECT_TRUE(equal(field_ref("a"), literal(1)).IsScalarExpression()); + + EXPECT_FALSE(equal(field_ref("a"), literal(arr)).IsScalarExpression()); + + EXPECT_TRUE(call("is_in", {field_ref("a")}, compute::SetLookupOptions{arr, true}) + .IsScalarExpression()); + + // non scalar function + EXPECT_FALSE(call("take", {field_ref("a"), literal(arr)}).IsScalarExpression()); +} + +TEST(Expression2, IsSatisfiable) { + EXPECT_TRUE(literal(true).IsSatisfiable()); + EXPECT_FALSE(literal(false).IsSatisfiable()); + + auto null = std::make_shared(); + EXPECT_FALSE(literal(null).IsSatisfiable()); + + EXPECT_TRUE(field_ref("a").IsSatisfiable()); + + EXPECT_TRUE(equal(field_ref("a"), literal(1)).IsSatisfiable()); + + // NB: no constant folding here + EXPECT_TRUE(equal(literal(0), literal(1)).IsSatisfiable()); + + // When a top level conjunction contains an Expression2 which is certain to evaluate to + // null, it can only evaluate to null or false. + auto null_or_false = and_(literal(null), field_ref("a")); + // This may appear in satisfiable filters if coalesced + EXPECT_TRUE(call("is_null", {null_or_false}).IsSatisfiable()); + // ... but at the top level it is not satisfiable. + // This special case arises when (for example) an absent column has made + // one member of the conjunction always-null. This is fairly common and + // would be a worthwhile optimization to support. + // EXPECT_FALSE(null_or_false).IsSatisfiable()); +} + +TEST(Expression2, FieldsInExpression) { + auto ExpectFieldsAre = [](Expression2 expr, std::vector expected) { + EXPECT_THAT(FieldsInExpression(expr), testing::ContainerEq(expected)); + }; + + ExpectFieldsAre(literal(true), {}); + + ExpectFieldsAre(field_ref("a"), {"a"}); + + ExpectFieldsAre(equal(field_ref("a"), literal(1)), {"a"}); + + ExpectFieldsAre(equal(field_ref("a"), field_ref("b")), {"a", "b"}); + + ExpectFieldsAre( + or_(equal(field_ref("a"), literal(1)), equal(field_ref("a"), literal(2))), + {"a", "a"}); + + ExpectFieldsAre( + or_(equal(field_ref("a"), literal(1)), equal(field_ref("b"), literal(2))), + {"a", "b"}); + + ExpectFieldsAre(or_(and_(not_(equal(field_ref("a"), literal(1))), + equal(field_ref("b"), literal(2))), + not_(less(field_ref("c"), literal(3)))), + {"a", "b", "c"}); +} + TEST(Expression2, BindLiteral) { for (Datum dat : { Datum(3), @@ -161,6 +229,18 @@ TEST(Expression2, BindCall) { expr.Bind(Schema({field("a", int32()), field("b", int32())}))); } +TEST(Expression2, BindDictionaryTransparent) { + auto expr = call("equal", {field_ref("a"), field_ref("b")}); + EXPECT_FALSE(expr.IsBound()); + + ASSERT_OK_AND_ASSIGN( + std::tie(expr, std::ignore), + expr.Bind(Schema({field("a", utf8()), field("b", dictionary(int32(), utf8()))}))); + + EXPECT_EQ(expr.descr(), ValueDescr::Array(boolean())); + EXPECT_TRUE(expr.IsBound()); +} + TEST(Expression2, BindNestedCall) { auto expr = call("add", {field_ref("a"), @@ -208,35 +288,48 @@ TEST(Expression2, ExecuteFieldRef) { {"a": 0.0}, {"a": -1} ])"), - MakeNullScalar(float64())); + MakeNullScalar(null())); } -Result NaiveExecuteScalarExpression(const Expression2& expr, - ExpressionState* state, const Datum& input) { +Result NaiveExecuteScalarExpression(const Expression2& expr, const Datum& input) { auto call = expr.call(); if (call == nullptr) { // already tested execution of field_ref, execution of literal is trivial - return ExecuteScalarExpression(expr, state, input); + return ExecuteScalarExpression(expr, /*state=*/nullptr, input); } - auto call_state = checked_cast(state); - std::vector arguments(call->arguments.size()); for (size_t i = 0; i < arguments.size(); ++i) { - auto argument_state = call_state->argument_states[i].get(); - ARROW_ASSIGN_OR_RAISE(arguments[i], NaiveExecuteScalarExpression( - call->arguments[i], argument_state, input)); + ARROW_ASSIGN_OR_RAISE(arguments[i], + NaiveExecuteScalarExpression(call->arguments[i], input)); - EXPECT_EQ(call->arguments[i].descr(), arguments[i].descr()); + if (RequriesDictionaryTransparency(*call)) { + RETURN_NOT_OK(EnsureNotDictionary(&arguments[i])); + } } ARROW_ASSIGN_OR_RAISE(auto function, compute::GetFunctionRegistry()->GetFunction(call->function)); - ARROW_ASSIGN_OR_RAISE(auto expected_kernel, - function->DispatchExact(GetDescriptors(call->arguments))); + + auto descrs = GetDescriptors(call->arguments); + for (size_t i = 0; i < arguments.size(); ++i) { + if (RequriesDictionaryTransparency(*call)) { + RETURN_NOT_OK(EnsureNotDictionary(&descrs[i])); + } + EXPECT_EQ(arguments[i].descr(), descrs[i]); + } + + ARROW_ASSIGN_OR_RAISE(auto expected_kernel, function->DispatchExact(descrs)); EXPECT_EQ(call->kernel, expected_kernel); + auto options = call->options; + if (RequriesDictionaryTransparency(*call)) { + auto non_dict_call = *call; + RETURN_NOT_OK(EnsureNotDictionary(&non_dict_call)); + options = non_dict_call.options; + } + compute::ExecContext exec_context; return function->Execute(arguments, call->options.get(), &exec_context); } @@ -247,8 +340,7 @@ void AssertExecute(Expression2 expr, Datum in) { ASSERT_OK_AND_ASSIGN(Datum actual, ExecuteScalarExpression(expr, state.get(), in)); - ASSERT_OK_AND_ASSIGN(Datum expected, - NaiveExecuteScalarExpression(expr, state.get(), in)); + ASSERT_OK_AND_ASSIGN(Datum expected, NaiveExecuteScalarExpression(expr, in)); AssertDatumsEqual(actual, expected, /*verbose=*/true); } @@ -277,7 +369,7 @@ TEST(Expression2, ExecuteCall) { {"a": "12/11/1900"} ])")); - AssertExecute(project({"a + 3.5"}, {call("add", {field_ref("a"), literal(3.5)})}), + AssertExecute(project({call("add", {field_ref("a"), literal(3.5)})}, {"a + 3.5"}), ArrayFromJSON(struct_({field("a", float64())}), R"([ {"a": 6.125}, {"a": 0.0}, @@ -285,18 +377,51 @@ TEST(Expression2, ExecuteCall) { ])")); } +TEST(Expression2, ExecuteDictionaryTransparent) { + AssertExecute( + equal(field_ref("a"), field_ref("b")), + ArrayFromJSON( + struct_({field("a", dictionary(int32(), utf8())), field("b", utf8())}), R"([ + {"a": "hi", "b": "hi"}, + {"a": "", "b": ""}, + {"a": "hi", "b": "hello"} + ])")); + + Datum dict_set = ArrayFromJSON(dictionary(int32(), utf8()), R"(["a"])"); + AssertExecute(call("is_in", {field_ref("a")}, + compute::SetLookupOptions{dict_set, + /*skip_nulls=*/false}), + ArrayFromJSON(struct_({field("a", utf8())}), R"([ + {"a": "a"}, + {"a": "good"}, + {"a": null} + ])")); +} + struct { void operator()(Expression2 expr, Expression2 expected) { - std::shared_ptr state; - ASSERT_OK_AND_ASSIGN(std::tie(expr, state), expr.Bind(kBoringSchema)); - ASSERT_OK_AND_ASSIGN(std::tie(expected, std::ignore), expected.Bind(kBoringSchema)); - ASSERT_OK_AND_ASSIGN(auto actual, FoldConstants(expr, state.get())); - EXPECT_EQ(actual, expected); - if (actual == expr) { + this->operator()(expr, expected, + [](Expression2::BoundWithState, Expression2::BoundWithState, + Expression2::BoundWithState) {}); + } + + template + void operator()(Expression2 expr, Expression2 unbound_expected, + const ExtraExpectations& expect) { + ASSERT_OK_AND_ASSIGN(auto bound, expr.Bind(*kBoringSchema)); + ASSERT_OK_AND_ASSIGN(auto expected, unbound_expected.Bind(*kBoringSchema)); + ASSERT_OK_AND_ASSIGN(auto folded, FoldConstants(bound)); + + EXPECT_EQ(folded.first, expected.first); + + if (folded.first == expr) { // no change -> must be identical - EXPECT_TRUE(Identical(actual, expr)); + EXPECT_TRUE(Identical(folded.first, expr)); } + + expect(bound, folded, expected); } + } ExpectFoldsTo; TEST(Expression2, FoldConstants) { @@ -345,6 +470,21 @@ TEST(Expression2, FoldConstants) { }), literal(2), })); + + compute::SetLookupOptions in_123(ArrayFromJSON(int32(), "[1,2,3]"), true); + ExpectFoldsTo( + call("is_in", + {call("add", {field_ref("i32"), call("multiply", {literal(2), literal(3)})})}, + in_123), + call("is_in", {call("add", {field_ref("i32"), literal(6)})}, in_123), + [](Expression2::BoundWithState bound, Expression2::BoundWithState folded, + Expression2::BoundWithState) { + const compute::KernelState* state = bound.second->Get(bound.first); + const compute::KernelState* folded_state = folded.second->Get(folded.first); + EXPECT_EQ(folded_state, state) << "The kernel state associated with is_in (the " + "hash table for looking up membership) " + "must be associated with the folded is_in call"; + }); } TEST(Expression2, FoldConstantsBoolean) { @@ -356,113 +496,113 @@ TEST(Expression2, FoldConstantsBoolean) { auto true_ = literal(true); auto false_ = literal(false); - ExpectFoldsTo(call("and_kleene", {false_, whatever}), false_); - ExpectFoldsTo(call("and_kleene", {true_, whatever}), whatever); + ExpectFoldsTo(and_(false_, whatever), false_); + ExpectFoldsTo(and_(true_, whatever), whatever); - ExpectFoldsTo(call("or_kleene", {true_, whatever}), true_); - ExpectFoldsTo(call("or_kleene", {false_, whatever}), whatever); + ExpectFoldsTo(or_(true_, whatever), true_); + ExpectFoldsTo(or_(false_, whatever), whatever); } TEST(Expression2, ExtractKnownFieldValues) { struct { void operator()(Expression2 guarantee, std::unordered_map expected) { + ASSERT_OK_AND_ASSIGN(std::tie(guarantee, std::ignore), + guarantee.Bind(*kBoringSchema)); ASSERT_OK_AND_ASSIGN(auto actual, ExtractKnownFieldValues(guarantee)); - EXPECT_THAT(actual, UnorderedElementsAreArray(expected)); + EXPECT_THAT(actual, UnorderedElementsAreArray(expected)) + << " guarantee: " << guarantee.ToString(); } } ExpectKnown; - ExpectKnown(call("equal", {field_ref("a"), literal(3)}), {{"a", Datum(3)}}); + ExpectKnown(equal(field_ref("i32"), literal(3)), {{"i32", Datum(3)}}); - ExpectKnown(call("greater", {field_ref("a"), literal(3)}), {}); + ExpectKnown(greater(field_ref("i32"), literal(3)), {}); // FIXME known null should be expressed with is_null rather than equality auto null_int32 = std::make_shared(); - ExpectKnown(call("equal", {field_ref("a"), literal(null_int32)}), - {{"a", Datum(null_int32)}}); - - ExpectKnown(call("and_kleene", {call("equal", {field_ref("a"), literal(3)}), - call("equal", {literal(1), field_ref("b")})}), - {{"a", Datum(3)}, {"b", Datum(1)}}); + ExpectKnown(equal(field_ref("i32"), literal(null_int32)), {{"i32", Datum(null_int32)}}); - ExpectKnown(call("and_kleene", - {call("equal", {field_ref("a"), literal(3)}), - call("and_kleene", {call("equal", {field_ref("b"), literal(2)}), - call("equal", {literal(1), field_ref("c")})})}), - {{"a", Datum(3)}, {"b", Datum(2)}, {"c", Datum(1)}}); + ExpectKnown( + and_({equal(field_ref("i32"), literal(3)), equal(literal(1.5F), field_ref("f32"))}), + {{"i32", Datum(3)}, {"f32", Datum(1.5F)}}); - ExpectKnown(call("and_kleene", - {call("or_kleene", {call("equal", {field_ref("a"), literal(3)}), - call("equal", {field_ref("a"), literal(4)})}), - call("equal", {literal(1), field_ref("b")})}), - {{"b", Datum(1)}}); + ExpectKnown( + and_({equal(field_ref("i32"), literal(3)), equal(literal(2.F), field_ref("f32")), + equal(literal(1), field_ref("i32_req"))}), + {{"i32", Datum(3)}, {"f32", Datum(2.F)}, {"i32_req", Datum(1)}}); ExpectKnown( - call("and_kleene", - {call("equal", {field_ref("a"), literal(3)}), - call("and_kleene", {call("and_kleene", {field_ref("b"), field_ref("d")}), - call("equal", {literal(1), field_ref("c")})})}), - {{"a", Datum(3)}, {"c", Datum(1)}}); + and_(or_(equal(field_ref("i32"), literal(3)), equal(field_ref("i32"), literal(4))), + equal(literal(2.F), field_ref("f32"))), + {{"f32", Datum(2.F)}}); + + ExpectKnown(and_({equal(field_ref("i32"), literal(3)), + equal(field_ref("f32"), field_ref("f32_req")), + equal(literal(1), field_ref("i32_req"))}), + {{"i32", Datum(3)}, {"i32_req", Datum(1)}}); } TEST(Expression2, ReplaceFieldsWithKnownValues) { - auto ExpectSimplifiesTo = + auto ExpectReplacesTo = [](Expression2 expr, std::unordered_map known_values, - Expression2 expected) { - ASSERT_OK_AND_ASSIGN(auto actual, - ReplaceFieldsWithKnownValues(known_values, expr)); - EXPECT_EQ(actual, expected); + Expression2 unbound_expected) { + ASSERT_OK_AND_ASSIGN(auto bound, expr.Bind(*kBoringSchema)); + ASSERT_OK_AND_ASSIGN(auto expected, unbound_expected.Bind(*kBoringSchema)); + ASSERT_OK_AND_ASSIGN(auto replaced, + ReplaceFieldsWithKnownValues(known_values, bound)); + + EXPECT_EQ(replaced.first, expected.first); - if (actual == expr) { + if (replaced.first == expr) { // no change -> must be identical - EXPECT_TRUE(Identical(actual, expr)); + EXPECT_TRUE(Identical(replaced.first, expr)); } }; - std::unordered_map a_is_3{{"a", Datum(3)}}; - - ExpectSimplifiesTo(literal(1), a_is_3, literal(1)); - - ExpectSimplifiesTo(field_ref("a"), a_is_3, literal(3)); - - ExpectSimplifiesTo(field_ref("b"), a_is_3, field_ref("b")); - - ExpectSimplifiesTo(call("equal", {field_ref("a"), literal(1)}), a_is_3, - call("equal", {literal(3), literal(1)})); - - ExpectSimplifiesTo(call("add", - { - call("subtract", - { - field_ref("a"), - call("multiply", {literal(2), literal(3)}), - }), - literal(2), - }), - a_is_3, - call("add", { - call("subtract", - { - literal(3), - call("multiply", {literal(2), literal(3)}), - }), - literal(2), - })); + std::unordered_map i32_is_3{{"i32", Datum(3)}}; + + ExpectReplacesTo(literal(1), i32_is_3, literal(1)); + + ExpectReplacesTo(field_ref("i32"), i32_is_3, literal(3)); + + ExpectReplacesTo(field_ref("b"), i32_is_3, field_ref("b")); + + ExpectReplacesTo(equal(field_ref("i32"), literal(1)), i32_is_3, + equal(literal(3), literal(1))); + + ExpectReplacesTo(call("add", + { + call("subtract", + { + field_ref("i32"), + call("multiply", {literal(2), literal(3)}), + }), + literal(2), + }), + i32_is_3, + call("add", { + call("subtract", + { + literal(3), + call("multiply", {literal(2), literal(3)}), + }), + literal(2), + })); } struct { - void operator()(Expression2 expr, Expression2 expected) const { - ASSERT_OK_AND_ASSIGN(auto actual, Canonicalize(expr)); - EXPECT_EQ(actual, expected); + void operator()(Expression2 expr, Expression2 unbound_expected) const { + ASSERT_OK_AND_ASSIGN(auto bound, expr.Bind(*kBoringSchema)); + ASSERT_OK_AND_ASSIGN(auto expected, unbound_expected.Bind(*kBoringSchema)); + ASSERT_OK_AND_ASSIGN(auto actual, Canonicalize(bound)); - if (expr.IsBound()) { - EXPECT_TRUE(actual.IsBound()); - } + EXPECT_EQ(actual.first, expected.first); - if (actual == expr) { + if (actual.first == expr) { // no change -> must be identical - EXPECT_TRUE(Identical(actual, expr)); + EXPECT_TRUE(Identical(actual.first, expr)); } } } ExpectCanonicalizesTo; @@ -472,8 +612,8 @@ TEST(Expression2, CanonicalizeTrivial) { ExpectCanonicalizesTo(field_ref("b"), field_ref("b")); - ExpectCanonicalizesTo(call("equal", {field_ref("a"), field_ref("b")}), - call("equal", {field_ref("a"), field_ref("b")})); + ExpectCanonicalizesTo(equal(field_ref("i32"), field_ref("i32_req")), + equal(field_ref("i32"), field_ref("i32_req"))); } TEST(Expression2, CanonicalizeAnd) { @@ -481,95 +621,79 @@ TEST(Expression2, CanonicalizeAnd) { auto true_ = literal(true); auto null_ = literal(std::make_shared()); - auto a = field_ref("a"); + auto b = field_ref("bool"); auto c = call("equal", {literal(1), literal(2)}); - auto and_ = [](Expression2 l, Expression2 r) { return call("and_kleene", {l, r}); }; - // no change possible: - ExpectCanonicalizesTo(and_(a, c), and_(a, c)); + ExpectCanonicalizesTo(and_(b, c), and_(b, c)); // literals are placed innermost - ExpectCanonicalizesTo(and_(a, true_), and_(true_, a)); - ExpectCanonicalizesTo(and_(true_, a), and_(true_, a)); + ExpectCanonicalizesTo(and_(b, true_), and_(true_, b)); + ExpectCanonicalizesTo(and_(true_, b), and_(true_, b)); - ExpectCanonicalizesTo(and_(a, and_(true_, c)), and_(and_(true_, a), c)); - ExpectCanonicalizesTo(and_(a, and_(and_(true_, true_), c)), - and_(and_(and_(true_, true_), a), c)); - ExpectCanonicalizesTo(and_(a, and_(and_(true_, null_), c)), - and_(and_(and_(null_, true_), a), c)); - ExpectCanonicalizesTo(and_(a, and_(and_(true_, null_), and_(c, null_))), - and_(and_(and_(and_(null_, null_), true_), a), c)); + ExpectCanonicalizesTo(and_(b, and_(true_, c)), and_(and_(true_, b), c)); + ExpectCanonicalizesTo(and_(b, and_(and_(true_, true_), c)), + and_(and_(and_(true_, true_), b), c)); + ExpectCanonicalizesTo(and_(b, and_(and_(true_, null_), c)), + and_(and_(and_(null_, true_), b), c)); + ExpectCanonicalizesTo(and_(b, and_(and_(true_, null_), and_(c, null_))), + and_(and_(and_(and_(null_, null_), true_), b), c)); // catches and_kleene even when it's a subexpression - ExpectCanonicalizesTo(call("is_valid", {and_(a, true_)}), - call("is_valid", {and_(true_, a)})); + ExpectCanonicalizesTo(call("is_valid", {and_(b, true_)}), + call("is_valid", {and_(true_, b)})); } TEST(Expression2, CanonicalizeComparison) { - // some aliases for brevity: - auto equal = [](Expression2 l, Expression2 r) { return call("equal", {l, r}); }; - auto less = [](Expression2 l, Expression2 r) { return call("less", {l, r}); }; - auto greater = [](Expression2 l, Expression2 r) { return call("greater", {l, r}); }; + ExpectCanonicalizesTo(equal(literal(1), field_ref("i32")), + equal(field_ref("i32"), literal(1))); - ExpectCanonicalizesTo(equal(literal(1), field_ref("a")), - equal(field_ref("a"), literal(1))); + ExpectCanonicalizesTo(equal(field_ref("i32"), literal(1)), + equal(field_ref("i32"), literal(1))); - ExpectCanonicalizesTo(equal(field_ref("a"), literal(1)), - equal(field_ref("a"), literal(1))); + ExpectCanonicalizesTo(less(literal(1), field_ref("i32")), + greater(field_ref("i32"), literal(1))); - ExpectCanonicalizesTo(less(literal(1), field_ref("a")), - greater(field_ref("a"), literal(1))); - - ExpectCanonicalizesTo(less(field_ref("a"), literal(1)), - less(field_ref("a"), literal(1))); + ExpectCanonicalizesTo(less(field_ref("i32"), literal(1)), + less(field_ref("i32"), literal(1))); } struct Simplify { - Expression2 filter; + Expression2 expr; struct Expectable { - Expression2 filter, guarantee; + Expression2 expr, guarantee; + + void Expect(Expression2 unbound_expected) { + ASSERT_OK_AND_ASSIGN(auto bound, expr.Bind(*kBoringSchema)); + ASSERT_OK_AND_ASSIGN(std::tie(guarantee, std::ignore), + guarantee.Bind(*kBoringSchema)); - void Expect(Expression2 expected) { - std::shared_ptr state; - ASSERT_OK_AND_ASSIGN(std::tie(filter, state), filter.Bind(kBoringSchema)); + ASSERT_OK_AND_ASSIGN(auto simplified, SimplifyWithGuarantee(bound, guarantee)); - ASSERT_OK_AND_ASSIGN(std::tie(expected, std::ignore), expected.Bind(kBoringSchema)); - ASSERT_OK_AND_ASSIGN(auto simplified, - SimplifyWithGuarantee(filter, state.get(), guarantee)); - EXPECT_EQ(simplified, expected) << " original: " << filter.ToString() << "\n" - << " guarantee: " << guarantee.ToString() << "\n" - << (simplified == filter ? " (no change)\n" : ""); + ASSERT_OK_AND_ASSIGN(auto expected, unbound_expected.Bind(*kBoringSchema)); + EXPECT_EQ(simplified.first, expected.first) + << " original: " << expr.ToString() << "\n" + << " guarantee: " << guarantee.ToString() << "\n" + << (simplified.first == bound.first ? " (no change)\n" : ""); - if (simplified == filter) { - EXPECT_TRUE(Identical(simplified, filter)); + if (simplified.first == bound.first) { + EXPECT_TRUE(Identical(simplified.first, bound.first)); } } - void ExpectUnchanged() { Expect(filter); } + void ExpectUnchanged() { Expect(expr); } void Expect(bool constant) { Expect(literal(constant)); } }; - Expectable WithGuarantee(Expression2 guarantee) { return {filter, guarantee}; } + Expectable WithGuarantee(Expression2 guarantee) { return {expr, guarantee}; } }; TEST(Expression2, SingleComparisonGuarantees) { - // some aliases for brevity: - auto equal = [](Expression2 l, Expression2 r) { return call("equal", {l, r}); }; - auto less = [](Expression2 l, Expression2 r) { return call("less", {l, r}); }; - auto greater = [](Expression2 l, Expression2 r) { return call("greater", {l, r}); }; - auto not_equal = [](Expression2 l, Expression2 r) { return call("not_equal", {l, r}); }; - auto less_equal = [](Expression2 l, Expression2 r) { - return call("less_equal", {l, r}); - }; - auto greater_equal = [](Expression2 l, Expression2 r) { - return call("greater_equal", {l, r}); - }; auto i32 = field_ref("i32"); // i32 is guaranteed equal to 3, so the projection can just materialize that constant // and need not incur IO - Simplify{project({"i32 + 1"}, {call("add", {i32, literal(1)})})} + Simplify{project({call("add", {i32, literal(1)})}, {"i32 + 1"})} .WithGuarantee(equal(i32, literal(3))) .Expect(literal( std::make_shared(ScalarVector{std::make_shared(4)}, @@ -658,7 +782,7 @@ TEST(Expression2, SingleComparisonGuarantees) { {"i32"})); std::shared_ptr state; - ASSERT_OK_AND_ASSIGN(std::tie(filter, state), filter.Bind(kBoringSchema)); + ASSERT_OK_AND_ASSIGN(std::tie(filter, state), filter.Bind(*kBoringSchema)); ASSERT_OK_AND_ASSIGN(Datum evaluated, ExecuteScalarExpression(filter, state.get(), input)); @@ -681,48 +805,135 @@ TEST(Expression2, SingleComparisonGuarantees) { TEST(Expression2, SimplifyWithGuarantee) { // drop both members of a conjunctive filter - Simplify{call("and_kleene", - { - call("equal", {field_ref("i32"), literal(2)}), - call("equal", {field_ref("f32"), literal(3.5F)}), - })} - .WithGuarantee(call("and_kleene", - { - call("greater_equal", {field_ref("i32"), literal(0)}), - call("less_equal", {field_ref("i32"), literal(1)}), - })) + Simplify{ + and_(equal(field_ref("i32"), literal(2)), equal(field_ref("f32"), literal(3.5F)))} + .WithGuarantee(and_(greater_equal(field_ref("i32"), literal(0)), + less_equal(field_ref("i32"), literal(1)))) .Expect(false); // drop one member of a conjunctive filter - Simplify{call("and_kleene", - { - call("equal", {field_ref("i32"), literal(0)}), - call("equal", {field_ref("f32"), literal(3.5F)}), - })} - .WithGuarantee(call("equal", {field_ref("i32"), literal(0)})) - .Expect(call("equal", {field_ref("f32"), literal(3.5F)})); + Simplify{ + and_(equal(field_ref("i32"), literal(0)), equal(field_ref("f32"), literal(3.5F)))} + .WithGuarantee(equal(field_ref("i32"), literal(0))) + .Expect(equal(field_ref("f32"), literal(3.5F))); // drop both members of a disjunctive filter - Simplify{call("or_kleene", - { - call("equal", {field_ref("i32"), literal(0)}), - call("equal", {field_ref("f32"), literal(3.5F)}), - })} - .WithGuarantee(call("equal", {field_ref("i32"), literal(0)})) + Simplify{ + or_(equal(field_ref("i32"), literal(0)), equal(field_ref("f32"), literal(3.5F)))} + .WithGuarantee(equal(field_ref("i32"), literal(0))) .Expect(true); // drop one member of a disjunctive filter - Simplify{call("or_kleene", - { - call("equal", {field_ref("i32"), literal(0)}), - call("equal", {field_ref("i32"), literal(3)}), - })} - .WithGuarantee(call("and_kleene", - { - call("greater_equal", {field_ref("i32"), literal(0)}), - call("less_equal", {field_ref("i32"), literal(1)}), - })) - .Expect(call("equal", {field_ref("i32"), literal(0)})); + Simplify{or_(equal(field_ref("i32"), literal(0)), equal(field_ref("i32"), literal(3)))} + .WithGuarantee(and_(greater_equal(field_ref("i32"), literal(0)), + less_equal(field_ref("i32"), literal(1)))) + .Expect(equal(field_ref("i32"), literal(0))); +} + +TEST(Expression2, SimplifyThenExecute) { + auto filter = + and_({equal(field_ref("f32"), literal(0.F)), + call("is_in", {field_ref("i32")}, + compute::SetLookupOptions{ArrayFromJSON(int32(), "[1,2,3]"), true})}); + + ASSERT_OK_AND_ASSIGN(auto bound, filter.Bind(*kBoringSchema)); + ASSERT_OK_AND_ASSIGN(auto guarantee, + equal(field_ref("f32"), literal(0.F)).Bind(*kBoringSchema)); + + ASSERT_OK_AND_ASSIGN(bound, SimplifyWithGuarantee(bound, guarantee.first)); + + auto input = RecordBatchFromJSON(kBoringSchema, R"([ + {"i32": 0, "f32": -0.1}, + {"i32": 0, "f32": 0.3}, + {"i32": 1, "f32": 0.2}, + {"i32": 2, "f32": -0.1}, + {"i32": 0, "f32": 0.1}, + {"i32": 0, "f32": null}, + {"i32": 0, "f32": 1.0} + ])"); + ASSERT_OK_AND_ASSIGN(Datum evaluated, + ExecuteScalarExpression(bound.first, bound.second.get(), input)); +} + +TEST(Expression2, Filter) { + auto ExpectFilter = [](Expression2 filter, std::string batch_json) { + ASSERT_OK_AND_ASSIGN(auto s, kBoringSchema->AddField(0, field("in", boolean()))); + auto batch = RecordBatchFromJSON(s, batch_json); + auto expected_mask = batch->column(0); + + std::shared_ptr state; + ASSERT_OK_AND_ASSIGN(std::tie(filter, state), filter.Bind(*kBoringSchema)); + ASSERT_OK_AND_ASSIGN(Datum mask, ExecuteScalarExpression(filter, state.get(), batch)); + + AssertDatumsEqual(expected_mask, mask); + }; + + ExpectFilter(equal(field_ref("i32"), literal(0)), R"([ + {"i32": 0, "f32": -0.1, "in": 1}, + {"i32": 0, "f32": 0.3, "in": 1}, + {"i32": 1, "f32": 0.2, "in": 0}, + {"i32": 2, "f32": -0.1, "in": 0}, + {"i32": 0, "f32": 0.1, "in": 1}, + {"i32": 0, "f32": null, "in": 1}, + {"i32": 0, "f32": 1.0, "in": 1} + ])"); +} + +TEST(Expression2, SerializationRoundTrips) { + auto ExpectRoundTrips = [](const Expression2& expr) { + ASSERT_OK_AND_ASSIGN(auto serialized, Serialize(expr)); + ASSERT_OK_AND_ASSIGN(Expression2 roundtripped, Deserialize(*serialized)); + EXPECT_EQ(expr, roundtripped); + }; + + ExpectRoundTrips(literal(MakeNullScalar(null()))); + + ExpectRoundTrips(literal(MakeNullScalar(int32()))); + + ExpectRoundTrips( + literal(MakeNullScalar(struct_({field("i", int32()), field("s", utf8())})))); + + ExpectRoundTrips(literal(true)); + + ExpectRoundTrips(literal(false)); + + ExpectRoundTrips(literal(1)); + + ExpectRoundTrips(literal(1.125)); + + ExpectRoundTrips(literal("stringy strings")); + + ExpectRoundTrips(field_ref("field")); + + ExpectRoundTrips(greater(field_ref("a"), literal(0.25))); + + ExpectRoundTrips( + or_({equal(field_ref("a"), literal(1)), not_equal(field_ref("b"), literal("hello")), + equal(field_ref("b"), literal("foo bar"))})); + + ExpectRoundTrips(not_(field_ref("alpha"))); + + compute::CastOptions to_float64; + to_float64.to_type = float64(); + ExpectRoundTrips( + call("is_in", {call("cast", {field_ref("version")}, to_float64)}, + compute::SetLookupOptions{ArrayFromJSON(float64(), "[0.5, 1.0, 2.0]"), true})); + + ExpectRoundTrips(call("is_valid", {field_ref("validity")})); + + ExpectRoundTrips(and_({and_(greater_equal(field_ref("x"), literal(-1.5)), + less(field_ref("x"), literal(0.0))), + and_(greater_equal(field_ref("y"), literal(0.0)), + less(field_ref("y"), literal(1.5))), + and_(greater(field_ref("z"), literal(1.5)), + less_equal(field_ref("z"), literal(3.0)))})); + + ExpectRoundTrips(and_({equal(field_ref("year"), literal(int16_t(1999))), + equal(field_ref("month"), literal(int8_t(12))), + equal(field_ref("day"), literal(int8_t(31))), + equal(field_ref("hour"), literal(int8_t(0))), + equal(field_ref("alpha"), literal(int32_t(0))), + equal(field_ref("beta"), literal(3.25f))})); } } // namespace dataset diff --git a/cpp/src/arrow/dataset/file_base.cc b/cpp/src/arrow/dataset/file_base.cc index ef044ea3a03..f8350a0292c 100644 --- a/cpp/src/arrow/dataset/file_base.cc +++ b/cpp/src/arrow/dataset/file_base.cc @@ -57,16 +57,16 @@ Result> FileSource::Open() const { Result> FileFormat::MakeFragment( FileSource source, std::shared_ptr physical_schema) { - return MakeFragment(std::move(source), scalar(true), std::move(physical_schema)); + return MakeFragment(std::move(source), literal(true), std::move(physical_schema)); } Result> FileFormat::MakeFragment( - FileSource source, std::shared_ptr partition_expression) { + FileSource source, Expression2 partition_expression) { return MakeFragment(std::move(source), std::move(partition_expression), nullptr); } Result> FileFormat::MakeFragment( - FileSource source, std::shared_ptr partition_expression, + FileSource source, Expression2 partition_expression, std::shared_ptr physical_schema) { return std::shared_ptr( new FileFragment(std::move(source), shared_from_this(), @@ -83,7 +83,7 @@ Result FileFragment::Scan(std::shared_ptr options } FileSystemDataset::FileSystemDataset(std::shared_ptr schema, - std::shared_ptr root_partition, + Expression2 root_partition, std::shared_ptr format, std::shared_ptr filesystem, std::vector> fragments) @@ -93,7 +93,7 @@ FileSystemDataset::FileSystemDataset(std::shared_ptr schema, fragments_(std::move(fragments)) {} Result> FileSystemDataset::Make( - std::shared_ptr schema, std::shared_ptr root_partition, + std::shared_ptr schema, Expression2 root_partition, std::shared_ptr format, std::shared_ptr filesystem, std::vector> fragments) { return std::shared_ptr(new FileSystemDataset( @@ -129,20 +129,23 @@ std::string FileSystemDataset::ToString() const { repr += "\n" + fragment->source().path(); const auto& partition = fragment->partition_expression(); - if (!partition->Equals(true)) { - repr += ": " + partition->ToString(); + if (partition != literal(true)) { + repr += ": " + partition.ToString(); } } return repr; } -FragmentIterator FileSystemDataset::GetFragmentsImpl( - std::shared_ptr predicate) { +Result FileSystemDataset::GetFragmentsImpl( + Expression2::BoundWithState predicate) { FragmentVector fragments; for (const auto& fragment : fragments_) { - if (predicate->IsSatisfiableWith(fragment->partition_expression())) { + ARROW_ASSIGN_OR_RAISE( + auto simplified, + SimplifyWithGuarantee(predicate, fragment->partition_expression())); + if (simplified.first.IsSatisfiable()) { fragments.push_back(fragment); } } @@ -273,7 +276,8 @@ Status FileSystemDataset::Write(const FileSystemDatasetWriteOptions& write_optio // // NB: neither of these will have any impact whatsoever on the common case of writing // an in-memory table to disk. - ARROW_ASSIGN_OR_RAISE(FragmentVector fragments, scanner->GetFragments().ToVector()); + ARROW_ASSIGN_OR_RAISE(auto fragment_it, scanner->GetFragments()); + ARROW_ASSIGN_OR_RAISE(FragmentVector fragments, fragment_it.ToVector()); ScanTaskVector scan_tasks; std::vector fragment_for_task; @@ -313,12 +317,14 @@ Status FileSystemDataset::Write(const FileSystemDatasetWriteOptions& write_optio std::unordered_set need_flushed; for (size_t i = 0; i < groups.batches.size(); ++i) { - AndExpression partition_expression(std::move(groups.expressions[i]), - fragment->partition_expression()); + ARROW_ASSIGN_OR_RAISE( + auto partition_expression, + and_(std::move(groups.expressions[i]), fragment->partition_expression()) + .Bind(*scanner->schema())); auto batch = std::move(groups.batches[i]); - ARROW_ASSIGN_OR_RAISE(auto part, - write_options.partitioning->Format(partition_expression)); + ARROW_ASSIGN_OR_RAISE( + auto part, write_options.partitioning->Format(partition_expression.first)); WriteQueue* queue; { diff --git a/cpp/src/arrow/dataset/file_base.h b/cpp/src/arrow/dataset/file_base.h index 192921e7cf0..1d0059ec089 100644 --- a/cpp/src/arrow/dataset/file_base.h +++ b/cpp/src/arrow/dataset/file_base.h @@ -143,11 +143,11 @@ class ARROW_DS_EXPORT FileFormat : public std::enable_shared_from_this> MakeFragment( - FileSource source, std::shared_ptr partition_expression, + FileSource source, Expression2 partition_expression, std::shared_ptr physical_schema); - Result> MakeFragment( - FileSource source, std::shared_ptr partition_expression); + Result> MakeFragment(FileSource source, + Expression2 partition_expression); Result> MakeFragment( FileSource source, std::shared_ptr physical_schema = NULLPTR); @@ -173,8 +173,7 @@ class ARROW_DS_EXPORT FileFragment : public Fragment { protected: FileFragment(FileSource source, std::shared_ptr format, - std::shared_ptr partition_expression, - std::shared_ptr physical_schema) + Expression2 partition_expression, std::shared_ptr physical_schema) : Fragment(std::move(partition_expression), std::move(physical_schema)), source_(std::move(source)), format_(std::move(format)) {} @@ -207,7 +206,7 @@ class ARROW_DS_EXPORT FileSystemDataset : public Dataset { /// /// \return A constructed dataset. static Result> Make( - std::shared_ptr schema, std::shared_ptr root_partition, + std::shared_ptr schema, Expression2 root_partition, std::shared_ptr format, std::shared_ptr filesystem, std::vector> fragments); @@ -234,10 +233,10 @@ class ARROW_DS_EXPORT FileSystemDataset : public Dataset { std::string ToString() const; protected: - FragmentIterator GetFragmentsImpl(std::shared_ptr predicate) override; + Result GetFragmentsImpl( + Expression2::BoundWithState predicate) override; - FileSystemDataset(std::shared_ptr schema, - std::shared_ptr root_partition, + FileSystemDataset(std::shared_ptr schema, Expression2 root_partition, std::shared_ptr format, std::shared_ptr filesystem, std::vector> fragments); diff --git a/cpp/src/arrow/dataset/file_csv.cc b/cpp/src/arrow/dataset/file_csv.cc index f889b12fddd..6f1ca9acbbb 100644 --- a/cpp/src/arrow/dataset/file_csv.cc +++ b/cpp/src/arrow/dataset/file_csv.cc @@ -34,6 +34,7 @@ #include "arrow/result.h" #include "arrow/type.h" #include "arrow/util/iterator.h" +#include "arrow/util/logging.h" namespace arrow { namespace dataset { @@ -90,13 +91,16 @@ static inline Result GetConvertOptions( } // FIXME(bkietz) also acquire types of fields materialized but not projected. - for (auto&& name : FieldsInExpression(scan_options->filter)) { - ARROW_ASSIGN_OR_RAISE(auto match, - FieldRef(name).FindOneOrNone(*scan_options->schema())); - if (match.indices().empty()) { - convert_options.include_columns.push_back(std::move(name)); + // This requires that scan_options include the full dataset schema (not just + // the projected schema). + for (const FieldRef& ref : FieldsInExpression(scan_options->filter2)) { + DCHECK(ref.name()); + ARROW_ASSIGN_OR_RAISE(auto match, ref.FindOneOrNone(*scan_options->schema())); + if (!match) { + convert_options.include_columns.push_back(*ref.name()); } } + return convert_options; } diff --git a/cpp/src/arrow/dataset/file_ipc_test.cc b/cpp/src/arrow/dataset/file_ipc_test.cc index 25bbe559cf4..0fb76803778 100644 --- a/cpp/src/arrow/dataset/file_ipc_test.cc +++ b/cpp/src/arrow/dataset/file_ipc_test.cc @@ -244,7 +244,7 @@ TEST_F(TestIpcFileFormat, ScanRecordBatchReaderProjected) { opts_ = ScanOptions::Make(schema_); opts_->projector = RecordBatchProjector(SchemaFromColumnNames(schema_, {"f64"})); - opts_->filter = equal(field_ref("i32"), scalar(0)); + opts_->filter2 = equal(field_ref("i32"), literal(0)); // NB: projector is applied by the scanner; FileFragment does not evaluate it so // we will not drop "i32" even though it is not in the projector's schema @@ -280,7 +280,7 @@ TEST_F(TestIpcFileFormat, ScanRecordBatchReaderProjectedMissingCols) { schema_ = reader->schema(); opts_ = ScanOptions::Make(schema_); opts_->projector = RecordBatchProjector(SchemaFromColumnNames(schema_, {"f64"})); - opts_->filter = equal(field_ref("i32"), scalar(0)); + opts_->filter2 = equal(field_ref("i32"), literal(0)); auto readers = {reader.get(), reader_without_i32.get(), reader_without_f64.get()}; for (auto reader : readers) { diff --git a/cpp/src/arrow/dataset/file_parquet.cc b/cpp/src/arrow/dataset/file_parquet.cc index e06cee649ef..f1e9ad48bb0 100644 --- a/cpp/src/arrow/dataset/file_parquet.cc +++ b/cpp/src/arrow/dataset/file_parquet.cc @@ -125,7 +125,7 @@ static Result> GetSchemaManifest( return manifest; } -static std::shared_ptr ColumnChunkStatisticsAsExpression( +static util::optional ColumnChunkStatisticsAsExpression( const SchemaField& schema_field, const parquet::RowGroupMetaData& metadata) { // For the remaining of this function, failure to extract/parse statistics // are ignored by returning nullptr. The goal is two fold. First @@ -134,13 +134,13 @@ static std::shared_ptr ColumnChunkStatisticsAsExpression( // For now, only leaf (primitive) types are supported. if (!schema_field.is_leaf()) { - return nullptr; + return util::nullopt; } auto column_metadata = metadata.ColumnChunk(schema_field.column_index); auto statistics = column_metadata->statistics(); if (statistics == nullptr) { - return nullptr; + return util::nullopt; } const auto& field = schema_field.field; @@ -148,12 +148,12 @@ static std::shared_ptr ColumnChunkStatisticsAsExpression( // Optimize for corner case where all values are nulls if (statistics->num_values() == statistics->null_count()) { - return equal(std::move(field_expr), scalar(MakeNullScalar(field->type()))); + return equal(std::move(field_expr), literal(MakeNullScalar(field->type()))); } std::shared_ptr min, max; if (!StatisticsAsScalars(*statistics, &min, &max).ok()) { - return nullptr; + return util::nullopt; } auto maybe_min = min->CastTo(field->type()); @@ -161,11 +161,11 @@ static std::shared_ptr ColumnChunkStatisticsAsExpression( if (maybe_min.ok() && maybe_max.ok()) { min = maybe_min.MoveValueUnsafe(); max = maybe_max.MoveValueUnsafe(); - return and_(greater_equal(field_expr, scalar(min)), - less_equal(field_expr, scalar(max))); + return and_(greater_equal(field_expr, literal(min)), + less_equal(field_expr, literal(max))); } - return nullptr; + return util::nullopt; } static void AddColumnIndices(const SchemaField& schema_field, @@ -291,18 +291,18 @@ Result ParquetFileFormat::ScanFile(std::shared_ptr row_groups; bool pre_filtered = false; - auto empty = [] { return MakeEmptyIterator>(); }; + auto MakeEmpty = [] { return MakeEmptyIterator>(); }; // If RowGroup metadata is cached completely we can pre-filter RowGroups before opening // a FileReader, potentially avoiding IO altogether if all RowGroups are excluded due to // prior statistics knowledge. In the case where a RowGroup doesn't have statistics // metdata, it will not be excluded. if (parquet_fragment->metadata() != nullptr) { - ARROW_ASSIGN_OR_RAISE(row_groups, - parquet_fragment->FilterRowGroups(*options->filter)); + ARROW_ASSIGN_OR_RAISE(row_groups, parquet_fragment->FilterRowGroups( + {options->filter2, context->expression_state})); pre_filtered = true; - if (row_groups.empty()) empty(); + if (row_groups.empty()) MakeEmpty(); } // Open the reader and pay the real IO cost. @@ -314,10 +314,10 @@ Result ParquetFileFormat::ScanFile(std::shared_ptrFilterRowGroups(*options->filter)); + ARROW_ASSIGN_OR_RAISE(row_groups, parquet_fragment->FilterRowGroups( + {options->filter2, context->expression_state})); - if (row_groups.empty()) empty(); + if (row_groups.empty()) MakeEmpty(); } auto column_projection = InferColumnProjection(*reader, *options); @@ -332,7 +332,7 @@ Result ParquetFileFormat::ScanFile(std::shared_ptr> ParquetFileFormat::MakeFragment( - FileSource source, std::shared_ptr partition_expression, + FileSource source, Expression2 partition_expression, std::shared_ptr physical_schema, std::vector row_groups) { return std::shared_ptr(new ParquetFileFragment( std::move(source), shared_from_this(), std::move(partition_expression), @@ -340,7 +340,7 @@ Result> ParquetFileFormat::MakeFragment( } Result> ParquetFileFormat::MakeFragment( - FileSource source, std::shared_ptr partition_expression, + FileSource source, Expression2 partition_expression, std::shared_ptr physical_schema) { return std::shared_ptr(new ParquetFileFragment( std::move(source), shared_from_this(), std::move(partition_expression), @@ -395,7 +395,7 @@ Status ParquetFileWriter::Finish() { return parquet_writer_->Close(); } ParquetFileFragment::ParquetFileFragment(FileSource source, std::shared_ptr format, - std::shared_ptr partition_expression, + Expression2 partition_expression, std::shared_ptr physical_schema, util::optional> row_groups) : FileFragment(std::move(source), std::move(format), std::move(partition_expression), @@ -442,7 +442,7 @@ Status ParquetFileFragment::SetMetadata( metadata_ = std::move(metadata); manifest_ = std::move(manifest); - statistics_expressions_.resize(row_groups_->size(), scalar(true)); + statistics_expressions_.resize(row_groups_->size(), literal(true)); statistics_expressions_complete_.resize(physical_schema_->num_fields(), false); for (int row_group : *row_groups_) { @@ -458,9 +458,9 @@ Status ParquetFileFragment::SetMetadata( } Result ParquetFileFragment::SplitByRowGroup( - const std::shared_ptr& predicate) { + Expression2::BoundWithState predicate) { RETURN_NOT_OK(EnsureCompleteMetadata()); - ARROW_ASSIGN_OR_RAISE(auto row_groups, FilterRowGroups(*predicate)); + ARROW_ASSIGN_OR_RAISE(auto row_groups, FilterRowGroups(predicate)); FragmentVector fragments(row_groups.size()); int i = 0; @@ -477,9 +477,9 @@ Result ParquetFileFragment::SplitByRowGroup( } Result> ParquetFileFragment::Subset( - const std::shared_ptr& predicate) { + Expression2::BoundWithState predicate) { RETURN_NOT_OK(EnsureCompleteMetadata()); - ARROW_ASSIGN_OR_RAISE(auto row_groups, FilterRowGroups(*predicate)); + ARROW_ASSIGN_OR_RAISE(auto row_groups, FilterRowGroups(predicate)); return Subset(std::move(row_groups)); } @@ -494,8 +494,8 @@ Result> ParquetFileFragment::Subset( return new_fragment; } -inline void FoldingAnd(std::shared_ptr* l, std::shared_ptr r) { - if ((*l)->Equals(true)) { +inline void FoldingAnd(Expression2* l, Expression2 r) { + if (*l == literal(true)) { *l = std::move(r); } else { *l = and_(std::move(*l), std::move(r)); @@ -503,13 +503,18 @@ inline void FoldingAnd(std::shared_ptr* l, std::shared_ptr> ParquetFileFragment::FilterRowGroups( - const Expression& predicate) { + Expression2::BoundWithState predicate) { auto lock = physical_schema_mutex_.Lock(); DCHECK_NE(metadata_, nullptr); - RETURN_NOT_OK(predicate.Validate(*physical_schema_)); + ARROW_ASSIGN_OR_RAISE( + predicate, SimplifyWithGuarantee(std::move(predicate), partition_expression_)); + + if (!predicate.first.IsSatisfiable()) { + return std::vector{}; + } - for (FieldRef ref : FieldsInExpression(predicate)) { + for (const FieldRef& ref : FieldsInExpression(predicate.first)) { ARROW_ASSIGN_OR_RAISE(auto path, ref.FindOneOrNone(*physical_schema_)); if (!path) continue; @@ -523,21 +528,20 @@ Result> ParquetFileFragment::FilterRowGroups( if (auto minmax = ColumnChunkStatisticsAsExpression(schema_field, *row_group_metadata)) { - FoldingAnd(&statistics_expressions_[i], std::move(minmax)); + FoldingAnd(&statistics_expressions_[i], std::move(*minmax)); + ARROW_ASSIGN_OR_RAISE(std::tie(statistics_expressions_[i], std::ignore), + statistics_expressions_[i].Bind(*physical_schema_)); } ++i; } } - auto simplified_predicate = predicate.Assume(partition_expression_); - if (!simplified_predicate->IsSatisfiable()) { - return std::vector{}; - } - std::vector row_groups; for (size_t i = 0; i < row_groups_->size(); ++i) { - if (simplified_predicate->IsSatisfiableWith(statistics_expressions_[i])) { + ARROW_ASSIGN_OR_RAISE(auto row_group_predicate, + SimplifyWithGuarantee(predicate, statistics_expressions_[i])); + if (row_group_predicate.first.IsSatisfiable()) { row_groups.push_back(row_groups_->at(i)); } } @@ -661,7 +665,7 @@ ParquetDatasetFactory::CollectParquetFragments(const Partitioning& partitioning) auto partition_expression = partitioning.Parse(StripPrefixAndFilename(path, options_.partition_base_dir)) - .ValueOr(scalar(true)); + .ValueOr(literal(true)); ARROW_ASSIGN_OR_RAISE( auto fragment, @@ -712,7 +716,7 @@ Result> ParquetDatasetFactory::Finish(FinishOptions opt } ARROW_ASSIGN_OR_RAISE(auto fragments, CollectParquetFragments(*partitioning)); - return FileSystemDataset::Make(std::move(schema), scalar(true), format_, filesystem_, + return FileSystemDataset::Make(std::move(schema), literal(true), format_, filesystem_, std::move(fragments)); } diff --git a/cpp/src/arrow/dataset/file_parquet.h b/cpp/src/arrow/dataset/file_parquet.h index b82241a574b..9de19186ad3 100644 --- a/cpp/src/arrow/dataset/file_parquet.h +++ b/cpp/src/arrow/dataset/file_parquet.h @@ -117,12 +117,12 @@ class ARROW_DS_EXPORT ParquetFileFormat : public FileFormat { /// \brief Create a Fragment targeting all RowGroups. Result> MakeFragment( - FileSource source, std::shared_ptr partition_expression, + FileSource source, Expression2 partition_expression, std::shared_ptr physical_schema) override; /// \brief Create a Fragment, restricted to the specified row groups. Result> MakeFragment( - FileSource source, std::shared_ptr partition_expression, + FileSource source, Expression2 partition_expression, std::shared_ptr physical_schema, std::vector row_groups); /// \brief Return a FileReader on the given source. @@ -150,7 +150,7 @@ class ARROW_DS_EXPORT ParquetFileFormat : public FileFormat { /// significant performance boost when scanning high latency file systems. class ARROW_DS_EXPORT ParquetFileFragment : public FileFragment { public: - Result SplitByRowGroup(const std::shared_ptr& predicate); + Result SplitByRowGroup(Expression2::BoundWithState predicate); /// \brief Return the RowGroups selected by this fragment. const std::vector& row_groups() const { @@ -166,12 +166,12 @@ class ARROW_DS_EXPORT ParquetFileFragment : public FileFragment { Status EnsureCompleteMetadata(parquet::arrow::FileReader* reader = NULLPTR); /// \brief Return fragment which selects a filtered subset of this fragment's RowGroups. - Result> Subset(const std::shared_ptr& predicate); + Result> Subset(Expression2::BoundWithState predicate); Result> Subset(std::vector row_group_ids); private: ParquetFileFragment(FileSource source, std::shared_ptr format, - std::shared_ptr partition_expression, + Expression2 partition_expression, std::shared_ptr physical_schema, util::optional> row_groups); @@ -185,7 +185,7 @@ class ARROW_DS_EXPORT ParquetFileFragment : public FileFragment { } // Return a filtered subset of row group indices. - Result> FilterRowGroups(const Expression& predicate); + Result> FilterRowGroups(Expression2::BoundWithState predicate); ParquetFileFormat& parquet_format_; @@ -193,7 +193,7 @@ class ARROW_DS_EXPORT ParquetFileFragment : public FileFragment { // or util::nullopt if all row groups are selected. util::optional> row_groups_; - ExpressionVector statistics_expressions_; + std::vector statistics_expressions_; std::vector statistics_expressions_complete_; std::shared_ptr metadata_; std::shared_ptr manifest_; diff --git a/cpp/src/arrow/dataset/file_parquet_test.cc b/cpp/src/arrow/dataset/file_parquet_test.cc index c2e0d6b632d..5a6bbddaba3 100644 --- a/cpp/src/arrow/dataset/file_parquet_test.cc +++ b/cpp/src/arrow/dataset/file_parquet_test.cc @@ -160,6 +160,11 @@ class TestParquetFileFormat : public ArrowParquetWriterMixin { return Batches(std::move(scan_task_it)); } + void SetFilter(Expression2 filter) { + ASSERT_OK_AND_ASSIGN(std::tie(opts_->filter2, ctx_->expression_state), + filter.Bind(*schema_)); + } + std::shared_ptr SingleBatch(Fragment* fragment) { auto batches = IteratorToVector(Batches(fragment)); EXPECT_EQ(batches.size(), 1); @@ -188,9 +193,12 @@ class TestParquetFileFormat : public ArrowParquetWriterMixin { void CountRowGroupsInFragment(const std::shared_ptr& fragment, std::vector expected_row_groups, - const Expression& filter) { + Expression2 filter) { + schema_ = opts_->schema(); + ASSERT_OK_AND_ASSIGN(auto bound, filter.Bind(*schema_)); + ctx_->expression_state = bound.second; auto parquet_fragment = checked_pointer_cast(fragment); - ASSERT_OK_AND_ASSIGN(auto fragments, parquet_fragment->SplitByRowGroup(filter.Copy())) + ASSERT_OK_AND_ASSIGN(auto fragments, parquet_fragment->SplitByRowGroup(bound)) EXPECT_EQ(fragments.size(), expected_row_groups.size()); for (size_t i = 0; i < fragments.size(); i++) { @@ -214,6 +222,7 @@ TEST_F(TestParquetFileFormat, ScanRecordBatchReader) { auto source = GetFileSource(reader.get()); opts_ = ScanOptions::Make(reader->schema()); + SetFilter(literal(true)); ASSERT_OK_AND_ASSIGN(auto fragment, format_->MakeFragment(*source)); int64_t row_count = 0; @@ -233,6 +242,7 @@ TEST_F(TestParquetFileFormat, ScanRecordBatchReaderDictEncoded) { auto source = GetFileSource(reader.get()); opts_ = ScanOptions::Make(reader->schema()); + SetFilter(literal(true)); format_->reader_options.dict_columns = {"utf8"}; ASSERT_OK_AND_ASSIGN(auto fragment, format_->MakeFragment(*source)); @@ -275,7 +285,7 @@ TEST_F(TestParquetFileFormat, ScanRecordBatchReaderProjected) { opts_ = ScanOptions::Make(schema_); opts_->projector = RecordBatchProjector(SchemaFromColumnNames(schema_, {"f64"})); - opts_->filter = equal(field_ref("i32"), scalar(0)); + SetFilter(equal(field_ref("i32"), scalar(0))); // NB: projector is applied by the scanner; FileFragment does not evaluate it so // we will not drop "i32" even though it is not in the projector's schema @@ -311,7 +321,7 @@ TEST_F(TestParquetFileFormat, ScanRecordBatchReaderProjectedMissingCols) { schema_ = reader->schema(); opts_ = ScanOptions::Make(schema_); opts_->projector = RecordBatchProjector(SchemaFromColumnNames(schema_, {"f64"})); - opts_->filter = equal(field_ref("i32"), scalar(0)); + SetFilter(equal(field_ref("i32"), scalar(0))); auto readers = {reader.get(), reader_without_i32.get(), reader_without_f64.get()}; for (auto reader : readers) { @@ -404,34 +414,35 @@ TEST_F(TestParquetFileFormat, PredicatePushdown) { auto source = GetFileSource(reader.get()); opts_ = ScanOptions::Make(reader->schema()); + schema_ = reader->schema(); ASSERT_OK_AND_ASSIGN(auto fragment, format_->MakeFragment(*source)); - opts_->filter = scalar(true); + SetFilter(literal(true)); CountRowsAndBatchesInScan(fragment, kTotalNumRows, kNumRowGroups); for (int64_t i = 1; i <= kNumRowGroups; i++) { - opts_->filter = ("i64"_ == int64_t(i)).Copy(); + SetFilter(("i64"_ == int64_t(i))); CountRowsAndBatchesInScan(fragment, i, 1); } // Out of bound filters should skip all RowGroups. - opts_->filter = scalar(false); + SetFilter(literal(false)); CountRowsAndBatchesInScan(fragment, 0, 0); - opts_->filter = ("i64"_ == int64_t(kNumRowGroups + 1)).Copy(); + SetFilter(("i64"_ == int64_t(kNumRowGroups + 1))); CountRowsAndBatchesInScan(fragment, 0, 0); - opts_->filter = ("i64"_ == int64_t(-1)).Copy(); + SetFilter(("i64"_ == int64_t(-1))); CountRowsAndBatchesInScan(fragment, 0, 0); // No rows match 1 and 2. - opts_->filter = ("i64"_ == int64_t(1) and "u8"_ == uint8_t(2)).Copy(); + SetFilter(("i64"_ == int64_t(1) and "u8"_ == uint8_t(2))); CountRowsAndBatchesInScan(fragment, 0, 0); - opts_->filter = ("i64"_ == int64_t(2) or "i64"_ == int64_t(4)).Copy(); + SetFilter(("i64"_ == int64_t(2) or "i64"_ == int64_t(4))); CountRowsAndBatchesInScan(fragment, 2 + 4, 2); - opts_->filter = ("i64"_ < int64_t(6)).Copy(); + SetFilter(("i64"_ < int64_t(6))); CountRowsAndBatchesInScan(fragment, 5 * (5 + 1) / 2, 5); - opts_->filter = ("i64"_ >= int64_t(6)).Copy(); + SetFilter(("i64"_ >= int64_t(6))); CountRowsAndBatchesInScan(fragment, kTotalNumRows - (5 * (5 + 1) / 2), kNumRowGroups - 5); } @@ -446,15 +457,17 @@ TEST_F(TestParquetFileFormat, PredicatePushdownRowGroupFragments) { ASSERT_OK_AND_ASSIGN(auto fragment, format_->MakeFragment(*source)); auto all_row_groups = internal::Iota(static_cast(kNumRowGroups)); - CountRowGroupsInFragment(fragment, all_row_groups, *scalar(true)); - CountRowGroupsInFragment(fragment, all_row_groups, "not here"_ == 0); + CountRowGroupsInFragment(fragment, all_row_groups, literal(true)); + + // FIXME this is only meaningful if "not here" is a virtual column + // CountRowGroupsInFragment(fragment, all_row_groups, "not here"_ == 0); for (int i = 0; i < kNumRowGroups; ++i) { CountRowGroupsInFragment(fragment, {i}, "i64"_ == int64_t(i + 1)); } // Out of bound filters should skip all RowGroups. - CountRowGroupsInFragment(fragment, {}, *scalar(false)); + CountRowGroupsInFragment(fragment, {}, literal(false)); CountRowGroupsInFragment(fragment, {}, "i64"_ == int64_t(kNumRowGroups + 1)); CountRowGroupsInFragment(fragment, {}, "i64"_ == int64_t(-1)); @@ -503,10 +516,12 @@ TEST_F(TestParquetFileFormat, ExplicitRowGroupSelection) { auto source = GetFileSource(reader.get()); opts_ = ScanOptions::Make(reader->schema()); + schema_ = reader->schema(); + SetFilter(literal(true)); auto row_groups_fragment = [&](std::vector row_groups) { EXPECT_OK_AND_ASSIGN(auto fragment, - format_->MakeFragment(*source, scalar(true), + format_->MakeFragment(*source, literal(true), /*physical_schema=*/nullptr, row_groups)); return fragment; }; @@ -514,7 +529,7 @@ TEST_F(TestParquetFileFormat, ExplicitRowGroupSelection) { // select all row groups EXPECT_OK_AND_ASSIGN( auto all_row_groups_fragment, - format_->MakeFragment(*source, scalar(true)) + format_->MakeFragment(*source, literal(true)) .Map([](std::shared_ptr f) { return internal::checked_pointer_cast(f); })); @@ -532,17 +547,17 @@ TEST_F(TestParquetFileFormat, ExplicitRowGroupSelection) { for (int i = 0; i < kNumRowGroups; ++i) { // conflicting selection/filter - opts_->filter = ("i64"_ == int64_t(i)).Copy(); + SetFilter(("i64"_ == int64_t(i))); CountRowsAndBatchesInScan(row_groups_fragment({i}), 0, 0); } for (int i = 0; i < kNumRowGroups; ++i) { // identical selection/filter - opts_->filter = ("i64"_ == int64_t(i + 1)).Copy(); + SetFilter(("i64"_ == int64_t(i + 1))); CountRowsAndBatchesInScan(row_groups_fragment({i}), i + 1, 1); } - opts_->filter = ("i64"_ > int64_t(3)).Copy(); + SetFilter(("i64"_ > int64_t(3))); CountRowsAndBatchesInScan(row_groups_fragment({2, 3, 4, 5}), 4 + 5 + 6, 3); EXPECT_RAISES_WITH_MESSAGE_THAT( diff --git a/cpp/src/arrow/dataset/file_test.cc b/cpp/src/arrow/dataset/file_test.cc index 8ee2613576f..c390f777ad4 100644 --- a/cpp/src/arrow/dataset/file_test.cc +++ b/cpp/src/arrow/dataset/file_test.cc @@ -82,15 +82,15 @@ TEST(FileSource, BufferBased) { TEST_F(TestFileSystemDataset, Basic) { MakeDataset({}); - AssertFragmentsAreFromPath(dataset_->GetFragments(), {}); + AssertFragmentsAreFromPath(*dataset_->GetFragments(), {}); MakeDataset({fs::File("a"), fs::File("b"), fs::File("c")}); - AssertFragmentsAreFromPath(dataset_->GetFragments(), {"a", "b", "c"}); + AssertFragmentsAreFromPath(*dataset_->GetFragments(), {"a", "b", "c"}); AssertFilesAre(dataset_, {"a", "b", "c"}); // Should not create fragment from directories. MakeDataset({fs::Dir("A"), fs::Dir("A/B"), fs::File("A/a"), fs::File("A/B/b")}); - AssertFragmentsAreFromPath(dataset_->GetFragments(), {"A/a", "A/B/b"}); + AssertFragmentsAreFromPath(*dataset_->GetFragments(), {"A/a", "A/B/b"}); AssertFilesAre(dataset_, {"A/a", "A/B/b"}); } @@ -98,7 +98,7 @@ TEST_F(TestFileSystemDataset, ReplaceSchema) { auto schm = schema({field("i32", int32()), field("f64", float64())}); auto format = std::make_shared(schm); ASSERT_OK_AND_ASSIGN(auto dataset, - FileSystemDataset::Make(schm, scalar(true), format, nullptr, {})); + FileSystemDataset::Make(schm, literal(true), format, nullptr, {})); // drop field ASSERT_OK(dataset->ReplaceSchema(schema({field("i32", int32())})).status()); @@ -119,45 +119,64 @@ TEST_F(TestFileSystemDataset, ReplaceSchema) { } TEST_F(TestFileSystemDataset, RootPartitionPruning) { - auto root_partition = ("a"_ == 5).Copy(); + auto root_partition = equal(field_ref("i32"), literal(5)); MakeDataset({fs::File("a"), fs::File("b")}, root_partition); + auto GetFragments = [&](Expression2 filter) { + return *dataset_->GetFragments(*filter.Bind(*dataset_->schema())); + }; + // Default filter should always return all data. - AssertFragmentsAreFromPath(dataset_->GetFragments(), {"a", "b"}); + AssertFragmentsAreFromPath(*dataset_->GetFragments(), {"a", "b"}); // filter == partition - AssertFragmentsAreFromPath(dataset_->GetFragments(root_partition), {"a", "b"}); + AssertFragmentsAreFromPath(GetFragments(root_partition), {"a", "b"}); // Same partition key, but non matching filter - AssertFragmentsAreFromPath(dataset_->GetFragments(("a"_ == 6).Copy()), {}); + AssertFragmentsAreFromPath(GetFragments(equal(field_ref("i32"), literal(6))), {}); - AssertFragmentsAreFromPath(dataset_->GetFragments(("a"_ > 1).Copy()), {"a", "b"}); + AssertFragmentsAreFromPath(GetFragments(greater(field_ref("i32"), literal(1))), + {"a", "b"}); // different key shouldn't prune - AssertFragmentsAreFromPath(dataset_->GetFragments(("b"_ == 6).Copy()), {"a", "b"}); + AssertFragmentsAreFromPath(GetFragments(equal(field_ref("f32"), literal(3.F))), + {"a", "b"}); // No partition should match MakeDataset({fs::File("a"), fs::File("b")}); - AssertFragmentsAreFromPath(dataset_->GetFragments(("b"_ == 6).Copy()), {"a", "b"}); + AssertFragmentsAreFromPath(GetFragments(equal(field_ref("f32"), literal(3.F))), + {"a", "b"}); } TEST_F(TestFileSystemDataset, TreePartitionPruning) { - auto root_partition = ("country"_ == "US").Copy(); + auto root_partition = equal(field_ref("country"), literal("US")); + std::vector regions = { fs::Dir("NY"), fs::File("NY/New York"), fs::File("NY/Franklin"), fs::Dir("CA"), fs::File("CA/San Francisco"), fs::File("CA/Franklin"), }; - ExpressionVector partitions = { - ("state"_ == "NY").Copy(), - ("state"_ == "NY" and "city"_ == "New York").Copy(), - ("state"_ == "NY" and "city"_ == "Franklin").Copy(), - ("state"_ == "CA").Copy(), - ("state"_ == "CA" and "city"_ == "San Francisco").Copy(), - ("state"_ == "CA" and "city"_ == "Franklin").Copy(), + std::vector partitions = { + equal(field_ref("state"), literal("NY")), + + equal(field_ref("state"), literal("NY")) and + equal(field_ref("city"), literal("New York")), + + equal(field_ref("state"), literal("NY")) and + equal(field_ref("city"), literal("Franklin")), + + equal(field_ref("state"), literal("CA")), + + equal(field_ref("state"), literal("CA")) and + equal(field_ref("city"), literal("San Francisco")), + + equal(field_ref("state"), literal("CA")) and + equal(field_ref("city"), literal("Franklin")), }; - MakeDataset(regions, root_partition, partitions); + MakeDataset( + regions, root_partition, partitions, + schema({field("country", utf8()), field("state", utf8()), field("city", utf8())})); std::vector all_cities = {"CA/San Francisco", "CA/Franklin", "NY/New York", "NY/Franklin"}; @@ -165,52 +184,67 @@ TEST_F(TestFileSystemDataset, TreePartitionPruning) { std::vector franklins = {"CA/Franklin", "NY/Franklin"}; // Default filter should always return all data. - AssertFragmentsAreFromPath(dataset_->GetFragments(), all_cities); + AssertFragmentsAreFromPath(*dataset_->GetFragments(), all_cities); + + auto GetFragments = [&](Expression2 filter) { + return *dataset_->GetFragments(*filter.Bind(*dataset_->schema())); + }; // Dataset's partitions are respected - AssertFragmentsAreFromPath(dataset_->GetFragments(("country"_ == "US").Copy()), + AssertFragmentsAreFromPath(GetFragments(equal(field_ref("country"), literal("US"))), all_cities); - AssertFragmentsAreFromPath(dataset_->GetFragments(("country"_ == "FR").Copy()), {}); + AssertFragmentsAreFromPath(GetFragments(equal(field_ref("country"), literal("FR"))), + {}); - AssertFragmentsAreFromPath(dataset_->GetFragments(("state"_ == "CA").Copy()), + AssertFragmentsAreFromPath(GetFragments(equal(field_ref("state"), literal("CA"))), ca_cities); // Filter where no decisions can be made on inner nodes when filter don't // apply to inner partitions. - AssertFragmentsAreFromPath(dataset_->GetFragments(("city"_ == "Franklin").Copy()), + AssertFragmentsAreFromPath(GetFragments(equal(field_ref("city"), literal("Franklin"))), franklins); } TEST_F(TestFileSystemDataset, FragmentPartitions) { - auto root_partition = ("country"_ == "US").Copy(); + auto root_partition = equal(field_ref("country"), literal("US")); std::vector regions = { fs::Dir("NY"), fs::File("NY/New York"), fs::File("NY/Franklin"), fs::Dir("CA"), fs::File("CA/San Francisco"), fs::File("CA/Franklin"), }; - ExpressionVector partitions = { - ("state"_ == "NY").Copy(), - ("state"_ == "NY" and "city"_ == "New York").Copy(), - ("state"_ == "NY" and "city"_ == "Franklin").Copy(), - ("state"_ == "CA").Copy(), - ("state"_ == "CA" and "city"_ == "San Francisco").Copy(), - ("state"_ == "CA" and "city"_ == "Franklin").Copy(), - }; + std::vector partitions = { + equal(field_ref("state"), literal("NY")), + + equal(field_ref("state"), literal("NY")) and + equal(field_ref("city"), literal("New York")), - MakeDataset(regions, root_partition, partitions); + equal(field_ref("state"), literal("NY")) and + equal(field_ref("city"), literal("Franklin")), - auto with_root = [&](const Expression& state, const Expression& city) { - return and_(state.Copy(), city.Copy()); + equal(field_ref("state"), literal("CA")), + + equal(field_ref("state"), literal("CA")) and + equal(field_ref("city"), literal("San Francisco")), + + equal(field_ref("state"), literal("CA")) and + equal(field_ref("city"), literal("Franklin")), }; + MakeDataset( + regions, root_partition, partitions, + schema({field("country", utf8()), field("state", utf8()), field("city", utf8())})); + AssertFragmentsHavePartitionExpressions( - dataset_->GetFragments(), - { - with_root("state"_ == "CA", "city"_ == "San Francisco"), - with_root("state"_ == "CA", "city"_ == "Franklin"), - with_root("state"_ == "NY", "city"_ == "New York"), - with_root("state"_ == "NY", "city"_ == "Franklin"), - }); + dataset_, { + and_(equal(field_ref("state"), literal("CA")), + equal(field_ref("city"), literal("San Francisco"))), + and_(equal(field_ref("state"), literal("CA")), + equal(field_ref("city"), literal("Franklin"))), + and_(equal(field_ref("state"), literal("NY")), + equal(field_ref("city"), literal("New York"))), + and_(equal(field_ref("state"), literal("NY")), + equal(field_ref("city"), literal("Franklin"))), + }); } } // namespace dataset diff --git a/cpp/src/arrow/dataset/filter.cc b/cpp/src/arrow/dataset/filter.cc index caab28e3f7c..e231737bc5f 100644 --- a/cpp/src/arrow/dataset/filter.cc +++ b/cpp/src/arrow/dataset/filter.cc @@ -29,6 +29,7 @@ #include "arrow/buffer.h" #include "arrow/compute/api.h" #include "arrow/dataset/dataset.h" +#include "arrow/dataset/expression.h" #include "arrow/io/memory.h" #include "arrow/ipc/reader.h" #include "arrow/ipc/writer.h" @@ -808,38 +809,6 @@ std::shared_ptr ScalarExpression::Copy() const { return std::make_shared(*this); } -std::shared_ptr and_(std::shared_ptr lhs, - std::shared_ptr rhs) { - return std::make_shared(std::move(lhs), std::move(rhs)); -} - -std::shared_ptr and_(const ExpressionVector& subexpressions) { - auto acc = scalar(true); - for (const auto& next : subexpressions) { - if (next->Equals(false)) return next; - acc = acc->Equals(true) ? next : and_(std::move(acc), next); - } - return acc; -} - -std::shared_ptr or_(std::shared_ptr lhs, - std::shared_ptr rhs) { - return std::make_shared(std::move(lhs), std::move(rhs)); -} - -std::shared_ptr or_(const ExpressionVector& subexpressions) { - auto acc = scalar(false); - for (const auto& next : subexpressions) { - if (next->Equals(true)) return next; - acc = acc->Equals(false) ? next : or_(std::move(acc), next); - } - return acc; -} - -std::shared_ptr not_(std::shared_ptr operand) { - return std::make_shared(std::move(operand)); -} - AndExpression operator&&(const Expression& lhs, const Expression& rhs) { return AndExpression(lhs.Copy(), rhs.Copy()); } @@ -848,8 +817,6 @@ OrExpression operator||(const Expression& lhs, const Expression& rhs) { return OrExpression(lhs.Copy(), rhs.Copy()); } -NotExpression operator!(const Expression& rhs) { return NotExpression(rhs.Copy()); } - CastExpression Expression::CastTo(std::shared_ptr type, compute::CastOptions options) const { return CastExpression(Copy(), type, std::move(options)); @@ -890,8 +857,8 @@ Status EnsureNullOrBool(const std::string& msg_prefix, return Status::TypeError(msg_prefix, *type); } -Result> ValidateBoolean(const ExpressionVector& operands, - const Schema& schema) { +Result> ValidateBoolean( + const std::vector>& operands, const Schema& schema) { for (const auto& operand : operands) { ARROW_ASSIGN_OR_RAISE(auto type, operand->Validate(schema)); RETURN_NOT_OK( @@ -1476,7 +1443,7 @@ struct DeserializeImpl { switch (expression_type) { case ExpressionType::FIELD: { ARROW_ASSIGN_OR_RAISE(auto name, GetView(struct_array, 0)); - return field_ref(std::string(name)); + return std::make_shared(std::string(name)); } case ExpressionType::SCALAR: { diff --git a/cpp/src/arrow/dataset/filter.h b/cpp/src/arrow/dataset/filter.h index c6b55b419ff..750c67c048d 100644 --- a/cpp/src/arrow/dataset/filter.h +++ b/cpp/src/arrow/dataset/filter.h @@ -464,24 +464,6 @@ class ARROW_DS_EXPORT CustomExpression : public Expression { CustomExpression() : Expression(ExpressionType::CUSTOM) {} }; -ARROW_DS_EXPORT std::shared_ptr and_(std::shared_ptr lhs, - std::shared_ptr rhs); - -ARROW_DS_EXPORT std::shared_ptr and_(const ExpressionVector& subexpressions); - -ARROW_DS_EXPORT AndExpression operator&&(const Expression& lhs, const Expression& rhs); - -ARROW_DS_EXPORT std::shared_ptr or_(std::shared_ptr lhs, - std::shared_ptr rhs); - -ARROW_DS_EXPORT std::shared_ptr or_(const ExpressionVector& subexpressions); - -ARROW_DS_EXPORT OrExpression operator||(const Expression& lhs, const Expression& rhs); - -ARROW_DS_EXPORT std::shared_ptr not_(std::shared_ptr operand); - -ARROW_DS_EXPORT NotExpression operator!(const Expression& rhs); - inline std::shared_ptr scalar(std::shared_ptr value) { return std::make_shared(std::move(value)); } @@ -491,43 +473,6 @@ auto scalar(T&& value) -> decltype(scalar(MakeScalar(std::forward(value)))) { return scalar(MakeScalar(std::forward(value))); } -#define COMPARISON_FACTORY(NAME, FACTORY_NAME, OP) \ - inline std::shared_ptr FACTORY_NAME( \ - const std::shared_ptr& lhs, const std::shared_ptr& rhs) { \ - return std::make_shared(CompareOperator::NAME, lhs, rhs); \ - } \ - \ - template ::type>::value>::type> \ - ComparisonExpression operator OP(const Expression& lhs, T&& rhs) { \ - return ComparisonExpression(CompareOperator::NAME, lhs.Copy(), \ - scalar(std::forward(rhs))); \ - } \ - \ - inline ComparisonExpression operator OP(const Expression& lhs, \ - const Expression& rhs) { \ - return ComparisonExpression(CompareOperator::NAME, lhs.Copy(), rhs.Copy()); \ - } -COMPARISON_FACTORY(EQUAL, equal, ==) -COMPARISON_FACTORY(NOT_EQUAL, not_equal, !=) -COMPARISON_FACTORY(GREATER, greater, >) -COMPARISON_FACTORY(GREATER_EQUAL, greater_equal, >=) -COMPARISON_FACTORY(LESS, less, <) -COMPARISON_FACTORY(LESS_EQUAL, less_equal, <=) -#undef COMPARISON_FACTORY - -inline std::shared_ptr field_ref(std::string name) { - return std::make_shared(std::move(name)); -} - -inline namespace string_literals { -// clang-format off -inline FieldExpression operator"" _(const char* name, size_t name_length) { - // clang-format on - return FieldExpression({name, name_length}); -} -} // namespace string_literals - template bool Expression::Equals(T&& t) const { if (type_ != ExpressionType::SCALAR) { diff --git a/cpp/src/arrow/dataset/filter_test.cc b/cpp/src/arrow/dataset/filter_test.cc index 7723912eeab..fce2d9cfcd9 100644 --- a/cpp/src/arrow/dataset/filter_test.cc +++ b/cpp/src/arrow/dataset/filter_test.cc @@ -45,30 +45,55 @@ using string_literals::operator"" _; using internal::checked_cast; using internal::checked_pointer_cast; +// A frozen shared_ptr with behavior expected by GTest +struct TestExpression : util::EqualityComparable, + util::ToStringOstreamable { + // NOLINTNEXTLINE runtime/explicit + TestExpression(std::shared_ptr e) : expression(std::move(e)) {} + + // NOLINTNEXTLINE runtime/explicit + TestExpression(const Expression& e) : expression(e.Copy()) {} + + // NOLINTNEXTLINE runtime/explicit + TestExpression(const Expression2& e) : expression(e) {} + + std::shared_ptr expression; + + using util::EqualityComparable::operator==; + bool Equals(const TestExpression& other) const { + return expression->Equals(other.expression); + } + + std::string ToString() const { return expression->ToString(); } + + friend bool operator==(const std::shared_ptr& lhs, + const TestExpression& rhs) { + return TestExpression(lhs) == rhs; + } + + friend void PrintTo(const TestExpression& expr, std::ostream* os) { + *os << expr.ToString(); + } +}; + using E = TestExpression; class ExpressionsTest : public ::testing::Test { public: - void AssertSimplifiesTo(const Expression& expr, const Expression& given, - const Expression& expected) { - ASSERT_OK_AND_ASSIGN(auto expr_type, expr.Validate(*schema_)); - ASSERT_OK_AND_ASSIGN(auto given_type, given.Validate(*schema_)); - ASSERT_OK_AND_ASSIGN(auto expected_type, expected.Validate(*schema_)); + void AssertSimplifiesTo(E expr, E given, E expected) { + ASSERT_OK_AND_ASSIGN(auto expr_type, expr.expression->Validate(*schema_)); + ASSERT_OK_AND_ASSIGN(auto given_type, given.expression->Validate(*schema_)); + ASSERT_OK_AND_ASSIGN(auto expected_type, expected.expression->Validate(*schema_)); EXPECT_TRUE(expr_type->Equals(expected_type)); EXPECT_TRUE(given_type->Equals(boolean())); - auto simplified = expr.Assume(given); - ASSERT_EQ(E{simplified}, E{expected}) + auto simplified = expr.expression->Assume(given.expression); + ASSERT_EQ(simplified, expected) << " simplification of: " << expr.ToString() << std::endl << " given: " << given.ToString() << std::endl; } - void AssertSimplifiesTo(const Expression& expr, const Expression& given, - const std::shared_ptr& expected) { - AssertSimplifiesTo(expr, given, *expected); - } - std::shared_ptr ns = timestamp(TimeUnit::NANO); std::shared_ptr schema_ = schema({field("a", int32()), field("b", int32()), field("f", float64()), @@ -78,16 +103,6 @@ class ExpressionsTest : public ::testing::Test { std::shared_ptr never = scalar(false); }; -TEST_F(ExpressionsTest, StringRepresentation) { - ASSERT_EQ("a"_.ToString(), "a"); - ASSERT_EQ(("a"_ > int32_t(3)).ToString(), "(a > 3:int32)"); - ASSERT_EQ(("a"_ > int32_t(3) and "a"_ < int32_t(4)).ToString(), - "((a > 3:int32) and (a < 4:int32))"); - ASSERT_EQ(("f"_ > double(4)).ToString(), "(f > 4:double)"); - ASSERT_EQ("f"_.CastTo(float64()).ToString(), "(cast f to double)"); - ASSERT_EQ("f"_.CastLike("a"_).ToString(), "(cast f like a)"); -} - TEST_F(ExpressionsTest, Equality) { ASSERT_EQ(E{"a"_}, E{"a"_}); ASSERT_NE(E{"a"_}, E{"b"_}); @@ -114,7 +129,7 @@ TEST_F(ExpressionsTest, SimplificationOfCompoundQuery) { AssertSimplifiesTo("b"_ == 3 or "b"_ == 4, "b"_ > 3, "b"_ == 4); AssertSimplifiesTo("b"_ == 3 or "b"_ == 4, "b"_ >= 3, "b"_ == 3 or "b"_ == 4); - AssertSimplifiesTo("f"_ > 0.5 and "f"_ < 1.5, not("f"_ < 0.0 or "f"_ > 1.0), + AssertSimplifiesTo("f"_ > 0.5 and "f"_ < 1.5, not_("f"_ < 0.0 or "f"_ > 1.0), "f"_ > 0.5); AssertSimplifiesTo("b"_ == 4, "a"_ == 0, "b"_ == 4); @@ -125,7 +140,7 @@ TEST_F(ExpressionsTest, SimplificationOfCompoundQuery) { AssertSimplifiesTo("a"_ == 3 and "b"_.In(set_123), "b"_ == 3, "a"_ == 3); AssertSimplifiesTo("a"_ == 3 and "b"_.In(set_123), "b"_ == 0, *never); - AssertSimplifiesTo("a"_ == 0 or not"b"_.IsValid(), "b"_ == 3, "a"_ == 0); + AssertSimplifiesTo("a"_ == 0 or not_("b"_.IsValid()), "b"_ == 3, "a"_ == 0); } TEST_F(ExpressionsTest, SimplificationAgainstCompoundCondition) { @@ -144,20 +159,6 @@ TEST_F(ExpressionsTest, SimplificationAgainstCompoundCondition) { *always); } -TEST_F(ExpressionsTest, SimplificationToNull) { - auto null = scalar(std::make_shared()); - auto null32 = scalar(std::make_shared()); - - AssertSimplifiesTo(*equal(field_ref("b"), null32), "b"_ == 3, *null); - AssertSimplifiesTo(*not_equal(field_ref("b"), null32), "b"_ == 3, *null); - - // Kleene logic applies here - AssertSimplifiesTo(*not_equal(field_ref("b"), null32) and "b"_ > 3, "b"_ == 3, *never); - AssertSimplifiesTo(*not_equal(field_ref("b"), null32) and "b"_ > 2, "b"_ == 3, *null); - AssertSimplifiesTo(*not_equal(field_ref("b"), null32) or "b"_ > 3, "b"_ == 3, *null); - AssertSimplifiesTo(*not_equal(field_ref("b"), null32) or "b"_ > 2, "b"_ == 3, *always); -} - class FilterTest : public ::testing::Test { public: FilterTest() { evaluator_ = std::make_shared(); } @@ -444,165 +445,6 @@ TEST_F(FilterTest, KleeneTruthTables) { ])"); } -class TakeExpression : public CustomExpression { - public: - TakeExpression(std::shared_ptr operand, std::shared_ptr dictionary) - : operand_(std::move(operand)), dictionary_(std::move(dictionary)) {} - - std::string ToString() const override { - return dictionary_->ToString() + "[" + operand_->ToString() + "]"; - } - - std::shared_ptr Copy() const override { - return std::make_shared(*this); - } - - bool Equals(const Expression& other) const override { - // in a real CustomExpression this would need to be more sophisticated - return other.type() == ExpressionType::CUSTOM && ToString() == other.ToString(); - } - - Result> Validate(const Schema& schema) const override { - ARROW_ASSIGN_OR_RAISE(auto operand_type, operand_->Validate(schema)); - if (!is_integer(operand_type->id())) { - return Status::TypeError("Take indices must be integral, not ", *operand_type); - } - return dictionary_->type(); - } - - class Evaluator : public TreeEvaluator { - public: - using TreeEvaluator::TreeEvaluator; - - using TreeEvaluator::Evaluate; - - Result Evaluate(const Expression& expr, const RecordBatch& batch, - MemoryPool* pool) const override { - if (expr.type() == ExpressionType::CUSTOM) { - const auto& take_expr = checked_cast(expr); - return EvaluateTake(take_expr, batch, pool); - } - return TreeEvaluator::Evaluate(expr, batch, pool); - } - - Result EvaluateTake(const TakeExpression& take_expr, const RecordBatch& batch, - MemoryPool* pool) const { - ARROW_ASSIGN_OR_RAISE(auto indices, Evaluate(*take_expr.operand_, batch, pool)); - - if (indices.kind() == Datum::SCALAR) { - ARROW_ASSIGN_OR_RAISE(auto indices_array, - MakeArrayFromScalar(*indices.scalar(), batch.num_rows(), - default_memory_pool())); - indices = Datum(indices_array->data()); - } - - DCHECK_EQ(indices.kind(), Datum::ARRAY); - compute::ExecContext ctx(pool); - ARROW_ASSIGN_OR_RAISE(Datum out, - compute::Take(take_expr.dictionary_->data(), indices, - compute::TakeOptions(), &ctx)); - return std::move(out); - } - }; - - private: - std::shared_ptr operand_; - std::shared_ptr dictionary_; -}; - -TEST_F(ExpressionsTest, TakeAssumeYieldsNothing) { - auto dict = ArrayFromJSON(float64(), "[0.0, 0.25, 0.5, 0.75, 1.0]"); - auto take_b_is_half = (TakeExpression(field_ref("b"), dict) == 0.5); - - // no special Assume logic was provided for TakeExpression so we should just ignore it - // (logically the below *could* be simplified to false but we haven't implemented that) - AssertSimplifiesTo(take_b_is_half, "b"_ == 3, take_b_is_half); - - // custom expressions will not interfere with simplification of other subexpressions and - // can be dropped if other subexpressions simplify trivially - - // ("b"_ > 5).Assume("b"_ == 3) simplifies to false regardless of take, so the and will - // be false - AssertSimplifiesTo("b"_ > 5 and take_b_is_half, "b"_ == 3, *never); - - // ("b"_ > 5).Assume("b"_ == 6) simplifies to true regardless of take, so it can be - // dropped - AssertSimplifiesTo("b"_ > 5 and take_b_is_half, "b"_ == 6, take_b_is_half); - - // ("b"_ > 5).Assume("b"_ == 6) simplifies to true regardless of take, so the or will be - // true - AssertSimplifiesTo(take_b_is_half or "b"_ > 5, "b"_ == 6, *always); - - // ("b"_ > 5).Assume("b"_ == 3) simplifies to true regardless of take, so it can be - // dropped - AssertSimplifiesTo(take_b_is_half or "b"_ > 5, "b"_ == 3, take_b_is_half); -} - -TEST_F(FilterTest, EvaluateTakeExpression) { - evaluator_ = std::make_shared(); - - auto dict = ArrayFromJSON(float64(), "[0.0, 0.25, 0.5, 0.75, 1.0]"); - - AssertFilter(TakeExpression(field_ref("b"), dict) == 0.5, - {field("b", int32()), field("f", float64())}, R"([ - {"b": 3, "f": -0.1, "in": 0}, - {"b": 2, "f": 0.3, "in": 1}, - {"b": 1, "f": 0.2, "in": 0}, - {"b": 2, "f": -0.1, "in": 1}, - {"b": 4, "f": 0.1, "in": 0}, - {"b": null, "f": 0.0, "in": null}, - {"b": 0, "f": 1.0, "in": 0} - ])"); -} - -void AssertFieldsInExpression(std::shared_ptr expr, - std::vector expected) { - EXPECT_THAT(FieldsInExpression(expr), testing::ContainerEq(expected)); -} - -TEST(FieldsInExpressionTest, Basic) { - AssertFieldsInExpression(scalar(true), {}); - - AssertFieldsInExpression(("a"_).Copy(), {"a"}); - AssertFieldsInExpression(("a"_ == 1).Copy(), {"a"}); - AssertFieldsInExpression(("a"_ == "b"_).Copy(), {"a", "b"}); - - AssertFieldsInExpression(("a"_ == 1 || "a"_ == 2).Copy(), {"a", "a"}); - AssertFieldsInExpression(("a"_ == 1 || "b"_ == 2).Copy(), {"a", "b"}); - AssertFieldsInExpression((not("a"_ == 1) && ("b"_ == 2 || not("c"_ < 3))).Copy(), - {"a", "b", "c"}); -} - -TEST(ExpressionSerializationTest, RoundTrips) { - std::vector exprs{ - scalar(MakeNullScalar(null())), - scalar(MakeNullScalar(int32())), - scalar(MakeNullScalar(struct_({field("i", int32()), field("s", utf8())}))), - scalar(true), - scalar(false), - scalar(1), - scalar(1.125), - scalar("stringy strings"), - "field"_, - "a"_ > 0.25, - "a"_ == 1 or "b"_ != "hello" or "b"_ == "foo bar", - not"alpha"_, - "valid"_ and "a"_.CastLike("b"_) >= "b"_, - "version"_.CastTo(float64()).In(ArrayFromJSON(float64(), "[0.5, 1.0, 2.0]")), - "validity"_.IsValid(), - ("x"_ >= -1.5 and "x"_ < 0.0) and ("y"_ >= 0.0 and "y"_ < 1.5) and - ("z"_ > 1.5 and "z"_ <= 3.0), - "year"_ == int16_t(1999) and "month"_ == int8_t(12) and "day"_ == int8_t(31) and - "hour"_ == int8_t(0) and "alpha"_ == int32_t(0) and "beta"_ == 3.25f, - }; - - for (const auto& expr : exprs) { - ASSERT_OK_AND_ASSIGN(auto serialized, expr.expression->Serialize()); - ASSERT_OK_AND_ASSIGN(E roundtripped, Expression::Deserialize(*serialized)); - ASSERT_EQ(expr, roundtripped); - } -} - void AssertGrouping(const FieldVector& by_fields, const std::string& batch_json, const std::string& expected_json) { FieldVector fields_with_ids = by_fields; diff --git a/cpp/src/arrow/dataset/partition.cc b/cpp/src/arrow/dataset/partition.cc index a39a8adaae0..b6178ddeea6 100644 --- a/cpp/src/arrow/dataset/partition.cc +++ b/cpp/src/arrow/dataset/partition.cc @@ -49,81 +49,47 @@ std::shared_ptr Partitioning::Default() { std::string type_name() const override { return "default"; } - Result> Parse(const std::string& path) const override { - return scalar(true); + Result Parse(const std::string& path) const override { + return literal(true); } - Result Format(const Expression& expr) const override { + Result Format(const Expression2& expr) const override { return Status::NotImplemented("formatting paths from ", type_name(), " Partitioning"); } Result Partition( const std::shared_ptr& batch) const override { - return PartitionedBatches{{batch}, {scalar(true)}}; + return PartitionedBatches{{batch}, {literal(true)}}; } }; return std::make_shared(); } -Status KeyValuePartitioning::VisitKeys( - const Expression& expr, - const std::function& value)>& visitor) { - return VisitConjunctionMembers(expr, [visitor](const Expression& expr) { - if (expr.type() != ExpressionType::COMPARISON) { - return Status::OK(); - } - - const auto& cmp = checked_cast(expr); - if (cmp.op() != compute::CompareOperator::EQUAL) { - return Status::OK(); - } - - auto lhs = cmp.left_operand().get(); - auto rhs = cmp.right_operand().get(); - if (lhs->type() != ExpressionType::FIELD) std::swap(lhs, rhs); - - if (lhs->type() != ExpressionType::FIELD || rhs->type() != ExpressionType::SCALAR) { - return Status::OK(); +Status KeyValuePartitioning::SetDefaultValuesFromKeys(const Expression2& expr, + RecordBatchProjector* projector) { + ARROW_ASSIGN_OR_RAISE(auto known_values, ExtractKnownFieldValues(expr)); + for (const auto& ref_value : known_values) { + if (!ref_value.second.is_scalar()) { + return Status::Invalid("non-scalar partition key ", ref_value.second.ToString()); } - return visitor(checked_cast(lhs)->name(), - checked_cast(rhs)->value()); - }); -} + ARROW_ASSIGN_OR_RAISE(auto match, + ref_value.first.FindOneOrNone(*projector->schema())); -Result>> -KeyValuePartitioning::GetKeys(const Expression& expr) { - std::unordered_map> keys; - RETURN_NOT_OK( - VisitKeys(expr, [&](const std::string& name, const std::shared_ptr& value) { - keys.emplace(name, value); - return Status::OK(); - })); - return keys; -} - -Status KeyValuePartitioning::SetDefaultValuesFromKeys(const Expression& expr, - RecordBatchProjector* projector) { - return KeyValuePartitioning::VisitKeys( - expr, [projector](const std::string& name, const std::shared_ptr& value) { - ARROW_ASSIGN_OR_RAISE(auto match, - FieldRef(name).FindOneOrNone(*projector->schema())); - if (!match) { - return Status::OK(); - } - return projector->SetDefaultValue(match, value); - }); + if (!match) continue; + RETURN_NOT_OK(projector->SetDefaultValue(match, ref_value.second.scalar())); + } + return Status::OK(); } inline std::shared_ptr ConjunctionFromGroupingRow(Scalar* row) { ScalarVector* values = &checked_cast(row)->value; - ExpressionVector equality_expressions(values->size()); + std::vector equality_expressions(values->size()); for (size_t i = 0; i < values->size(); ++i) { const std::string& name = row->type->field(static_cast(i))->name(); - equality_expressions[i] = equal(field_ref(name), scalar(std::move(values->at(i)))); + equality_expressions[i] = equal(field_ref(name), literal(std::move(values->at(i)))); } return and_(std::move(equality_expressions)); } @@ -147,7 +113,7 @@ Result KeyValuePartitioning::Partition( if (by_fields.empty()) { // no fields to group by; return the whole batch - return PartitionedBatches{{batch}, {scalar(true)}}; + return PartitionedBatches{{batch}, {literal(true)}}; } ARROW_ASSIGN_OR_RAISE(auto by, @@ -168,11 +134,10 @@ Result KeyValuePartitioning::Partition( return out; } -Result> KeyValuePartitioning::ConvertKey( - const Key& key) const { +Result KeyValuePartitioning::ConvertKey(const Key& key) const { ARROW_ASSIGN_OR_RAISE(auto match, FieldRef(key.name).FindOneOrNone(*schema_)); if (!match) { - return scalar(true); + return literal(true); } auto field_index = match[0]; @@ -209,41 +174,49 @@ Result> KeyValuePartitioning::ConvertKey( ARROW_ASSIGN_OR_RAISE(converted, Scalar::Parse(field->type(), key.value)); } - return equal(field_ref(field->name()), scalar(std::move(converted))); + return equal(field_ref(field->name()), literal(std::move(converted))); } -Result> KeyValuePartitioning::Parse( - const std::string& path) const { - ExpressionVector expressions; +Result KeyValuePartitioning::Parse(const std::string& path) const { + std::vector expressions; for (const Key& key : ParseKeys(path)) { ARROW_ASSIGN_OR_RAISE(auto expr, ConvertKey(key)); - - if (expr->Equals(true)) continue; - + if (expr == literal(true)) continue; expressions.push_back(std::move(expr)); } - return and_(std::move(expressions)); + auto parsed = and_(std::move(expressions)); + ARROW_ASSIGN_OR_RAISE(std::tie(parsed, std::ignore), parsed.Bind(*schema_)) + return parsed; } -Result KeyValuePartitioning::Format(const Expression& expr) const { +Result KeyValuePartitioning::Format(const Expression2& expr) const { + if (!expr.IsBound()) { + return Status::Invalid("formatted partition expressions must be bound"); + } + std::vector values{static_cast(schema_->num_fields()), nullptr}; - RETURN_NOT_OK(VisitKeys(expr, [&](const std::string& name, - const std::shared_ptr& value) { - ARROW_ASSIGN_OR_RAISE(auto match, FieldRef(name).FindOneOrNone(*schema_)); - if (match) { - const auto& field = schema_->field(match[0]); - if (!value->type->Equals(field->type())) { - return Status::TypeError("scalar ", value->ToString(), " (of type ", *value->type, - ") is invalid for ", field->ToString()); - } + ARROW_ASSIGN_OR_RAISE(auto known_values, ExtractKnownFieldValues(expr)); + for (const auto& ref_value : known_values) { + if (!ref_value.second.is_scalar()) { + return Status::Invalid("non-scalar partition key ", ref_value.second.ToString()); + } + + ARROW_ASSIGN_OR_RAISE(auto match, ref_value.first.FindOneOrNone(*schema_)); + if (!match) continue; - values[match[0]] = value.get(); + const auto& value = ref_value.second.scalar(); + + const auto& field = schema_->field(match[0]); + if (!value->type->Equals(field->type())) { + return Status::TypeError("scalar ", value->ToString(), " (of type ", *value->type, + ") is invalid for ", field->ToString()); } - return Status::OK(); - })); + + values[match[0]] = value.get(); + } return FormatValues(values); } diff --git a/cpp/src/arrow/dataset/partition.h b/cpp/src/arrow/dataset/partition.h index 165fcfb5248..1bad8956b65 100644 --- a/cpp/src/arrow/dataset/partition.h +++ b/cpp/src/arrow/dataset/partition.h @@ -26,7 +26,7 @@ #include #include -#include "arrow/dataset/filter.h" +#include "arrow/dataset/expression.h" #include "arrow/dataset/type_fwd.h" #include "arrow/dataset/visibility.h" #include "arrow/util/optional.h" @@ -63,15 +63,15 @@ class ARROW_DS_EXPORT Partitioning { /// produce sub-batches which satisfy mutually exclusive Expressions. struct PartitionedBatches { RecordBatchVector batches; - ExpressionVector expressions; + std::vector expressions; }; virtual Result Partition( const std::shared_ptr& batch) const = 0; /// \brief Parse a path into a partition expression - virtual Result> Parse(const std::string& path) const = 0; + virtual Result Parse(const std::string& path) const = 0; - virtual Result Format(const Expression& expr) const = 0; + virtual Result Format(const Expression2& expr) const = 0; /// \brief A default Partitioning which always yields scalar(true) static std::shared_ptr Default(); @@ -122,23 +122,15 @@ class ARROW_DS_EXPORT KeyValuePartitioning : public Partitioning { std::string name, value; }; - static Status VisitKeys( - const Expression& expr, - const std::function& value)>& visitor); - - static Result>> GetKeys( - const Expression& expr); - - static Status SetDefaultValuesFromKeys(const Expression& expr, + static Status SetDefaultValuesFromKeys(const Expression2& expr, RecordBatchProjector* projector); Result Partition( const std::shared_ptr& batch) const override; - Result> Parse(const std::string& path) const override; + Result Parse(const std::string& path) const override; - Result Format(const Expression& expr) const override; + Result Format(const Expression2& expr) const override; protected: KeyValuePartitioning(std::shared_ptr schema, ArrayVector dictionaries) @@ -153,7 +145,7 @@ class ARROW_DS_EXPORT KeyValuePartitioning : public Partitioning { virtual Result FormatValues(const std::vector& values) const = 0; /// Convert a Key to a full expression. - Result> ConvertKey(const Key& key) const; + Result ConvertKey(const Key& key) const; ArrayVector dictionaries_; }; @@ -215,10 +207,9 @@ class ARROW_DS_EXPORT HivePartitioning : public KeyValuePartitioning { /// \brief Implementation provided by lambda or other callable class ARROW_DS_EXPORT FunctionPartitioning : public Partitioning { public: - using ParseImpl = - std::function>(const std::string&)>; + using ParseImpl = std::function(const std::string&)>; - using FormatImpl = std::function(const Expression&)>; + using FormatImpl = std::function(const Expression2&)>; FunctionPartitioning(std::shared_ptr schema, ParseImpl parse_impl, FormatImpl format_impl = NULLPTR, std::string name = "function") @@ -229,11 +220,11 @@ class ARROW_DS_EXPORT FunctionPartitioning : public Partitioning { std::string type_name() const override { return name_; } - Result> Parse(const std::string& path) const override { + Result Parse(const std::string& path) const override { return parse_impl_(path); } - Result Format(const Expression& expr) const override { + Result Format(const Expression2& expr) const override { if (format_impl_) { return format_impl_(expr); } diff --git a/cpp/src/arrow/dataset/partition_test.cc b/cpp/src/arrow/dataset/partition_test.cc index b47f5ba9926..37e65cacac4 100644 --- a/cpp/src/arrow/dataset/partition_test.cc +++ b/cpp/src/arrow/dataset/partition_test.cc @@ -37,32 +37,40 @@ using internal::checked_pointer_cast; namespace dataset { -using E = TestExpression; - class TestPartitioning : public ::testing::Test { public: void AssertParseError(const std::string& path) { ASSERT_RAISES(Invalid, partitioning_->Parse(path)); } - void AssertParse(const std::string& path, E expected) { + void AssertParse(const std::string& path, Expression2 expected) { ASSERT_OK_AND_ASSIGN(auto parsed, partitioning_->Parse(path)); - ASSERT_EQ(E{parsed}, expected); + ASSERT_OK_AND_ASSIGN(std::tie(expected, std::ignore), + expected.Bind(*partitioning_->schema())); + ASSERT_EQ(parsed, expected); } template - void AssertFormatError(E expr) { - ASSERT_EQ(partitioning_->Format(*expr.expression).status().code(), code); + void AssertFormatError(Expression2 expr) { + ASSERT_OK_AND_ASSIGN(std::tie(expr, std::ignore), expr.Bind(*written_schema_)); + ASSERT_EQ(partitioning_->Format(expr).status().code(), code); } - void AssertFormat(E expr, const std::string& expected) { - ASSERT_OK_AND_ASSIGN(auto formatted, partitioning_->Format(*expr.expression)); + void AssertFormat(Expression2 expr, const std::string& expected) { + // formatted partition expressions are bound to the schema of the dataset being + // written + ASSERT_OK_AND_ASSIGN(std::tie(expr, std::ignore), expr.Bind(*written_schema_)); + + ASSERT_OK_AND_ASSIGN(auto formatted, partitioning_->Format(expr)); ASSERT_EQ(formatted, expected); // ensure the formatted path round trips the relevant components of the partition // expression: roundtripped should be a subset of expr - ASSERT_OK_AND_ASSIGN(auto roundtripped, partitioning_->Parse(formatted)); - ASSERT_EQ(E{roundtripped->Assume(*expr.expression)}, E{scalar(true)}); + ASSERT_OK_AND_ASSIGN(Expression2 roundtripped, partitioning_->Parse(formatted)); + + ASSERT_OK_AND_ASSIGN(auto bound, roundtripped.Bind(*partitioning_->schema())); + ASSERT_OK_AND_ASSIGN(auto simplified, SimplifyWithGuarantee(bound, expr)); + ASSERT_EQ(simplified.first, literal(true)); } void AssertInspect(const std::vector& paths, @@ -95,6 +103,7 @@ class TestPartitioning : public ::testing::Test { std::shared_ptr partitioning_; std::shared_ptr factory_; + std::shared_ptr written_schema_; }; TEST_F(TestPartitioning, DirectoryPartitioning) { @@ -103,10 +112,10 @@ TEST_F(TestPartitioning, DirectoryPartitioning) { AssertParse("/0/hello", "alpha"_ == int32_t(0) and "beta"_ == "hello"); AssertParse("/3", "alpha"_ == int32_t(3)); - AssertParseError("/world/0"); // reversed order - AssertParseError("/0.0/foo"); // invalid alpha - AssertParseError("/3.25"); // invalid alpha with missing beta - AssertParse("", scalar(true)); // no segments to parse + AssertParseError("/world/0"); // reversed order + AssertParseError("/0.0/foo"); // invalid alpha + AssertParseError("/3.25"); // invalid alpha with missing beta + AssertParse("", literal(true)); // no segments to parse // gotcha someday: AssertParse("/0/dat.parquet", "alpha"_ == int32_t(0) and "beta"_ == "dat.parquet"); @@ -118,15 +127,22 @@ TEST_F(TestPartitioning, DirectoryPartitioningFormat) { partitioning_ = std::make_shared( schema({field("alpha", int32()), field("beta", utf8())})); + written_schema_ = partitioning_->schema(); + AssertFormat("alpha"_ == int32_t(0) and "beta"_ == "hello", "0/hello"); AssertFormat("beta"_ == "hello" and "alpha"_ == int32_t(0), "0/hello"); AssertFormat("alpha"_ == int32_t(0), "0"); AssertFormatError("beta"_ == "hello"); - AssertFormat(scalar(true), ""); + AssertFormat(literal(true), ""); - AssertFormatError("alpha"_ == 0.0 and "beta"_ == "hello"); + ASSERT_OK_AND_ASSIGN(written_schema_, + written_schema_->AddField(0, field("gamma", utf8()))); AssertFormat("gamma"_ == "yo" and "alpha"_ == int32_t(0) and "beta"_ == "hello", "0/hello"); + + // written_schema_ is incompatible with partitioning_'s schema + written_schema_ = schema({field("alpha", utf8()), field("beta", utf8())}); + AssertFormatError("alpha"_ == "0.0" and "beta"_ == "hello"); } TEST_F(TestPartitioning, DirectoryPartitioningWithTemporal) { @@ -216,7 +232,7 @@ TEST_F(TestPartitioning, HivePartitioning) { AssertParse("/beta=3.25/alpha=0", "beta"_ == 3.25f and "alpha"_ == int32_t(0)); AssertParse("/alpha=0", "alpha"_ == int32_t(0)); AssertParse("/beta=3.25", "beta"_ == 3.25f); - AssertParse("", scalar(true)); + AssertParse("", literal(true)); AssertParse("/alpha=0/unexpected/beta=3.25", "alpha"_ == int32_t(0) and "beta"_ == 3.25f); @@ -224,7 +240,7 @@ TEST_F(TestPartitioning, HivePartitioning) { AssertParse("/alpha=0/beta=3.25/ignored=2341", "alpha"_ == int32_t(0) and "beta"_ == 3.25f); - AssertParse("/ignored=2341", scalar(true)); + AssertParse("/ignored=2341", literal(true)); AssertParseError("/alpha=0.0/beta=3.25"); // conversion of "0.0" to int32 fails } @@ -233,15 +249,22 @@ TEST_F(TestPartitioning, HivePartitioningFormat) { partitioning_ = std::make_shared( schema({field("alpha", int32()), field("beta", float32())})); + written_schema_ = partitioning_->schema(); + AssertFormat("alpha"_ == int32_t(0) and "beta"_ == 3.25f, "alpha=0/beta=3.25"); AssertFormat("beta"_ == 3.25f and "alpha"_ == int32_t(0), "alpha=0/beta=3.25"); AssertFormat("alpha"_ == int32_t(0), "alpha=0"); AssertFormat("beta"_ == 3.25f, "alpha/beta=3.25"); - AssertFormat(scalar(true), ""); + AssertFormat(literal(true), ""); - AssertFormatError("alpha"_ == "yo" and "beta"_ == 3.25f); + ASSERT_OK_AND_ASSIGN(written_schema_, + written_schema_->AddField(0, field("gamma", utf8()))); AssertFormat("gamma"_ == "yo" and "alpha"_ == int32_t(0) and "beta"_ == 3.25f, "alpha=0/beta=3.25"); + + // written_schema_ is incompatible with partitioning_'s schema + written_schema_ = schema({field("alpha", utf8()), field("beta", utf8())}); + AssertFormatError("alpha"_ == "0.0" and "beta"_ == "hello"); } TEST_F(TestPartitioning, DiscoverHiveSchema) { @@ -322,10 +345,12 @@ TEST_F(TestPartitioning, EtlThenHive) { FieldVector alphabeta_fields{field("alpha", int32()), field("beta", float32())}; HivePartitioning alphabeta_part(schema(alphabeta_fields)); - partitioning_ = std::make_shared( + auto schm = schema({field("year", int16()), field("month", int8()), field("day", int8()), - field("hour", int8()), field("alpha", int32()), field("beta", float32())}), - [&](const std::string& path) -> Result> { + field("hour", int8()), field("alpha", int32()), field("beta", float32())}); + + partitioning_ = std::make_shared( + schm, [&](const std::string& path) -> Result { auto segments = fs::internal::SplitAbstractPath(path); if (segments.size() < etl_fields.size() + alphabeta_fields.size()) { return Status::Invalid("path ", path, " can't be parsed"); @@ -341,7 +366,9 @@ TEST_F(TestPartitioning, EtlThenHive) { fs::internal::JoinAbstractPath(etl_segments_end, alphabeta_segments_end); ARROW_ASSIGN_OR_RAISE(auto alphabeta_expr, alphabeta_part.Parse(alphabeta_path)); - return and_(etl_expr, alphabeta_expr); + auto expr = and_(etl_expr, alphabeta_expr); + ARROW_ASSIGN_OR_RAISE(std::tie(expr, std::ignore), expr.Bind(*schm)); + return expr; }); AssertParse("/1999/12/31/00/alpha=0/beta=3.25", @@ -350,7 +377,7 @@ TEST_F(TestPartitioning, EtlThenHive) { ("alpha"_ == int32_t(0) and "beta"_ == 3.25f)); AssertParseError("/20X6/03/21/05/alpha=0/beta=3.25"); -} // namespace dataset +} TEST_F(TestPartitioning, Set) { auto ints = [](std::vector ints) { @@ -359,12 +386,13 @@ TEST_F(TestPartitioning, Set) { return out; }; + auto schm = schema({field("x", int32())}); + // An adhoc partitioning which parses segments like "/x in [1 4 5]" // into ("x"_ == 1 or "x"_ == 4 or "x"_ == 5) partitioning_ = std::make_shared( - schema({field("x", int32())}), - [&](const std::string& path) -> Result> { - ExpressionVector subexpressions; + schm, [&](const std::string& path) -> Result { + std::vector subexpressions; for (auto segment : fs::internal::SplitAbstractPath(path)) { std::smatch matches; @@ -380,7 +408,13 @@ TEST_F(TestPartitioning, Set) { set.push_back(checked_cast(*s).value); } - subexpressions.push_back(field_ref(matches[1])->In(ints(set)).Copy()); + auto is_in_expr = call("is_in", {field_ref(matches[1])}, + compute::SetLookupOptions{ints(set), true}); + + ARROW_ASSIGN_OR_RAISE(std::tie(is_in_expr, std::ignore), + is_in_expr.Bind(*schm)); + + subexpressions.push_back(is_in_expr); } return and_(std::move(subexpressions)); }); @@ -398,8 +432,8 @@ class RangePartitioning : public Partitioning { std::string type_name() const override { return "range"; } - Result> Parse(const std::string& path) const override { - ExpressionVector ranges; + Result Parse(const std::string& path) const override { + std::vector ranges; for (auto segment : fs::internal::SplitAbstractPath(path)) { auto key = HivePartitioning::ParseKey(segment); @@ -419,10 +453,13 @@ class RangePartitioning : public Partitioning { ARROW_ASSIGN_OR_RAISE(auto min, Scalar::Parse(type, min_repr)); ARROW_ASSIGN_OR_RAISE(auto max, Scalar::Parse(type, max_repr)); - ranges.push_back(and_(min_cmp(field_ref(key->name), scalar(min)), - max_cmp(field_ref(key->name), scalar(max)))); + ranges.push_back(and_(min_cmp(field_ref(key->name), literal(min)), + max_cmp(field_ref(key->name), literal(max)))); } - return and_(ranges); + + Expression2 expr; + ARROW_ASSIGN_OR_RAISE(std::tie(expr, std::ignore), and_(ranges).Bind(*schema_)); + return expr; } static Status DoRegex(const std::string& segment, std::smatch* matches) { @@ -442,7 +479,7 @@ class RangePartitioning : public Partitioning { return Status::OK(); } - Result Format(const Expression&) const override { return ""; } + Result Format(const Expression2&) const override { return ""; } Result Partition( const std::shared_ptr&) const override { return Status::OK(); diff --git a/cpp/src/arrow/dataset/scanner.cc b/cpp/src/arrow/dataset/scanner.cc index 019416041aa..67df665733b 100644 --- a/cpp/src/arrow/dataset/scanner.cc +++ b/cpp/src/arrow/dataset/scanner.cc @@ -27,6 +27,7 @@ #include "arrow/dataset/scanner_internal.h" #include "arrow/table.h" #include "arrow/util/iterator.h" +#include "arrow/util/logging.h" #include "arrow/util/task_group.h" #include "arrow/util/thread_pool.h" @@ -34,14 +35,12 @@ namespace arrow { namespace dataset { ScanOptions::ScanOptions(std::shared_ptr schema) - : evaluator(ExpressionEvaluator::Null()), - projector(RecordBatchProjector(std::move(schema))) {} + : projector(RecordBatchProjector(std::move(schema))) {} std::shared_ptr ScanOptions::ReplaceSchema( std::shared_ptr schema) const { auto copy = ScanOptions::Make(std::move(schema)); - copy->filter = filter; - copy->evaluator = evaluator; + copy->filter2 = filter2; copy->batch_size = batch_size; return copy; } @@ -53,8 +52,9 @@ std::vector ScanOptions::MaterializedFields() const { fields.push_back(f->name()); } - for (auto&& name : FieldsInExpression(filter)) { - fields.push_back(std::move(name)); + for (const FieldRef& ref : FieldsInExpression(filter2)) { + DCHECK(ref.name()); + fields.push_back(*ref.name()); } return fields; @@ -64,7 +64,7 @@ Result InMemoryScanTask::Execute() { return MakeVectorIterator(record_batches_); } -FragmentIterator Scanner::GetFragments() { +Result Scanner::GetFragments() { if (fragment_ != nullptr) { return MakeVectorIterator(FragmentVector{fragment_}); } @@ -72,14 +72,16 @@ FragmentIterator Scanner::GetFragments() { // Transform Datasets in a flat Iterator. This // iterator is lazily constructed, i.e. Dataset::GetFragments is // not invoked until a Fragment is requested. - return GetFragmentsFromDatasets({dataset_}, scan_options_->filter); + return GetFragmentsFromDatasets( + {dataset_}, {scan_options_->filter2, scan_context_->expression_state}); } Result Scanner::Scan() { // Transforms Iterator into a unified // Iterator. The first Iterator::Next invocation is going to do // all the work of unwinding the chained iterators. - return GetScanTaskIterator(GetFragments(), scan_options_, scan_context_); + ARROW_ASSIGN_OR_RAISE(auto fragment_it, GetFragments()); + return GetScanTaskIterator(std::move(fragment_it), scan_options_, scan_context_); } Result ScanTaskIteratorFromRecordBatch( @@ -95,15 +97,24 @@ ScannerBuilder::ScannerBuilder(std::shared_ptr dataset, : dataset_(std::move(dataset)), fragment_(nullptr), scan_options_(ScanOptions::Make(dataset_->schema())), - scan_context_(std::move(scan_context)) {} + scan_context_(std::move(scan_context)) { + DCHECK_OK(Filter(literal(true))); +} ScannerBuilder::ScannerBuilder(std::shared_ptr schema, std::shared_ptr fragment, std::shared_ptr scan_context) : dataset_(nullptr), fragment_(std::move(fragment)), - scan_options_(ScanOptions::Make(schema)), - scan_context_(std::move(scan_context)) {} + fragment_schema_(schema), + scan_options_(ScanOptions::Make(std::move(schema))), + scan_context_(std::move(scan_context)) { + DCHECK_OK(Filter(literal(true))); +} + +const std::shared_ptr& ScannerBuilder::schema() const { + return fragment_ ? fragment_schema_ : dataset_->schema(); +} Status ScannerBuilder::Project(std::vector columns) { RETURN_NOT_OK(schema()->CanReferenceFieldsByNames(columns)); @@ -112,15 +123,12 @@ Status ScannerBuilder::Project(std::vector columns) { return Status::OK(); } -Status ScannerBuilder::Filter(std::shared_ptr filter) { - RETURN_NOT_OK(schema()->CanReferenceFieldsByNames(FieldsInExpression(*filter))); - RETURN_NOT_OK(filter->Validate(*schema())); - scan_options_->filter = std::move(filter); +Status ScannerBuilder::Filter(const Expression2& filter) { + ARROW_ASSIGN_OR_RAISE(std::tie(scan_options_->filter2, scan_context_->expression_state), + filter.Bind(*schema())); return Status::OK(); } -Status ScannerBuilder::Filter(const Expression& filter) { return Filter(filter.Copy()); } - Status ScannerBuilder::UseThreads(bool use_threads) { scan_context_->use_threads = use_threads; return Status::OK(); @@ -143,10 +151,6 @@ Result> ScannerBuilder::Finish() const { scan_options = std::make_shared(*scan_options_); } - if (!scan_options->filter->Equals(true)) { - scan_options->evaluator = std::make_shared(); - } - if (dataset_ == nullptr) { return std::make_shared(fragment_, std::move(scan_options), scan_context_); } diff --git a/cpp/src/arrow/dataset/scanner.h b/cpp/src/arrow/dataset/scanner.h index 950d416f615..27c2e3e9ed7 100644 --- a/cpp/src/arrow/dataset/scanner.h +++ b/cpp/src/arrow/dataset/scanner.h @@ -25,6 +25,7 @@ #include #include "arrow/dataset/dataset.h" +#include "arrow/dataset/expression.h" #include "arrow/dataset/projector.h" #include "arrow/dataset/type_fwd.h" #include "arrow/dataset/visibility.h" @@ -47,6 +48,8 @@ struct ARROW_DS_EXPORT ScanContext { /// Return a threaded or serial TaskGroup according to use_threads. std::shared_ptr TaskGroup() const; + + std::shared_ptr expression_state; }; class ARROW_DS_EXPORT ScanOptions { @@ -62,10 +65,7 @@ class ARROW_DS_EXPORT ScanOptions { std::shared_ptr ReplaceSchema(std::shared_ptr schema) const; // Filter - std::shared_ptr filter = scalar(true); - - // Evaluator for Filter - std::shared_ptr evaluator; + Expression2 filter2 = literal(true); // Schema to which record batches will be reconciled const std::shared_ptr& schema() const { return projector.schema(); } @@ -172,7 +172,7 @@ class ARROW_DS_EXPORT Scanner { Result> ToTable(); /// \brief GetFragments returns an iterator over all Fragments in this scan. - FragmentIterator GetFragments(); + Result GetFragments(); const std::shared_ptr& schema() const { return scan_options_->schema(); } @@ -222,8 +222,7 @@ class ARROW_DS_EXPORT ScannerBuilder { /// /// \return Failure if any referenced columns does not exist in the dataset's /// Schema. - Status Filter(std::shared_ptr filter); - Status Filter(const Expression& filter); + Status Filter(const Expression2& filter); /// \brief Indicate if the Scanner should make use of the available /// ThreadPool found in ScanContext; @@ -240,11 +239,12 @@ class ARROW_DS_EXPORT ScannerBuilder { /// \brief Return the constructed now-immutable Scanner object Result> Finish() const; - std::shared_ptr schema() const { return scan_options_->schema(); } + const std::shared_ptr& schema() const; private: std::shared_ptr dataset_; std::shared_ptr fragment_; + std::shared_ptr fragment_schema_; std::shared_ptr scan_options_; std::shared_ptr scan_context_; bool has_projection_ = false; diff --git a/cpp/src/arrow/dataset/scanner_internal.h b/cpp/src/arrow/dataset/scanner_internal.h index 94df94470fa..f01ced8396b 100644 --- a/cpp/src/arrow/dataset/scanner_internal.h +++ b/cpp/src/arrow/dataset/scanner_internal.h @@ -20,6 +20,9 @@ #include #include +#include "arrow/array/array_nested.h" +#include "arrow/compute/api_vector.h" +#include "arrow/compute/exec.h" #include "arrow/dataset/dataset_internal.h" #include "arrow/dataset/filter.h" #include "arrow/dataset/partition.h" @@ -29,13 +32,27 @@ namespace arrow { namespace dataset { inline RecordBatchIterator FilterRecordBatch(RecordBatchIterator it, - const ExpressionEvaluator& evaluator, - const Expression& filter, MemoryPool* pool) { + Expression2::BoundWithState filter, + MemoryPool* pool) { return MakeMaybeMapIterator( - [&filter, &evaluator, pool](std::shared_ptr in) { - return evaluator.Evaluate(filter, *in, pool).Map([&](Datum selection) { - return evaluator.Filter(selection, in); - }); + [=](std::shared_ptr in) -> Result> { + compute::ExecContext exec_context{pool}; + ARROW_ASSIGN_OR_RAISE(Datum mask, + ExecuteScalarExpression(filter.first, filter.second.get(), + Datum(in), &exec_context)); + + if (mask.is_scalar()) { + const auto& mask_scalar = mask.scalar_as(); + if (mask_scalar.is_valid && mask_scalar.value) { + return std::move(in); + } + return in->Slice(0, 0); + } + + ARROW_ASSIGN_OR_RAISE( + Datum filtered, + compute::Filter(in, mask, compute::FilterOptions::Defaults(), &exec_context)); + return filtered.record_batch(); }, std::move(it)); } @@ -56,39 +73,41 @@ inline RecordBatchIterator ProjectRecordBatch(RecordBatchIterator it, class FilterAndProjectScanTask : public ScanTask { public: - explicit FilterAndProjectScanTask(std::shared_ptr task, - std::shared_ptr partition) + explicit FilterAndProjectScanTask(std::shared_ptr task, Expression2 partition) : ScanTask(task->options(), task->context()), task_(std::move(task)), partition_(std::move(partition)), - filter_(options()->filter->Assume(partition_)), + filter_(options()->filter2), projector_(options()->projector) {} Result Execute() override { ARROW_ASSIGN_OR_RAISE(auto it, task_->Execute()); - auto filter_it = - FilterRecordBatch(std::move(it), *options_->evaluator, *filter_, context_->pool); + ARROW_ASSIGN_OR_RAISE( + Expression2::BoundWithState simplified_filter, + SimplifyWithGuarantee({filter_, context_->expression_state}, partition_)); + + RecordBatchIterator filter_it = + FilterRecordBatch(std::move(it), simplified_filter, context_->pool); + + RETURN_NOT_OK( + KeyValuePartitioning::SetDefaultValuesFromKeys(partition_, &projector_)); - if (partition_) { - RETURN_NOT_OK( - KeyValuePartitioning::SetDefaultValuesFromKeys(*partition_, &projector_)); - } return ProjectRecordBatch(std::move(filter_it), &projector_, context_->pool); } private: std::shared_ptr task_; - std::shared_ptr partition_; - std::shared_ptr filter_; + Expression2 partition_; + Expression2 filter_; RecordBatchProjector projector_; }; /// \brief GetScanTaskIterator transforms an Iterator in a /// flattened Iterator. -inline ScanTaskIterator GetScanTaskIterator(FragmentIterator fragments, - std::shared_ptr options, - std::shared_ptr context) { +inline Result GetScanTaskIterator( + FragmentIterator fragments, std::shared_ptr options, + std::shared_ptr context) { // Fragment -> ScanTaskIterator auto fn = [options, context](std::shared_ptr fragment) -> Result { @@ -112,42 +131,5 @@ inline ScanTaskIterator GetScanTaskIterator(FragmentIterator fragments, return MakeFlattenIterator(std::move(maybe_scantask_it)); } -struct FragmentRecordBatchReader : RecordBatchReader { - public: - std::shared_ptr schema() const override { return options_->schema(); } - - Status ReadNext(std::shared_ptr* batch) override { - return iterator_.Next().Value(batch); - } - - static Result> Make( - std::shared_ptr fragment, std::shared_ptr schema, - std::shared_ptr context) { - // ensure schema is cached in fragment - auto options = ScanOptions::Make(std::move(schema)); - RETURN_NOT_OK(KeyValuePartitioning::SetDefaultValuesFromKeys( - *fragment->partition_expression(), &options->projector)); - - auto pool = context->pool; - ARROW_ASSIGN_OR_RAISE(auto scan_tasks, fragment->Scan(options, std::move(context))); - - auto reader = std::make_shared(); - reader->options_ = std::move(options); - reader->fragment_ = std::move(fragment); - reader->iterator_ = ProjectRecordBatch( - MakeFlattenIterator(MakeMaybeMapIterator( - [](std::shared_ptr task) { return task->Execute(); }, - std::move(scan_tasks))), - &reader->options_->projector, pool); - - return reader; - } - - private: - std::shared_ptr options_; - std::shared_ptr fragment_; - RecordBatchIterator iterator_; -}; - } // namespace dataset } // namespace arrow diff --git a/cpp/src/arrow/dataset/scanner_test.cc b/cpp/src/arrow/dataset/scanner_test.cc index 36966760346..d79cc49faac 100644 --- a/cpp/src/arrow/dataset/scanner_test.cc +++ b/cpp/src/arrow/dataset/scanner_test.cc @@ -55,7 +55,7 @@ class TestScanner : public DatasetFixtureMixin { // structures of the scanner, i.e. Scanner[Dataset[ScanTask[RecordBatch]]] AssertScannerEquals(expected.get(), &scanner); } -}; // namespace dataset +}; constexpr int64_t TestScanner::kNumberChildDatasets; constexpr int64_t TestScanner::kNumberBatches; @@ -88,8 +88,7 @@ TEST_F(TestScanner, FilteredScan) { value += 1.0; })); - options_->filter = ("f64"_ > 0.0).Copy(); - options_->evaluator = std::make_shared(); + SetFilter(greater(field_ref("f64"), literal(0.0))); auto batch = RecordBatch::Make(schema_, f64->length(), {f64}); @@ -148,7 +147,7 @@ TEST_F(TestScanner, ToTable) { } class TestScannerBuilder : public ::testing::Test { - void SetUp() { + void SetUp() override { DatasetVector sources; schema_ = schema({ @@ -163,7 +162,7 @@ class TestScannerBuilder : public ::testing::Test { } protected: - std::shared_ptr ctx_; + std::shared_ptr ctx_ = std::make_shared(); std::shared_ptr schema_; std::shared_ptr dataset_; }; @@ -184,14 +183,28 @@ TEST_F(TestScannerBuilder, TestProject) { TEST_F(TestScannerBuilder, TestFilter) { ScannerBuilder builder(dataset_, ctx_); - ASSERT_OK(builder.Filter(scalar(true))); - ASSERT_OK(builder.Filter("i64"_ == int64_t(10))); - ASSERT_OK(builder.Filter("i64"_ == int64_t(10) || "b"_ == true)); - - ASSERT_RAISES(TypeError, builder.Filter("i64"_ == int32_t(10))); - ASSERT_RAISES(Invalid, builder.Filter("not_a_column"_ == true)); - ASSERT_RAISES(Invalid, - builder.Filter("i64"_ == int64_t(10) || "not_a_column"_ == true)); + ASSERT_OK(builder.Filter(literal(true))); + ASSERT_OK(builder.Filter(call("equal", {field_ref("i64"), literal(10)}))); + ASSERT_OK(builder.Filter( + call("or_kleene", { + call("equal", {field_ref("i64"), literal(10)}), + call("equal", {field_ref("b"), literal(true)}), + }))); + + ASSERT_RAISES(NotImplemented, + builder.Filter(call("equal", {field_ref("i64"), literal(10)}))); + + ASSERT_RAISES( + NotImplemented, + builder.Filter(call("equal", {field_ref("not_a_column"), literal(10)}))); + + ASSERT_RAISES( + NotImplemented, + builder.Filter( + call("or_kleene", { + call("equal", {field_ref("i64"), literal(10)}), + call("equal", {field_ref("not_a_column"), literal(true)}), + }))); } using testing::ElementsAre; @@ -204,7 +217,7 @@ TEST(ScanOptions, TestMaterializedFields) { auto opts = ScanOptions::Make(schema({})); EXPECT_THAT(opts->MaterializedFields(), IsEmpty()); - opts->filter = ("i32"_ == 10).Copy(); + opts->filter2 = call("equal", {field_ref("i32"), literal(10)}); EXPECT_THAT(opts->MaterializedFields(), ElementsAre("i32")); opts = ScanOptions::Make(schema({i32, i64})); @@ -213,10 +226,10 @@ TEST(ScanOptions, TestMaterializedFields) { opts = opts->ReplaceSchema(schema({i32})); EXPECT_THAT(opts->MaterializedFields(), ElementsAre("i32")); - opts->filter = ("i32"_ == 10).Copy(); + opts->filter2 = call("equal", {field_ref("i32"), literal(10)}); EXPECT_THAT(opts->MaterializedFields(), ElementsAre("i32", "i32")); - opts->filter = ("i64"_ == 10).Copy(); + opts->filter2 = call("equal", {field_ref("i64"), literal(10)}); EXPECT_THAT(opts->MaterializedFields(), ElementsAre("i32", "i64")); } diff --git a/cpp/src/arrow/dataset/test_util.h b/cpp/src/arrow/dataset/test_util.h index 9cfba41a624..70c9b7548e8 100644 --- a/cpp/src/arrow/dataset/test_util.h +++ b/cpp/src/arrow/dataset/test_util.h @@ -50,6 +50,43 @@ namespace arrow { namespace dataset { +const std::shared_ptr kBoringSchema = schema({ + field("i32", int32()), + field("i32_req", int32(), /*nullable=*/false), + field("f32", float32()), + field("f32_req", float32(), /*nullable=*/false), + field("bool", boolean()), +}); + +inline namespace string_literals { +// clang-format off +inline FieldExpression operator"" _(const char* name, size_t name_length) { + // clang-format on + return FieldExpression({name, name_length}); +} +} // namespace string_literals + +#define COMPARISON_FACTORY(NAME, FACTORY_NAME, OP) \ + template ::type>::value>::type> \ + ComparisonExpression operator OP(const Expression& lhs, T&& rhs) { \ + return ComparisonExpression(CompareOperator::NAME, lhs.Copy(), \ + scalar(std::forward(rhs))); \ + } \ + \ + inline ComparisonExpression operator OP(const Expression& lhs, \ + const Expression& rhs) { \ + return ComparisonExpression(CompareOperator::NAME, lhs.Copy(), rhs.Copy()); \ + } + +COMPARISON_FACTORY(EQUAL, equal, ==) +COMPARISON_FACTORY(NOT_EQUAL, not_equal, !=) +COMPARISON_FACTORY(GREATER, greater, >) +COMPARISON_FACTORY(GREATER_EQUAL, greater_equal, >=) +COMPARISON_FACTORY(LESS, less, <) +COMPARISON_FACTORY(LESS_EQUAL, less_equal, <=) +#undef COMPARISON_FACTORY + using fs::internal::GetAbstractPathExtension; using internal::checked_cast; using internal::checked_pointer_cast; @@ -66,7 +103,7 @@ template class GeneratedRecordBatch : public RecordBatchReader { public: GeneratedRecordBatch(std::shared_ptr schema, Gen gen) - : schema_(schema), gen_(gen) {} + : schema_(std::move(schema)), gen_(gen) {} std::shared_ptr schema() const override { return schema_; } @@ -101,8 +138,6 @@ void EnsureRecordBatchReaderDrained(RecordBatchReader* reader) { class DatasetFixtureMixin : public ::testing::Test { public: - DatasetFixtureMixin() : ctx_(std::make_shared()) {} - /// \brief Ensure that record batches found in reader are equals to the /// record batches yielded by the data fragment. void AssertScanTaskEquals(RecordBatchReader* expected, ScanTask* task, @@ -141,7 +176,8 @@ class DatasetFixtureMixin : public ::testing::Test { /// record batches yielded by the data fragments of a dataset. void AssertDatasetFragmentsEqual(RecordBatchReader* expected, Dataset* dataset, bool ensure_drained = true) { - auto it = dataset->GetFragments(options_->filter); + ASSERT_OK_AND_ASSIGN(auto predicate, options_->filter2.Bind(*dataset->schema())); + ASSERT_OK_AND_ASSIGN(auto it, dataset->GetFragments(predicate)); ARROW_EXPECT_OK(it.Visit([&](std::shared_ptr fragment) -> Status { AssertFragmentEquals(expected, fragment.get(), false); @@ -186,11 +222,17 @@ class DatasetFixtureMixin : public ::testing::Test { void SetSchema(std::vector> fields) { schema_ = schema(std::move(fields)); options_ = ScanOptions::Make(schema_); + SetFilter(literal(true)); + } + + void SetFilter(Expression2 filter) { + ASSERT_OK_AND_ASSIGN(std::tie(options_->filter2, ctx_->expression_state), + filter.Bind(*schema_)); } std::shared_ptr schema_; std::shared_ptr options_; - std::shared_ptr ctx_; + std::shared_ptr ctx_ = std::make_shared(); }; /// \brief A dummy FileFormat implementation @@ -321,15 +363,14 @@ struct MakeFileSystemDatasetMixin { } void MakeDataset(const std::vector& infos, - std::shared_ptr root_partition = scalar(true), - ExpressionVector partitions = {}) { + Expression2 root_partition = literal(true), + std::vector partitions = {}, + std::shared_ptr s = kBoringSchema) { auto n_fragments = infos.size(); if (partitions.empty()) { - partitions.resize(n_fragments, scalar(true)); + partitions.resize(n_fragments, literal(true)); } - auto s = schema({}); - MakeFileSystem(infos); auto format = std::make_shared(s); @@ -340,21 +381,17 @@ struct MakeFileSystemDatasetMixin { continue; } + ASSERT_OK_AND_ASSIGN(std::tie(partitions[i], std::ignore), partitions[i].Bind(*s)); ASSERT_OK_AND_ASSIGN(auto fragment, format->MakeFragment({info, fs_}, partitions[i])); fragments.push_back(std::move(fragment)); } + ASSERT_OK_AND_ASSIGN(std::tie(root_partition, std::ignore), root_partition.Bind(*s)); ASSERT_OK_AND_ASSIGN(dataset_, FileSystemDataset::Make(s, root_partition, format, fs_, std::move(fragments))); } - void MakeDatasetFromPathlist(const std::string& pathlist, - std::shared_ptr root_partition = scalar(true), - ExpressionVector partitions = {}) { - MakeDataset(ParsePathList(pathlist), root_partition, partitions); - } - std::shared_ptr fs_; std::shared_ptr dataset_; std::shared_ptr options_ = ScanOptions::Make(schema({})); @@ -387,49 +424,24 @@ void AssertFragmentsAreFromPath(FragmentIterator it, std::vector ex testing::UnorderedElementsAreArray(expected)); } -// A frozen shared_ptr with behavior expected by GTest -struct TestExpression : util::EqualityComparable, - util::ToStringOstreamable { - // NOLINTNEXTLINE runtime/explicit - TestExpression(std::shared_ptr e) : expression(std::move(e)) {} - - // NOLINTNEXTLINE runtime/explicit - TestExpression(const Expression& e) : expression(e.Copy()) {} - - std::shared_ptr expression; - - using util::EqualityComparable::operator==; - bool Equals(const TestExpression& other) const { - return expression->Equals(other.expression); - } - - std::string ToString() const { return expression->ToString(); } - - friend bool operator==(const std::shared_ptr& lhs, - const TestExpression& rhs) { - return TestExpression(lhs) == rhs; - } - - friend void PrintTo(const TestExpression& expr, std::ostream* os) { - *os << expr.ToString(); - } -}; - -static std::vector PartitionExpressionsOf( - const FragmentVector& fragments) { - std::vector partition_expressions; +static std::vector PartitionExpressionsOf(const FragmentVector& fragments) { + std::vector partition_expressions; std::transform(fragments.begin(), fragments.end(), std::back_inserter(partition_expressions), [](const std::shared_ptr& fragment) { - return TestExpression(fragment->partition_expression()); + return fragment->partition_expression(); }); return partition_expressions; } -void AssertFragmentsHavePartitionExpressions(FragmentIterator it, - ExpressionVector expected) { +void AssertFragmentsHavePartitionExpressions(std::shared_ptr dataset, + std::vector expected) { + ASSERT_OK_AND_ASSIGN(auto fragment_it, dataset->GetFragments()); + for (auto& expr : expected) { + ASSERT_OK_AND_ASSIGN(std::tie(expr, std::ignore), expr.Bind(*dataset->schema())); + } // Ordering is not guaranteed. - EXPECT_THAT(PartitionExpressionsOf(IteratorToVector(std::move(it))), + EXPECT_THAT(PartitionExpressionsOf(IteratorToVector(std::move(fragment_it))), testing::UnorderedElementsAreArray(expected)); } @@ -456,7 +468,7 @@ struct ArithmeticDatasetFixture { /// 2}, static std::string JSONRecordFor(int64_t n) { std::stringstream ss; - int32_t n_i32 = static_cast(n); + auto n_i32 = static_cast(n); ss << "{"; ss << "\"i64\": " << n << ", "; @@ -579,6 +591,9 @@ class WriteFileSystemDatasetMixin : public MakeFileSystemDatasetMixin { ASSERT_OK_AND_ASSIGN(dataset_, factory->Finish()); scan_options_ = ScanOptions::Make(source_schema_); + ASSERT_OK_AND_ASSIGN( + std::tie(scan_options_->filter2, scan_context_->expression_state), + literal(true).Bind(*source_schema_)); } void SetWriteOptions(std::shared_ptr file_write_options) { @@ -741,7 +756,8 @@ class WriteFileSystemDatasetMixin : public MakeFileSystemDatasetMixin { } EXPECT_THAT(actual_paths, testing::UnorderedElementsAreArray(expected_paths)); - for (auto maybe_fragment : written_->GetFragments()) { + ASSERT_OK_AND_ASSIGN(auto written_fragments_it, written_->GetFragments()); + for (auto maybe_fragment : written_fragments_it) { ASSERT_OK_AND_ASSIGN(auto fragment, maybe_fragment); ASSERT_OK_AND_ASSIGN(auto actual_physical_schema, fragment->ReadPhysicalSchema()); diff --git a/cpp/src/arrow/dataset/type_fwd.h b/cpp/src/arrow/dataset/type_fwd.h index 522def49818..e4bd322ba10 100644 --- a/cpp/src/arrow/dataset/type_fwd.h +++ b/cpp/src/arrow/dataset/type_fwd.h @@ -69,12 +69,6 @@ class ParquetFileWriter; class ParquetFileWriteOptions; class Expression; -using ExpressionVector = std::vector>; -class ExpressionEvaluator; - -/// forward declared to facilitate scalar(true) as a default for Expression parameters -ARROW_DS_EXPORT -const std::shared_ptr& scalar(bool); struct ExpressionState; class Expression2; diff --git a/cpp/src/arrow/datum.cc b/cpp/src/arrow/datum.cc index c1666adc4c5..0f2fa8ebe90 100644 --- a/cpp/src/arrow/datum.cc +++ b/cpp/src/arrow/datum.cc @@ -70,6 +70,7 @@ Datum::Datum(float value) : value(std::make_shared(value)) {} Datum::Datum(double value) : value(std::make_shared(value)) {} Datum::Datum(std::string value) : value(std::make_shared(std::move(value))) {} +Datum::Datum(const char* value) : value(std::make_shared(value)) {} Datum::Datum(const ChunkedArray& value) : value(std::make_shared(value.chunks(), value.type())) {} @@ -198,6 +199,8 @@ static std::string FormatValueDescr(const ValueDescr& descr) { std::string ValueDescr::ToString() const { return FormatValueDescr(*this); } +void PrintTo(const ValueDescr& descr, std::ostream* os) { *os << descr.ToString(); } + std::string Datum::ToString() const { switch (this->kind()) { case Datum::NONE: diff --git a/cpp/src/arrow/datum.h b/cpp/src/arrow/datum.h index 5f80551f337..9c620279c9e 100644 --- a/cpp/src/arrow/datum.h +++ b/cpp/src/arrow/datum.h @@ -89,6 +89,8 @@ struct ARROW_EXPORT ValueDescr { bool operator!=(const ValueDescr& other) const { return !(*this == other); } std::string ToString() const; + + ARROW_EXPORT friend void PrintTo(const ValueDescr&, std::ostream*); }; /// \brief For use with scalar functions, returns the broadcasted Value::Shape @@ -161,6 +163,7 @@ struct ARROW_EXPORT Datum { explicit Datum(float value); explicit Datum(double value); explicit Datum(std::string value); + explicit Datum(const char* value); Datum::Kind kind() const { switch (this->value.index()) { diff --git a/cpp/src/arrow/type.cc b/cpp/src/arrow/type.cc index 5604c5e2717..20dc0f14ca0 100644 --- a/cpp/src/arrow/type.cc +++ b/cpp/src/arrow/type.cc @@ -1501,7 +1501,7 @@ std::shared_ptr Schema::WithMetadata( return std::make_shared(impl_->fields_, metadata); } -std::shared_ptr Schema::metadata() const { +const std::shared_ptr& Schema::metadata() const { return impl_->metadata_; } diff --git a/cpp/src/arrow/type.h b/cpp/src/arrow/type.h index 3594c248e10..cadb819b9f2 100644 --- a/cpp/src/arrow/type.h +++ b/cpp/src/arrow/type.h @@ -1690,7 +1690,7 @@ class ARROW_EXPORT Schema : public detail::Fingerprintable, /// \brief The custom key-value metadata, if any /// /// \return metadata may be null - std::shared_ptr metadata() const; + const std::shared_ptr& metadata() const; /// \brief Render a string representation of the schema suitable for debugging /// \param[in] show_metadata when true, if KeyValueMetadata is non-empty, diff --git a/cpp/src/arrow/util/variant.h b/cpp/src/arrow/util/variant.h index 4713976d485..89f39ab8917 100644 --- a/cpp/src/arrow/util/variant.h +++ b/cpp/src/arrow/util/variant.h @@ -389,6 +389,22 @@ U& get(Variant& v) { return *v.template get(); } +/// \brief Get a const pointer to the value held by the variant +/// +/// If the type given as template argument doesn't match, a nullptr is returned. +template +const U* get_if(const Variant* v) { + return v->template get(); +} + +/// \brief Get a pointer to the value held by the variant +/// +/// If the type given as template argument doesn't match, a nullptr is returned. +template +U* get_if(Variant* v) { + return v->template get(); +} + namespace detail { template From 67dcd0eab64242a18e8fcf9a16bf44ad82f09e8e Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Wed, 9 Dec 2020 13:32:16 -0500 Subject: [PATCH 03/38] remove ExpressionState --- cpp/src/arrow/dataset/dataset.cc | 11 +- cpp/src/arrow/dataset/dataset.h | 9 +- cpp/src/arrow/dataset/dataset_internal.h | 4 +- cpp/src/arrow/dataset/discovery_test.cc | 2 +- cpp/src/arrow/dataset/expression.cc | 170 ++++++-------------- cpp/src/arrow/dataset/expression.h | 37 +++-- cpp/src/arrow/dataset/expression_internal.h | 43 ++--- cpp/src/arrow/dataset/expression_test.cc | 104 +++++------- cpp/src/arrow/dataset/file_base.cc | 9 +- cpp/src/arrow/dataset/file_base.h | 2 +- cpp/src/arrow/dataset/file_parquet.cc | 25 ++- cpp/src/arrow/dataset/file_parquet.h | 6 +- cpp/src/arrow/dataset/file_parquet_test.cc | 4 +- cpp/src/arrow/dataset/partition.cc | 4 +- cpp/src/arrow/dataset/partition_test.cc | 20 +-- cpp/src/arrow/dataset/scanner.cc | 6 +- cpp/src/arrow/dataset/scanner.h | 2 - cpp/src/arrow/dataset/scanner_internal.h | 11 +- cpp/src/arrow/dataset/test_util.h | 12 +- cpp/src/arrow/dataset/type_fwd.h | 2 - 20 files changed, 173 insertions(+), 310 deletions(-) diff --git a/cpp/src/arrow/dataset/dataset.cc b/cpp/src/arrow/dataset/dataset.cc index 56c901f959d..5da90e48ebc 100644 --- a/cpp/src/arrow/dataset/dataset.cc +++ b/cpp/src/arrow/dataset/dataset.cc @@ -111,11 +111,11 @@ Result Dataset::GetFragments() { return GetFragments(std::move(predicate)); } -Result Dataset::GetFragments(Expression2::BoundWithState predicate) { +Result Dataset::GetFragments(Expression2 predicate) { ARROW_ASSIGN_OR_RAISE( predicate, SimplifyWithGuarantee(std::move(predicate), partition_expression_)); - return predicate.first.IsSatisfiable() ? GetFragmentsImpl(std::move(predicate)) - : MakeEmptyIterator>(); + return predicate.IsSatisfiable() ? GetFragmentsImpl(std::move(predicate)) + : MakeEmptyIterator>(); } struct VectorRecordBatchGenerator : InMemoryDataset::RecordBatchGenerator { @@ -155,7 +155,7 @@ Result> InMemoryDataset::ReplaceSchema( return std::make_shared(std::move(schema), get_batches_); } -Result InMemoryDataset::GetFragmentsImpl(Expression2::BoundWithState) { +Result InMemoryDataset::GetFragmentsImpl(Expression2) { auto schema = this->schema(); auto create_fragment = @@ -196,8 +196,7 @@ Result> UnionDataset::ReplaceSchema( new UnionDataset(std::move(schema), std::move(children))); } -Result UnionDataset::GetFragmentsImpl( - Expression2::BoundWithState predicate) { +Result UnionDataset::GetFragmentsImpl(Expression2 predicate) { return GetFragmentsFromDatasets(children_, predicate); } diff --git a/cpp/src/arrow/dataset/dataset.h b/cpp/src/arrow/dataset/dataset.h index 35c91b79fed..88782831670 100644 --- a/cpp/src/arrow/dataset/dataset.h +++ b/cpp/src/arrow/dataset/dataset.h @@ -122,7 +122,7 @@ class ARROW_DS_EXPORT Dataset : public std::enable_shared_from_this { Result> NewScan(); /// \brief GetFragments returns an iterator of Fragments given a predicate. - Result GetFragments(Expression2::BoundWithState predicate); + Result GetFragments(Expression2 predicate); Result GetFragments(); const std::shared_ptr& schema() const { return schema_; } @@ -148,8 +148,7 @@ class ARROW_DS_EXPORT Dataset : public std::enable_shared_from_this { Dataset(std::shared_ptr schema, Expression2 partition_expression); - virtual Result GetFragmentsImpl( - Expression2::BoundWithState predicate) = 0; + virtual Result GetFragmentsImpl(Expression2 predicate) = 0; std::shared_ptr schema_; Expression2 partition_expression_ = literal(true); @@ -182,7 +181,7 @@ class ARROW_DS_EXPORT InMemoryDataset : public Dataset { protected: Result GetFragmentsImpl( - Expression2::BoundWithState predicate) override; + Expression2 predicate) override; std::shared_ptr get_batches_; }; @@ -207,7 +206,7 @@ class ARROW_DS_EXPORT UnionDataset : public Dataset { protected: Result GetFragmentsImpl( - Expression2::BoundWithState predicate) override; + Expression2 predicate) override; explicit UnionDataset(std::shared_ptr schema, DatasetVector children) : Dataset(std::move(schema)), children_(std::move(children)) {} diff --git a/cpp/src/arrow/dataset/dataset_internal.h b/cpp/src/arrow/dataset/dataset_internal.h index a6ee8a117bc..9e070192cd0 100644 --- a/cpp/src/arrow/dataset/dataset_internal.h +++ b/cpp/src/arrow/dataset/dataset_internal.h @@ -35,8 +35,8 @@ namespace dataset { /// \brief GetFragmentsFromDatasets transforms a vector into a /// flattened FragmentIterator. -inline Result GetFragmentsFromDatasets( - const DatasetVector& datasets, Expression2::BoundWithState predicate) { +inline Result GetFragmentsFromDatasets(const DatasetVector& datasets, + Expression2 predicate) { // Iterator auto datasets_it = MakeVectorIterator(datasets); diff --git a/cpp/src/arrow/dataset/discovery_test.cc b/cpp/src/arrow/dataset/discovery_test.cc index cd0ba2bee4d..219ee070fdc 100644 --- a/cpp/src/arrow/dataset/discovery_test.cc +++ b/cpp/src/arrow/dataset/discovery_test.cc @@ -375,7 +375,7 @@ TEST_F(FileSystemDatasetFactoryTest, FilenameNotPartOfPartitions) { ASSERT_OK_AND_ASSIGN(auto fragment_it, dataset->GetFragments()); for (const auto& maybe_fragment : fragment_it) { ASSERT_OK_AND_ASSIGN(auto fragment, maybe_fragment); - EXPECT_EQ(fragment->partition_expression(), expected.first); + EXPECT_EQ(fragment->partition_expression(), expected); } } diff --git a/cpp/src/arrow/dataset/expression.cc b/cpp/src/arrow/dataset/expression.cc index c762cb89e0f..8f9a3ce5663 100644 --- a/cpp/src/arrow/dataset/expression.cc +++ b/cpp/src/arrow/dataset/expression.cc @@ -22,7 +22,6 @@ #include "arrow/chunked_array.h" #include "arrow/compute/exec_internal.h" -#include "arrow/compute/registry.h" #include "arrow/dataset/expression_internal.h" #include "arrow/dataset/filter.h" #include "arrow/io/memory.h" @@ -379,67 +378,29 @@ Result> InitKernelState( return std::move(kernel_state); } -Result> CloneExpressionState( - const ExpressionState& state, compute::ExecContext* exec_context) { - auto clone = std::make_shared(); - clone->kernel_states.reserve(state.kernel_states.size()); - - for (const auto& sub_state : state.kernel_states) { - auto call = CallNotNull(sub_state.first); - - if (KernelStateIsImmutable(call->function)) { - // The kernel's state is immutable so it's safe to just - // share a pointer between threads - clone->kernel_states.insert(sub_state); - continue; - } - - // The kernel's state must be re-initialized. - ARROW_ASSIGN_OR_RAISE(auto kernel_state, InitKernelState(*call, exec_context)); - clone->kernel_states.emplace(sub_state.first, std::move(kernel_state)); - } - - return clone; -} - -Result Expression2::Bind( - ValueDescr in, compute::ExecContext* exec_context) const { +Result Expression2::Bind(ValueDescr in, + compute::ExecContext* exec_context) const { if (exec_context == nullptr) { compute::ExecContext exec_context; return Bind(std::move(in), &exec_context); } - BoundWithState ret{*this, std::make_shared()}; - - if (literal()) return ret; + if (literal()) return *this; if (auto ref = field_ref()) { ARROW_ASSIGN_OR_RAISE(auto field, ref->GetOneOrNone(*in.type)); - ret.first.descr_ = - field ? ValueDescr{field->type(), in.shape} : ValueDescr::Scalar(null()); - return ret; + auto out = *this; + out.descr_ = field ? ValueDescr{field->type(), in.shape} : ValueDescr::Scalar(null()); + return out; } auto bound_call = *CallNotNull(*this); - std::shared_ptr function; - if (bound_call.function == "cast") { - // XXX this special case is strange; why not make "cast" a ScalarFunction? - const auto& to_type = - checked_cast(*bound_call.options).to_type; - ARROW_ASSIGN_OR_RAISE(function, compute::GetCastFunction(to_type)); - } else { - ARROW_ASSIGN_OR_RAISE( - function, exec_context->func_registry()->GetFunction(bound_call.function)); - } + ARROW_ASSIGN_OR_RAISE(auto function, GetFunction(bound_call, exec_context)); bound_call.function_kind = function->kind(); - auto state = std::make_shared(); - for (size_t i = 0; i < bound_call.arguments.size(); ++i) { - std::shared_ptr argument_state; - ARROW_ASSIGN_OR_RAISE(std::tie(bound_call.arguments[i], argument_state), - bound_call.arguments[i].Bind(in, exec_context)); - state->MoveFrom(argument_state.get()); + for (auto&& argument : bound_call.arguments) { + ARROW_ASSIGN_OR_RAISE(argument, argument.Bind(in, exec_context)); } if (RequriesDictionaryTransparency(bound_call)) { @@ -447,7 +408,7 @@ Result Expression2::Bind( } auto descrs = GetDescriptors(bound_call.arguments); - for (auto& descr : descrs) { + for (auto&& descr : descrs) { if (RequriesDictionaryTransparency(bound_call)) { RETURN_NOT_OK(EnsureNotDictionary(&descr)); } @@ -456,29 +417,26 @@ Result Expression2::Bind( ARROW_ASSIGN_OR_RAISE(bound_call.kernel, function->DispatchExact(descrs)); compute::KernelContext kernel_context(exec_context); - ARROW_ASSIGN_OR_RAISE(auto kernel_state, InitKernelState(bound_call, exec_context)); - kernel_context.SetState(kernel_state.get()); + ARROW_ASSIGN_OR_RAISE(bound_call.kernel_state, + InitKernelState(bound_call, exec_context)); + kernel_context.SetState(bound_call.kernel_state.get()); ARROW_ASSIGN_OR_RAISE(auto descr, bound_call.kernel->signature->out_type().Resolve( &kernel_context, descrs)); - Expression2 bound(std::make_shared(std::move(bound_call)), std::move(descr)); - - state->kernel_states.emplace(bound, std::move(kernel_state)); - return BoundWithState{std::move(bound), std::move(state)}; + return Expression2(std::make_shared(std::move(bound_call)), std::move(descr)); } -Result Expression2::Bind( - const Schema& in_schema, compute::ExecContext* exec_context) const { +Result Expression2::Bind(const Schema& in_schema, + compute::ExecContext* exec_context) const { return Bind(ValueDescr::Array(struct_(in_schema.fields())), exec_context); } -Result ExecuteScalarExpression(const Expression2& expr, ExpressionState* state, - const Datum& input, +Result ExecuteScalarExpression(const Expression2& expr, const Datum& input, compute::ExecContext* exec_context) { if (exec_context == nullptr) { compute::ExecContext exec_context; - return ExecuteScalarExpression(expr, state, input, &exec_context); + return ExecuteScalarExpression(expr, input, &exec_context); } if (!expr.IsBound()) { @@ -510,8 +468,8 @@ Result ExecuteScalarExpression(const Expression2& expr, ExpressionState* std::vector arguments(call->arguments.size()); for (size_t i = 0; i < arguments.size(); ++i) { - ARROW_ASSIGN_OR_RAISE(arguments[i], ExecuteScalarExpression(call->arguments[i], state, - input, exec_context)); + ARROW_ASSIGN_OR_RAISE( + arguments[i], ExecuteScalarExpression(call->arguments[i], input, exec_context)); if (RequriesDictionaryTransparency(*call)) { RETURN_NOT_OK(EnsureNotDictionary(&arguments[i])); @@ -521,7 +479,7 @@ Result ExecuteScalarExpression(const Expression2& expr, ExpressionState* auto executor = compute::detail::KernelExecutor::MakeScalar(); compute::KernelContext kernel_context(exec_context); - kernel_context.SetState(state->Get(expr)); + kernel_context.SetState(call->kernel_state.get()); auto kernel = call->kernel; auto descrs = GetDescriptors(arguments); @@ -601,7 +559,7 @@ std::vector FieldsInExpression(const Expression2& expr) { } template -Result Modify(Expression2 expr, ExpressionState* state, const PreVisit& pre, +Result Modify(Expression2 expr, const PreVisit& pre, const PostVisitCall& post_call) { ARROW_ASSIGN_OR_RAISE(expr, Result(pre(std::move(expr)))); @@ -613,7 +571,7 @@ Result Modify(Expression2 expr, ExpressionState* state, const PreVi auto modified_argument = modified_call.arguments.begin(); for (const auto& argument : call->arguments) { - ARROW_ASSIGN_OR_RAISE(*modified_argument, Modify(argument, state, pre, post_call)); + ARROW_ASSIGN_OR_RAISE(*modified_argument, Modify(argument, pre, post_call)); if (!Identical(*modified_argument, argument)) { at_least_one_modified = true; @@ -626,45 +584,24 @@ Result Modify(Expression2 expr, ExpressionState* state, const PreVi auto modified_expr = Expression2( std::make_shared(std::move(modified_call)), expr.descr()); - // if expr had associated kernel state, associate it with modified_expr - state->Replace(expr, modified_expr); return post_call(std::move(modified_expr), &expr); } return post_call(std::move(expr), nullptr); } -template -Result Modify(Expression2::BoundWithState bound, - const PreVisit& pre, - const PostVisitCall& post_call) { - DCHECK(bound.first.IsBound()); - - auto expr = std::move(bound.first); - auto state = bound.second.get(); - - ARROW_ASSIGN_OR_RAISE(expr, Modify(std::move(expr), state, pre, post_call)); - - bound.first = std::move(expr); - - return bound; -} - -Result FoldConstants(Expression2::BoundWithState bound) { - bound.second = std::make_shared(*bound.second); - auto state = bound.second.get(); +Result FoldConstants(Expression2 expr) { return Modify( - std::move(bound), [](Expression2 expr) { return expr; }, - [state](Expression2 expr, ...) -> Result { + std::move(expr), [](Expression2 expr) { return expr; }, + [](Expression2 expr, ...) -> Result { auto call = CallNotNull(expr); if (std::all_of(call->arguments.begin(), call->arguments.end(), [](const Expression2& argument) { return argument.literal(); })) { // all arguments are literal; we can evaluate this subexpression *now* static const Datum ignored_input; ARROW_ASSIGN_OR_RAISE(Datum constant, - ExecuteScalarExpression(expr, state, ignored_input)); + ExecuteScalarExpression(expr, ignored_input)); - state->Drop(expr); return literal(std::move(constant)); } @@ -676,7 +613,6 @@ Result FoldConstants(Expression2::BoundWithState bo // to null *now* if any of their inputs is a null literal for (const auto& argument : call->arguments) { if (argument.IsNullLiteral()) { - state->Drop(expr); return argument; } } @@ -722,12 +658,8 @@ inline std::vector GuaranteeConjunctionMembers( Status ExtractKnownFieldValuesImpl( std::vector* conjunction_members, std::unordered_map* known_values) { - { - auto empty_state = std::make_shared(); - for (auto&& member : *conjunction_members) { - ARROW_ASSIGN_OR_RAISE(std::tie(member, std::ignore), - Canonicalize({std::move(member), empty_state})); - } + for (auto&& member : *conjunction_members) { + ARROW_ASSIGN_OR_RAISE(member, Canonicalize({std::move(member)})); } auto unconsumed_end = @@ -780,12 +712,11 @@ Result> ExtractKnownFieldVal return known_values; } -Result ReplaceFieldsWithKnownValues( +Result ReplaceFieldsWithKnownValues( const std::unordered_map& known_values, - Expression2::BoundWithState bound) { - bound.second = std::make_shared(*bound.second); + Expression2 expr) { return Modify( - std::move(bound), + std::move(expr), [&known_values](Expression2 expr) { if (auto ref = expr.field_ref()) { auto it = known_values.find(*ref); @@ -807,11 +738,10 @@ inline bool IsBinaryAssociativeCommutative(const Expression2::Call& call) { return it != binary_associative_commutative.end(); } -Result Canonicalize(Expression2::BoundWithState bound, - compute::ExecContext* exec_context) { +Result Canonicalize(Expression2 expr, compute::ExecContext* exec_context) { if (exec_context == nullptr) { compute::ExecContext exec_context; - return Canonicalize(std::move(bound), &exec_context); + return Canonicalize(std::move(expr), &exec_context); } // If potentially reconstructing more deeply than a call's immediate arguments @@ -830,7 +760,7 @@ Result Canonicalize(Expression2::BoundWithState bou } AlreadyCanonicalized; return Modify( - std::move(bound), + std::move(expr), [&AlreadyCanonicalized, exec_context](Expression2 expr) -> Result { auto call = expr.call(); if (!call) return expr; @@ -906,10 +836,10 @@ Result Canonicalize(Expression2::BoundWithState bou [](Expression2 expr, ...) { return expr; }); } -Result DirectComparisonSimplification( - Expression2::BoundWithState bound, const Expression2::Call& guarantee) { +Result DirectComparisonSimplification(Expression2 expr, + const Expression2::Call& guarantee) { return Modify( - std::move(bound), [](Expression2 expr) { return expr; }, + std::move(expr), [](Expression2 expr) { return expr; }, [&guarantee](Expression2 expr, ...) -> Result { auto call = expr.call(); if (!call) return expr; @@ -964,8 +894,8 @@ Result DirectComparisonSimplification( }); } -Result SimplifyWithGuarantee( - Expression2::BoundWithState bound, const Expression2& guaranteed_true_predicate) { +Result SimplifyWithGuarantee(Expression2 expr, + const Expression2& guaranteed_true_predicate) { if (!guaranteed_true_predicate.IsBound()) { return Status::Invalid("guaranteed_true_predicate was not bound"); } @@ -975,29 +905,29 @@ Result SimplifyWithGuarantee( std::unordered_map known_values; RETURN_NOT_OK(ExtractKnownFieldValuesImpl(&conjunction_members, &known_values)); - ARROW_ASSIGN_OR_RAISE(bound, - ReplaceFieldsWithKnownValues(known_values, std::move(bound))); + ARROW_ASSIGN_OR_RAISE(expr, + ReplaceFieldsWithKnownValues(known_values, std::move(expr))); - auto CanonicalizeAndFoldConstants = [&bound] { - ARROW_ASSIGN_OR_RAISE(bound, Canonicalize(std::move(bound))); - ARROW_ASSIGN_OR_RAISE(bound, FoldConstants(std::move(bound))); + auto CanonicalizeAndFoldConstants = [&expr] { + ARROW_ASSIGN_OR_RAISE(expr, Canonicalize(std::move(expr))); + ARROW_ASSIGN_OR_RAISE(expr, FoldConstants(std::move(expr))); return Status::OK(); }; RETURN_NOT_OK(CanonicalizeAndFoldConstants()); for (const auto& guarantee : conjunction_members) { if (Comparison::Get(guarantee) && guarantee.call()->arguments[1].literal()) { - ARROW_ASSIGN_OR_RAISE(auto simplified, DirectComparisonSimplification( - bound, *CallNotNull(guarantee))); + ARROW_ASSIGN_OR_RAISE( + auto simplified, DirectComparisonSimplification(expr, *CallNotNull(guarantee))); - if (Identical(simplified.first, bound.first)) continue; + if (Identical(simplified, expr)) continue; - bound = std::move(simplified); + expr = std::move(simplified); RETURN_NOT_OK(CanonicalizeAndFoldConstants()); } } - return bound; + return expr; } Result> Serialize(const Expression2& expr) { diff --git a/cpp/src/arrow/dataset/expression.h b/cpp/src/arrow/dataset/expression.h index f452e54b7c2..16c9e8a96f6 100644 --- a/cpp/src/arrow/dataset/expression.h +++ b/cpp/src/arrow/dataset/expression.h @@ -54,7 +54,9 @@ class ARROW_DS_EXPORT Expression2 { // post-Bind properties: const compute::Kernel* kernel = NULLPTR; - compute::Function::Kind function_kind; + compute::Function::Kind + function_kind; // XXX give Kernel a non-owning pointer to its Function + std::shared_ptr kernel_state; }; std::string ToString() const; @@ -67,10 +69,18 @@ class ARROW_DS_EXPORT Expression2 { /// Bind this expression to the given input type, looking up Kernels and field types. /// Some expression simplification may be performed and implicit casts will be inserted. /// Any state necessary for execution will be initialized and returned. - using BoundWithState = std::pair>; - Result Bind(ValueDescr in, compute::ExecContext* = NULLPTR) const; - Result Bind(const Schema& in_schema, - compute::ExecContext* = NULLPTR) const; + Result Bind(ValueDescr in, compute::ExecContext* = NULLPTR) const; + Result Bind(const Schema& in_schema, + compute::ExecContext* = NULLPTR) const; + + // XXX someday + // Clone all KernelState in this bound expression. If any function referenced by this + // expression has mutable KernelState, it is not safe to execute or apply simplification + // passes to it (or copies of it!) from multiple threads. Cloning state produces new + // KernelStates where necessary to ensure that Expression2s may be manipulated safely + // on multiple threads. + // Result CloneState() const; + // Status SetState(ExpressionState); /// Return true if all an expression's field references have explicit ValueDescr and all /// of its functions' kernels are looked up. @@ -175,27 +185,25 @@ Result> ExtractKnownFieldVal /// equivalent Expressions may result in different canonicalized expressions. /// TODO this could be a strong canonicalization ARROW_DS_EXPORT -Result Canonicalize(Expression2::BoundWithState, - compute::ExecContext* = NULLPTR); +Result Canonicalize(Expression2, compute::ExecContext* = NULLPTR); /// Simplify Expressions based on literal arguments (for example, add(null, x) will always /// be null so replace the call with a null literal). Includes early evaluation of all /// calls whose arguments are entirely literal. ARROW_DS_EXPORT -Result FoldConstants(Expression2::BoundWithState); +Result FoldConstants(Expression2); ARROW_DS_EXPORT -Result ReplaceFieldsWithKnownValues( - const std::unordered_map& known_values, - Expression2::BoundWithState); +Result ReplaceFieldsWithKnownValues( + const std::unordered_map& known_values, Expression2); /// Simplify an expression by replacing subexpressions based on a guarantee: /// a boolean expression which is guaranteed to evaluate to `true`. For example, this is /// used to remove redundant function calls from a filter expression or to replace a /// reference to a constant-value field with a literal. ARROW_DS_EXPORT -Result SimplifyWithGuarantee( - Expression2::BoundWithState, const Expression2& guaranteed_true_predicate); +Result SimplifyWithGuarantee(Expression2, + const Expression2& guaranteed_true_predicate); /// @} @@ -203,8 +211,7 @@ Result SimplifyWithGuarantee( /// Execute a scalar expression against the provided state and input Datum. This /// expression must be bound. -Result ExecuteScalarExpression(const Expression2&, ExpressionState*, - const Datum& input, +Result ExecuteScalarExpression(const Expression2&, const Datum& input, compute::ExecContext* = NULLPTR); // Serialization diff --git a/cpp/src/arrow/dataset/expression_internal.h b/cpp/src/arrow/dataset/expression_internal.h index 548d36711fb..84b802937e0 100644 --- a/cpp/src/arrow/dataset/expression_internal.h +++ b/cpp/src/arrow/dataset/expression_internal.h @@ -22,6 +22,7 @@ #include #include "arrow/compute/api_vector.h" +#include "arrow/compute/registry.h" #include "arrow/record_batch.h" #include "arrow/table.h" #include "arrow/util/logging.h" @@ -71,38 +72,6 @@ inline std::vector GetDescriptors(const std::vector& values) return descrs; } -struct ARROW_DS_EXPORT ExpressionState { - std::unordered_map, - Expression2::Hash> - kernel_states; - - compute::KernelState* Get(const Expression2& expr) const { - auto it = kernel_states.find(expr); - if (it == kernel_states.end()) return nullptr; - return it->second.get(); - } - - void Replace(const Expression2& expr, const Expression2& replacement) { - auto it = kernel_states.find(expr); - if (it == kernel_states.end()) return; - - auto kernel_state = std::move(it->second); - kernel_states.erase(it); - kernel_states.emplace(replacement, std::move(kernel_state)); - } - - void Drop(const Expression2& expr) { - auto it = kernel_states.find(expr); - if (it == kernel_states.end()) return; - kernel_states.erase(it); - } - - void MoveFrom(ExpressionState* other) { - std::move(other->kernel_states.begin(), other->kernel_states.end(), - std::inserter(kernel_states, kernel_states.end())); - } -}; - struct FieldPathGetDatumImpl { template ()))> Result operator()(const std::shared_ptr& ptr) { @@ -418,5 +387,15 @@ struct FlattenedAssociativeChain { } }; +inline Result> GetFunction( + const Expression2::Call& call, compute::ExecContext* exec_context) { + if (call.function != "cast") { + return exec_context->func_registry()->GetFunction(call.function); + } + // XXX this special case is strange; why not make "cast" a ScalarFunction? + const auto& to_type = checked_cast(*call.options).to_type; + return compute::GetCastFunction(to_type); +} + } // namespace dataset } // namespace arrow diff --git a/cpp/src/arrow/dataset/expression_test.cc b/cpp/src/arrow/dataset/expression_test.cc index 6274147d208..8ea70adf01c 100644 --- a/cpp/src/arrow/dataset/expression_test.cc +++ b/cpp/src/arrow/dataset/expression_test.cc @@ -64,6 +64,9 @@ TEST(Expression2, Equality) { EXPECT_EQ(call("add", {literal(3), call("index_in", {field_ref("beta")})}), call("add", {literal(3), call("index_in", {field_ref("beta")})})); + + // FIXME options are not currently compared, so two cast exprs will compare equal + // regardless of differing target type } TEST(Expression2, Hash) { @@ -183,8 +186,7 @@ TEST(Expression2, BindFieldRef) { { auto expr = field_ref("alpha"); // binding a field_ref looks up that field's type in the input Schema - ASSERT_OK_AND_ASSIGN(std::tie(expr, std::ignore), - expr.Bind(Schema({field("alpha", int32())}))); + ASSERT_OK_AND_ASSIGN(expr, expr.Bind(Schema({field("alpha", int32())}))); EXPECT_EQ(expr.descr(), ValueDescr::Array(int32())); EXPECT_TRUE(expr.IsBound()); } @@ -192,7 +194,7 @@ TEST(Expression2, BindFieldRef) { { // if the field is not found, a null scalar will be emitted auto expr = field_ref("alpha"); - ASSERT_OK_AND_ASSIGN(std::tie(expr, std::ignore), expr.Bind(Schema({}))); + ASSERT_OK_AND_ASSIGN(expr, expr.Bind(Schema({}))); EXPECT_EQ(expr.descr(), ValueDescr::Scalar(null())); EXPECT_TRUE(expr.IsBound()); } @@ -208,7 +210,7 @@ TEST(Expression2, BindFieldRef) { { // referencing nested fields is supported auto expr = field_ref("a", "b"); - ASSERT_OK_AND_ASSIGN(std::tie(expr, std::ignore), + ASSERT_OK_AND_ASSIGN(expr, expr.Bind(Schema({field("a", struct_({field("b", int32())}))}))); EXPECT_EQ(expr.descr(), ValueDescr::Array(int32())); EXPECT_TRUE(expr.IsBound()); @@ -219,7 +221,7 @@ TEST(Expression2, BindCall) { auto expr = call("add", {field_ref("a"), field_ref("b")}); EXPECT_FALSE(expr.IsBound()); - ASSERT_OK_AND_ASSIGN(std::tie(expr, std::ignore), + ASSERT_OK_AND_ASSIGN(expr, expr.Bind(Schema({field("a", int32()), field("b", int32())}))); EXPECT_EQ(expr.descr(), ValueDescr::Array(int32())); EXPECT_TRUE(expr.IsBound()); @@ -234,7 +236,7 @@ TEST(Expression2, BindDictionaryTransparent) { EXPECT_FALSE(expr.IsBound()); ASSERT_OK_AND_ASSIGN( - std::tie(expr, std::ignore), + expr, expr.Bind(Schema({field("a", utf8()), field("b", dictionary(int32(), utf8()))}))); EXPECT_EQ(expr.descr(), ValueDescr::Array(boolean())); @@ -248,7 +250,7 @@ TEST(Expression2, BindNestedCall) { field_ref("d")})}); EXPECT_FALSE(expr.IsBound()); - ASSERT_OK_AND_ASSIGN(std::tie(expr, std::ignore), + ASSERT_OK_AND_ASSIGN(expr, expr.Bind(Schema({field("a", int32()), field("b", int32()), field("c", int32()), field("d", int32())}))); EXPECT_EQ(expr.descr(), ValueDescr::Array(int32())); @@ -259,9 +261,8 @@ TEST(Expression2, ExecuteFieldRef) { auto AssertRefIs = [](FieldRef ref, Datum in, Datum expected) { auto expr = field_ref(ref); - ASSERT_OK_AND_ASSIGN(std::tie(expr, std::ignore), expr.Bind(in.descr())); - ASSERT_OK_AND_ASSIGN(Datum actual, - ExecuteScalarExpression(expr, /*state=*/nullptr, in)); + ASSERT_OK_AND_ASSIGN(expr, expr.Bind(in.descr())); + ASSERT_OK_AND_ASSIGN(Datum actual, ExecuteScalarExpression(expr, in)); AssertDatumsEqual(actual, expected, /*verbose=*/true); }; @@ -295,7 +296,7 @@ Result NaiveExecuteScalarExpression(const Expression2& expr, const Datum& auto call = expr.call(); if (call == nullptr) { // already tested execution of field_ref, execution of literal is trivial - return ExecuteScalarExpression(expr, /*state=*/nullptr, input); + return ExecuteScalarExpression(expr, input); } std::vector arguments(call->arguments.size()); @@ -335,10 +336,9 @@ Result NaiveExecuteScalarExpression(const Expression2& expr, const Datum& } void AssertExecute(Expression2 expr, Datum in) { - std::shared_ptr state; - ASSERT_OK_AND_ASSIGN(std::tie(expr, state), expr.Bind(in.descr())); + ASSERT_OK_AND_ASSIGN(expr, expr.Bind(in.descr())); - ASSERT_OK_AND_ASSIGN(Datum actual, ExecuteScalarExpression(expr, state.get(), in)); + ASSERT_OK_AND_ASSIGN(Datum actual, ExecuteScalarExpression(expr, in)); ASSERT_OK_AND_ASSIGN(Datum expected, NaiveExecuteScalarExpression(expr, in)); @@ -400,26 +400,17 @@ TEST(Expression2, ExecuteDictionaryTransparent) { struct { void operator()(Expression2 expr, Expression2 expected) { - this->operator()(expr, expected, - [](Expression2::BoundWithState, Expression2::BoundWithState, - Expression2::BoundWithState) {}); - } + ASSERT_OK_AND_ASSIGN(expr, expr.Bind(*kBoringSchema)); + ASSERT_OK_AND_ASSIGN(expected, expected.Bind(*kBoringSchema)); - template - void operator()(Expression2 expr, Expression2 unbound_expected, - const ExtraExpectations& expect) { - ASSERT_OK_AND_ASSIGN(auto bound, expr.Bind(*kBoringSchema)); - ASSERT_OK_AND_ASSIGN(auto expected, unbound_expected.Bind(*kBoringSchema)); - ASSERT_OK_AND_ASSIGN(auto folded, FoldConstants(bound)); + ASSERT_OK_AND_ASSIGN(auto folded, FoldConstants(expr)); - EXPECT_EQ(folded.first, expected.first); + EXPECT_EQ(folded, expected); - if (folded.first == expr) { + if (folded == expr) { // no change -> must be identical - EXPECT_TRUE(Identical(folded.first, expr)); + EXPECT_TRUE(Identical(folded, expr)); } - - expect(bound, folded, expected); } } ExpectFoldsTo; @@ -476,15 +467,7 @@ TEST(Expression2, FoldConstants) { call("is_in", {call("add", {field_ref("i32"), call("multiply", {literal(2), literal(3)})})}, in_123), - call("is_in", {call("add", {field_ref("i32"), literal(6)})}, in_123), - [](Expression2::BoundWithState bound, Expression2::BoundWithState folded, - Expression2::BoundWithState) { - const compute::KernelState* state = bound.second->Get(bound.first); - const compute::KernelState* folded_state = folded.second->Get(folded.first); - EXPECT_EQ(folded_state, state) << "The kernel state associated with is_in (the " - "hash table for looking up membership) " - "must be associated with the folded is_in call"; - }); + call("is_in", {call("add", {field_ref("i32"), literal(6)})}, in_123)); } TEST(Expression2, FoldConstantsBoolean) { @@ -507,8 +490,7 @@ TEST(Expression2, ExtractKnownFieldValues) { struct { void operator()(Expression2 guarantee, std::unordered_map expected) { - ASSERT_OK_AND_ASSIGN(std::tie(guarantee, std::ignore), - guarantee.Bind(*kBoringSchema)); + ASSERT_OK_AND_ASSIGN(guarantee, guarantee.Bind(*kBoringSchema)); ASSERT_OK_AND_ASSIGN(auto actual, ExtractKnownFieldValues(guarantee)); EXPECT_THAT(actual, UnorderedElementsAreArray(expected)) << " guarantee: " << guarantee.ToString(); @@ -553,11 +535,11 @@ TEST(Expression2, ReplaceFieldsWithKnownValues) { ASSERT_OK_AND_ASSIGN(auto replaced, ReplaceFieldsWithKnownValues(known_values, bound)); - EXPECT_EQ(replaced.first, expected.first); + EXPECT_EQ(replaced, expected); - if (replaced.first == expr) { + if (replaced == expr) { // no change -> must be identical - EXPECT_TRUE(Identical(replaced.first, expr)); + EXPECT_TRUE(Identical(replaced, expr)); } }; @@ -598,11 +580,11 @@ struct { ASSERT_OK_AND_ASSIGN(auto expected, unbound_expected.Bind(*kBoringSchema)); ASSERT_OK_AND_ASSIGN(auto actual, Canonicalize(bound)); - EXPECT_EQ(actual.first, expected.first); + EXPECT_EQ(actual, expected); - if (actual.first == expr) { + if (actual == expr) { // no change -> must be identical - EXPECT_TRUE(Identical(actual.first, expr)); + EXPECT_TRUE(Identical(actual, expr)); } } } ExpectCanonicalizesTo; @@ -666,19 +648,17 @@ struct Simplify { void Expect(Expression2 unbound_expected) { ASSERT_OK_AND_ASSIGN(auto bound, expr.Bind(*kBoringSchema)); - ASSERT_OK_AND_ASSIGN(std::tie(guarantee, std::ignore), - guarantee.Bind(*kBoringSchema)); + ASSERT_OK_AND_ASSIGN(guarantee, guarantee.Bind(*kBoringSchema)); ASSERT_OK_AND_ASSIGN(auto simplified, SimplifyWithGuarantee(bound, guarantee)); ASSERT_OK_AND_ASSIGN(auto expected, unbound_expected.Bind(*kBoringSchema)); - EXPECT_EQ(simplified.first, expected.first) - << " original: " << expr.ToString() << "\n" - << " guarantee: " << guarantee.ToString() << "\n" - << (simplified.first == bound.first ? " (no change)\n" : ""); + EXPECT_EQ(simplified, expected) << " original: " << expr.ToString() << "\n" + << " guarantee: " << guarantee.ToString() << "\n" + << (simplified == bound ? " (no change)\n" : ""); - if (simplified.first == bound.first) { - EXPECT_TRUE(Identical(simplified.first, bound.first)); + if (simplified == bound) { + EXPECT_TRUE(Identical(simplified, bound)); } } void ExpectUnchanged() { Expect(expr); } @@ -781,10 +761,8 @@ TEST(Expression2, SingleComparisonGuarantees) { StructArray::Make({ArrayFromJSON(int32(), satisfying_i32[guarantee_op])}, {"i32"})); - std::shared_ptr state; - ASSERT_OK_AND_ASSIGN(std::tie(filter, state), filter.Bind(*kBoringSchema)); - ASSERT_OK_AND_ASSIGN(Datum evaluated, - ExecuteScalarExpression(filter, state.get(), input)); + ASSERT_OK_AND_ASSIGN(filter, filter.Bind(*kBoringSchema)); + ASSERT_OK_AND_ASSIGN(Datum evaluated, ExecuteScalarExpression(filter, input)); // ensure that the simplified filter is as simplified as it could be // (this is always possible for single comparisons) @@ -840,7 +818,7 @@ TEST(Expression2, SimplifyThenExecute) { ASSERT_OK_AND_ASSIGN(auto guarantee, equal(field_ref("f32"), literal(0.F)).Bind(*kBoringSchema)); - ASSERT_OK_AND_ASSIGN(bound, SimplifyWithGuarantee(bound, guarantee.first)); + ASSERT_OK_AND_ASSIGN(bound, SimplifyWithGuarantee(bound, guarantee)); auto input = RecordBatchFromJSON(kBoringSchema, R"([ {"i32": 0, "f32": -0.1}, @@ -851,8 +829,7 @@ TEST(Expression2, SimplifyThenExecute) { {"i32": 0, "f32": null}, {"i32": 0, "f32": 1.0} ])"); - ASSERT_OK_AND_ASSIGN(Datum evaluated, - ExecuteScalarExpression(bound.first, bound.second.get(), input)); + ASSERT_OK_AND_ASSIGN(Datum evaluated, ExecuteScalarExpression(bound, input)); } TEST(Expression2, Filter) { @@ -861,9 +838,8 @@ TEST(Expression2, Filter) { auto batch = RecordBatchFromJSON(s, batch_json); auto expected_mask = batch->column(0); - std::shared_ptr state; - ASSERT_OK_AND_ASSIGN(std::tie(filter, state), filter.Bind(*kBoringSchema)); - ASSERT_OK_AND_ASSIGN(Datum mask, ExecuteScalarExpression(filter, state.get(), batch)); + ASSERT_OK_AND_ASSIGN(filter, filter.Bind(*kBoringSchema)); + ASSERT_OK_AND_ASSIGN(Datum mask, ExecuteScalarExpression(filter, batch)); AssertDatumsEqual(expected_mask, mask); }; diff --git a/cpp/src/arrow/dataset/file_base.cc b/cpp/src/arrow/dataset/file_base.cc index f8350a0292c..dfc17a25812 100644 --- a/cpp/src/arrow/dataset/file_base.cc +++ b/cpp/src/arrow/dataset/file_base.cc @@ -137,15 +137,14 @@ std::string FileSystemDataset::ToString() const { return repr; } -Result FileSystemDataset::GetFragmentsImpl( - Expression2::BoundWithState predicate) { +Result FileSystemDataset::GetFragmentsImpl(Expression2 predicate) { FragmentVector fragments; for (const auto& fragment : fragments_) { ARROW_ASSIGN_OR_RAISE( auto simplified, SimplifyWithGuarantee(predicate, fragment->partition_expression())); - if (simplified.first.IsSatisfiable()) { + if (simplified.IsSatisfiable()) { fragments.push_back(fragment); } } @@ -323,8 +322,8 @@ Status FileSystemDataset::Write(const FileSystemDatasetWriteOptions& write_optio .Bind(*scanner->schema())); auto batch = std::move(groups.batches[i]); - ARROW_ASSIGN_OR_RAISE( - auto part, write_options.partitioning->Format(partition_expression.first)); + ARROW_ASSIGN_OR_RAISE(auto part, + write_options.partitioning->Format(partition_expression)); WriteQueue* queue; { diff --git a/cpp/src/arrow/dataset/file_base.h b/cpp/src/arrow/dataset/file_base.h index 1d0059ec089..f6def56b7de 100644 --- a/cpp/src/arrow/dataset/file_base.h +++ b/cpp/src/arrow/dataset/file_base.h @@ -234,7 +234,7 @@ class ARROW_DS_EXPORT FileSystemDataset : public Dataset { protected: Result GetFragmentsImpl( - Expression2::BoundWithState predicate) override; + Expression2 predicate) override; FileSystemDataset(std::shared_ptr schema, Expression2 root_partition, std::shared_ptr format, diff --git a/cpp/src/arrow/dataset/file_parquet.cc b/cpp/src/arrow/dataset/file_parquet.cc index f1e9ad48bb0..10c5b4a4bcf 100644 --- a/cpp/src/arrow/dataset/file_parquet.cc +++ b/cpp/src/arrow/dataset/file_parquet.cc @@ -298,8 +298,8 @@ Result ParquetFileFormat::ScanFile(std::shared_ptrmetadata() != nullptr) { - ARROW_ASSIGN_OR_RAISE(row_groups, parquet_fragment->FilterRowGroups( - {options->filter2, context->expression_state})); + ARROW_ASSIGN_OR_RAISE(row_groups, + parquet_fragment->FilterRowGroups(options->filter2)); pre_filtered = true; if (row_groups.empty()) MakeEmpty(); @@ -314,8 +314,8 @@ Result ParquetFileFormat::ScanFile(std::shared_ptrFilterRowGroups( - {options->filter2, context->expression_state})); + ARROW_ASSIGN_OR_RAISE(row_groups, + parquet_fragment->FilterRowGroups(options->filter2)); if (row_groups.empty()) MakeEmpty(); } @@ -457,8 +457,7 @@ Status ParquetFileFragment::SetMetadata( return Status::OK(); } -Result ParquetFileFragment::SplitByRowGroup( - Expression2::BoundWithState predicate) { +Result ParquetFileFragment::SplitByRowGroup(Expression2 predicate) { RETURN_NOT_OK(EnsureCompleteMetadata()); ARROW_ASSIGN_OR_RAISE(auto row_groups, FilterRowGroups(predicate)); @@ -476,8 +475,7 @@ Result ParquetFileFragment::SplitByRowGroup( return fragments; } -Result> ParquetFileFragment::Subset( - Expression2::BoundWithState predicate) { +Result> ParquetFileFragment::Subset(Expression2 predicate) { RETURN_NOT_OK(EnsureCompleteMetadata()); ARROW_ASSIGN_OR_RAISE(auto row_groups, FilterRowGroups(predicate)); return Subset(std::move(row_groups)); @@ -502,19 +500,18 @@ inline void FoldingAnd(Expression2* l, Expression2 r) { } } -Result> ParquetFileFragment::FilterRowGroups( - Expression2::BoundWithState predicate) { +Result> ParquetFileFragment::FilterRowGroups(Expression2 predicate) { auto lock = physical_schema_mutex_.Lock(); DCHECK_NE(metadata_, nullptr); ARROW_ASSIGN_OR_RAISE( predicate, SimplifyWithGuarantee(std::move(predicate), partition_expression_)); - if (!predicate.first.IsSatisfiable()) { + if (!predicate.IsSatisfiable()) { return std::vector{}; } - for (const FieldRef& ref : FieldsInExpression(predicate.first)) { + for (const FieldRef& ref : FieldsInExpression(predicate)) { ARROW_ASSIGN_OR_RAISE(auto path, ref.FindOneOrNone(*physical_schema_)); if (!path) continue; @@ -529,7 +526,7 @@ Result> ParquetFileFragment::FilterRowGroups( if (auto minmax = ColumnChunkStatisticsAsExpression(schema_field, *row_group_metadata)) { FoldingAnd(&statistics_expressions_[i], std::move(*minmax)); - ARROW_ASSIGN_OR_RAISE(std::tie(statistics_expressions_[i], std::ignore), + ARROW_ASSIGN_OR_RAISE(statistics_expressions_[i], statistics_expressions_[i].Bind(*physical_schema_)); } @@ -541,7 +538,7 @@ Result> ParquetFileFragment::FilterRowGroups( for (size_t i = 0; i < row_groups_->size(); ++i) { ARROW_ASSIGN_OR_RAISE(auto row_group_predicate, SimplifyWithGuarantee(predicate, statistics_expressions_[i])); - if (row_group_predicate.first.IsSatisfiable()) { + if (row_group_predicate.IsSatisfiable()) { row_groups.push_back(row_groups_->at(i)); } } diff --git a/cpp/src/arrow/dataset/file_parquet.h b/cpp/src/arrow/dataset/file_parquet.h index 9de19186ad3..9eab3fe9449 100644 --- a/cpp/src/arrow/dataset/file_parquet.h +++ b/cpp/src/arrow/dataset/file_parquet.h @@ -150,7 +150,7 @@ class ARROW_DS_EXPORT ParquetFileFormat : public FileFormat { /// significant performance boost when scanning high latency file systems. class ARROW_DS_EXPORT ParquetFileFragment : public FileFragment { public: - Result SplitByRowGroup(Expression2::BoundWithState predicate); + Result SplitByRowGroup(Expression2 predicate); /// \brief Return the RowGroups selected by this fragment. const std::vector& row_groups() const { @@ -166,7 +166,7 @@ class ARROW_DS_EXPORT ParquetFileFragment : public FileFragment { Status EnsureCompleteMetadata(parquet::arrow::FileReader* reader = NULLPTR); /// \brief Return fragment which selects a filtered subset of this fragment's RowGroups. - Result> Subset(Expression2::BoundWithState predicate); + Result> Subset(Expression2 predicate); Result> Subset(std::vector row_group_ids); private: @@ -185,7 +185,7 @@ class ARROW_DS_EXPORT ParquetFileFragment : public FileFragment { } // Return a filtered subset of row group indices. - Result> FilterRowGroups(Expression2::BoundWithState predicate); + Result> FilterRowGroups(Expression2 predicate); ParquetFileFormat& parquet_format_; diff --git a/cpp/src/arrow/dataset/file_parquet_test.cc b/cpp/src/arrow/dataset/file_parquet_test.cc index 5a6bbddaba3..436826c2971 100644 --- a/cpp/src/arrow/dataset/file_parquet_test.cc +++ b/cpp/src/arrow/dataset/file_parquet_test.cc @@ -161,8 +161,7 @@ class TestParquetFileFormat : public ArrowParquetWriterMixin { } void SetFilter(Expression2 filter) { - ASSERT_OK_AND_ASSIGN(std::tie(opts_->filter2, ctx_->expression_state), - filter.Bind(*schema_)); + ASSERT_OK_AND_ASSIGN(opts_->filter2, filter.Bind(*schema_)); } std::shared_ptr SingleBatch(Fragment* fragment) { @@ -196,7 +195,6 @@ class TestParquetFileFormat : public ArrowParquetWriterMixin { Expression2 filter) { schema_ = opts_->schema(); ASSERT_OK_AND_ASSIGN(auto bound, filter.Bind(*schema_)); - ctx_->expression_state = bound.second; auto parquet_fragment = checked_pointer_cast(fragment); ASSERT_OK_AND_ASSIGN(auto fragments, parquet_fragment->SplitByRowGroup(bound)) diff --git a/cpp/src/arrow/dataset/partition.cc b/cpp/src/arrow/dataset/partition.cc index b6178ddeea6..4e7cc062997 100644 --- a/cpp/src/arrow/dataset/partition.cc +++ b/cpp/src/arrow/dataset/partition.cc @@ -186,9 +186,7 @@ Result KeyValuePartitioning::Parse(const std::string& path) const { expressions.push_back(std::move(expr)); } - auto parsed = and_(std::move(expressions)); - ARROW_ASSIGN_OR_RAISE(std::tie(parsed, std::ignore), parsed.Bind(*schema_)) - return parsed; + return and_(std::move(expressions)).Bind(*schema_); } Result KeyValuePartitioning::Format(const Expression2& expr) const { diff --git a/cpp/src/arrow/dataset/partition_test.cc b/cpp/src/arrow/dataset/partition_test.cc index 37e65cacac4..57a1693b953 100644 --- a/cpp/src/arrow/dataset/partition_test.cc +++ b/cpp/src/arrow/dataset/partition_test.cc @@ -45,21 +45,20 @@ class TestPartitioning : public ::testing::Test { void AssertParse(const std::string& path, Expression2 expected) { ASSERT_OK_AND_ASSIGN(auto parsed, partitioning_->Parse(path)); - ASSERT_OK_AND_ASSIGN(std::tie(expected, std::ignore), - expected.Bind(*partitioning_->schema())); + ASSERT_OK_AND_ASSIGN(expected, expected.Bind(*partitioning_->schema())); ASSERT_EQ(parsed, expected); } template void AssertFormatError(Expression2 expr) { - ASSERT_OK_AND_ASSIGN(std::tie(expr, std::ignore), expr.Bind(*written_schema_)); + ASSERT_OK_AND_ASSIGN(expr, expr.Bind(*written_schema_)); ASSERT_EQ(partitioning_->Format(expr).status().code(), code); } void AssertFormat(Expression2 expr, const std::string& expected) { // formatted partition expressions are bound to the schema of the dataset being // written - ASSERT_OK_AND_ASSIGN(std::tie(expr, std::ignore), expr.Bind(*written_schema_)); + ASSERT_OK_AND_ASSIGN(expr, expr.Bind(*written_schema_)); ASSERT_OK_AND_ASSIGN(auto formatted, partitioning_->Format(expr)); ASSERT_EQ(formatted, expected); @@ -70,7 +69,7 @@ class TestPartitioning : public ::testing::Test { ASSERT_OK_AND_ASSIGN(auto bound, roundtripped.Bind(*partitioning_->schema())); ASSERT_OK_AND_ASSIGN(auto simplified, SimplifyWithGuarantee(bound, expr)); - ASSERT_EQ(simplified.first, literal(true)); + ASSERT_EQ(simplified, literal(true)); } void AssertInspect(const std::vector& paths, @@ -366,9 +365,7 @@ TEST_F(TestPartitioning, EtlThenHive) { fs::internal::JoinAbstractPath(etl_segments_end, alphabeta_segments_end); ARROW_ASSIGN_OR_RAISE(auto alphabeta_expr, alphabeta_part.Parse(alphabeta_path)); - auto expr = and_(etl_expr, alphabeta_expr); - ARROW_ASSIGN_OR_RAISE(std::tie(expr, std::ignore), expr.Bind(*schm)); - return expr; + return and_(etl_expr, alphabeta_expr).Bind(*schm); }); AssertParse("/1999/12/31/00/alpha=0/beta=3.25", @@ -411,8 +408,7 @@ TEST_F(TestPartitioning, Set) { auto is_in_expr = call("is_in", {field_ref(matches[1])}, compute::SetLookupOptions{ints(set), true}); - ARROW_ASSIGN_OR_RAISE(std::tie(is_in_expr, std::ignore), - is_in_expr.Bind(*schm)); + ARROW_ASSIGN_OR_RAISE(is_in_expr, is_in_expr.Bind(*schm)); subexpressions.push_back(is_in_expr); } @@ -457,9 +453,7 @@ class RangePartitioning : public Partitioning { max_cmp(field_ref(key->name), literal(max)))); } - Expression2 expr; - ARROW_ASSIGN_OR_RAISE(std::tie(expr, std::ignore), and_(ranges).Bind(*schema_)); - return expr; + return and_(ranges).Bind(*schema_); } static Status DoRegex(const std::string& segment, std::smatch* matches) { diff --git a/cpp/src/arrow/dataset/scanner.cc b/cpp/src/arrow/dataset/scanner.cc index 67df665733b..e3629bfd8cc 100644 --- a/cpp/src/arrow/dataset/scanner.cc +++ b/cpp/src/arrow/dataset/scanner.cc @@ -72,8 +72,7 @@ Result Scanner::GetFragments() { // Transform Datasets in a flat Iterator. This // iterator is lazily constructed, i.e. Dataset::GetFragments is // not invoked until a Fragment is requested. - return GetFragmentsFromDatasets( - {dataset_}, {scan_options_->filter2, scan_context_->expression_state}); + return GetFragmentsFromDatasets({dataset_}, scan_options_->filter2); } Result Scanner::Scan() { @@ -124,8 +123,7 @@ Status ScannerBuilder::Project(std::vector columns) { } Status ScannerBuilder::Filter(const Expression2& filter) { - ARROW_ASSIGN_OR_RAISE(std::tie(scan_options_->filter2, scan_context_->expression_state), - filter.Bind(*schema())); + ARROW_ASSIGN_OR_RAISE(scan_options_->filter2, filter.Bind(*schema())); return Status::OK(); } diff --git a/cpp/src/arrow/dataset/scanner.h b/cpp/src/arrow/dataset/scanner.h index 27c2e3e9ed7..e2f9892af38 100644 --- a/cpp/src/arrow/dataset/scanner.h +++ b/cpp/src/arrow/dataset/scanner.h @@ -48,8 +48,6 @@ struct ARROW_DS_EXPORT ScanContext { /// Return a threaded or serial TaskGroup according to use_threads. std::shared_ptr TaskGroup() const; - - std::shared_ptr expression_state; }; class ARROW_DS_EXPORT ScanOptions { diff --git a/cpp/src/arrow/dataset/scanner_internal.h b/cpp/src/arrow/dataset/scanner_internal.h index f01ced8396b..686be1e6dfb 100644 --- a/cpp/src/arrow/dataset/scanner_internal.h +++ b/cpp/src/arrow/dataset/scanner_internal.h @@ -31,15 +31,13 @@ namespace arrow { namespace dataset { -inline RecordBatchIterator FilterRecordBatch(RecordBatchIterator it, - Expression2::BoundWithState filter, +inline RecordBatchIterator FilterRecordBatch(RecordBatchIterator it, Expression2 filter, MemoryPool* pool) { return MakeMaybeMapIterator( [=](std::shared_ptr in) -> Result> { compute::ExecContext exec_context{pool}; ARROW_ASSIGN_OR_RAISE(Datum mask, - ExecuteScalarExpression(filter.first, filter.second.get(), - Datum(in), &exec_context)); + ExecuteScalarExpression(filter, Datum(in), &exec_context)); if (mask.is_scalar()) { const auto& mask_scalar = mask.scalar_as(); @@ -83,9 +81,8 @@ class FilterAndProjectScanTask : public ScanTask { Result Execute() override { ARROW_ASSIGN_OR_RAISE(auto it, task_->Execute()); - ARROW_ASSIGN_OR_RAISE( - Expression2::BoundWithState simplified_filter, - SimplifyWithGuarantee({filter_, context_->expression_state}, partition_)); + ARROW_ASSIGN_OR_RAISE(Expression2 simplified_filter, + SimplifyWithGuarantee(filter_, partition_)); RecordBatchIterator filter_it = FilterRecordBatch(std::move(it), simplified_filter, context_->pool); diff --git a/cpp/src/arrow/dataset/test_util.h b/cpp/src/arrow/dataset/test_util.h index 70c9b7548e8..ea39fef3eec 100644 --- a/cpp/src/arrow/dataset/test_util.h +++ b/cpp/src/arrow/dataset/test_util.h @@ -226,8 +226,7 @@ class DatasetFixtureMixin : public ::testing::Test { } void SetFilter(Expression2 filter) { - ASSERT_OK_AND_ASSIGN(std::tie(options_->filter2, ctx_->expression_state), - filter.Bind(*schema_)); + ASSERT_OK_AND_ASSIGN(options_->filter2, filter.Bind(*schema_)); } std::shared_ptr schema_; @@ -381,13 +380,13 @@ struct MakeFileSystemDatasetMixin { continue; } - ASSERT_OK_AND_ASSIGN(std::tie(partitions[i], std::ignore), partitions[i].Bind(*s)); + ASSERT_OK_AND_ASSIGN(partitions[i], partitions[i].Bind(*s)); ASSERT_OK_AND_ASSIGN(auto fragment, format->MakeFragment({info, fs_}, partitions[i])); fragments.push_back(std::move(fragment)); } - ASSERT_OK_AND_ASSIGN(std::tie(root_partition, std::ignore), root_partition.Bind(*s)); + ASSERT_OK_AND_ASSIGN(root_partition, root_partition.Bind(*s)); ASSERT_OK_AND_ASSIGN(dataset_, FileSystemDataset::Make(s, root_partition, format, fs_, std::move(fragments))); } @@ -438,7 +437,7 @@ void AssertFragmentsHavePartitionExpressions(std::shared_ptr dataset, std::vector expected) { ASSERT_OK_AND_ASSIGN(auto fragment_it, dataset->GetFragments()); for (auto& expr : expected) { - ASSERT_OK_AND_ASSIGN(std::tie(expr, std::ignore), expr.Bind(*dataset->schema())); + ASSERT_OK_AND_ASSIGN(expr, expr.Bind(*dataset->schema())); } // Ordering is not guaranteed. EXPECT_THAT(PartitionExpressionsOf(IteratorToVector(std::move(fragment_it))), @@ -591,9 +590,6 @@ class WriteFileSystemDatasetMixin : public MakeFileSystemDatasetMixin { ASSERT_OK_AND_ASSIGN(dataset_, factory->Finish()); scan_options_ = ScanOptions::Make(source_schema_); - ASSERT_OK_AND_ASSIGN( - std::tie(scan_options_->filter2, scan_context_->expression_state), - literal(true).Bind(*source_schema_)); } void SetWriteOptions(std::shared_ptr file_write_options) { diff --git a/cpp/src/arrow/dataset/type_fwd.h b/cpp/src/arrow/dataset/type_fwd.h index e4bd322ba10..81441a18166 100644 --- a/cpp/src/arrow/dataset/type_fwd.h +++ b/cpp/src/arrow/dataset/type_fwd.h @@ -69,8 +69,6 @@ class ParquetFileWriter; class ParquetFileWriteOptions; class Expression; - -struct ExpressionState; class Expression2; class Partitioning; From e9dffce7e84337303a51a562d0d038a0ed1d4fdb Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Thu, 10 Dec 2020 12:21:11 -0500 Subject: [PATCH 04/38] repair implicit casts --- cpp/src/arrow/compute/api_scalar.h | 4 +- cpp/src/arrow/compute/cast.h | 24 +- cpp/src/arrow/compute/exec.cc | 3 +- .../compute/kernels/scalar_cast_internal.cc | 31 ++- .../compute/kernels/scalar_cast_numeric.cc | 4 +- .../arrow/compute/kernels/scalar_cast_test.cc | 4 + cpp/src/arrow/dataset/dataset_test.cc | 3 +- cpp/src/arrow/dataset/discovery_test.cc | 2 +- cpp/src/arrow/dataset/expression.cc | 236 +++++++++++++----- cpp/src/arrow/dataset/expression_internal.h | 91 ++++--- cpp/src/arrow/dataset/expression_test.cc | 224 ++++++++++++----- cpp/src/arrow/dataset/filter.cc | 3 +- cpp/src/arrow/dataset/partition.cc | 6 +- cpp/src/arrow/dataset/partition_test.cc | 20 +- cpp/src/arrow/dataset/scanner.cc | 3 + cpp/src/arrow/dataset/scanner_test.cc | 8 +- cpp/src/arrow/dataset/test_util.h | 3 + 17 files changed, 457 insertions(+), 212 deletions(-) diff --git a/cpp/src/arrow/compute/api_scalar.h b/cpp/src/arrow/compute/api_scalar.h index 1292fb7f7e9..85c1587327f 100644 --- a/cpp/src/arrow/compute/api_scalar.h +++ b/cpp/src/arrow/compute/api_scalar.h @@ -70,7 +70,7 @@ struct ARROW_EXPORT SplitPatternOptions : public SplitOptions { /// Options for IsIn and IndexIn functions struct ARROW_EXPORT SetLookupOptions : public FunctionOptions { - explicit SetLookupOptions(Datum value_set, bool skip_nulls) + explicit SetLookupOptions(Datum value_set, bool skip_nulls = false) : value_set(std::move(value_set)), skip_nulls(skip_nulls) {} /// The set of values to look up input values into. @@ -86,7 +86,7 @@ struct ARROW_EXPORT SetLookupOptions : public FunctionOptions { struct ARROW_EXPORT StrptimeOptions : public FunctionOptions { explicit StrptimeOptions(std::string format, TimeUnit::type unit) - : format(format), unit(unit) {} + : format(std::move(format)), unit(unit) {} std::string format; TimeUnit::type unit; diff --git a/cpp/src/arrow/compute/cast.h b/cpp/src/arrow/compute/cast.h index 6808d1d86f3..759b7c7665b 100644 --- a/cpp/src/arrow/compute/cast.h +++ b/cpp/src/arrow/compute/cast.h @@ -42,15 +42,7 @@ class ExecContext; /// @{ struct ARROW_EXPORT CastOptions : public FunctionOptions { - CastOptions() - : allow_int_overflow(false), - allow_time_truncate(false), - allow_time_overflow(false), - allow_decimal_truncate(false), - allow_float_truncate(false), - allow_invalid_utf8(false) {} - - explicit CastOptions(bool safe) + explicit CastOptions(bool safe = true) : allow_int_overflow(!safe), allow_time_truncate(!safe), allow_time_overflow(!safe), @@ -58,9 +50,17 @@ struct ARROW_EXPORT CastOptions : public FunctionOptions { allow_float_truncate(!safe), allow_invalid_utf8(!safe) {} - static CastOptions Safe() { return CastOptions(true); } - - static CastOptions Unsafe() { return CastOptions(false); } + static CastOptions Safe(std::shared_ptr to_type = NULLPTR) { + CastOptions safe(true); + safe.to_type = std::move(to_type); + return safe; + } + + static CastOptions Unsafe(std::shared_ptr to_type = NULLPTR) { + CastOptions unsafe(false); + unsafe.to_type = std::move(to_type); + return unsafe; + } // Type being casted to. May be passed separate to eager function // compute::Cast diff --git a/cpp/src/arrow/compute/exec.cc b/cpp/src/arrow/compute/exec.cc index 790ae9b3997..fbd8229f0c8 100644 --- a/cpp/src/arrow/compute/exec.cc +++ b/cpp/src/arrow/compute/exec.cc @@ -648,7 +648,8 @@ class ScalarExecutor : public KernelExecutorImpl { // Decide if we need to preallocate memory for this kernel validity_preallocated_ = (kernel_->null_handling != NullHandling::COMPUTED_NO_PREALLOCATE && - kernel_->null_handling != NullHandling::OUTPUT_NOT_NULL); + kernel_->null_handling != NullHandling::OUTPUT_NOT_NULL && + output_descr_.type->id() != Type::NA); if (kernel_->mem_allocation == MemAllocation::PREALLOCATE) { ComputeDataPreallocate(*output_descr_.type, &data_preallocated_); } diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_internal.cc b/cpp/src/arrow/compute/kernels/scalar_cast_internal.cc index 67f0820402a..7abf1c618da 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_internal.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_internal.cc @@ -149,6 +149,18 @@ void CastNumberToNumberUnsafe(Type::type in_type, Type::type out_type, const Dat // ---------------------------------------------------------------------- void UnpackDictionary(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + auto Finish = [&](Result result) { + if (!result.ok()) { + ctx->SetStatus(result.status()); + return; + } + *out = *result; + }; + + if (out->is_scalar()) { + return Finish(batch[0].scalar_as().GetEncodedValue()); + } + DictionaryArray dict_arr(batch[0].array()); const CastOptions& options = checked_cast(*ctx->state()).options; @@ -160,16 +172,15 @@ void UnpackDictionary(KernelContext* ctx, const ExecBatch& batch, Datum* out) { return; } - Result result = Take(Datum(dict_arr.dictionary()), Datum(dict_arr.indices()), - /*options=*/TakeOptions::Defaults(), ctx->exec_context()); - if (!result.ok()) { - ctx->SetStatus(result.status()); - return; - } - *out = *result; + return Finish(Take(Datum(dict_arr.dictionary()), Datum(dict_arr.indices()), + /*options=*/TakeOptions::Defaults(), ctx->exec_context())); } void OutputAllNull(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + if (out->is_scalar()) { + out->scalar()->is_valid = false; + return; + } ArrayData* output = out->mutable_array(); output->buffers = {nullptr}; output->null_count = batch.length; @@ -261,9 +272,9 @@ void AddCommonCasts(Type::type out_type_id, OutputType out_ty, CastFunction* fun // // XXX: Uses Take and does its own memory allocation for the moment. We can // fix this later. - DCHECK_OK(func->AddKernel( - Type::DICTIONARY, {InputType::Array(Type::DICTIONARY)}, out_ty, UnpackDictionary, - NullHandling::COMPUTED_NO_PREALLOCATE, MemAllocation::NO_PREALLOCATE)); + DCHECK_OK(func->AddKernel(Type::DICTIONARY, {InputType(Type::DICTIONARY)}, out_ty, + UnpackDictionary, NullHandling::COMPUTED_NO_PREALLOCATE, + MemAllocation::NO_PREALLOCATE)); } // From extension type to this type diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc b/cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc index f46cf7e7a75..62665d4ea44 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc @@ -630,8 +630,8 @@ std::vector> GetNumericCasts() { // Make a cast to null that does not do much. Not sure why we need to be able // to cast from dict -> null but there are unit tests for it auto cast_null = std::make_shared("cast_null", Type::NA); - DCHECK_OK(cast_null->AddKernel(Type::DICTIONARY, {InputType::Array(Type::DICTIONARY)}, - null(), OutputAllNull)); + DCHECK_OK(cast_null->AddKernel(Type::DICTIONARY, {InputType(Type::DICTIONARY)}, null(), + OutputAllNull)); functions.push_back(cast_null); functions.push_back(GetCastToInteger("cast_int8")); diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_test.cc b/cpp/src/arrow/compute/kernels/scalar_cast_test.cc index 566a89f4b47..fda8073beb9 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_test.cc @@ -1801,6 +1801,10 @@ TYPED_TEST(TestDictionaryCast, Basic) { // TODO: Should casting dictionary scalars work? this->CheckPass(*dict_arr, *expected, expected->type(), CastOptions::Safe(), /*check_scalar=*/false); + + auto opts = CastOptions::Safe(); + opts.to_type = expected->type(); + CheckScalarUnary("cast", dict_arr, expected, &opts); } } diff --git a/cpp/src/arrow/dataset/dataset_test.cc b/cpp/src/arrow/dataset/dataset_test.cc index 58a4a4a0b3f..cf1e85ea055 100644 --- a/cpp/src/arrow/dataset/dataset_test.cc +++ b/cpp/src/arrow/dataset/dataset_test.cc @@ -786,8 +786,7 @@ TEST(TestDictPartitionColumn, SelectPartitionColumnFilterPhysicalColumn) { // Selects re-ordered virtual column with a filter on a physical column ASSERT_OK_AND_ASSIGN(auto scan_builder, dataset->NewScan()); - ASSERT_OK(scan_builder->Filter("phy_1"_ == 111)); - + ASSERT_OK(scan_builder->Filter(equal(field_ref("phy_1"), literal(111)))); ASSERT_OK(scan_builder->Project({"part"})); ASSERT_OK_AND_ASSIGN(auto scanner, scan_builder->Finish()); diff --git a/cpp/src/arrow/dataset/discovery_test.cc b/cpp/src/arrow/dataset/discovery_test.cc index 219ee070fdc..93d7c3869dd 100644 --- a/cpp/src/arrow/dataset/discovery_test.cc +++ b/cpp/src/arrow/dataset/discovery_test.cc @@ -369,7 +369,7 @@ TEST_F(FileSystemDatasetFactoryTest, FilenameNotPartOfPartitions) { // column. In such case, the filename should not be used. MakeFactory({fs::File("one/file.parquet")}); - ASSERT_OK_AND_ASSIGN(auto expected, equal(field_ref("first"), literal("one")).Bind(*s)); + auto expected = equal(field_ref("first"), literal("one")); ASSERT_OK_AND_ASSIGN(auto dataset, factory_->Finish()); ASSERT_OK_AND_ASSIGN(auto fragment_it, dataset->GetFragments()); diff --git a/cpp/src/arrow/dataset/expression.cc b/cpp/src/arrow/dataset/expression.cc index 8f9a3ce5663..21c067e3d64 100644 --- a/cpp/src/arrow/dataset/expression.cc +++ b/cpp/src/arrow/dataset/expression.cc @@ -213,7 +213,47 @@ std::string Expression2::ToString() const { for (const auto& arg : call->arguments) { out += arg.ToString() + ","; } - out.back() = ')'; + + if (!call->options) { + out.back() = ')'; + return out; + } + + if (call->options) { + out += " {"; + + if (auto options = GetSetLookupOptions(*call)) { + DCHECK_EQ(options->value_set.kind(), Datum::ARRAY); + out += "value_set:" + options->value_set.make_array()->ToString(); + if (options->skip_nulls) { + out += ",skip_nulls"; + } + } + + if (auto options = GetCastOptions(*call)) { + out += "to_type:" + options->to_type->ToString(); + if (options->allow_int_overflow) out += ",allow_int_overflow"; + if (options->allow_time_truncate) out += ",allow_time_truncate"; + if (options->allow_time_overflow) out += ",allow_time_overflow"; + if (options->allow_decimal_truncate) out += ",allow_decimal_truncate"; + if (options->allow_float_truncate) out += ",allow_float_truncate"; + if (options->allow_invalid_utf8) out += ",allow_invalid_utf8"; + } + + if (auto options = GetStructOptions(*call)) { + for (const auto& field_name : options->field_names) { + out += field_name + ","; + } + out.resize(out.size() - 1); + } + + if (auto options = GetStrptimeOptions(*call)) { + out += "format:" + options->format; + out += ",unit:" + internal::ToString(options->unit); + } + + out += "})"; + } return out; } @@ -246,14 +286,49 @@ bool Expression2::Equals(const Expression2& other) const { return false; } - // FIXME compare FunctionOptions for equality for (size_t i = 0; i < call->arguments.size(); ++i) { if (!call->arguments[i].Equals(other_call->arguments[i])) { return false; } } - return true; + if (call->options == other_call->options) return true; + + if (auto options = GetSetLookupOptions(*call)) { + auto other_options = GetSetLookupOptions(*other_call); + return options->value_set == other_options->value_set && + options->skip_nulls == other_options->skip_nulls; + } + + if (auto options = GetCastOptions(*call)) { + auto other_options = GetCastOptions(*other_call); + for (auto safety_opt : { + &compute::CastOptions::allow_int_overflow, + &compute::CastOptions::allow_time_truncate, + &compute::CastOptions::allow_time_overflow, + &compute::CastOptions::allow_decimal_truncate, + &compute::CastOptions::allow_float_truncate, + &compute::CastOptions::allow_invalid_utf8, + }) { + if (options->*safety_opt != other_options->*safety_opt) return false; + } + return options->to_type->Equals(other_options->to_type); + } + + if (auto options = GetStructOptions(*call)) { + auto other_options = GetStructOptions(*other_call); + return options->field_names == other_options->field_names; + } + + if (auto options = GetStrptimeOptions(*call)) { + auto other_options = GetStrptimeOptions(*other_call); + return options->format == other_options->format && + options->unit == other_options->unit; + } + + ARROW_LOG(WARNING) << "comparing unknown FunctionOptions for function " + << call->function; + return false; } size_t Expression2::hash() const { @@ -378,6 +453,81 @@ Result> InitKernelState( return std::move(kernel_state); } +Status MaybeInsertCast(std::shared_ptr to_type, Expression2* expr) { + if (expr->descr().type->Equals(to_type)) { + return Status::OK(); + } + + if (auto lit = expr->literal()) { + ARROW_ASSIGN_OR_RAISE(Datum new_lit, compute::Cast(*lit, to_type)); + *expr = literal(std::move(new_lit)); + return Status::OK(); + } + + // FIXME the resulting cast Call must be bound but this is a hack + auto with_cast = call("cast", {literal(MakeNullScalar(expr->descr().type))}, + compute::CastOptions::Safe(to_type)); + + static ValueDescr ignored_descr; + ARROW_ASSIGN_OR_RAISE(with_cast, with_cast.Bind(ignored_descr)); + + auto call_with_cast = *CallNotNull(with_cast); + call_with_cast.arguments[0] = std::move(*expr); + *expr = Expression2(std::make_shared(std::move(call_with_cast)), + ValueDescr{std::move(to_type), expr->descr().shape}); + + return Status::OK(); +} + +Status InsertImplicitCasts(Expression2::Call* call) { + DCHECK(std::all_of(call->arguments.begin(), call->arguments.end(), + [](const Expression2& argument) { return argument.IsBound(); })); + + if (Comparison::Get(call->function)) { + for (auto&& argument : call->arguments) { + if (auto value_type = GetDictionaryValueType(argument.descr().type)) { + RETURN_NOT_OK(MaybeInsertCast(std::move(value_type), &argument)); + } + } + + if (call->arguments[0].descr().shape == ValueDescr::SCALAR) { + // argument 0 is scalar so casting is cheap + return MaybeInsertCast(call->arguments[1].descr().type, &call->arguments[0]); + } else { + return MaybeInsertCast(call->arguments[0].descr().type, &call->arguments[1]); + } + } + + if (auto options = GetSetLookupOptions(*call)) { + if (auto value_type = GetDictionaryValueType(call->arguments[0].descr().type)) { + // DICTIONARY input is not supported; decode it. + RETURN_NOT_OK(MaybeInsertCast(std::move(value_type), &call->arguments[0])); + } + + if (options->value_set.type()->id() == Type::DICTIONARY) { + // DICTIONARY value_set is not supported; decode it. + auto new_options = std::make_shared(*options); + RETURN_NOT_OK(EnsureNotDictionary(&new_options->value_set)); + options = new_options.get(); + call->options = std::move(new_options); + } + + if (!options->value_set.type()->Equals(call->arguments[0].descr().type)) { + // The value_set is assumed smaller than inputs, casting it should be cheaper. + auto new_options = std::make_shared(*options); + ARROW_ASSIGN_OR_RAISE(new_options->value_set, + compute::Cast(std::move(new_options->value_set), + call->arguments[0].descr().type)); + options = new_options.get(); + call->options = std::move(new_options); + } + + return Status::OK(); + } + + return Status::OK(); +} + Result Expression2::Bind(ValueDescr in, compute::ExecContext* exec_context) const { if (exec_context == nullptr) { @@ -402,18 +552,9 @@ Result Expression2::Bind(ValueDescr in, for (auto&& argument : bound_call.arguments) { ARROW_ASSIGN_OR_RAISE(argument, argument.Bind(in, exec_context)); } - - if (RequriesDictionaryTransparency(bound_call)) { - RETURN_NOT_OK(EnsureNotDictionary(&bound_call)); - } + RETURN_NOT_OK(InsertImplicitCasts(&bound_call)); auto descrs = GetDescriptors(bound_call.arguments); - for (auto&& descr : descrs) { - if (RequriesDictionaryTransparency(bound_call)) { - RETURN_NOT_OK(EnsureNotDictionary(&descr)); - } - } - ARROW_ASSIGN_OR_RAISE(bound_call.kernel, function->DispatchExact(descrs)); compute::KernelContext kernel_context(exec_context); @@ -470,10 +611,6 @@ Result ExecuteScalarExpression(const Expression2& expr, const Datum& inpu for (size_t i = 0; i < arguments.size(); ++i) { ARROW_ASSIGN_OR_RAISE( arguments[i], ExecuteScalarExpression(call->arguments[i], input, exec_context)); - - if (RequriesDictionaryTransparency(*call)) { - RETURN_NOT_OK(EnsureNotDictionary(&arguments[i])); - } } auto executor = compute::detail::KernelExecutor::MakeScalar(); @@ -558,38 +695,6 @@ std::vector FieldsInExpression(const Expression2& expr) { return fields; } -template -Result Modify(Expression2 expr, const PreVisit& pre, - const PostVisitCall& post_call) { - ARROW_ASSIGN_OR_RAISE(expr, Result(pre(std::move(expr)))); - - auto call = expr.call(); - if (!call) return expr; - - bool at_least_one_modified = false; - auto modified_call = *call; - auto modified_argument = modified_call.arguments.begin(); - - for (const auto& argument : call->arguments) { - ARROW_ASSIGN_OR_RAISE(*modified_argument, Modify(argument, pre, post_call)); - - if (!Identical(*modified_argument, argument)) { - at_least_one_modified = true; - } - ++modified_argument; - } - - if (at_least_one_modified) { - // reconstruct the call expression with the modified arguments - auto modified_expr = Expression2( - std::make_shared(std::move(modified_call)), expr.descr()); - - return post_call(std::move(modified_expr), &expr); - } - - return post_call(std::move(expr), nullptr); -} - Result FoldConstants(Expression2 expr) { return Modify( std::move(expr), [](Expression2 expr) { return expr; }, @@ -625,6 +730,9 @@ Result FoldConstants(Expression2 expr) { // false and x == false if (args.first == literal(false)) return args.first; + + // x and x == x + if (args.first == args.second) return args.first; } return expr; } @@ -636,6 +744,9 @@ Result FoldConstants(Expression2 expr) { // true or x == true if (args.first == literal(true)) return args.first; + + // x or x == x + if (args.first == args.second) return args.first; } return expr; } @@ -658,10 +769,6 @@ inline std::vector GuaranteeConjunctionMembers( Status ExtractKnownFieldValuesImpl( std::vector* conjunction_members, std::unordered_map* known_values) { - for (auto&& member : *conjunction_members) { - ARROW_ASSIGN_OR_RAISE(member, Canonicalize({std::move(member)})); - } - auto unconsumed_end = std::partition(conjunction_members->begin(), conjunction_members->end(), [](const Expression2& expr) { @@ -703,9 +810,6 @@ Status ExtractKnownFieldValuesImpl( Result> ExtractKnownFieldValues( const Expression2& guaranteed_true_predicate) { - if (!guaranteed_true_predicate.IsBound()) { - return Status::Invalid("guaranteed_true_predicate was not bound"); - } auto conjunction_members = GuaranteeConjunctionMembers(guaranteed_true_predicate); std::unordered_map known_values; RETURN_NOT_OK(ExtractKnownFieldValuesImpl(&conjunction_members, &known_values)); @@ -715,13 +819,20 @@ Result> ExtractKnownFieldVal Result ReplaceFieldsWithKnownValues( const std::unordered_map& known_values, Expression2 expr) { + if (!expr.IsBound()) { + return Status::Invalid( + "ReplaceFieldsWithKnownValues called on an unbound Expression2"); + } + return Modify( std::move(expr), - [&known_values](Expression2 expr) { + [&known_values](Expression2 expr) -> Result { if (auto ref = expr.field_ref()) { auto it = known_values.find(*ref); if (it != known_values.end()) { - return literal(it->second); + ARROW_ASSIGN_OR_RAISE(Datum lit, + compute::Cast(it->second, expr.descr().type)); + return literal(std::move(lit)); } } return expr; @@ -817,11 +928,6 @@ Result Canonicalize(Expression2 expr, compute::ExecContext* exec_co exec_context->func_registry()->GetFunction(flipped_call.function)); auto descrs = GetDescriptors(flipped_call.arguments); - for (auto& descr : descrs) { - if (RequriesDictionaryTransparency(flipped_call)) { - RETURN_NOT_OK(EnsureNotDictionary(&descr)); - } - } ARROW_ASSIGN_OR_RAISE(flipped_call.kernel, function->DispatchExact(descrs)); std::swap(flipped_call.arguments[0], flipped_call.arguments[1]); @@ -896,10 +1002,6 @@ Result DirectComparisonSimplification(Expression2 expr, Result SimplifyWithGuarantee(Expression2 expr, const Expression2& guaranteed_true_predicate) { - if (!guaranteed_true_predicate.IsBound()) { - return Status::Invalid("guaranteed_true_predicate was not bound"); - } - auto conjunction_members = GuaranteeConjunctionMembers(guaranteed_true_predicate); std::unordered_map known_values; diff --git a/cpp/src/arrow/dataset/expression_internal.h b/cpp/src/arrow/dataset/expression_internal.h index 84b802937e0..0324eaedc77 100644 --- a/cpp/src/arrow/dataset/expression_internal.h +++ b/cpp/src/arrow/dataset/expression_internal.h @@ -140,10 +140,16 @@ struct Comparison { return nullptr; } + // Execute a simple Comparison between scalars, casting the RHS if types disagree static Result Execute(Datum l, Datum r) { if (!l.is_scalar() || !r.is_scalar()) { return Status::Invalid("Cannot Execute Comparison on non-scalars"); } + + if (!l.type()->Equals(r.type())) { + ARROW_ASSIGN_OR_RAISE(r, compute::Cast(r, l.type())); + } + std::vector arguments{std::move(l), std::move(r)}; ARROW_ASSIGN_OR_RAISE(auto equal, compute::CallFunction("equal", arguments)); @@ -201,6 +207,11 @@ struct Comparison { } }; +inline const compute::CastOptions* GetCastOptions(const Expression2::Call& call) { + if (call.function != "cast") return nullptr; + return checked_cast(call.options.get()); +} + inline bool IsSetLookup(const std::string& function) { return function == "is_in" || function == "index_in"; } @@ -211,45 +222,37 @@ inline const compute::SetLookupOptions* GetSetLookupOptions( return checked_cast(call.options.get()); } -inline bool RequriesDictionaryTransparency(const Expression2::Call& call) { - // TODO move this functionality into compute:: - - // Functions which don't provide kernels for dictionary types. Dictionaries will be - // decoded for these functions. - if (Comparison::Get(call.function)) return true; +inline const compute::StructOptions* GetStructOptions(const Expression2::Call& call) { + if (call.function != "struct") return nullptr; + return checked_cast(call.options.get()); +} - if (IsSetLookup(call.function)) return true; +inline const compute::StrptimeOptions* GetStrptimeOptions(const Expression2::Call& call) { + if (call.function != "strptime") return nullptr; + return checked_cast(call.options.get()); +} - return false; +inline const std::shared_ptr& GetDictionaryValueType( + const std::shared_ptr& type) { + if (type && type->id() == Type::DICTIONARY) { + return checked_cast(*type).value_type(); + } + static std::shared_ptr null; + return null; } inline Status EnsureNotDictionary(ValueDescr* descr) { - const auto& type = descr->type; - if (type && type->id() == Type::DICTIONARY) { - descr->type = checked_cast(*type).value_type(); + if (auto value_type = GetDictionaryValueType(descr->type)) { + descr->type = std::move(value_type); } return Status::OK(); } inline Status EnsureNotDictionary(Datum* datum) { - if (datum->type()->id() != Type::DICTIONARY) { - return Status::OK(); - } - - if (datum->is_scalar()) { - ARROW_ASSIGN_OR_RAISE( - *datum, - checked_cast(*datum->scalar()).GetEncodedValue()); - return Status::OK(); + if (datum->type()->id() == Type::DICTIONARY) { + const auto& type = checked_cast(*datum->type()).value_type(); + ARROW_ASSIGN_OR_RAISE(*datum, compute::Cast(*datum, type)); } - - DCHECK_EQ(datum->kind(), Datum::ARRAY); - ArrayData indices = *datum->array(); - indices.type = checked_cast(*datum->type()).index_type(); - auto values = std::move(indices.dictionary); - - ARROW_ASSIGN_OR_RAISE( - *datum, compute::Take(values, indices, compute::TakeOptions::NoBoundsCheck())); return Status::OK(); } @@ -397,5 +400,37 @@ inline Result> GetFunction( return compute::GetCastFunction(to_type); } +template +Result Modify(Expression2 expr, const PreVisit& pre, + const PostVisitCall& post_call) { + ARROW_ASSIGN_OR_RAISE(expr, Result(pre(std::move(expr)))); + + auto call = expr.call(); + if (!call) return expr; + + bool at_least_one_modified = false; + auto modified_call = *call; + auto modified_argument = modified_call.arguments.begin(); + + for (const auto& argument : call->arguments) { + ARROW_ASSIGN_OR_RAISE(*modified_argument, Modify(argument, pre, post_call)); + + if (!Identical(*modified_argument, argument)) { + at_least_one_modified = true; + } + ++modified_argument; + } + + if (at_least_one_modified) { + // reconstruct the call expression with the modified arguments + auto modified_expr = Expression2( + std::make_shared(std::move(modified_call)), expr.descr()); + + return post_call(std::move(modified_expr), &expr); + } + + return post_call(std::move(expr), nullptr); +} + } // namespace dataset } // namespace arrow diff --git a/cpp/src/arrow/dataset/expression_test.cc b/cpp/src/arrow/dataset/expression_test.cc index 8ea70adf01c..44dd322aa63 100644 --- a/cpp/src/arrow/dataset/expression_test.cc +++ b/cpp/src/arrow/dataset/expression_test.cc @@ -42,6 +42,11 @@ namespace dataset { #define EXPECT_OK ARROW_EXPECT_OK +Expression2 cast(Expression2 argument, std::shared_ptr to_type) { + return call("cast", {std::move(argument)}, + compute::CastOptions::Safe(std::move(to_type))); +} + TEST(Expression2, ToString) { EXPECT_EQ(field_ref("alpha").ToString(), "FieldRef(alpha)"); @@ -50,23 +55,59 @@ TEST(Expression2, ToString) { EXPECT_EQ(call("add", {literal(3), field_ref("beta")}).ToString(), "add(3,FieldRef(beta))"); - EXPECT_EQ(call("add", {literal(3), call("index_in", {field_ref("beta")})}).ToString(), - "add(3,index_in(FieldRef(beta)))"); + auto in_12 = call("index_in", {field_ref("beta")}, + compute::SetLookupOptions{ArrayFromJSON(int32(), "[1,2]")}); + + EXPECT_EQ(in_12.ToString(), "index_in(FieldRef(beta), {value_set:[\n 1,\n 2\n]})"); + + EXPECT_EQ(cast(field_ref("a"), int32()).ToString(), + "cast(FieldRef(a), {to_type:int32})"); + + EXPECT_EQ(project( + { + field_ref("a"), + field_ref("a"), + literal(3), + in_12, + }, + { + "a", + "renamed_a", + "three", + "b", + }) + .ToString(), + "struct(FieldRef(a),FieldRef(a),3," + in_12.ToString() + + ", {a,renamed_a,three,b})"); } TEST(Expression2, Equality) { EXPECT_EQ(literal(1), literal(1)); + EXPECT_NE(literal(1), literal(2)); EXPECT_EQ(field_ref("a"), field_ref("a")); + EXPECT_NE(field_ref("a"), field_ref("b")); + EXPECT_NE(field_ref("a"), literal(2)); EXPECT_EQ(call("add", {literal(3), field_ref("a")}), call("add", {literal(3), field_ref("a")})); + EXPECT_NE(call("add", {literal(3), field_ref("a")}), + call("add", {literal(2), field_ref("a")})); + EXPECT_NE(call("add", {field_ref("a"), literal(3)}), + call("add", {literal(3), field_ref("a")})); - EXPECT_EQ(call("add", {literal(3), call("index_in", {field_ref("beta")})}), - call("add", {literal(3), call("index_in", {field_ref("beta")})})); + auto in_123 = compute::SetLookupOptions{ArrayFromJSON(int32(), "[1,2,3]")}; + EXPECT_EQ(call("add", {literal(3), call("index_in", {field_ref("beta")}, in_123)}), + call("add", {literal(3), call("index_in", {field_ref("beta")}, in_123)})); - // FIXME options are not currently compared, so two cast exprs will compare equal - // regardless of differing target type + auto in_12 = compute::SetLookupOptions{ArrayFromJSON(int32(), "[1,2]")}; + EXPECT_NE(call("add", {literal(3), call("index_in", {field_ref("beta")}, in_12)}), + call("add", {literal(3), call("index_in", {field_ref("beta")}, in_123)})); + + EXPECT_EQ(cast(field_ref("a"), int32()), cast(field_ref("a"), int32())); + EXPECT_NE(cast(field_ref("a"), int32()), cast(field_ref("a"), int64())); + EXPECT_NE(cast(field_ref("a"), int32()), + call("cast", {field_ref("a")}, compute::CastOptions::Unsafe(int32()))); } TEST(Expression2, Hash) { @@ -231,13 +272,53 @@ TEST(Expression2, BindCall) { expr.Bind(Schema({field("a", int32()), field("b", int32())}))); } -TEST(Expression2, BindDictionaryTransparent) { - auto expr = call("equal", {field_ref("a"), field_ref("b")}); - EXPECT_FALSE(expr.IsBound()); +TEST(Expression2, BindWithImplicitCasts) { + for (auto cmp : {equal, not_equal, less, less_equal, greater, greater_equal}) { + // cast arguments to same type + ASSERT_OK_AND_ASSIGN(auto expr, + cmp(field_ref("i32"), field_ref("i64")).Bind(*kBoringSchema)); + + // NB: RHS is cast unless LHS is scalar. + ASSERT_OK_AND_ASSIGN( + auto expected, + cmp(field_ref("i32"), cast(field_ref("i64"), int32())).Bind(*kBoringSchema)); + + EXPECT_EQ(expr, expected); + + // cast dictionary to value type + ASSERT_OK_AND_ASSIGN( + expr, cmp(field_ref("dict_str"), field_ref("str")).Bind(*kBoringSchema)); + + ASSERT_OK_AND_ASSIGN( + expected, + cmp(cast(field_ref("dict_str"), utf8()), field_ref("str")).Bind(*kBoringSchema)); + + EXPECT_EQ(expr, expected); + } + // cast value_set to argument type + auto Opts = [](std::shared_ptr type) { + return compute::SetLookupOptions{ArrayFromJSON(type, R"(["a"])")}; + }; + ASSERT_OK_AND_ASSIGN( + auto expr, call("is_in", {field_ref("str")}, Opts(binary())).Bind(*kBoringSchema)); + ASSERT_OK_AND_ASSIGN( + auto expected, + call("is_in", {field_ref("str")}, Opts(utf8())).Bind(*kBoringSchema)); + + EXPECT_EQ(expr, expected); + + // dictionary decode set then cast to argument type + ASSERT_OK_AND_ASSIGN( + expr, call("is_in", {field_ref("str")}, Opts(dictionary(int32(), binary()))) + .Bind(*kBoringSchema)); + + EXPECT_EQ(expr, expected); +} + +TEST(Expression2, BindDictionaryTransparent) { ASSERT_OK_AND_ASSIGN( - expr, - expr.Bind(Schema({field("a", utf8()), field("b", dictionary(int32(), utf8()))}))); + auto expr, equal(field_ref("str"), field_ref("dict_str")).Bind(*kBoringSchema)); EXPECT_EQ(expr.descr(), ValueDescr::Array(boolean())); EXPECT_TRUE(expr.IsBound()); @@ -303,46 +384,34 @@ Result NaiveExecuteScalarExpression(const Expression2& expr, const Datum& for (size_t i = 0; i < arguments.size(); ++i) { ARROW_ASSIGN_OR_RAISE(arguments[i], NaiveExecuteScalarExpression(call->arguments[i], input)); - - if (RequriesDictionaryTransparency(*call)) { - RETURN_NOT_OK(EnsureNotDictionary(&arguments[i])); - } } - ARROW_ASSIGN_OR_RAISE(auto function, - compute::GetFunctionRegistry()->GetFunction(call->function)); + compute::ExecContext exec_context; + ARROW_ASSIGN_OR_RAISE(auto function, GetFunction(*call, &exec_context)); auto descrs = GetDescriptors(call->arguments); - for (size_t i = 0; i < arguments.size(); ++i) { - if (RequriesDictionaryTransparency(*call)) { - RETURN_NOT_OK(EnsureNotDictionary(&descrs[i])); - } - EXPECT_EQ(arguments[i].descr(), descrs[i]); - } - ARROW_ASSIGN_OR_RAISE(auto expected_kernel, function->DispatchExact(descrs)); EXPECT_EQ(call->kernel, expected_kernel); - - auto options = call->options; - if (RequriesDictionaryTransparency(*call)) { - auto non_dict_call = *call; - RETURN_NOT_OK(EnsureNotDictionary(&non_dict_call)); - options = non_dict_call.options; - } - - compute::ExecContext exec_context; return function->Execute(arguments, call->options.get(), &exec_context); } -void AssertExecute(Expression2 expr, Datum in) { - ASSERT_OK_AND_ASSIGN(expr, expr.Bind(in.descr())); +void AssertExecute(Expression2 expr, Datum in, Datum* actual_out = NULLPTR) { + if (in.is_value()) { + ASSERT_OK_AND_ASSIGN(expr, expr.Bind(in.descr())); + } else { + ASSERT_OK_AND_ASSIGN(expr, expr.Bind(*in.record_batch()->schema())); + } ASSERT_OK_AND_ASSIGN(Datum actual, ExecuteScalarExpression(expr, in)); ASSERT_OK_AND_ASSIGN(Datum expected, NaiveExecuteScalarExpression(expr, in)); AssertDatumsEqual(actual, expected, /*verbose=*/true); + + if (actual_out) { + *actual_out = actual; + } } TEST(Expression2, ExecuteCall) { @@ -481,16 +550,17 @@ TEST(Expression2, FoldConstantsBoolean) { ExpectFoldsTo(and_(false_, whatever), false_); ExpectFoldsTo(and_(true_, whatever), whatever); + ExpectFoldsTo(and_(whatever, whatever), whatever); ExpectFoldsTo(or_(true_, whatever), true_); ExpectFoldsTo(or_(false_, whatever), whatever); + ExpectFoldsTo(or_(whatever, whatever), whatever); } TEST(Expression2, ExtractKnownFieldValues) { struct { void operator()(Expression2 guarantee, std::unordered_map expected) { - ASSERT_OK_AND_ASSIGN(guarantee, guarantee.Bind(*kBoringSchema)); ASSERT_OK_AND_ASSIGN(auto actual, ExtractKnownFieldValues(guarantee)); EXPECT_THAT(actual, UnorderedElementsAreArray(expected)) << " guarantee: " << guarantee.ToString(); @@ -506,22 +576,34 @@ TEST(Expression2, ExtractKnownFieldValues) { ExpectKnown(equal(field_ref("i32"), literal(null_int32)), {{"i32", Datum(null_int32)}}); ExpectKnown( - and_({equal(field_ref("i32"), literal(3)), equal(literal(1.5F), field_ref("f32"))}), + and_({equal(field_ref("i32"), literal(3)), equal(field_ref("f32"), literal(1.5F))}), {{"i32", Datum(3)}, {"f32", Datum(1.5F)}}); + // NB: guarantees are *not* automatically canonicalized ExpectKnown( - and_({equal(field_ref("i32"), literal(3)), equal(literal(2.F), field_ref("f32")), - equal(literal(1), field_ref("i32_req"))}), + and_({equal(field_ref("i32"), literal(3)), equal(literal(1.5F), field_ref("f32"))}), + {{"i32", Datum(3)}}); + + // NB: guarantees are *not* automatically simplified + // (the below could be constant folded to a usable guarantee) + ExpectKnown(or_({equal(field_ref("i32"), literal(3)), literal(false)}), {}); + + // NB: guarantees are unbound; applying them may require casts + ExpectKnown(equal(field_ref("i32"), literal("1234324")), {{"i32", Datum("1234324")}}); + + ExpectKnown( + and_({equal(field_ref("i32"), literal(3)), equal(field_ref("f32"), literal(2.F)), + equal(field_ref("i32_req"), literal(1))}), {{"i32", Datum(3)}, {"f32", Datum(2.F)}, {"i32_req", Datum(1)}}); ExpectKnown( and_(or_(equal(field_ref("i32"), literal(3)), equal(field_ref("i32"), literal(4))), - equal(literal(2.F), field_ref("f32"))), + equal(field_ref("f32"), literal(2.F))), {{"f32", Datum(2.F)}}); ExpectKnown(and_({equal(field_ref("i32"), literal(3)), equal(field_ref("f32"), field_ref("f32_req")), - equal(literal(1), field_ref("i32_req"))}), + equal(field_ref("i32_req"), literal(1))}), {{"i32", Datum(3)}, {"i32_req", Datum(1)}}); } @@ -530,10 +612,10 @@ TEST(Expression2, ReplaceFieldsWithKnownValues) { [](Expression2 expr, std::unordered_map known_values, Expression2 unbound_expected) { - ASSERT_OK_AND_ASSIGN(auto bound, expr.Bind(*kBoringSchema)); + ASSERT_OK_AND_ASSIGN(expr, expr.Bind(*kBoringSchema)); ASSERT_OK_AND_ASSIGN(auto expected, unbound_expected.Bind(*kBoringSchema)); ASSERT_OK_AND_ASSIGN(auto replaced, - ReplaceFieldsWithKnownValues(known_values, bound)); + ReplaceFieldsWithKnownValues(known_values, expr)); EXPECT_EQ(replaced, expected); @@ -549,6 +631,9 @@ TEST(Expression2, ReplaceFieldsWithKnownValues) { ExpectReplacesTo(field_ref("i32"), i32_is_3, literal(3)); + // NB: known_values will be cast + ExpectReplacesTo(field_ref("i32"), {{"i32", Datum("3")}}, literal(3)); + ExpectReplacesTo(field_ref("b"), i32_is_3, field_ref("b")); ExpectReplacesTo(equal(field_ref("i32"), literal(1)), i32_is_3, @@ -648,7 +733,6 @@ struct Simplify { void Expect(Expression2 unbound_expected) { ASSERT_OK_AND_ASSIGN(auto bound, expr.Bind(*kBoringSchema)); - ASSERT_OK_AND_ASSIGN(guarantee, guarantee.Bind(*kBoringSchema)); ASSERT_OK_AND_ASSIGN(auto simplified, SimplifyWithGuarantee(bound, guarantee)); @@ -806,30 +890,47 @@ TEST(Expression2, SimplifyWithGuarantee) { .WithGuarantee(and_(greater_equal(field_ref("i32"), literal(0)), less_equal(field_ref("i32"), literal(1)))) .Expect(equal(field_ref("i32"), literal(0))); + Simplify{ + or_(equal(field_ref("f32"), literal("0")), equal(field_ref("i32"), literal(3)))} + .WithGuarantee(greater(field_ref("f32"), literal(0.0))) + .Expect(equal(field_ref("i32"), literal(3))); + + // simplification can see through implicit casts + Simplify{or_({equal(field_ref("f32"), literal("0")), + call("is_in", {field_ref("i64")}, + compute::SetLookupOptions{ + ArrayFromJSON(dictionary(int32(), int32()), "[1,2,3]"), true})})} + .WithGuarantee(greater(field_ref("f32"), literal(0.0))) + .Expect(call("is_in", {field_ref("i64")}, + compute::SetLookupOptions{ArrayFromJSON(int64(), "[1,2,3]"), true})); } TEST(Expression2, SimplifyThenExecute) { auto filter = - and_({equal(field_ref("f32"), literal(0.F)), - call("is_in", {field_ref("i32")}, - compute::SetLookupOptions{ArrayFromJSON(int32(), "[1,2,3]"), true})}); + or_({equal(field_ref("f32"), literal("0")), + call("is_in", {field_ref("i64")}, + compute::SetLookupOptions{ + ArrayFromJSON(dictionary(int32(), int32()), "[1,2,3]"), true})}); - ASSERT_OK_AND_ASSIGN(auto bound, filter.Bind(*kBoringSchema)); - ASSERT_OK_AND_ASSIGN(auto guarantee, - equal(field_ref("f32"), literal(0.F)).Bind(*kBoringSchema)); + ASSERT_OK_AND_ASSIGN(filter, filter.Bind(*kBoringSchema)); + auto guarantee = greater(field_ref("f32"), literal(0.0)); - ASSERT_OK_AND_ASSIGN(bound, SimplifyWithGuarantee(bound, guarantee)); + ASSERT_OK_AND_ASSIGN(auto simplified, SimplifyWithGuarantee(filter, guarantee)); auto input = RecordBatchFromJSON(kBoringSchema, R"([ - {"i32": 0, "f32": -0.1}, - {"i32": 0, "f32": 0.3}, - {"i32": 1, "f32": 0.2}, - {"i32": 2, "f32": -0.1}, - {"i32": 0, "f32": 0.1}, - {"i32": 0, "f32": null}, - {"i32": 0, "f32": 1.0} + {"i64": 0, "f32": 0.1}, + {"i64": 0, "f32": 0.3}, + {"i64": 1, "f32": 0.5}, + {"i64": 2, "f32": 0.1}, + {"i64": 0, "f32": 0.1}, + {"i64": 0, "f32": 0.4}, + {"i64": 0, "f32": 1.0} ])"); - ASSERT_OK_AND_ASSIGN(Datum evaluated, ExecuteScalarExpression(bound, input)); + + Datum evaluated, simplified_evaluated; + AssertExecute(filter, input, &evaluated); + AssertExecute(simplified, input, &simplified_evaluated); + AssertDatumsEqual(evaluated, simplified_evaluated, /*verbose=*/true); } TEST(Expression2, Filter) { @@ -889,10 +990,9 @@ TEST(Expression2, SerializationRoundTrips) { ExpectRoundTrips(not_(field_ref("alpha"))); - compute::CastOptions to_float64; - to_float64.to_type = float64(); ExpectRoundTrips( - call("is_in", {call("cast", {field_ref("version")}, to_float64)}, + call("is_in", + {call("cast", {field_ref("version")}, compute::CastOptions::Safe(float64()))}, compute::SetLookupOptions{ArrayFromJSON(float64(), "[0.5, 1.0, 2.0]"), true})); ExpectRoundTrips(call("is_valid", {field_ref("validity")})); diff --git a/cpp/src/arrow/dataset/filter.cc b/cpp/src/arrow/dataset/filter.cc index e231737bc5f..71253cfde89 100644 --- a/cpp/src/arrow/dataset/filter.cc +++ b/cpp/src/arrow/dataset/filter.cc @@ -983,7 +983,8 @@ struct InsertImplicitCastsImpl { if (!op.type->Equals(set->type())) { // cast the set (which we assume to be small) to match op.type - ARROW_ASSIGN_OR_RAISE(auto encoded_set, CastOrDictionaryEncode(*set, op.type, {})); + ARROW_ASSIGN_OR_RAISE(auto encoded_set, CastOrDictionaryEncode( + *set, op.type, compute::CastOptions{})); set = encoded_set.make_array(); } diff --git a/cpp/src/arrow/dataset/partition.cc b/cpp/src/arrow/dataset/partition.cc index 4e7cc062997..16a49b988a9 100644 --- a/cpp/src/arrow/dataset/partition.cc +++ b/cpp/src/arrow/dataset/partition.cc @@ -186,14 +186,10 @@ Result KeyValuePartitioning::Parse(const std::string& path) const { expressions.push_back(std::move(expr)); } - return and_(std::move(expressions)).Bind(*schema_); + return and_(std::move(expressions)); } Result KeyValuePartitioning::Format(const Expression2& expr) const { - if (!expr.IsBound()) { - return Status::Invalid("formatted partition expressions must be bound"); - } - std::vector values{static_cast(schema_->num_fields()), nullptr}; ARROW_ASSIGN_OR_RAISE(auto known_values, ExtractKnownFieldValues(expr)); diff --git a/cpp/src/arrow/dataset/partition_test.cc b/cpp/src/arrow/dataset/partition_test.cc index 57a1693b953..7bf7efa3761 100644 --- a/cpp/src/arrow/dataset/partition_test.cc +++ b/cpp/src/arrow/dataset/partition_test.cc @@ -45,21 +45,17 @@ class TestPartitioning : public ::testing::Test { void AssertParse(const std::string& path, Expression2 expected) { ASSERT_OK_AND_ASSIGN(auto parsed, partitioning_->Parse(path)); - ASSERT_OK_AND_ASSIGN(expected, expected.Bind(*partitioning_->schema())); ASSERT_EQ(parsed, expected); } template void AssertFormatError(Expression2 expr) { - ASSERT_OK_AND_ASSIGN(expr, expr.Bind(*written_schema_)); ASSERT_EQ(partitioning_->Format(expr).status().code(), code); } void AssertFormat(Expression2 expr, const std::string& expected) { // formatted partition expressions are bound to the schema of the dataset being // written - ASSERT_OK_AND_ASSIGN(expr, expr.Bind(*written_schema_)); - ASSERT_OK_AND_ASSIGN(auto formatted, partitioning_->Format(expr)); ASSERT_EQ(formatted, expected); @@ -67,8 +63,8 @@ class TestPartitioning : public ::testing::Test { // expression: roundtripped should be a subset of expr ASSERT_OK_AND_ASSIGN(Expression2 roundtripped, partitioning_->Parse(formatted)); - ASSERT_OK_AND_ASSIGN(auto bound, roundtripped.Bind(*partitioning_->schema())); - ASSERT_OK_AND_ASSIGN(auto simplified, SimplifyWithGuarantee(bound, expr)); + ASSERT_OK_AND_ASSIGN(roundtripped, roundtripped.Bind(*written_schema_)); + ASSERT_OK_AND_ASSIGN(auto simplified, SimplifyWithGuarantee(roundtripped, expr)); ASSERT_EQ(simplified, literal(true)); } @@ -365,7 +361,7 @@ TEST_F(TestPartitioning, EtlThenHive) { fs::internal::JoinAbstractPath(etl_segments_end, alphabeta_segments_end); ARROW_ASSIGN_OR_RAISE(auto alphabeta_expr, alphabeta_part.Parse(alphabeta_path)); - return and_(etl_expr, alphabeta_expr).Bind(*schm); + return and_(etl_expr, alphabeta_expr); }); AssertParse("/1999/12/31/00/alpha=0/beta=3.25", @@ -405,12 +401,8 @@ TEST_F(TestPartitioning, Set) { set.push_back(checked_cast(*s).value); } - auto is_in_expr = call("is_in", {field_ref(matches[1])}, - compute::SetLookupOptions{ints(set), true}); - - ARROW_ASSIGN_OR_RAISE(is_in_expr, is_in_expr.Bind(*schm)); - - subexpressions.push_back(is_in_expr); + subexpressions.push_back(call("is_in", {field_ref(matches[1])}, + compute::SetLookupOptions{ints(set), true})); } return and_(std::move(subexpressions)); }); @@ -453,7 +445,7 @@ class RangePartitioning : public Partitioning { max_cmp(field_ref(key->name), literal(max)))); } - return and_(ranges).Bind(*schema_); + return and_(ranges); } static Status DoRegex(const std::string& segment, std::smatch* matches) { diff --git a/cpp/src/arrow/dataset/scanner.cc b/cpp/src/arrow/dataset/scanner.cc index e3629bfd8cc..043f6695222 100644 --- a/cpp/src/arrow/dataset/scanner.cc +++ b/cpp/src/arrow/dataset/scanner.cc @@ -123,6 +123,9 @@ Status ScannerBuilder::Project(std::vector columns) { } Status ScannerBuilder::Filter(const Expression2& filter) { + for (const auto& ref : FieldsInExpression(filter)) { + RETURN_NOT_OK(ref.FindOne(*schema())); + } ARROW_ASSIGN_OR_RAISE(scan_options_->filter2, filter.Bind(*schema())); return Status::OK(); } diff --git a/cpp/src/arrow/dataset/scanner_test.cc b/cpp/src/arrow/dataset/scanner_test.cc index d79cc49faac..7261bdb8037 100644 --- a/cpp/src/arrow/dataset/scanner_test.cc +++ b/cpp/src/arrow/dataset/scanner_test.cc @@ -191,15 +191,13 @@ TEST_F(TestScannerBuilder, TestFilter) { call("equal", {field_ref("b"), literal(true)}), }))); - ASSERT_RAISES(NotImplemented, - builder.Filter(call("equal", {field_ref("i64"), literal(10)}))); + ASSERT_OK(builder.Filter(call("equal", {field_ref("i64"), literal(10)}))); ASSERT_RAISES( - NotImplemented, - builder.Filter(call("equal", {field_ref("not_a_column"), literal(10)}))); + Invalid, builder.Filter(call("equal", {field_ref("not_a_column"), literal(true)}))); ASSERT_RAISES( - NotImplemented, + Invalid, builder.Filter( call("or_kleene", { call("equal", {field_ref("i64"), literal(10)}), diff --git a/cpp/src/arrow/dataset/test_util.h b/cpp/src/arrow/dataset/test_util.h index ea39fef3eec..44d211260b6 100644 --- a/cpp/src/arrow/dataset/test_util.h +++ b/cpp/src/arrow/dataset/test_util.h @@ -53,9 +53,12 @@ namespace dataset { const std::shared_ptr kBoringSchema = schema({ field("i32", int32()), field("i32_req", int32(), /*nullable=*/false), + field("i64", int64()), field("f32", float32()), field("f32_req", float32(), /*nullable=*/false), field("bool", boolean()), + field("str", utf8()), + field("dict_str", dictionary(int32(), utf8())), }); inline namespace string_literals { From 0e6da5b89af75912bcb2cc41ad975ced098cf60c Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Thu, 10 Dec 2020 13:47:40 -0500 Subject: [PATCH 05/38] delete Expression DSL operators and old Expression class --- .../arrow/dataset-parquet-scan-example.cc | 2 +- cpp/src/arrow/dataset/CMakeLists.txt | 2 - cpp/src/arrow/dataset/api.h | 2 +- cpp/src/arrow/dataset/dataset.cc | 1 - cpp/src/arrow/dataset/dataset_test.cc | 10 +- cpp/src/arrow/dataset/discovery_test.cc | 1 - cpp/src/arrow/dataset/expression.cc | 147 +- cpp/src/arrow/dataset/expression.h | 11 - cpp/src/arrow/dataset/expression_test.cc | 115 +- cpp/src/arrow/dataset/file_base.cc | 1 - cpp/src/arrow/dataset/file_csv.cc | 1 - cpp/src/arrow/dataset/file_csv_test.cc | 3 +- cpp/src/arrow/dataset/file_ipc_test.cc | 1 - cpp/src/arrow/dataset/file_parquet.cc | 1 - cpp/src/arrow/dataset/file_parquet_test.cc | 58 +- cpp/src/arrow/dataset/file_test.cc | 32 +- cpp/src/arrow/dataset/filter.cc | 1724 +---------------- cpp/src/arrow/dataset/filter.h | 600 +----- cpp/src/arrow/dataset/filter_test.cc | 485 +---- cpp/src/arrow/dataset/partition.cc | 192 +- cpp/src/arrow/dataset/partition.h | 17 + cpp/src/arrow/dataset/partition_test.cc | 159 +- cpp/src/arrow/dataset/scanner.cc | 1 - cpp/src/arrow/dataset/scanner_internal.h | 1 - cpp/src/arrow/dataset/test_util.h | 31 +- cpp/src/arrow/dataset/type_fwd.h | 1 - 26 files changed, 440 insertions(+), 3159 deletions(-) diff --git a/cpp/examples/arrow/dataset-parquet-scan-example.cc b/cpp/examples/arrow/dataset-parquet-scan-example.cc index 5b933d3ca62..46778cebaa5 100644 --- a/cpp/examples/arrow/dataset-parquet-scan-example.cc +++ b/cpp/examples/arrow/dataset-parquet-scan-example.cc @@ -18,9 +18,9 @@ #include #include #include +#include #include #include -#include #include #include #include diff --git a/cpp/src/arrow/dataset/CMakeLists.txt b/cpp/src/arrow/dataset/CMakeLists.txt index ae6bb3656dc..aa9cde2dbe9 100644 --- a/cpp/src/arrow/dataset/CMakeLists.txt +++ b/cpp/src/arrow/dataset/CMakeLists.txt @@ -25,7 +25,6 @@ set(ARROW_DATASET_SRCS expression.cc file_base.cc file_ipc.cc - filter.cc partition.cc projector.cc scanner.cc) @@ -110,7 +109,6 @@ add_arrow_dataset_test(discovery_test) add_arrow_dataset_test(expression_test) add_arrow_dataset_test(file_ipc_test) add_arrow_dataset_test(file_test) -add_arrow_dataset_test(filter_test) add_arrow_dataset_test(partition_test) add_arrow_dataset_test(scanner_test) diff --git a/cpp/src/arrow/dataset/api.h b/cpp/src/arrow/dataset/api.h index c4c90088e8c..da9f5ed371e 100644 --- a/cpp/src/arrow/dataset/api.h +++ b/cpp/src/arrow/dataset/api.h @@ -21,9 +21,9 @@ #include "arrow/dataset/dataset.h" #include "arrow/dataset/discovery.h" +#include "arrow/dataset/expression.h" #include "arrow/dataset/file_base.h" #include "arrow/dataset/file_csv.h" #include "arrow/dataset/file_ipc.h" #include "arrow/dataset/file_parquet.h" -#include "arrow/dataset/filter.h" #include "arrow/dataset/scanner.h" diff --git a/cpp/src/arrow/dataset/dataset.cc b/cpp/src/arrow/dataset/dataset.cc index 5da90e48ebc..bedecf1e3a5 100644 --- a/cpp/src/arrow/dataset/dataset.cc +++ b/cpp/src/arrow/dataset/dataset.cc @@ -21,7 +21,6 @@ #include #include "arrow/dataset/dataset_internal.h" -#include "arrow/dataset/filter.h" #include "arrow/dataset/scanner.h" #include "arrow/table.h" #include "arrow/util/bit_util.h" diff --git a/cpp/src/arrow/dataset/dataset_test.cc b/cpp/src/arrow/dataset/dataset_test.cc index cf1e85ea055..82a5a63c2c2 100644 --- a/cpp/src/arrow/dataset/dataset_test.cc +++ b/cpp/src/arrow/dataset/dataset_test.cc @@ -702,8 +702,10 @@ TEST_F(TestSchemaUnification, SelectPhysicalColumnsFilterPartitionColumn) { // when some of the columns may not be materialized ASSERT_OK_AND_ASSIGN(auto scan_builder, dataset_->NewScan()); ASSERT_OK(scan_builder->Project({"phy_2", "phy_3", "phy_4"})); - ASSERT_OK(scan_builder->Filter(("part_df"_ == 1 && "phy_2"_ == 211) || - ("part_ds"_ == 2 && "phy_4"_ != 422))); + ASSERT_OK(scan_builder->Filter(or_(and_(equal(field_ref("part_df"), literal(1)), + equal(field_ref("phy_2"), literal(211))), + and_(equal(field_ref("part_ds"), literal(2)), + not_equal(field_ref("phy_4"), literal(422)))))); using TupleType = std::tuple; std::vector rows = { @@ -734,7 +736,7 @@ TEST_F(TestSchemaUnification, SelectPartitionColumns) { TEST_F(TestSchemaUnification, SelectPartitionColumnsFilterPhysicalColumn) { // Selects re-ordered virtual columns with a filter on a physical columns ASSERT_OK_AND_ASSIGN(auto scan_builder, dataset_->NewScan()); - ASSERT_OK(scan_builder->Filter("phy_1"_ == 111)); + ASSERT_OK(scan_builder->Filter(equal(field_ref("phy_1"), literal(111)))); ASSERT_OK(scan_builder->Project({"part_df", "part_ds"})); using TupleType = std::tuple; @@ -748,7 +750,7 @@ TEST_F(TestSchemaUnification, SelectMixedColumnsAndFilter) { // Selects mix of physical/virtual with a different order and uses a filter on // a physical column not selected. ASSERT_OK_AND_ASSIGN(auto scan_builder, dataset_->NewScan()); - ASSERT_OK(scan_builder->Filter("phy_2"_ >= 212)); + ASSERT_OK(scan_builder->Filter(greater_equal(field_ref("phy_2"), literal(212)))); ASSERT_OK(scan_builder->Project({"part_df", "phy_3", "part_ds", "phy_1"})); using TupleType = std::tuple; diff --git a/cpp/src/arrow/dataset/discovery_test.cc b/cpp/src/arrow/dataset/discovery_test.cc index 93d7c3869dd..a951e827fa4 100644 --- a/cpp/src/arrow/dataset/discovery_test.cc +++ b/cpp/src/arrow/dataset/discovery_test.cc @@ -23,7 +23,6 @@ #include #include -#include "arrow/dataset/filter.h" #include "arrow/dataset/partition.h" #include "arrow/dataset/test_util.h" #include "arrow/filesystem/test_util.h" diff --git a/cpp/src/arrow/dataset/expression.cc b/cpp/src/arrow/dataset/expression.cc index 21c067e3d64..f1b17752bf2 100644 --- a/cpp/src/arrow/dataset/expression.cc +++ b/cpp/src/arrow/dataset/expression.cc @@ -23,7 +23,6 @@ #include "arrow/chunked_array.h" #include "arrow/compute/exec_internal.h" #include "arrow/dataset/expression_internal.h" -#include "arrow/dataset/filter.h" #include "arrow/io/memory.h" #include "arrow/ipc/reader.h" #include "arrow/ipc/writer.h" @@ -50,144 +49,6 @@ const FieldRef* Expression2::field_ref() const { return util::get_if(impl_.get()); } -Expression2::operator std::shared_ptr() const { - if (auto lit = literal()) { - DCHECK(lit->is_scalar()); - return std::make_shared(lit->scalar()); - } - - if (auto ref = field_ref()) { - DCHECK(ref->name()); - return std::make_shared(*ref->name()); - } - - auto call = CallNotNull(*this); - if (call->function == "invert") { - return std::make_shared(call->arguments[0]); - } - - if (call->function == "cast") { - const auto& options = checked_cast(*call->options); - return std::make_shared(call->arguments[0], options.to_type, options); - } - - if (call->function == "and_kleene") { - return std::make_shared(call->arguments[0], call->arguments[1]); - } - - if (call->function == "or_kleene") { - return std::make_shared(call->arguments[0], call->arguments[1]); - } - - if (auto cmp = Comparison::Get(call->function)) { - compute::CompareOperator op = [&] { - switch (*cmp) { - case Comparison::EQUAL: - return compute::EQUAL; - case Comparison::LESS: - return compute::LESS; - case Comparison::GREATER: - return compute::GREATER; - case Comparison::NOT_EQUAL: - return compute::NOT_EQUAL; - case Comparison::LESS_EQUAL: - return compute::LESS_EQUAL; - case Comparison::GREATER_EQUAL: - return compute::GREATER_EQUAL; - default: - break; - } - return static_cast(-1); - }(); - - return std::make_shared(op, call->arguments[0], - call->arguments[1]); - } - - if (call->function == "is_valid") { - return std::make_shared(call->arguments[0]); - } - - if (call->function == "is_in") { - auto set = checked_cast(*call->options) - .value_set.make_array(); - return std::make_shared(call->arguments[0], std::move(set)); - } - - DCHECK(false) << "untranslatable Expression2: " << ToString(); - return nullptr; -} - -Expression2::Expression2(const Expression& expr) { - switch (expr.type()) { - case ExpressionType::FIELD: - *this = - ::arrow::dataset::field_ref(checked_cast(expr).name()); - return; - - case ExpressionType::SCALAR: - *this = - ::arrow::dataset::literal(checked_cast(expr).value()); - return; - - case ExpressionType::NOT: - *this = ::arrow::dataset::call( - "invert", {checked_cast(expr).operand()}); - return; - - case ExpressionType::CAST: { - const auto& cast_expr = checked_cast(expr); - auto options = cast_expr.options(); - options.to_type = cast_expr.to_type(); - *this = ::arrow::dataset::call("cast", {cast_expr.operand()}, std::move(options)); - return; - } - - case ExpressionType::AND: { - const auto& and_expr = checked_cast(expr); - *this = ::arrow::dataset::call("and_kleene", - {and_expr.left_operand(), and_expr.right_operand()}); - return; - } - - case ExpressionType::OR: { - const auto& or_expr = checked_cast(expr); - *this = ::arrow::dataset::call("or_kleene", - {or_expr.left_operand(), or_expr.right_operand()}); - return; - } - - case ExpressionType::COMPARISON: { - const auto& cmp_expr = checked_cast(expr); - static std::array ops = { - "equal", "not_equal", "greater", "greater_equal", "less", "less_equal", - }; - *this = ::arrow::dataset::call(ops[cmp_expr.op()], - {cmp_expr.left_operand(), cmp_expr.right_operand()}); - return; - } - - case ExpressionType::IS_VALID: { - const auto& is_valid_expr = checked_cast(expr); - *this = ::arrow::dataset::call("is_valid", {is_valid_expr.operand()}); - return; - } - - case ExpressionType::IN: { - const auto& in_expr = checked_cast(expr); - *this = ::arrow::dataset::call( - "is_in", {in_expr.operand()}, - compute::SetLookupOptions{in_expr.set(), /*skip_nulls=*/true}); - return; - } - - default: - break; - } - - DCHECK(false) << "untranslatable Expression: " << expr.ToString(); -} - std::string Expression2::ToString() const { if (auto lit = literal()) { if (lit->is_scalar()) { @@ -1032,6 +893,9 @@ Result SimplifyWithGuarantee(Expression2 expr, return expr; } +// Serialization is accomplished by converting expressions to KeyValueMetadata and storing +// this in the schema of a RecordBatch. Embedded arrays and scalars are stored in its +// columns. Finally, the RecordBatch is written to an IPC file. Result> Serialize(const Expression2& expr) { struct { std::shared_ptr metadata_ = std::make_shared(); @@ -1238,10 +1102,5 @@ Expression2 operator||(Expression2 lhs, Expression2 rhs) { return or_(std::move(lhs), std::move(rhs)); } -Result InsertImplicitCasts(Expression2 expr, const Schema& s) { - std::shared_ptr e(expr); - return InsertImplicitCasts(*e, s); -} - } // namespace dataset } // namespace arrow diff --git a/cpp/src/arrow/dataset/expression.h b/cpp/src/arrow/dataset/expression.h index 16c9e8a96f6..e6495433d23 100644 --- a/cpp/src/arrow/dataset/expression.h +++ b/cpp/src/arrow/dataset/expression.h @@ -103,12 +103,6 @@ class ARROW_DS_EXPORT Expression2 { const Datum* literal() const; const FieldRef* field_ref() const; - // FIXME remove these - operator std::shared_ptr() const; // NOLINT runtime/explicit - Expression2(const Expression& expr); // NOLINT runtime/explicit - Expression2(std::shared_ptr expr) // NOLINT runtime/explicit - : Expression2(*expr) {} - const ValueDescr& descr() const { return descr_; } using Impl = util::Variant; @@ -245,10 +239,5 @@ ARROW_DS_EXPORT Expression2 or_(Expression2 lhs, Expression2 rhs); ARROW_DS_EXPORT Expression2 or_(const std::vector&); ARROW_DS_EXPORT Expression2 not_(Expression2 operand); -// FIXME remove these -ARROW_DS_EXPORT Expression2 operator&&(Expression2 lhs, Expression2 rhs); -ARROW_DS_EXPORT Expression2 operator||(Expression2 lhs, Expression2 rhs); -ARROW_DS_EXPORT Result InsertImplicitCasts(Expression2, const Schema&); - } // namespace dataset } // namespace arrow diff --git a/cpp/src/arrow/dataset/expression_test.cc b/cpp/src/arrow/dataset/expression_test.cc index 44dd322aa63..f2b6951a8a3 100644 --- a/cpp/src/arrow/dataset/expression_test.cc +++ b/cpp/src/arrow/dataset/expression_test.cc @@ -216,46 +216,42 @@ TEST(Expression2, BindLiteral) { } } +void ExpectBindsTo(Expression2 expr, Expression2 expected, + Expression2* bound_out = nullptr) { + ASSERT_OK_AND_ASSIGN(auto bound, expr.Bind(*kBoringSchema)); + EXPECT_TRUE(bound.IsBound()); + + ASSERT_OK_AND_ASSIGN(expected, expected.Bind(*kBoringSchema)); + EXPECT_EQ(bound, expected) << " unbound: " << expr.ToString(); + + if (bound_out) { + *bound_out = bound; + } +} + TEST(Expression2, BindFieldRef) { // an unbound field_ref does not have the output ValueDescr set - { - auto expr = field_ref("alpha"); - EXPECT_EQ(expr.descr(), ValueDescr{}); - EXPECT_FALSE(expr.IsBound()); - } + auto expr = field_ref("alpha"); + EXPECT_EQ(expr.descr(), ValueDescr{}); + EXPECT_FALSE(expr.IsBound()); - { - auto expr = field_ref("alpha"); - // binding a field_ref looks up that field's type in the input Schema - ASSERT_OK_AND_ASSIGN(expr, expr.Bind(Schema({field("alpha", int32())}))); - EXPECT_EQ(expr.descr(), ValueDescr::Array(int32())); - EXPECT_TRUE(expr.IsBound()); - } + ExpectBindsTo(field_ref("i32"), field_ref("i32"), &expr); + EXPECT_EQ(expr.descr(), ValueDescr::Array(int32())); - { - // if the field is not found, a null scalar will be emitted - auto expr = field_ref("alpha"); - ASSERT_OK_AND_ASSIGN(expr, expr.Bind(Schema({}))); - EXPECT_EQ(expr.descr(), ValueDescr::Scalar(null())); - EXPECT_TRUE(expr.IsBound()); - } + // if the field is not found, a null scalar will be emitted + ExpectBindsTo(field_ref("no such field"), field_ref("no such field"), &expr); + EXPECT_EQ(expr.descr(), ValueDescr::Scalar(null())); - { - // referencing a field by name is not supported if that name is not unique - // in the input schema - auto expr = field_ref("alpha"); - ASSERT_RAISES( - Invalid, expr.Bind(Schema({field("alpha", int32()), field("alpha", float32())}))); - } + // referencing a field by name is not supported if that name is not unique + // in the input schema + ASSERT_RAISES(Invalid, field_ref("alpha").Bind(Schema( + {field("alpha", int32()), field("alpha", float32())}))); - { - // referencing nested fields is supported - auto expr = field_ref("a", "b"); - ASSERT_OK_AND_ASSIGN(expr, - expr.Bind(Schema({field("a", struct_({field("b", int32())}))}))); - EXPECT_EQ(expr.descr(), ValueDescr::Array(int32())); - EXPECT_TRUE(expr.IsBound()); - } + // referencing nested fields is supported + ASSERT_OK_AND_ASSIGN(expr, field_ref("a", "b").Bind( + Schema({field("a", struct_({field("b", int32())}))}))); + EXPECT_TRUE(expr.IsBound()); + EXPECT_EQ(expr.descr(), ValueDescr::Array(int32())); } TEST(Expression2, BindCall) { @@ -275,53 +271,32 @@ TEST(Expression2, BindCall) { TEST(Expression2, BindWithImplicitCasts) { for (auto cmp : {equal, not_equal, less, less_equal, greater, greater_equal}) { // cast arguments to same type - ASSERT_OK_AND_ASSIGN(auto expr, - cmp(field_ref("i32"), field_ref("i64")).Bind(*kBoringSchema)); - + ExpectBindsTo(cmp(field_ref("i32"), field_ref("i64")), + cmp(field_ref("i32"), cast(field_ref("i64"), int32()))); // NB: RHS is cast unless LHS is scalar. - ASSERT_OK_AND_ASSIGN( - auto expected, - cmp(field_ref("i32"), cast(field_ref("i64"), int32())).Bind(*kBoringSchema)); - - EXPECT_EQ(expr, expected); // cast dictionary to value type - ASSERT_OK_AND_ASSIGN( - expr, cmp(field_ref("dict_str"), field_ref("str")).Bind(*kBoringSchema)); - - ASSERT_OK_AND_ASSIGN( - expected, - cmp(cast(field_ref("dict_str"), utf8()), field_ref("str")).Bind(*kBoringSchema)); - - EXPECT_EQ(expr, expected); + ExpectBindsTo(cmp(field_ref("dict_str"), field_ref("str")), + cmp(cast(field_ref("dict_str"), utf8()), field_ref("str"))); } + // scalars are directly cast when possible: + ExpectBindsTo( + equal(field_ref("ts_ns"), literal("1990-10-23 10:23:33")), + equal(field_ref("ts_ns"), + literal( + *MakeScalar("1990-10-23 10:23:33")->CastTo(timestamp(TimeUnit::NANO))))); + // cast value_set to argument type auto Opts = [](std::shared_ptr type) { return compute::SetLookupOptions{ArrayFromJSON(type, R"(["a"])")}; }; - ASSERT_OK_AND_ASSIGN( - auto expr, call("is_in", {field_ref("str")}, Opts(binary())).Bind(*kBoringSchema)); - ASSERT_OK_AND_ASSIGN( - auto expected, - call("is_in", {field_ref("str")}, Opts(utf8())).Bind(*kBoringSchema)); - - EXPECT_EQ(expr, expected); + ExpectBindsTo(call("is_in", {field_ref("str")}, Opts(binary())), + call("is_in", {field_ref("str")}, Opts(utf8()))); // dictionary decode set then cast to argument type - ASSERT_OK_AND_ASSIGN( - expr, call("is_in", {field_ref("str")}, Opts(dictionary(int32(), binary()))) - .Bind(*kBoringSchema)); - - EXPECT_EQ(expr, expected); -} - -TEST(Expression2, BindDictionaryTransparent) { - ASSERT_OK_AND_ASSIGN( - auto expr, equal(field_ref("str"), field_ref("dict_str")).Bind(*kBoringSchema)); - - EXPECT_EQ(expr.descr(), ValueDescr::Array(boolean())); - EXPECT_TRUE(expr.IsBound()); + ExpectBindsTo(call("is_in", {field_ref("str")}, Opts(dictionary(int32(), binary()))), + call("is_in", {field_ref("str")}, Opts(utf8()))); } TEST(Expression2, BindNestedCall) { diff --git a/cpp/src/arrow/dataset/file_base.cc b/cpp/src/arrow/dataset/file_base.cc index dfc17a25812..86a262a662b 100644 --- a/cpp/src/arrow/dataset/file_base.cc +++ b/cpp/src/arrow/dataset/file_base.cc @@ -24,7 +24,6 @@ #include #include "arrow/dataset/dataset_internal.h" -#include "arrow/dataset/filter.h" #include "arrow/dataset/scanner.h" #include "arrow/dataset/scanner_internal.h" #include "arrow/filesystem/filesystem.h" diff --git a/cpp/src/arrow/dataset/file_csv.cc b/cpp/src/arrow/dataset/file_csv.cc index 6f1ca9acbbb..bc1a69066f7 100644 --- a/cpp/src/arrow/dataset/file_csv.cc +++ b/cpp/src/arrow/dataset/file_csv.cc @@ -28,7 +28,6 @@ #include "arrow/csv/reader.h" #include "arrow/dataset/dataset_internal.h" #include "arrow/dataset/file_base.h" -#include "arrow/dataset/filter.h" #include "arrow/dataset/type_fwd.h" #include "arrow/dataset/visibility.h" #include "arrow/result.h" diff --git a/cpp/src/arrow/dataset/file_csv_test.cc b/cpp/src/arrow/dataset/file_csv_test.cc index 8e73be0ee8f..eb0e1bf9395 100644 --- a/cpp/src/arrow/dataset/file_csv_test.cc +++ b/cpp/src/arrow/dataset/file_csv_test.cc @@ -23,7 +23,6 @@ #include "arrow/dataset/dataset_internal.h" #include "arrow/dataset/file_base.h" -#include "arrow/dataset/filter.h" #include "arrow/dataset/partition.h" #include "arrow/dataset/test_util.h" #include "arrow/io/memory.h" @@ -148,7 +147,7 @@ N/A,bar schema_ = schema({field("f64", utf8()), field("str", utf8())}); ScannerBuilder builder(schema_, fragment, ctx_); // filter expression validated against declared schema - ASSERT_OK(builder.Filter("f64"_ == "str"_)); + ASSERT_OK(builder.Filter(equal(field_ref("f64"), field_ref("str")))); // project only "str" ASSERT_OK(builder.Project({"str"})); ASSERT_OK_AND_ASSIGN(auto scanner, builder.Finish()); diff --git a/cpp/src/arrow/dataset/file_ipc_test.cc b/cpp/src/arrow/dataset/file_ipc_test.cc index 0fb76803778..f49337b362e 100644 --- a/cpp/src/arrow/dataset/file_ipc_test.cc +++ b/cpp/src/arrow/dataset/file_ipc_test.cc @@ -24,7 +24,6 @@ #include "arrow/dataset/dataset_internal.h" #include "arrow/dataset/discovery.h" #include "arrow/dataset/file_base.h" -#include "arrow/dataset/filter.h" #include "arrow/dataset/partition.h" #include "arrow/dataset/test_util.h" #include "arrow/io/memory.h" diff --git a/cpp/src/arrow/dataset/file_parquet.cc b/cpp/src/arrow/dataset/file_parquet.cc index 10c5b4a4bcf..8ba5f67f6fc 100644 --- a/cpp/src/arrow/dataset/file_parquet.cc +++ b/cpp/src/arrow/dataset/file_parquet.cc @@ -24,7 +24,6 @@ #include #include "arrow/dataset/dataset_internal.h" -#include "arrow/dataset/filter.h" #include "arrow/dataset/scanner.h" #include "arrow/filesystem/path_util.h" #include "arrow/table.h" diff --git a/cpp/src/arrow/dataset/file_parquet_test.cc b/cpp/src/arrow/dataset/file_parquet_test.cc index 436826c2971..47a563b911f 100644 --- a/cpp/src/arrow/dataset/file_parquet_test.cc +++ b/cpp/src/arrow/dataset/file_parquet_test.cc @@ -22,7 +22,6 @@ #include #include "arrow/dataset/dataset_internal.h" -#include "arrow/dataset/filter.h" #include "arrow/dataset/test_util.h" #include "arrow/record_batch.h" #include "arrow/table.h" @@ -283,7 +282,7 @@ TEST_F(TestParquetFileFormat, ScanRecordBatchReaderProjected) { opts_ = ScanOptions::Make(schema_); opts_->projector = RecordBatchProjector(SchemaFromColumnNames(schema_, {"f64"})); - SetFilter(equal(field_ref("i32"), scalar(0))); + SetFilter(equal(field_ref("i32"), literal(0))); // NB: projector is applied by the scanner; FileFragment does not evaluate it so // we will not drop "i32" even though it is not in the projector's schema @@ -319,7 +318,7 @@ TEST_F(TestParquetFileFormat, ScanRecordBatchReaderProjectedMissingCols) { schema_ = reader->schema(); opts_ = ScanOptions::Make(schema_); opts_->projector = RecordBatchProjector(SchemaFromColumnNames(schema_, {"f64"})); - SetFilter(equal(field_ref("i32"), scalar(0))); + SetFilter(equal(field_ref("i32"), literal(0))); auto readers = {reader.get(), reader_without_i32.get(), reader_without_f64.get()}; for (auto reader : readers) { @@ -419,28 +418,30 @@ TEST_F(TestParquetFileFormat, PredicatePushdown) { CountRowsAndBatchesInScan(fragment, kTotalNumRows, kNumRowGroups); for (int64_t i = 1; i <= kNumRowGroups; i++) { - SetFilter(("i64"_ == int64_t(i))); + SetFilter(equal(field_ref("i64"), literal(i))); CountRowsAndBatchesInScan(fragment, i, 1); } // Out of bound filters should skip all RowGroups. SetFilter(literal(false)); CountRowsAndBatchesInScan(fragment, 0, 0); - SetFilter(("i64"_ == int64_t(kNumRowGroups + 1))); + SetFilter(equal(field_ref("i64"), literal(kNumRowGroups + 1))); CountRowsAndBatchesInScan(fragment, 0, 0); - SetFilter(("i64"_ == int64_t(-1))); + SetFilter(equal(field_ref("i64"), literal(-1))); CountRowsAndBatchesInScan(fragment, 0, 0); // No rows match 1 and 2. - SetFilter(("i64"_ == int64_t(1) and "u8"_ == uint8_t(2))); + SetFilter(and_(equal(field_ref("i64"), literal(1)), + equal(field_ref("u8"), literal(2)))); CountRowsAndBatchesInScan(fragment, 0, 0); - SetFilter(("i64"_ == int64_t(2) or "i64"_ == int64_t(4))); + SetFilter(or_(equal(field_ref("i64"), literal(2)), + equal(field_ref("i64"), literal(4)))); CountRowsAndBatchesInScan(fragment, 2 + 4, 2); - SetFilter(("i64"_ < int64_t(6))); + SetFilter(less(field_ref("i64"), literal(6))); CountRowsAndBatchesInScan(fragment, 5 * (5 + 1) / 2, 5); - SetFilter(("i64"_ >= int64_t(6))); + SetFilter(greater_equal(field_ref("i64"), literal(6))); CountRowsAndBatchesInScan(fragment, kTotalNumRows - (5 * (5 + 1) / 2), kNumRowGroups - 5); } @@ -461,32 +462,39 @@ TEST_F(TestParquetFileFormat, PredicatePushdownRowGroupFragments) { // CountRowGroupsInFragment(fragment, all_row_groups, "not here"_ == 0); for (int i = 0; i < kNumRowGroups; ++i) { - CountRowGroupsInFragment(fragment, {i}, "i64"_ == int64_t(i + 1)); + CountRowGroupsInFragment(fragment, {i}, equal(field_ref("i64"), literal(i + 1))); } // Out of bound filters should skip all RowGroups. CountRowGroupsInFragment(fragment, {}, literal(false)); - CountRowGroupsInFragment(fragment, {}, "i64"_ == int64_t(kNumRowGroups + 1)); - CountRowGroupsInFragment(fragment, {}, "i64"_ == int64_t(-1)); + CountRowGroupsInFragment(fragment, {}, + equal(field_ref("i64"), literal(kNumRowGroups + 1))); + CountRowGroupsInFragment(fragment, {}, equal(field_ref("i64"), literal(-1))); // No rows match 1 and 2. - CountRowGroupsInFragment(fragment, {}, "i64"_ == int64_t(1) and "u8"_ == uint8_t(2)); - CountRowGroupsInFragment(fragment, {}, "i64"_ == int64_t(2) and "i64"_ == int64_t(4)); + CountRowGroupsInFragment( + fragment, {}, + and_(equal(field_ref("i64"), literal(1)), equal(field_ref("u8"), literal(2)))); + CountRowGroupsInFragment( + fragment, {}, + and_(equal(field_ref("i64"), literal(2)), equal(field_ref("i64"), literal(4)))); - CountRowGroupsInFragment(fragment, {1, 3}, - "i64"_ == int64_t(2) or "i64"_ == int64_t(4)); + CountRowGroupsInFragment( + fragment, {1, 3}, + or_(equal(field_ref("i64"), literal(2)), equal(field_ref("i64"), literal(4)))); // TODO(bkietz): better Assume support for InExpression // auto set = ArrayFromJSON(int64(), "[2, 4]"); - // CountRowGroupsInFragment(fragment, {1, 3}, "i64"_.In(set)); + // CountRowGroupsInFragment(fragment, {1, 3}, field_ref("i64").In(set)); - CountRowGroupsInFragment(fragment, {0, 1, 2, 3, 4}, "i64"_ < int64_t(6)); + CountRowGroupsInFragment(fragment, {0, 1, 2, 3, 4}, less(field_ref("i64"), literal(6))); CountRowGroupsInFragment(fragment, internal::Iota(5, static_cast(kNumRowGroups)), - "i64"_ >= int64_t(6)); + greater_equal(field_ref("i64"), literal(6))); CountRowGroupsInFragment(fragment, {5, 6}, - "i64"_ >= int64_t(6) and "i64"_ < int64_t(8)); + and_(greater_equal(field_ref("i64"), literal(6)), + less(field_ref("i64"), literal(8)))); } TEST_F(TestParquetFileFormat, PredicatePushdownRowGroupFragmentsUsingStringColumn) { @@ -503,7 +511,7 @@ TEST_F(TestParquetFileFormat, PredicatePushdownRowGroupFragmentsUsingStringColum opts_ = ScanOptions::Make(reader.schema()); ASSERT_OK_AND_ASSIGN(auto fragment, format_->MakeFragment(*source)); - CountRowGroupsInFragment(fragment, {0, 3}, "x"_ == "a"); + CountRowGroupsInFragment(fragment, {0, 3}, equal(field_ref("x"), literal("a"))); } TEST_F(TestParquetFileFormat, ExplicitRowGroupSelection) { @@ -545,17 +553,17 @@ TEST_F(TestParquetFileFormat, ExplicitRowGroupSelection) { for (int i = 0; i < kNumRowGroups; ++i) { // conflicting selection/filter - SetFilter(("i64"_ == int64_t(i))); + SetFilter(equal(field_ref("i64"), literal(i))); CountRowsAndBatchesInScan(row_groups_fragment({i}), 0, 0); } for (int i = 0; i < kNumRowGroups; ++i) { // identical selection/filter - SetFilter(("i64"_ == int64_t(i + 1))); + SetFilter(equal(field_ref("i64"), literal(i + 1))); CountRowsAndBatchesInScan(row_groups_fragment({i}), i + 1, 1); } - SetFilter(("i64"_ > int64_t(3))); + SetFilter(greater(field_ref("i64"), literal(3))); CountRowsAndBatchesInScan(row_groups_fragment({2, 3, 4, 5}), 4 + 5 + 6, 3); EXPECT_RAISES_WITH_MESSAGE_THAT( diff --git a/cpp/src/arrow/dataset/file_test.cc b/cpp/src/arrow/dataset/file_test.cc index c390f777ad4..fefd6911ac0 100644 --- a/cpp/src/arrow/dataset/file_test.cc +++ b/cpp/src/arrow/dataset/file_test.cc @@ -159,19 +159,19 @@ TEST_F(TestFileSystemDataset, TreePartitionPruning) { std::vector partitions = { equal(field_ref("state"), literal("NY")), - equal(field_ref("state"), literal("NY")) and - equal(field_ref("city"), literal("New York")), + and_(equal(field_ref("state"), literal("NY")), + equal(field_ref("city"), literal("New York"))), - equal(field_ref("state"), literal("NY")) and - equal(field_ref("city"), literal("Franklin")), + and_(equal(field_ref("state"), literal("NY")), + equal(field_ref("city"), literal("Franklin"))), equal(field_ref("state"), literal("CA")), - equal(field_ref("state"), literal("CA")) and - equal(field_ref("city"), literal("San Francisco")), + and_(equal(field_ref("state"), literal("CA")), + equal(field_ref("city"), literal("San Francisco"))), - equal(field_ref("state"), literal("CA")) and - equal(field_ref("city"), literal("Franklin")), + and_(equal(field_ref("state"), literal("CA")), + equal(field_ref("city"), literal("Franklin"))), }; MakeDataset( @@ -215,19 +215,19 @@ TEST_F(TestFileSystemDataset, FragmentPartitions) { std::vector partitions = { equal(field_ref("state"), literal("NY")), - equal(field_ref("state"), literal("NY")) and - equal(field_ref("city"), literal("New York")), + and_(equal(field_ref("state"), literal("NY")), + equal(field_ref("city"), literal("New York"))), - equal(field_ref("state"), literal("NY")) and - equal(field_ref("city"), literal("Franklin")), + and_(equal(field_ref("state"), literal("NY")), + equal(field_ref("city"), literal("Franklin"))), equal(field_ref("state"), literal("CA")), - equal(field_ref("state"), literal("CA")) and - equal(field_ref("city"), literal("San Francisco")), + and_(equal(field_ref("state"), literal("CA")), + equal(field_ref("city"), literal("San Francisco"))), - equal(field_ref("state"), literal("CA")) and - equal(field_ref("city"), literal("Franklin")), + and_(equal(field_ref("state"), literal("CA")), + equal(field_ref("city"), literal("Franklin"))), }; MakeDataset( diff --git a/cpp/src/arrow/dataset/filter.cc b/cpp/src/arrow/dataset/filter.cc index 71253cfde89..9852f8c0808 100644 --- a/cpp/src/arrow/dataset/filter.cc +++ b/cpp/src/arrow/dataset/filter.cc @@ -15,1726 +15,4 @@ // specific language governing permissions and limitations // under the License. -#include "arrow/dataset/filter.h" - -#include -#include -#include -#include -#include -#include -#include - -#include "arrow/array/builder_primitive.h" -#include "arrow/buffer.h" -#include "arrow/compute/api.h" -#include "arrow/dataset/dataset.h" -#include "arrow/dataset/expression.h" -#include "arrow/io/memory.h" -#include "arrow/ipc/reader.h" -#include "arrow/ipc/writer.h" -#include "arrow/record_batch.h" -#include "arrow/result.h" -#include "arrow/scalar.h" -#include "arrow/type_fwd.h" -#include "arrow/util/checked_cast.h" -#include "arrow/util/int_util_internal.h" -#include "arrow/util/iterator.h" -#include "arrow/util/logging.h" -#include "arrow/util/string.h" -#include "arrow/visitor_inline.h" - -namespace arrow { - -using compute::CompareOperator; -using compute::ExecContext; - -namespace dataset { - -using arrow::internal::checked_cast; -using arrow::internal::checked_pointer_cast; - -inline std::shared_ptr NullExpression() { - return std::make_shared(std::make_shared()); -} - -inline Datum NullDatum() { return Datum(std::make_shared()); } - -bool IsNullDatum(const Datum& datum) { - if (datum.is_scalar()) { - auto scalar = datum.scalar(); - return !scalar->is_valid; - } - - auto array_data = datum.array(); - return array_data->GetNullCount() == array_data->length; -} - -struct Comparison { - enum type { - LESS, - EQUAL, - GREATER, - NULL_, - }; -}; - -Result> EnsureNotDictionary( - const std::shared_ptr& scalar) { - if (scalar->type->id() == Type::DICTIONARY) { - return checked_cast(*scalar).GetEncodedValue(); - } - return scalar; -} - -Result Compare(const Scalar& lhs, const Scalar& rhs); - -struct CompareVisitor { - template - using ScalarType = typename TypeTraits::ScalarType; - - Status Visit(const NullType&) { - result_ = Comparison::NULL_; - return Status::OK(); - } - - Status Visit(const BooleanType&) { return CompareValues(); } - - template - enable_if_physical_floating_point Visit(const T&) { - return CompareValues(); - } - - template - enable_if_physical_signed_integer Visit(const T&) { - return CompareValues(); - } - - template - enable_if_physical_unsigned_integer Visit(const T&) { - return CompareValues(); - } - - template - enable_if_nested Visit(const T&) { - return Status::NotImplemented("comparison of scalars of type ", *lhs_.type); - } - - template - enable_if_binary_like Visit(const T&) { - auto lhs = checked_cast&>(lhs_).value; - auto rhs = checked_cast&>(rhs_).value; - auto cmp = std::memcmp(lhs->data(), rhs->data(), std::min(lhs->size(), rhs->size())); - if (cmp == 0) { - return CompareValues(lhs->size(), rhs->size()); - } - return CompareValues(cmp, 0); - } - - template - enable_if_string_like Visit(const T&) { - auto lhs = checked_cast&>(lhs_).value; - auto rhs = checked_cast&>(rhs_).value; - auto cmp = std::memcmp(lhs->data(), rhs->data(), std::min(lhs->size(), rhs->size())); - if (cmp == 0) { - return CompareValues(lhs->size(), rhs->size()); - } - return CompareValues(cmp, 0); - } - - Status Visit(const Decimal128Type&) { return CompareValues(); } - Status Visit(const Decimal256Type&) { return CompareValues(); } - - // Explicit because it falls under `physical_unsigned_integer`. - // TODO(bkietz) whenever we vendor a float16, this can be implemented - Status Visit(const HalfFloatType&) { - return Status::NotImplemented("comparison of scalars of type ", *lhs_.type); - } - - Status Visit(const ExtensionType&) { - return Status::NotImplemented("comparison of scalars of type ", *lhs_.type); - } - - Status Visit(const DictionaryType&) { - return Status::NotImplemented("comparison of scalars of type ", *lhs_.type); - } - - // defer comparison to ScalarType::value - template - Status CompareValues() { - auto lhs = checked_cast&>(lhs_).value; - auto rhs = checked_cast&>(rhs_).value; - return CompareValues(lhs, rhs); - } - - // defer comparison to explicit values - template - Status CompareValues(Value lhs, Value rhs) { - result_ = lhs < rhs ? Comparison::LESS - : lhs == rhs ? Comparison::EQUAL : Comparison::GREATER; - return Status::OK(); - } - - Comparison::type result_; - const Scalar& lhs_; - const Scalar& rhs_; -}; - -// Compare two scalars -// if either is null, return is null -// TODO(bkietz) extract this to the scalar comparison kernels -Result Compare(const Scalar& lhs, const Scalar& rhs) { - if (!lhs.type->Equals(*rhs.type)) { - return Status::TypeError("Cannot compare scalars of differing type: ", *lhs.type, - " vs ", *rhs.type); - } - if (!lhs.is_valid || !rhs.is_valid) { - return Comparison::NULL_; - } - CompareVisitor vis{Comparison::NULL_, lhs, rhs}; - RETURN_NOT_OK(VisitTypeInline(*lhs.type, &vis)); - return vis.result_; -} - -CompareOperator InvertCompareOperator(CompareOperator op) { - switch (op) { - case CompareOperator::EQUAL: - return CompareOperator::NOT_EQUAL; - - case CompareOperator::NOT_EQUAL: - return CompareOperator::EQUAL; - - case CompareOperator::GREATER: - return CompareOperator::LESS_EQUAL; - - case CompareOperator::GREATER_EQUAL: - return CompareOperator::LESS; - - case CompareOperator::LESS: - return CompareOperator::GREATER_EQUAL; - - case CompareOperator::LESS_EQUAL: - return CompareOperator::GREATER; - - default: - break; - } - - DCHECK(false); - return CompareOperator::EQUAL; -} - -template -std::shared_ptr InvertBoolean(const Boolean& expr) { - auto lhs = Invert(*expr.left_operand()); - auto rhs = Invert(*expr.right_operand()); - - if (std::is_same::value) { - return std::make_shared(std::move(lhs), std::move(rhs)); - } - - if (std::is_same::value) { - return std::make_shared(std::move(lhs), std::move(rhs)); - } - - return nullptr; -} - -std::shared_ptr Invert(const Expression& expr) { - switch (expr.type()) { - case ExpressionType::NOT: - return checked_cast(expr).operand(); - - case ExpressionType::AND: - return InvertBoolean(checked_cast(expr)); - - case ExpressionType::OR: - return InvertBoolean(checked_cast(expr)); - - case ExpressionType::COMPARISON: { - const auto& comparison = checked_cast(expr); - auto inverted_op = InvertCompareOperator(comparison.op()); - return std::make_shared( - inverted_op, comparison.left_operand(), comparison.right_operand()); - } - - default: - break; - } - return nullptr; -} - -std::shared_ptr Expression::Assume(const Expression& given) const { - std::shared_ptr out; - - DCHECK_OK(VisitConjunctionMembers(given, [&](const Expression& given) { - if (out != nullptr) { - return Status::OK(); - } - - if (given.type() != ExpressionType::COMPARISON) { - return Status::OK(); - } - - const auto& given_cmp = checked_cast(given); - if (given_cmp.op() != CompareOperator::EQUAL) { - return Status::OK(); - } - - if (this->Equals(given_cmp.left_operand())) { - out = given_cmp.right_operand(); - return Status::OK(); - } - - if (this->Equals(given_cmp.right_operand())) { - out = given_cmp.left_operand(); - return Status::OK(); - } - - return Status::OK(); - })); - - return out ? out : Copy(); -} - -std::shared_ptr ComparisonExpression::Assume(const Expression& given) const { - switch (given.type()) { - case ExpressionType::COMPARISON: { - return AssumeGivenComparison(checked_cast(given)); - } - - case ExpressionType::NOT: { - const auto& given_not = checked_cast(given); - if (auto inverted = Invert(*given_not.operand())) { - return Assume(*inverted); - } - return Copy(); - } - - case ExpressionType::OR: { - const auto& given_or = checked_cast(given); - - auto left_simplified = Assume(*given_or.left_operand()); - auto right_simplified = Assume(*given_or.right_operand()); - - // The result of simplification against the operands of an OrExpression - // cannot be used unless they are identical - if (left_simplified->Equals(right_simplified)) { - return left_simplified; - } - - return Copy(); - } - - case ExpressionType::AND: { - const auto& given_and = checked_cast(given); - - auto simplified = Copy(); - simplified = simplified->Assume(*given_and.left_operand()); - simplified = simplified->Assume(*given_and.right_operand()); - return simplified; - } - - // TODO(bkietz) we should be able to use ExpressionType::IN here - - default: - break; - } - - return Copy(); -} - -// Try to simplify one comparison against another comparison. -// For example, -// ("x"_ > 3) is a subset of ("x"_ > 2), so ("x"_ > 2).Assume("x"_ > 3) == (true) -// ("x"_ < 0) is disjoint with ("x"_ > 2), so ("x"_ > 2).Assume("x"_ < 0) == (false) -// If simplification to (true) or (false) is not possible, pass e through unchanged. -std::shared_ptr ComparisonExpression::AssumeGivenComparison( - const ComparisonExpression& given) const { - if (!left_operand_->Equals(given.left_operand_)) { - return Copy(); - } - - for (auto rhs : {right_operand_, given.right_operand_}) { - if (rhs->type() != ExpressionType::SCALAR) { - return Copy(); - } - } - - auto this_rhs = - EnsureNotDictionary(checked_cast(*right_operand_).value()) - .ValueOr(nullptr); - auto given_rhs = - EnsureNotDictionary( - checked_cast(*given.right_operand_).value()) - .ValueOr(nullptr); - - if (!this_rhs || !given_rhs) { - return Copy(); - } - - auto cmp = Compare(*this_rhs, *given_rhs).ValueOrDie(); - - if (cmp == Comparison::NULL_) { - // the RHS of e or given was null - return NullExpression(); - } - - static auto always = scalar(true); - static auto never = scalar(false); - - if (cmp == Comparison::GREATER) { - // the rhs of e is greater than that of given - switch (op()) { - case CompareOperator::EQUAL: - case CompareOperator::GREATER: - case CompareOperator::GREATER_EQUAL: - switch (given.op()) { - case CompareOperator::EQUAL: - case CompareOperator::LESS: - case CompareOperator::LESS_EQUAL: - return never; - default: - return Copy(); - } - case CompareOperator::NOT_EQUAL: - case CompareOperator::LESS: - case CompareOperator::LESS_EQUAL: - switch (given.op()) { - case CompareOperator::EQUAL: - case CompareOperator::LESS: - case CompareOperator::LESS_EQUAL: - return always; - default: - return Copy(); - } - default: - return Copy(); - } - } - - if (cmp == Comparison::LESS) { - // the rhs of e is less than that of given - switch (op()) { - case CompareOperator::EQUAL: - case CompareOperator::LESS: - case CompareOperator::LESS_EQUAL: - switch (given.op()) { - case CompareOperator::EQUAL: - case CompareOperator::GREATER: - case CompareOperator::GREATER_EQUAL: - return never; - default: - return Copy(); - } - case CompareOperator::NOT_EQUAL: - case CompareOperator::GREATER: - case CompareOperator::GREATER_EQUAL: - switch (given.op()) { - case CompareOperator::EQUAL: - case CompareOperator::GREATER: - case CompareOperator::GREATER_EQUAL: - return always; - default: - return Copy(); - } - default: - return Copy(); - } - } - - DCHECK_EQ(cmp, Comparison::EQUAL); - - // the rhs of the comparisons are equal - switch (op_) { - case CompareOperator::EQUAL: - switch (given.op()) { - case CompareOperator::NOT_EQUAL: - case CompareOperator::GREATER: - case CompareOperator::LESS: - return never; - case CompareOperator::EQUAL: - return always; - default: - return Copy(); - } - case CompareOperator::NOT_EQUAL: - switch (given.op()) { - case CompareOperator::EQUAL: - return never; - case CompareOperator::NOT_EQUAL: - case CompareOperator::GREATER: - case CompareOperator::LESS: - return always; - default: - return Copy(); - } - case CompareOperator::GREATER: - switch (given.op()) { - case CompareOperator::EQUAL: - case CompareOperator::LESS_EQUAL: - case CompareOperator::LESS: - return never; - case CompareOperator::GREATER: - return always; - default: - return Copy(); - } - case CompareOperator::GREATER_EQUAL: - switch (given.op()) { - case CompareOperator::LESS: - return never; - case CompareOperator::EQUAL: - case CompareOperator::GREATER: - case CompareOperator::GREATER_EQUAL: - return always; - default: - return Copy(); - } - case CompareOperator::LESS: - switch (given.op()) { - case CompareOperator::EQUAL: - case CompareOperator::GREATER: - case CompareOperator::GREATER_EQUAL: - return never; - case CompareOperator::LESS: - return always; - default: - return Copy(); - } - case CompareOperator::LESS_EQUAL: - switch (given.op()) { - case CompareOperator::GREATER: - return never; - case CompareOperator::EQUAL: - case CompareOperator::LESS: - case CompareOperator::LESS_EQUAL: - return always; - default: - return Copy(); - } - default: - return Copy(); - } - return Copy(); -} - -std::shared_ptr AndExpression::Assume(const Expression& given) const { - auto left_operand = left_operand_->Assume(given); - auto right_operand = right_operand_->Assume(given); - - // if either of the operands is trivially false then so is this AND - if (left_operand->Equals(false) || right_operand->Equals(false)) { - return scalar(false); - } - - // if either operand is trivially null then so is this AND - if (left_operand->IsNull() || right_operand->IsNull()) { - return NullExpression(); - } - - // if one of the operands is trivially true then drop it - if (left_operand->Equals(true)) { - return right_operand; - } - if (right_operand->Equals(true)) { - return left_operand; - } - - // if neither of the operands is trivial, simply construct a new AND - return and_(std::move(left_operand), std::move(right_operand)); -} - -std::shared_ptr OrExpression::Assume(const Expression& given) const { - auto left_operand = left_operand_->Assume(given); - auto right_operand = right_operand_->Assume(given); - - // if either of the operands is trivially true then so is this OR - if (left_operand->Equals(true) || right_operand->Equals(true)) { - return scalar(true); - } - - // if either operand is trivially null then so is this OR - if (left_operand->IsNull() || right_operand->IsNull()) { - return NullExpression(); - } - - // if one of the operands is trivially false then drop it - if (left_operand->Equals(false)) { - return right_operand; - } - if (right_operand->Equals(false)) { - return left_operand; - } - - // if neither of the operands is trivial, simply construct a new OR - return or_(std::move(left_operand), std::move(right_operand)); -} - -std::shared_ptr NotExpression::Assume(const Expression& given) const { - auto operand = operand_->Assume(given); - - if (operand->IsNull()) { - return NullExpression(); - } - if (operand->Equals(true)) { - return scalar(false); - } - if (operand->Equals(false)) { - return scalar(true); - } - - return Copy(); -} - -std::shared_ptr InExpression::Assume(const Expression& given) const { - auto operand = operand_->Assume(given); - if (operand->type() != ExpressionType::SCALAR) { - return std::make_shared(std::move(operand), set_); - } - - if (operand->IsNull()) { - return scalar(set_->null_count() > 0); - } - - Datum set, value; - if (set_->type_id() == Type::DICTIONARY) { - const auto& dict_set = checked_cast(*set_); - auto maybe_decoded = compute::Take(dict_set.dictionary(), dict_set.indices()); - auto maybe_value = checked_cast( - *checked_cast(*operand).value()) - .GetEncodedValue(); - if (!maybe_decoded.ok() || !maybe_value.ok()) { - return std::make_shared(std::move(operand), set_); - } - set = *maybe_decoded; - value = *maybe_value; - } else { - set = set_; - value = checked_cast(*operand).value(); - } - - compute::CompareOptions eq(CompareOperator::EQUAL); - Result maybe_out = compute::Compare(set, value, eq); - if (!maybe_out.ok()) { - return std::make_shared(std::move(operand), set_); - } - - Datum out = maybe_out.ValueOrDie(); - - DCHECK(out.is_array()); - DCHECK_EQ(out.type()->id(), Type::BOOL); - auto out_array = checked_pointer_cast(out.make_array()); - - for (int64_t i = 0; i < out_array->length(); ++i) { - if (out_array->IsValid(i) && out_array->Value(i)) { - return scalar(true); - } - } - return scalar(false); -} - -std::shared_ptr IsValidExpression::Assume(const Expression& given) const { - auto operand = operand_->Assume(given); - if (operand->type() == ExpressionType::SCALAR) { - return scalar(!operand->IsNull()); - } - - return std::make_shared(std::move(operand)); -} - -std::shared_ptr CastExpression::Assume(const Expression& given) const { - auto operand = operand_->Assume(given); - if (arrow::util::holds_alternative>(to_)) { - auto to_type = arrow::util::get>(to_); - return std::make_shared(std::move(operand), std::move(to_type), - options_); - } - auto like = arrow::util::get>(to_)->Assume(given); - return std::make_shared(std::move(operand), std::move(like), options_); -} - -const std::shared_ptr& CastExpression::to_type() const { - if (arrow::util::holds_alternative>(to_)) { - return arrow::util::get>(to_); - } - static std::shared_ptr null; - return null; -} - -const std::shared_ptr& CastExpression::like_expr() const { - if (arrow::util::holds_alternative>(to_)) { - return arrow::util::get>(to_); - } - static std::shared_ptr null; - return null; -} - -std::string FieldExpression::ToString() const { return name_; } - -std::string OperatorName(compute::CompareOperator op) { - switch (op) { - case CompareOperator::EQUAL: - return "=="; - case CompareOperator::NOT_EQUAL: - return "!="; - case CompareOperator::LESS: - return "<"; - case CompareOperator::LESS_EQUAL: - return "<="; - case CompareOperator::GREATER: - return ">"; - case CompareOperator::GREATER_EQUAL: - return ">="; - default: - DCHECK(false); - } - return ""; -} - -std::string ScalarExpression::ToString() const { - auto type_repr = value_->type->ToString(); - if (!value_->is_valid) { - return "null:" + type_repr; - } - - return value_->ToString() + ":" + type_repr; -} - -using arrow::internal::JoinStrings; - -std::string AndExpression::ToString() const { - return JoinStrings( - {"(", left_operand_->ToString(), " and ", right_operand_->ToString(), ")"}, ""); -} - -std::string OrExpression::ToString() const { - return JoinStrings( - {"(", left_operand_->ToString(), " or ", right_operand_->ToString(), ")"}, ""); -} - -std::string NotExpression::ToString() const { - if (operand_->type() == ExpressionType::IS_VALID) { - const auto& is_valid = checked_cast(*operand_); - return JoinStrings({"(", is_valid.operand()->ToString(), " is null)"}, ""); - } - return JoinStrings({"(not ", operand_->ToString(), ")"}, ""); -} - -std::string IsValidExpression::ToString() const { - return JoinStrings({"(", operand_->ToString(), " is not null)"}, ""); -} - -std::string InExpression::ToString() const { - return JoinStrings({"(", operand_->ToString(), " is in ", set_->ToString(), ")"}, ""); -} - -std::string CastExpression::ToString() const { - std::string to; - if (arrow::util::holds_alternative>(to_)) { - auto to_type = arrow::util::get>(to_); - to = " to " + to_type->ToString(); - } else { - auto like = arrow::util::get>(to_); - to = " like " + like->ToString(); - } - return JoinStrings({"(cast ", operand_->ToString(), std::move(to), ")"}, ""); -} - -std::string ComparisonExpression::ToString() const { - return JoinStrings({"(", left_operand_->ToString(), " ", OperatorName(op()), " ", - right_operand_->ToString(), ")"}, - ""); -} - -bool UnaryExpression::Equals(const Expression& other) const { - return type_ == other.type() && - operand_->Equals(checked_cast(other).operand_); -} - -bool BinaryExpression::Equals(const Expression& other) const { - return type_ == other.type() && - left_operand_->Equals( - checked_cast(other).left_operand_) && - right_operand_->Equals( - checked_cast(other).right_operand_); -} - -bool ComparisonExpression::Equals(const Expression& other) const { - return BinaryExpression::Equals(other) && - op_ == checked_cast(other).op_; -} - -bool ScalarExpression::Equals(const Expression& other) const { - return other.type() == ExpressionType::SCALAR && - value_->Equals(*checked_cast(other).value_); -} - -bool FieldExpression::Equals(const Expression& other) const { - return other.type() == ExpressionType::FIELD && - name_ == checked_cast(other).name_; -} - -bool Expression::Equals(const std::shared_ptr& other) const { - if (other == nullptr) { - return false; - } - return Equals(*other); -} - -bool Expression::IsNull() const { - if (type_ != ExpressionType::SCALAR) { - return false; - } - - const auto& scalar = checked_cast(*this).value(); - if (!scalar->is_valid) { - return true; - } - - return false; -} - -InExpression Expression::In(std::shared_ptr set) const { - return InExpression(Copy(), std::move(set)); -} - -IsValidExpression Expression::IsValid() const { return IsValidExpression(Copy()); } - -std::shared_ptr FieldExpression::Copy() const { - return std::make_shared(*this); -} - -std::shared_ptr ScalarExpression::Copy() const { - return std::make_shared(*this); -} - -AndExpression operator&&(const Expression& lhs, const Expression& rhs) { - return AndExpression(lhs.Copy(), rhs.Copy()); -} - -OrExpression operator||(const Expression& lhs, const Expression& rhs) { - return OrExpression(lhs.Copy(), rhs.Copy()); -} - -CastExpression Expression::CastTo(std::shared_ptr type, - compute::CastOptions options) const { - return CastExpression(Copy(), type, std::move(options)); -} - -CastExpression Expression::CastLike(std::shared_ptr expr, - compute::CastOptions options) const { - return CastExpression(Copy(), std::move(expr), std::move(options)); -} - -CastExpression Expression::CastLike(const Expression& expr, - compute::CastOptions options) const { - return CastLike(expr.Copy(), std::move(options)); -} - -Result> ComparisonExpression::Validate( - const Schema& schema) const { - ARROW_ASSIGN_OR_RAISE(auto lhs_type, left_operand_->Validate(schema)); - ARROW_ASSIGN_OR_RAISE(auto rhs_type, right_operand_->Validate(schema)); - - if (lhs_type->id() == Type::NA || rhs_type->id() == Type::NA) { - return boolean(); - } - - if (!lhs_type->Equals(rhs_type)) { - return Status::TypeError("cannot compare expressions of differing type, ", *lhs_type, - " vs ", *rhs_type); - } - - return boolean(); -} - -Status EnsureNullOrBool(const std::string& msg_prefix, - const std::shared_ptr& type) { - if (type->id() == Type::BOOL || type->id() == Type::NA) { - return Status::OK(); - } - return Status::TypeError(msg_prefix, *type); -} - -Result> ValidateBoolean( - const std::vector>& operands, const Schema& schema) { - for (const auto& operand : operands) { - ARROW_ASSIGN_OR_RAISE(auto type, operand->Validate(schema)); - RETURN_NOT_OK( - EnsureNullOrBool("cannot combine expressions including one of type ", type)); - } - return boolean(); -} - -Result> AndExpression::Validate(const Schema& schema) const { - return ValidateBoolean({left_operand_, right_operand_}, schema); -} - -Result> OrExpression::Validate(const Schema& schema) const { - return ValidateBoolean({left_operand_, right_operand_}, schema); -} - -Result> NotExpression::Validate(const Schema& schema) const { - return ValidateBoolean({operand_}, schema); -} - -Result> InExpression::Validate(const Schema& schema) const { - ARROW_ASSIGN_OR_RAISE(auto operand_type, operand_->Validate(schema)); - if (operand_type->id() == Type::NA || set_->type()->id() == Type::NA) { - return boolean(); - } - - if (!operand_type->Equals(set_->type())) { - return Status::TypeError("mismatch: set type ", *set_->type(), " vs operand type ", - *operand_type); - } - // TODO(bkietz) check if IsIn supports operand_type - return boolean(); -} - -Result> IsValidExpression::Validate( - const Schema& schema) const { - ARROW_ASSIGN_OR_RAISE(std::ignore, operand_->Validate(schema)); - return boolean(); -} - -Result> CastExpression::Validate(const Schema& schema) const { - ARROW_ASSIGN_OR_RAISE(auto operand_type, operand_->Validate(schema)); - std::shared_ptr to_type; - if (arrow::util::holds_alternative>(to_)) { - to_type = arrow::util::get>(to_); - } else { - auto like = arrow::util::get>(to_); - ARROW_ASSIGN_OR_RAISE(to_type, like->Validate(schema)); - } - - // Until expressions carry a shape, detect scalar and try to cast it. Works - // if the operand is a scalar leaf. - if (operand_->type() == ExpressionType::SCALAR) { - auto scalar_expr = checked_pointer_cast(operand_); - ARROW_ASSIGN_OR_RAISE(std::ignore, scalar_expr->value()->CastTo(to_type)); - return to_type; - } - - if (!compute::CanCast(*operand_type, *to_type)) { - return Status::Invalid("Cannot cast to ", to_type->ToString()); - } - - return to_type; -} - -Result> ScalarExpression::Validate(const Schema& schema) const { - return value_->type; -} - -Result> FieldExpression::Validate(const Schema& schema) const { - ARROW_ASSIGN_OR_RAISE(auto field, FieldRef(name_).GetOneOrNone(schema)); - if (field != nullptr) { - return field->type(); - } - return null(); -} - -Result CastOrDictionaryEncode(const Datum& arr, - const std::shared_ptr& type, - const compute::CastOptions opts) { - if (type->id() == Type::DICTIONARY) { - const auto& dict_type = checked_cast(*type); - if (dict_type.index_type()->id() != Type::INT32) { - return Status::TypeError("cannot DictionaryEncode to index type ", - *dict_type.index_type()); - } - ARROW_ASSIGN_OR_RAISE(auto dense, compute::Cast(arr, dict_type.value_type(), opts)); - return compute::DictionaryEncode(dense); - } - - return compute::Cast(arr, type, opts); -} - -struct InsertImplicitCastsImpl { - struct ValidatedAndCast { - std::shared_ptr expr; - std::shared_ptr type; - }; - - Result InsertCastsAndValidate(const Expression& expr) { - ValidatedAndCast out; - ARROW_ASSIGN_OR_RAISE(out.expr, InsertImplicitCasts(expr, schema_)); - ARROW_ASSIGN_OR_RAISE(out.type, out.expr->Validate(schema_)); - return std::move(out); - } - - Result> Cast(std::shared_ptr type, - const Expression& expr) { - if (expr.type() != ExpressionType::SCALAR) { - return expr.CastTo(type).Copy(); - } - - // cast the scalar directly - const auto& value = checked_cast(expr).value(); - ARROW_ASSIGN_OR_RAISE(auto cast_value, value->CastTo(std::move(type))); - return scalar(cast_value); - } - - Result> operator()(const InExpression& expr) { - ARROW_ASSIGN_OR_RAISE(auto op, InsertCastsAndValidate(*expr.operand())); - auto set = expr.set(); - - if (!op.type->Equals(set->type())) { - // cast the set (which we assume to be small) to match op.type - ARROW_ASSIGN_OR_RAISE(auto encoded_set, CastOrDictionaryEncode( - *set, op.type, compute::CastOptions{})); - set = encoded_set.make_array(); - } - - return std::make_shared(std::move(op.expr), std::move(set)); - } - - Result> operator()(const NotExpression& expr) { - ARROW_ASSIGN_OR_RAISE(auto op, InsertCastsAndValidate(*expr.operand())); - - if (op.type->id() != Type::BOOL) { - ARROW_ASSIGN_OR_RAISE(op.expr, Cast(boolean(), *op.expr)); - } - return not_(std::move(op.expr)); - } - - Result> operator()(const AndExpression& expr) { - ARROW_ASSIGN_OR_RAISE(auto lhs, InsertCastsAndValidate(*expr.left_operand())); - ARROW_ASSIGN_OR_RAISE(auto rhs, InsertCastsAndValidate(*expr.right_operand())); - - if (lhs.type->id() != Type::BOOL) { - ARROW_ASSIGN_OR_RAISE(lhs.expr, Cast(boolean(), *lhs.expr)); - } - if (rhs.type->id() != Type::BOOL) { - ARROW_ASSIGN_OR_RAISE(rhs.expr, Cast(boolean(), *rhs.expr)); - } - return and_(std::move(lhs.expr), std::move(rhs.expr)); - } - - Result> operator()(const OrExpression& expr) { - ARROW_ASSIGN_OR_RAISE(auto lhs, InsertCastsAndValidate(*expr.left_operand())); - ARROW_ASSIGN_OR_RAISE(auto rhs, InsertCastsAndValidate(*expr.right_operand())); - - if (lhs.type->id() != Type::BOOL) { - ARROW_ASSIGN_OR_RAISE(lhs.expr, Cast(boolean(), *lhs.expr)); - } - if (rhs.type->id() != Type::BOOL) { - ARROW_ASSIGN_OR_RAISE(rhs.expr, Cast(boolean(), *rhs.expr)); - } - return or_(std::move(lhs.expr), std::move(rhs.expr)); - } - - Result> operator()(const ComparisonExpression& expr) { - ARROW_ASSIGN_OR_RAISE(auto lhs, InsertCastsAndValidate(*expr.left_operand())); - ARROW_ASSIGN_OR_RAISE(auto rhs, InsertCastsAndValidate(*expr.right_operand())); - - if (lhs.type->Equals(rhs.type)) { - return expr.Copy(); - } - - if (lhs.expr->type() == ExpressionType::SCALAR) { - ARROW_ASSIGN_OR_RAISE(lhs.expr, Cast(rhs.type, *lhs.expr)); - } else { - ARROW_ASSIGN_OR_RAISE(rhs.expr, Cast(lhs.type, *rhs.expr)); - } - return std::make_shared(expr.op(), std::move(lhs.expr), - std::move(rhs.expr)); - } - - Result> operator()(const Expression& expr) const { - return expr.Copy(); - } - - const Schema& schema_; -}; - -Result> InsertImplicitCasts(const Expression& expr, - const Schema& schema) { - RETURN_NOT_OK(schema.CanReferenceFieldsByNames(FieldsInExpression(expr))); - return VisitExpression(expr, InsertImplicitCastsImpl{schema}); -} - -Status VisitConjunctionMembers(const Expression& expr, - const std::function& visitor) { - if (expr.type() == ExpressionType::AND) { - const auto& and_ = checked_cast(expr); - RETURN_NOT_OK(VisitConjunctionMembers(*and_.left_operand(), visitor)); - RETURN_NOT_OK(VisitConjunctionMembers(*and_.right_operand(), visitor)); - return Status::OK(); - } - - return visitor(expr); -} - -std::vector FieldsInExpression(const Expression& expr) { - struct { - void operator()(const FieldExpression& expr) { fields.push_back(expr.name()); } - - void operator()(const UnaryExpression& expr) { - VisitExpression(*expr.operand(), *this); - } - - void operator()(const BinaryExpression& expr) { - VisitExpression(*expr.left_operand(), *this); - VisitExpression(*expr.right_operand(), *this); - } - - void operator()(const Expression&) const {} - - std::vector fields; - } visitor; - - VisitExpression(expr, visitor); - return std::move(visitor.fields); -} - -std::vector FieldsInExpression(const std::shared_ptr& expr) { - DCHECK_NE(expr, nullptr); - if (expr == nullptr) { - return {}; - } - - return FieldsInExpression(*expr); -} - -RecordBatchIterator ExpressionEvaluator::FilterBatches(RecordBatchIterator unfiltered, - std::shared_ptr filter, - MemoryPool* pool) { - auto filter_batches = [filter, pool, this](std::shared_ptr unfiltered) { - auto filtered = Evaluate(*filter, *unfiltered, pool).Map([&](Datum selection) { - return Filter(selection, unfiltered, pool); - }); - - if (filtered.ok() && (*filtered)->num_rows() == 0) { - // drop empty batches - return FilterIterator::Reject>(); - } - - return FilterIterator::MaybeAccept(std::move(filtered)); - }; - - return MakeFilterIterator(std::move(filter_batches), std::move(unfiltered)); -} - -std::shared_ptr ExpressionEvaluator::Null() { - struct Impl : ExpressionEvaluator { - Result Evaluate(const Expression& expr, const RecordBatch& batch, - MemoryPool* pool) const override { - ARROW_ASSIGN_OR_RAISE(auto type, expr.Validate(*batch.schema())); - return Datum(MakeNullScalar(type)); - } - - Result> Filter(const Datum& selection, - const std::shared_ptr& batch, - MemoryPool* pool) const override { - return batch; - } - }; - - return std::make_shared(); -} - -struct TreeEvaluator::Impl { - Result operator()(const ScalarExpression& expr) const { - return Datum(expr.value()); - } - - Result operator()(const FieldExpression& expr) const { - if (auto column = batch_.GetColumnByName(expr.name())) { - return std::move(column); - } - return NullDatum(); - } - - Result operator()(const AndExpression& expr) const { - return EvaluateBoolean(expr, compute::KleeneAnd); - } - - Result operator()(const OrExpression& expr) const { - return EvaluateBoolean(expr, compute::KleeneOr); - } - - Result EvaluateBoolean(const BinaryExpression& expr, - Result kernel(const Datum& left, - const Datum& right, - ExecContext* ctx)) const { - ARROW_ASSIGN_OR_RAISE(Datum lhs, Evaluate(*expr.left_operand())); - ARROW_ASSIGN_OR_RAISE(Datum rhs, Evaluate(*expr.right_operand())); - - if (lhs.is_scalar()) { - ARROW_ASSIGN_OR_RAISE( - auto lhs_array, - MakeArrayFromScalar(*lhs.scalar(), batch_.num_rows(), ctx_.memory_pool())); - lhs = Datum(std::move(lhs_array)); - } - - if (rhs.is_scalar()) { - ARROW_ASSIGN_OR_RAISE( - auto rhs_array, - MakeArrayFromScalar(*rhs.scalar(), batch_.num_rows(), ctx_.memory_pool())); - rhs = Datum(std::move(rhs_array)); - } - - return kernel(lhs, rhs, &ctx_); - } - - Result operator()(const NotExpression& expr) const { - ARROW_ASSIGN_OR_RAISE(Datum to_invert, Evaluate(*expr.operand())); - if (IsNullDatum(to_invert)) { - return NullDatum(); - } - - if (to_invert.is_scalar()) { - bool trivial_condition = - checked_cast(*to_invert.scalar()).value; - return Datum(std::make_shared(!trivial_condition)); - } - return compute::Invert(to_invert, &ctx_); - } - - Result operator()(const InExpression& expr) const { - ARROW_ASSIGN_OR_RAISE(Datum operand_values, Evaluate(*expr.operand())); - if (IsNullDatum(operand_values)) { - return Datum(expr.set()->null_count() != 0); - } - - DCHECK(operand_values.is_array()); - return compute::IsIn(operand_values, expr.set(), &ctx_); - } - - Result operator()(const IsValidExpression& expr) const { - ARROW_ASSIGN_OR_RAISE(Datum operand_values, Evaluate(*expr.operand())); - if (IsNullDatum(operand_values)) { - return Datum(false); - } - - if (operand_values.is_scalar()) { - return Datum(true); - } - - DCHECK(operand_values.is_array()); - if (operand_values.array()->GetNullCount() == 0) { - return Datum(true); - } - - return Datum(std::make_shared(operand_values.array()->length, - operand_values.array()->buffers[0])); - } - - Result operator()(const CastExpression& expr) const { - ARROW_ASSIGN_OR_RAISE(auto to_type, expr.Validate(*batch_.schema())); - - ARROW_ASSIGN_OR_RAISE(auto to_cast, Evaluate(*expr.operand())); - if (to_cast.is_scalar()) { - return to_cast.scalar()->CastTo(to_type); - } - - DCHECK(to_cast.is_array()); - return CastOrDictionaryEncode(to_cast, to_type, expr.options()); - } - - Result operator()(const ComparisonExpression& expr) const { - ARROW_ASSIGN_OR_RAISE(Datum lhs, Evaluate(*expr.left_operand())); - ARROW_ASSIGN_OR_RAISE(Datum rhs, Evaluate(*expr.right_operand())); - - if (IsNullDatum(lhs) || IsNullDatum(rhs)) { - return Datum(std::make_shared()); - } - - if (lhs.type()->id() == Type::DICTIONARY && rhs.type()->id() == Type::DICTIONARY) { - if (lhs.is_array() && rhs.is_array()) { - // decode dictionary arrays - for (Datum* arg : {&lhs, &rhs}) { - auto dict = checked_pointer_cast(arg->make_array()); - ARROW_ASSIGN_OR_RAISE(*arg, compute::Take(dict->dictionary(), dict->indices(), - compute::TakeOptions::Defaults())); - } - } else if (lhs.is_array() || rhs.is_array()) { - auto dict = checked_pointer_cast( - (lhs.is_array() ? lhs : rhs).make_array()); - - ARROW_ASSIGN_OR_RAISE(auto scalar, checked_cast( - *(lhs.is_scalar() ? lhs : rhs).scalar()) - .GetEncodedValue()); - if (lhs.is_array()) { - lhs = dict->dictionary(); - rhs = std::move(scalar); - } else { - lhs = std::move(scalar); - rhs = dict->dictionary(); - } - ARROW_ASSIGN_OR_RAISE( - Datum out_dict, - compute::Compare(lhs, rhs, compute::CompareOptions(expr.op()), &ctx_)); - - return compute::Take(out_dict, dict->indices(), compute::TakeOptions::Defaults()); - } - } - - return compute::Compare(lhs, rhs, compute::CompareOptions(expr.op()), &ctx_); - } - - Result operator()(const Expression& expr) const { - return Status::NotImplemented("evaluation of ", expr.ToString()); - } - - Result Evaluate(const Expression& expr) const { - return this_->Evaluate(expr, batch_, ctx_.memory_pool()); - } - - const TreeEvaluator* this_; - const RecordBatch& batch_; - mutable compute::ExecContext ctx_; -}; - -Result TreeEvaluator::Evaluate(const Expression& expr, const RecordBatch& batch, - MemoryPool* pool) const { - return VisitExpression(expr, Impl{this, batch, compute::ExecContext{pool}}); -} - -Result> TreeEvaluator::Filter( - const Datum& selection, const std::shared_ptr& batch, - MemoryPool* pool) const { - if (selection.is_array()) { - auto selection_array = selection.make_array(); - compute::ExecContext ctx(pool); - ARROW_ASSIGN_OR_RAISE(Datum filtered, - compute::Filter(batch, selection_array, - compute::FilterOptions::Defaults(), &ctx)); - return filtered.record_batch(); - } - - if (!selection.is_scalar() || selection.type()->id() != Type::BOOL) { - return Status::NotImplemented("Filtering batches against DatumKind::", - selection.kind(), " of type ", *selection.type()); - } - - if (BooleanScalar(true).Equals(*selection.scalar())) { - return batch; - } - - return batch->Slice(0, 0); -} - -const std::shared_ptr& scalar(bool value) { - static auto true_ = scalar(MakeScalar(true)); - static auto false_ = scalar(MakeScalar(false)); - return value ? true_ : false_; -} - -// Serialization is accomplished by converting expressions to single element StructArrays -// then writing that to an IPC file. The last field is always an int32 column containing -// ExpressionType, the rest store the Expression's members. -struct SerializeImpl { - Result> ToArray(const Expression& expr) const { - return VisitExpression(expr, *this); - } - - Result> TaggedWithChildren(const Expression& expr, - ArrayVector children) const { - children.emplace_back(); - ARROW_ASSIGN_OR_RAISE(children.back(), - MakeArrayFromScalar(Int32Scalar(expr.type()), 1)); - - return StructArray::Make(children, std::vector(children.size(), "")); - } - - Result> operator()(const FieldExpression& expr) const { - // store the field's name in a StringArray - ARROW_ASSIGN_OR_RAISE(auto name, MakeArrayFromScalar(StringScalar(expr.name()), 1)); - return TaggedWithChildren(expr, {name}); - } - - Result> operator()(const ScalarExpression& expr) const { - // store the scalar's value in a single element Array - ARROW_ASSIGN_OR_RAISE(auto value, MakeArrayFromScalar(*expr.value(), 1)); - return TaggedWithChildren(expr, {value}); - } - - Result> operator()(const UnaryExpression& expr) const { - // recurse to store the operand in a single element StructArray - ARROW_ASSIGN_OR_RAISE(auto operand, ToArray(*expr.operand())); - return TaggedWithChildren(expr, {operand}); - } - - Result> operator()(const CastExpression& expr) const { - // recurse to store the operand in a single element StructArray - ARROW_ASSIGN_OR_RAISE(auto operand, ToArray(*expr.operand())); - - // store the cast target and a discriminant - std::shared_ptr is_like_expr, to; - if (const auto& to_type = expr.to_type()) { - ARROW_ASSIGN_OR_RAISE(is_like_expr, MakeArrayFromScalar(BooleanScalar(false), 1)); - ARROW_ASSIGN_OR_RAISE(to, MakeArrayOfNull(to_type, 1)); - } - if (const auto& like_expr = expr.like_expr()) { - ARROW_ASSIGN_OR_RAISE(is_like_expr, MakeArrayFromScalar(BooleanScalar(true), 1)); - ARROW_ASSIGN_OR_RAISE(to, ToArray(*like_expr)); - } - - return TaggedWithChildren(expr, {operand, is_like_expr, to}); - } - - Result> operator()(const BinaryExpression& expr) const { - // recurse to store the operands in single element StructArrays - ARROW_ASSIGN_OR_RAISE(auto left_operand, ToArray(*expr.left_operand())); - ARROW_ASSIGN_OR_RAISE(auto right_operand, ToArray(*expr.right_operand())); - return TaggedWithChildren(expr, {left_operand, right_operand}); - } - - Result> operator()( - const ComparisonExpression& expr) const { - // recurse to store the operands in single element StructArrays - ARROW_ASSIGN_OR_RAISE(auto left_operand, ToArray(*expr.left_operand())); - ARROW_ASSIGN_OR_RAISE(auto right_operand, ToArray(*expr.right_operand())); - // store the CompareOperator in a single element Int32Array - ARROW_ASSIGN_OR_RAISE(auto op, MakeArrayFromScalar(Int32Scalar(expr.op()), 1)); - return TaggedWithChildren(expr, {left_operand, right_operand, op}); - } - - Result> operator()(const InExpression& expr) const { - // recurse to store the operand in a single element StructArray - ARROW_ASSIGN_OR_RAISE(auto operand, ToArray(*expr.operand())); - - // store the set as a single element ListArray - auto set_type = list(expr.set()->type()); - - ARROW_ASSIGN_OR_RAISE(auto set_offsets, AllocateBuffer(sizeof(int32_t) * 2)); - reinterpret_cast(set_offsets->mutable_data())[0] = 0; - reinterpret_cast(set_offsets->mutable_data())[1] = - static_cast(expr.set()->length()); - - auto set_values = expr.set(); - - auto set = std::make_shared(std::move(set_type), 1, std::move(set_offsets), - std::move(set_values)); - return TaggedWithChildren(expr, {operand, set}); - } - - Result> operator()(const Expression& expr) const { - return Status::NotImplemented("serialization of ", expr.ToString()); - } - - Result> ToBuffer(const Expression& expr) const { - ARROW_ASSIGN_OR_RAISE(auto array, SerializeImpl{}.ToArray(expr)); - ARROW_ASSIGN_OR_RAISE(auto batch, RecordBatch::FromStructArray(array)); - ARROW_ASSIGN_OR_RAISE(auto stream, io::BufferOutputStream::Create()); - ARROW_ASSIGN_OR_RAISE(auto writer, ipc::MakeFileWriter(stream, batch->schema())); - RETURN_NOT_OK(writer->WriteRecordBatch(*batch)); - RETURN_NOT_OK(writer->Close()); - return stream->Finish(); - } -}; - -Result> Expression::Serialize() const { - return SerializeImpl{}.ToBuffer(*this); -} - -struct DeserializeImpl { - Result> FromArray(const Array& array) const { - if (array.type_id() != Type::STRUCT || array.length() != 1) { - return Status::Invalid("can only deserialize expressions from unit-length", - " StructArray, got ", array); - } - const auto& struct_array = checked_cast(array); - - ARROW_ASSIGN_OR_RAISE(auto expression_type, GetExpressionType(struct_array)); - switch (expression_type) { - case ExpressionType::FIELD: { - ARROW_ASSIGN_OR_RAISE(auto name, GetView(struct_array, 0)); - return std::make_shared(std::string(name)); - } - - case ExpressionType::SCALAR: { - ARROW_ASSIGN_OR_RAISE(auto value, struct_array.field(0)->GetScalar(0)); - return scalar(std::move(value)); - } - - case ExpressionType::NOT: { - ARROW_ASSIGN_OR_RAISE(auto operand, FromArray(*struct_array.field(0))); - return not_(std::move(operand)); - } - - case ExpressionType::CAST: { - ARROW_ASSIGN_OR_RAISE(auto operand, FromArray(*struct_array.field(0))); - ARROW_ASSIGN_OR_RAISE(auto is_like_expr, GetView(struct_array, 1)); - if (is_like_expr) { - ARROW_ASSIGN_OR_RAISE(auto like_expr, FromArray(*struct_array.field(2))); - return operand->CastLike(std::move(like_expr)).Copy(); - } - return operand->CastTo(struct_array.field(2)->type()).Copy(); - } - - case ExpressionType::AND: { - ARROW_ASSIGN_OR_RAISE(auto left_operand, FromArray(*struct_array.field(0))); - ARROW_ASSIGN_OR_RAISE(auto right_operand, FromArray(*struct_array.field(1))); - return and_(std::move(left_operand), std::move(right_operand)); - } - - case ExpressionType::OR: { - ARROW_ASSIGN_OR_RAISE(auto left_operand, FromArray(*struct_array.field(0))); - ARROW_ASSIGN_OR_RAISE(auto right_operand, FromArray(*struct_array.field(1))); - return or_(std::move(left_operand), std::move(right_operand)); - } - - case ExpressionType::COMPARISON: { - ARROW_ASSIGN_OR_RAISE(auto left_operand, FromArray(*struct_array.field(0))); - ARROW_ASSIGN_OR_RAISE(auto right_operand, FromArray(*struct_array.field(1))); - ARROW_ASSIGN_OR_RAISE(auto op, GetView(struct_array, 2)); - return std::make_shared(static_cast(op), - std::move(left_operand), - std::move(right_operand)); - } - - case ExpressionType::IS_VALID: { - ARROW_ASSIGN_OR_RAISE(auto operand, FromArray(*struct_array.field(0))); - return std::make_shared(std::move(operand)); - } - - case ExpressionType::IN: { - ARROW_ASSIGN_OR_RAISE(auto operand, FromArray(*struct_array.field(0))); - if (struct_array.field(1)->type_id() != Type::LIST) { - return Status::TypeError("expected field 1 of ", struct_array, - " to have list type"); - } - auto set = checked_cast(*struct_array.field(1)).values(); - return std::make_shared(std::move(operand), std::move(set)); - } - - default: - break; - } - - return Status::Invalid("non-deserializable ExpressionType ", expression_type); - } - - template ::ArrayType> - static Result().GetView(0))> GetView(const StructArray& array, - int index) { - if (index >= array.num_fields()) { - return Status::IndexError("expected ", array, " to have a child at index ", index); - } - - const auto& child = *array.field(index); - if (child.type_id() != T::type_id) { - return Status::TypeError("expected child ", index, " of ", array, " to have type ", - T::type_id); - } - - return checked_cast(child).GetView(0); - } - - static Result GetExpressionType(const StructArray& array) { - if (array.struct_type()->num_fields() < 1) { - return Status::Invalid("StructArray didn't contain ExpressionType member"); - } - - ARROW_ASSIGN_OR_RAISE(auto expression_type, - GetView(array, array.num_fields() - 1)); - return static_cast(expression_type); - } - - Result> FromBuffer(const Buffer& serialized) { - io::BufferReader stream(serialized); - ARROW_ASSIGN_OR_RAISE(auto reader, ipc::RecordBatchFileReader::Open(&stream)); - ARROW_ASSIGN_OR_RAISE(auto batch, reader->ReadRecordBatch(0)); - ARROW_ASSIGN_OR_RAISE(auto array, batch->ToStructArray()); - return FromArray(*array); - } -}; - -Result> Expression::Deserialize(const Buffer& serialized) { - return DeserializeImpl{}.FromBuffer(serialized); -} - -// Transform an array of counts to offsets which will divide a ListArray -// into an equal number of slices with corresponding lengths. -inline Result> CountsToOffsets( - std::shared_ptr counts) { - Int32Builder offset_builder; - RETURN_NOT_OK(offset_builder.Resize(counts->length() + 1)); - offset_builder.UnsafeAppend(0); - - for (int64_t i = 0; i < counts->length(); ++i) { - DCHECK_NE(counts->Value(i), 0); - auto next_offset = static_cast(offset_builder[i] + counts->Value(i)); - offset_builder.UnsafeAppend(next_offset); - } - - std::shared_ptr offsets; - RETURN_NOT_OK(offset_builder.Finish(&offsets)); - return offsets; -} - -// Helper for simultaneous dictionary encoding of multiple arrays. -// -// The fused dictionary is the Cartesian product of the individual dictionaries. -// For example given two arrays A, B where A has unique values ["ex", "why"] -// and B has unique values [0, 1] the fused dictionary is the set of tuples -// [["ex", 0], ["ex", 1], ["why", 0], ["ex", 1]]. -// -// TODO(bkietz) this capability belongs in an Action of the hash kernels, where -// it can be used to group aggregates without materializing a grouped batch. -// For the purposes of writing we need the materialized grouped batch anyway -// since no Writers accept a selection vector. -class StructDictionary { - public: - struct Encoded { - std::shared_ptr indices; - std::shared_ptr dictionary; - }; - - static Result Encode(const ArrayVector& columns) { - Encoded out{nullptr, std::make_shared()}; - - for (const auto& column : columns) { - if (column->null_count() != 0) { - return Status::NotImplemented("Grouping on a field with nulls"); - } - - RETURN_NOT_OK(out.dictionary->AddOne(column, &out.indices)); - } - - return out; - } - - Result> Decode(std::shared_ptr fused_indices, - FieldVector fields) { - std::vector builders(dictionaries_.size()); - for (Int32Builder& b : builders) { - RETURN_NOT_OK(b.Resize(fused_indices->length())); - } - - std::vector codes(dictionaries_.size()); - for (int64_t i = 0; i < fused_indices->length(); ++i) { - Expand(fused_indices->Value(i), codes.data()); - - auto builder_it = builders.begin(); - for (int32_t index : codes) { - builder_it++->UnsafeAppend(index); - } - } - - ArrayVector columns(dictionaries_.size()); - for (size_t i = 0; i < dictionaries_.size(); ++i) { - std::shared_ptr indices; - RETURN_NOT_OK(builders[i].FinishInternal(&indices)); - - ARROW_ASSIGN_OR_RAISE(Datum column, compute::Take(dictionaries_[i], indices)); - columns[i] = column.make_array(); - } - - return StructArray::Make(std::move(columns), std::move(fields)); - } - - private: - Status AddOne(Datum column, std::shared_ptr* fused_indices) { - ArrayData* encoded; - if (column.type()->id() != Type::DICTIONARY) { - ARROW_ASSIGN_OR_RAISE(column, compute::DictionaryEncode(column)); - } - encoded = column.mutable_array(); - - auto indices = - std::make_shared(encoded->length, std::move(encoded->buffers[1])); - - dictionaries_.push_back(MakeArray(std::move(encoded->dictionary))); - auto dictionary_size = static_cast(dictionaries_.back()->length()); - - if (*fused_indices == nullptr) { - *fused_indices = std::move(indices); - size_ = dictionary_size; - return Status::OK(); - } - - // It's useful to think about the case where each of dictionaries_ has size 10. - // In this case the decimal digit in the ones place is the code in dictionaries_[0], - // the tens place corresponds to dictionaries_[1], etc. - // The incumbent indices must be shifted to the hundreds place so as not to collide. - ARROW_ASSIGN_OR_RAISE(Datum new_fused_indices, - compute::Multiply(indices, MakeScalar(size_))); - - ARROW_ASSIGN_OR_RAISE(new_fused_indices, - compute::Add(new_fused_indices, *fused_indices)); - - *fused_indices = checked_pointer_cast(new_fused_indices.make_array()); - - // XXX should probably cap this at 2**15 or so - ARROW_CHECK(!internal::MultiplyWithOverflow(size_, dictionary_size, &size_)); - return Status::OK(); - } - - // expand a fused code into component dict codes, order is in order of addition - void Expand(int32_t fused_code, int32_t* codes) { - for (size_t i = 0; i < dictionaries_.size(); ++i) { - auto dictionary_size = static_cast(dictionaries_[i]->length()); - codes[i] = fused_code % dictionary_size; - fused_code /= dictionary_size; - } - } - - int32_t size_; - ArrayVector dictionaries_; -}; - -Result> MakeGroupings(const StructArray& by) { - if (by.num_fields() == 0) { - return Status::NotImplemented("Grouping with no criteria"); - } - - ARROW_ASSIGN_OR_RAISE(auto fused, StructDictionary::Encode(by.fields())); - - ARROW_ASSIGN_OR_RAISE(auto sort_indices, compute::SortIndices(*fused.indices)); - ARROW_ASSIGN_OR_RAISE(Datum sorted, compute::Take(fused.indices, *sort_indices)); - fused.indices = checked_pointer_cast(sorted.make_array()); - - ARROW_ASSIGN_OR_RAISE(auto fused_counts_and_values, - compute::ValueCounts(fused.indices)); - fused.indices.reset(); - - auto unique_fused_indices = - checked_pointer_cast(fused_counts_and_values->GetFieldByName("values")); - ARROW_ASSIGN_OR_RAISE( - auto unique_rows, - fused.dictionary->Decode(std::move(unique_fused_indices), by.type()->fields())); - - auto counts = - checked_pointer_cast(fused_counts_and_values->GetFieldByName("counts")); - ARROW_ASSIGN_OR_RAISE(auto offsets, CountsToOffsets(std::move(counts))); - - ARROW_ASSIGN_OR_RAISE(auto grouped_sort_indices, - ListArray::FromArrays(*offsets, *sort_indices)); - - return StructArray::Make( - ArrayVector{std::move(unique_rows), std::move(grouped_sort_indices)}, - std::vector{"values", "groupings"}); -} - -Result> ApplyGroupings(const ListArray& groupings, - const Array& array) { - ARROW_ASSIGN_OR_RAISE(Datum sorted, - compute::Take(array, groupings.data()->child_data[0])); - - return std::make_shared(list(array.type()), groupings.length(), - groupings.value_offsets(), sorted.make_array()); -} - -Result ApplyGroupings(const ListArray& groupings, - const std::shared_ptr& batch) { - ARROW_ASSIGN_OR_RAISE(Datum sorted, - compute::Take(batch, groupings.data()->child_data[0])); - - const auto& sorted_batch = *sorted.record_batch(); - - RecordBatchVector out(static_cast(groupings.length())); - for (size_t i = 0; i < out.size(); ++i) { - out[i] = sorted_batch.Slice(groupings.value_offset(i), groupings.value_length(i)); - } - - return out; -} - -} // namespace dataset -} // namespace arrow +// FIXME remove diff --git a/cpp/src/arrow/dataset/filter.h b/cpp/src/arrow/dataset/filter.h index 750c67c048d..6dd39e25e1e 100644 --- a/cpp/src/arrow/dataset/filter.h +++ b/cpp/src/arrow/dataset/filter.h @@ -15,603 +15,5 @@ // specific language governing permissions and limitations // under the License. -// This API is EXPERIMENTAL. +// FIXME remove -#pragma once - -#include -#include -#include -#include -#include - -#include "arrow/compute/api_scalar.h" -#include "arrow/compute/cast.h" -#include "arrow/dataset/type_fwd.h" -#include "arrow/dataset/visibility.h" -#include "arrow/datum.h" -#include "arrow/result.h" -#include "arrow/scalar.h" -#include "arrow/type_fwd.h" -#include "arrow/util/checked_cast.h" -#include "arrow/util/variant.h" - -namespace arrow { -namespace dataset { - -using compute::CastOptions; -using compute::CompareOperator; - -struct ExpressionType { - enum type { - /// a reference to a column within a record batch, will evaluate to an array - FIELD, - - /// a literal singular value encapsulated in a Scalar - SCALAR, - - /// a literal Array - // TODO(bkietz) ARRAY, - - /// an inversion of another expression - NOT, - - /// cast an expression to a given DataType - CAST, - - /// a conjunction of multiple expressions (true if all operands are true) - AND, - - /// a disjunction of multiple expressions (true if any operand is true) - OR, - - /// a comparison of two other expressions - COMPARISON, - - /// replace nulls with other expressions - /// currently only boolean expressions may be coalesced - // TODO(bkietz) COALESCE, - - /// extract validity as a boolean expression - IS_VALID, - - /// check each element for membership in a set - IN, - - /// custom user defined expression - CUSTOM, - }; -}; - -class InExpression; -class CastExpression; -class IsValidExpression; - -/// Represents an expression tree -class ARROW_DS_EXPORT Expression { - public: - explicit Expression(ExpressionType::type type) : type_(type) {} - - virtual ~Expression() = default; - - /// Returns true iff the expressions are identical; does not check for equivalence. - /// For example, (A and B) is not equal to (B and A) nor is (A and not A) equal to - /// (false). - virtual bool Equals(const Expression& other) const = 0; - - bool Equals(const std::shared_ptr& other) const; - - /// Overload for the common case of checking for equality to a specific scalar. - template ()))> - bool Equals(T&& t) const; - - /// If true, this Expression is a ScalarExpression wrapping a null scalar. - bool IsNull() const; - - /// Validate this expression for execution against a schema. This will check that all - /// reference fields are present (fields not in the schema will be replaced with null) - /// and all subexpressions are executable. Returns the type to which this expression - /// will evaluate. - virtual Result> Validate(const Schema& schema) const = 0; - - /// \brief Simplify to an equivalent Expression given assumed constraints on input. - /// This can be used to do less filtering work using predicate push down. - /// - /// Both expressions must pass validation against a schema before Assume may be used. - /// - /// Two expressions can be considered equivalent for a given subset of possible inputs - /// if they yield identical results. Formally, if given.Evaluate(input).Equals(input) - /// then Assume guarantees that: - /// expr.Assume(given).Evaluate(input).Equals(expr.Evaluate(input)) - /// - /// For example if we are given that all inputs will - /// satisfy ("a"_ == 1) then the expression ("a"_ > 0 and "b"_ > 0) is equivalent to - /// ("b"_ > 0). It is impossible that the comparison ("a"_ > 0) will evaluate false - /// given ("a"_ == 1), so both expressions will yield identical results. Thus we can - /// write: - /// ("a"_ > 0 and "b"_ > 0).Assume("a"_ == 1).Equals("b"_ > 0) - /// - /// filter.Assume(partition) is trivial if filter and partition are disjoint or if - /// partition is a subset of filter. FIXME(bkietz) write this better - /// - If the two are disjoint, then (false) may be substituted for filter. - /// - If partition is a subset of filter then (true) may be substituted for filter. - /// - /// filter.Assume(partition) is straightforward if both filter and partition are simple - /// comparisons. - /// - filter may be a superset of partition, in which case the filter is - /// satisfied by all inputs: - /// ("a"_ > 0).Assume("a"_ == 1).Equals(true) - /// - filter may be disjoint with partition, in which case there are no inputs which - /// satisfy filter: - /// ("a"_ < 0).Assume("a"_ == 1).Equals(false) - /// - If neither of these is the case, partition provides no information which can - /// simplify filter: - /// ("a"_ == 1).Assume("a"_ > 0).Equals("a"_ == 1) - /// ("a"_ == 1).Assume("b"_ == 1).Equals("a"_ == 1) - /// - /// If filter is compound, Assume can be distributed across the boolean operator. To - /// prove this is valid, we again demonstrate that the simplified expression will yield - /// identical results. For conjunction of filters lhs and rhs: - /// (lhs.Assume(p) and rhs.Assume(p)).Evaluate(input) - /// == Intersection(lhs.Assume(p).Evaluate(input), rhs.Assume(p).Evaluate(input)) - /// == Intersection(lhs.Evaluate(input), rhs.Evaluate(input)) - /// == (lhs and rhs).Evaluate(input) - /// - The proof for disjunction is symmetric; just replace Intersection with Union. Thus - /// we can write: - /// (lhs and rhs).Assume(p).Equals(lhs.Assume(p) and rhs.Assume(p)) - /// (lhs or rhs).Assume(p).Equals(lhs.Assume(p) or rhs.Assume(p)) - /// - For negation: - /// (not e.Assume(p)).Evaluate(input) - /// == Difference(input, e.Assume(p).Evaluate(input)) - /// == Difference(input, e.Evaluate(input)) - /// == (not e).Evaluate(input) - /// - Thus we can write: - /// (not e).Assume(p).Equals(not e.Assume(p)) - /// - /// If the partition expression is a conjunction then each of its subexpressions is - /// true for all input and can be used independently: - /// filter.Assume(lhs).Assume(rhs).Evaluate(input) - /// == filter.Assume(lhs).Evaluate(input) - /// == filter.Evaluate(input) - /// - Thus we can write: - /// filter.Assume(lhs and rhs).Equals(filter.Assume(lhs).Assume(rhs)) - /// - /// FIXME(bkietz) disjunction proof - /// filter.Assume(lhs or rhs).Equals(filter.Assume(lhs) and filter.Assume(rhs)) - /// - This may not result in a simpler expression so it is only used when - /// filter.Assume(lhs).Equals(filter.Assume(rhs)) - /// - /// If the partition expression is a negation then we can use the above relations by - /// replacing comparisons with their complements and using the properties: - /// (not (a and b)).Equals(not a or not b) - /// (not (a or b)).Equals(not a and not b) - virtual std::shared_ptr Assume(const Expression& given) const; - - std::shared_ptr Assume(const std::shared_ptr& given) const { - return Assume(*given); - } - - /// Indicates if the expression is satisfiable. - /// - /// This is a shortcut to check if the expression is neither null nor false. - bool IsSatisfiable() const { return !IsNull() && !Equals(false); } - - /// Indicates if the expression is satisfiable given an other expression. - /// - /// This behaves like IsSatisfiable, but it simplifies the current expression - /// with the given `other` information. - bool IsSatisfiableWith(const Expression& other) const { - return Assume(other)->IsSatisfiable(); - } - - bool IsSatisfiableWith(const std::shared_ptr& other) const { - return Assume(other)->IsSatisfiable(); - } - - /// returns a debug string representing this expression - virtual std::string ToString() const = 0; - - /// serialize/deserialize an Expression. - Result> Serialize() const; - static Result> Deserialize(const Buffer&); - - /// \brief Return the expression's type identifier - ExpressionType::type type() const { return type_; } - - /// Copy this expression into a shared pointer. - virtual std::shared_ptr Copy() const = 0; - - InExpression In(std::shared_ptr set) const; - - IsValidExpression IsValid() const; - - CastExpression CastTo(std::shared_ptr type, - CastOptions options = CastOptions()) const; - - CastExpression CastLike(const Expression& expr, - CastOptions options = CastOptions()) const; - - CastExpression CastLike(std::shared_ptr expr, - CastOptions options = CastOptions()) const; - - protected: - ExpressionType::type type_; -}; - -/// Helper class which implements Copy and forwards construction -template -class ExpressionImpl : public Base { - public: - static constexpr ExpressionType::type expression_type = E; - - template - explicit ExpressionImpl(A0&& arg0, A&&... args) - : Base(expression_type, std::forward(arg0), std::forward(args)...) {} - - std::shared_ptr Copy() const override { - return std::make_shared(internal::checked_cast(*this)); - } -}; - -/// Base class for an expression with exactly one operand -class ARROW_DS_EXPORT UnaryExpression : public Expression { - public: - const std::shared_ptr& operand() const { return operand_; } - - bool Equals(const Expression& other) const override; - - protected: - UnaryExpression(ExpressionType::type type, std::shared_ptr operand) - : Expression(type), operand_(std::move(operand)) {} - - std::shared_ptr operand_; -}; - -/// Base class for an expression with exactly two operands -class ARROW_DS_EXPORT BinaryExpression : public Expression { - public: - const std::shared_ptr& left_operand() const { return left_operand_; } - - const std::shared_ptr& right_operand() const { return right_operand_; } - - bool Equals(const Expression& other) const override; - - protected: - BinaryExpression(ExpressionType::type type, std::shared_ptr left_operand, - std::shared_ptr right_operand) - : Expression(type), - left_operand_(std::move(left_operand)), - right_operand_(std::move(right_operand)) {} - - std::shared_ptr left_operand_, right_operand_; -}; - -class ARROW_DS_EXPORT ComparisonExpression final - : public ExpressionImpl { - public: - ComparisonExpression(CompareOperator op, std::shared_ptr left_operand, - std::shared_ptr right_operand) - : ExpressionImpl(std::move(left_operand), std::move(right_operand)), op_(op) {} - - std::string ToString() const override; - - bool Equals(const Expression& other) const override; - - std::shared_ptr Assume(const Expression& given) const override; - - CompareOperator op() const { return op_; } - - Result> Validate(const Schema& schema) const override; - - private: - std::shared_ptr AssumeGivenComparison( - const ComparisonExpression& given) const; - - CompareOperator op_; -}; - -class ARROW_DS_EXPORT AndExpression final - : public ExpressionImpl { - public: - using ExpressionImpl::ExpressionImpl; - - std::string ToString() const override; - - std::shared_ptr Assume(const Expression& given) const override; - - Result> Validate(const Schema& schema) const override; -}; - -class ARROW_DS_EXPORT OrExpression final - : public ExpressionImpl { - public: - using ExpressionImpl::ExpressionImpl; - - std::string ToString() const override; - - std::shared_ptr Assume(const Expression& given) const override; - - Result> Validate(const Schema& schema) const override; -}; - -class ARROW_DS_EXPORT NotExpression final - : public ExpressionImpl { - public: - using ExpressionImpl::ExpressionImpl; - - std::string ToString() const override; - - std::shared_ptr Assume(const Expression& given) const override; - - Result> Validate(const Schema& schema) const override; -}; - -class ARROW_DS_EXPORT IsValidExpression final - : public ExpressionImpl { - public: - using ExpressionImpl::ExpressionImpl; - - std::string ToString() const override; - - Result> Validate(const Schema& schema) const override; - - std::shared_ptr Assume(const Expression& given) const override; -}; - -class ARROW_DS_EXPORT InExpression final - : public ExpressionImpl { - public: - InExpression(std::shared_ptr operand, std::shared_ptr set) - : ExpressionImpl(std::move(operand)), set_(std::move(set)) {} - - std::string ToString() const override; - - Result> Validate(const Schema& schema) const override; - - std::shared_ptr Assume(const Expression& given) const override; - - /// The set against which the operand will be compared - const std::shared_ptr& set() const { return set_; } - - private: - std::shared_ptr set_; -}; - -/// Explicitly cast an expression to a different type -class ARROW_DS_EXPORT CastExpression final - : public ExpressionImpl { - public: - CastExpression(std::shared_ptr operand, std::shared_ptr to, - CastOptions options) - : ExpressionImpl(std::move(operand)), - to_(std::move(to)), - options_(std::move(options)) {} - - /// The operand will be cast to whatever type `like` would evaluate to, given the same - /// schema. - CastExpression(std::shared_ptr operand, std::shared_ptr like, - CastOptions options) - : ExpressionImpl(std::move(operand)), - to_(std::move(like)), - options_(std::move(options)) {} - - std::string ToString() const override; - - std::shared_ptr Assume(const Expression& given) const override; - - Result> Validate(const Schema& schema) const override; - - const CastOptions& options() const { return options_; } - - /// Return the target type of this CastTo expression, or nullptr if this is a - /// CastLike expression. - const std::shared_ptr& to_type() const; - - /// Return the target expression of this CastLike expression, or nullptr if - /// this is a CastTo expression. - const std::shared_ptr& like_expr() const; - - private: - util::Variant, std::shared_ptr> to_; - CastOptions options_; -}; - -/// Represents a scalar value; thin wrapper around arrow::Scalar -class ARROW_DS_EXPORT ScalarExpression final : public Expression { - public: - explicit ScalarExpression(const std::shared_ptr& value) - : Expression(ExpressionType::SCALAR), value_(std::move(value)) {} - - const std::shared_ptr& value() const { return value_; } - - std::string ToString() const override; - - bool Equals(const Expression& other) const override; - - Result> Validate(const Schema& schema) const override; - - std::shared_ptr Copy() const override; - - private: - std::shared_ptr value_; -}; - -/// Represents a reference to a field. Stores only the field's name (type and other -/// information is known only when a Schema is provided) -class ARROW_DS_EXPORT FieldExpression final : public Expression { - public: - explicit FieldExpression(std::string name) - : Expression(ExpressionType::FIELD), name_(std::move(name)) {} - - std::string name() const { return name_; } - - std::string ToString() const override; - - bool Equals(const Expression& other) const override; - - Result> Validate(const Schema& schema) const override; - - std::shared_ptr Copy() const override; - - private: - std::string name_; -}; - -class ARROW_DS_EXPORT CustomExpression : public Expression { - protected: - CustomExpression() : Expression(ExpressionType::CUSTOM) {} -}; - -inline std::shared_ptr scalar(std::shared_ptr value) { - return std::make_shared(std::move(value)); -} - -template -auto scalar(T&& value) -> decltype(scalar(MakeScalar(std::forward(value)))) { - return scalar(MakeScalar(std::forward(value))); -} - -template -bool Expression::Equals(T&& t) const { - if (type_ != ExpressionType::SCALAR) { - return false; - } - auto s = MakeScalar(std::forward(t)); - return internal::checked_cast(*this).value()->Equals(*s); -} - -template -auto VisitExpression(const Expression& expr, Visitor&& visitor) - -> decltype(visitor(expr)) { - switch (expr.type()) { - case ExpressionType::FIELD: - return visitor(internal::checked_cast(expr)); - - case ExpressionType::SCALAR: - return visitor(internal::checked_cast(expr)); - - case ExpressionType::IN: - return visitor(internal::checked_cast(expr)); - - case ExpressionType::IS_VALID: - return visitor(internal::checked_cast(expr)); - - case ExpressionType::AND: - return visitor(internal::checked_cast(expr)); - - case ExpressionType::OR: - return visitor(internal::checked_cast(expr)); - - case ExpressionType::NOT: - return visitor(internal::checked_cast(expr)); - - case ExpressionType::CAST: - return visitor(internal::checked_cast(expr)); - - case ExpressionType::COMPARISON: - return visitor(internal::checked_cast(expr)); - - case ExpressionType::CUSTOM: - default: - break; - } - return visitor(internal::checked_cast(expr)); -} - -/// \brief Visit each subexpression of an arbitrarily nested conjunction. -/// -/// | given | visit | -/// |--------------------------------|---------------------------------------------| -/// | a and b | visit(a), visit(b) | -/// | c | visit(c) | -/// | (a and b) and ((c or d) and e) | visit(a), visit(b), visit(c or d), visit(e) | -ARROW_DS_EXPORT Status VisitConjunctionMembers( - const Expression& expr, const std::function& visitor); - -/// \brief Insert CastExpressions where necessary to make a valid expression. -ARROW_DS_EXPORT Result> InsertImplicitCasts( - const Expression& expr, const Schema& schema); - -/// \brief Returns field names referenced in the expression. -ARROW_DS_EXPORT std::vector FieldsInExpression(const Expression& expr); - -ARROW_DS_EXPORT std::vector FieldsInExpression( - const std::shared_ptr& expr); - -/// Interface for evaluation of expressions against record batches. -class ARROW_DS_EXPORT ExpressionEvaluator { - public: - virtual ~ExpressionEvaluator() = default; - - /// Evaluate expr against each row of a RecordBatch. - /// Returned Datum will be of either SCALAR or ARRAY kind. - /// A return value of ARRAY kind will have length == batch.num_rows() - /// An return value of SCALAR kind is equivalent to an array of the same type whose - /// slots contain a single repeated value. - /// - /// expr must be validated against the schema of batch before calling this method. - virtual Result Evaluate(const Expression& expr, const RecordBatch& batch, - MemoryPool* pool) const = 0; - - Result Evaluate(const Expression& expr, const RecordBatch& batch) const { - return Evaluate(expr, batch, default_memory_pool()); - } - - virtual Result> Filter( - const Datum& selection, const std::shared_ptr& batch, - MemoryPool* pool) const = 0; - - Result> Filter( - const Datum& selection, const std::shared_ptr& batch) const { - return Filter(selection, batch, default_memory_pool()); - } - - /// \brief Wrap an iterator of record batches with a filter expression. The resulting - /// iterator will yield record batches filtered by the given expression. - /// - /// \note The ExpressionEvaluator must outlive the returned iterator. - RecordBatchIterator FilterBatches(RecordBatchIterator unfiltered, - std::shared_ptr filter, - MemoryPool* pool = default_memory_pool()); - - /// construct an Evaluator which evaluates all expressions to null and does no - /// filtering - static std::shared_ptr Null(); -}; - -/// construct an Evaluator which uses compute kernels to evaluate expressions and -/// filter record batches in depth first order -class ARROW_DS_EXPORT TreeEvaluator : public ExpressionEvaluator { - public: - Result Evaluate(const Expression& expr, const RecordBatch& batch, - MemoryPool* pool) const override; - - Result> Filter(const Datum& selection, - const std::shared_ptr& batch, - MemoryPool* pool) const override; - - protected: - struct Impl; -}; - -/// \brief Assemble lists of indices of identical rows. -/// -/// \param[in] by A StructArray whose columns will be used as grouping criteria. -/// \return A StructArray mapping unique rows (in field "values", represented as a -/// StructArray with the same fields as `by`) to lists of indices where -/// that row appears (in field "groupings"). -ARROW_DS_EXPORT -Result> MakeGroupings(const StructArray& by); - -/// \brief Produce slices of an Array which correspond to the provided groupings. -ARROW_DS_EXPORT -Result> ApplyGroupings(const ListArray& groupings, - const Array& array); -ARROW_DS_EXPORT -Result ApplyGroupings(const ListArray& groupings, - const std::shared_ptr& batch); - -} // namespace dataset -} // namespace arrow diff --git a/cpp/src/arrow/dataset/filter_test.cc b/cpp/src/arrow/dataset/filter_test.cc index fce2d9cfcd9..ef32b8a7bb6 100644 --- a/cpp/src/arrow/dataset/filter_test.cc +++ b/cpp/src/arrow/dataset/filter_test.cc @@ -15,487 +15,4 @@ // specific language governing permissions and limitations // under the License. -#include "arrow/dataset/filter.h" - -#include -#include -#include -#include - -#include -#include - -#include "arrow/compute/api.h" -#include "arrow/dataset/test_util.h" -#include "arrow/record_batch.h" -#include "arrow/status.h" -#include "arrow/testing/gtest_util.h" -#include "arrow/type.h" -#include "arrow/type_fwd.h" -#include "arrow/util/checked_cast.h" -#include "arrow/util/logging.h" - -namespace arrow { -namespace dataset { - -// clang-format off -using string_literals::operator"" _; -// clang-format on - -using internal::checked_cast; -using internal::checked_pointer_cast; - -// A frozen shared_ptr with behavior expected by GTest -struct TestExpression : util::EqualityComparable, - util::ToStringOstreamable { - // NOLINTNEXTLINE runtime/explicit - TestExpression(std::shared_ptr e) : expression(std::move(e)) {} - - // NOLINTNEXTLINE runtime/explicit - TestExpression(const Expression& e) : expression(e.Copy()) {} - - // NOLINTNEXTLINE runtime/explicit - TestExpression(const Expression2& e) : expression(e) {} - - std::shared_ptr expression; - - using util::EqualityComparable::operator==; - bool Equals(const TestExpression& other) const { - return expression->Equals(other.expression); - } - - std::string ToString() const { return expression->ToString(); } - - friend bool operator==(const std::shared_ptr& lhs, - const TestExpression& rhs) { - return TestExpression(lhs) == rhs; - } - - friend void PrintTo(const TestExpression& expr, std::ostream* os) { - *os << expr.ToString(); - } -}; - -using E = TestExpression; - -class ExpressionsTest : public ::testing::Test { - public: - void AssertSimplifiesTo(E expr, E given, E expected) { - ASSERT_OK_AND_ASSIGN(auto expr_type, expr.expression->Validate(*schema_)); - ASSERT_OK_AND_ASSIGN(auto given_type, given.expression->Validate(*schema_)); - ASSERT_OK_AND_ASSIGN(auto expected_type, expected.expression->Validate(*schema_)); - - EXPECT_TRUE(expr_type->Equals(expected_type)); - EXPECT_TRUE(given_type->Equals(boolean())); - - auto simplified = expr.expression->Assume(given.expression); - ASSERT_EQ(simplified, expected) - << " simplification of: " << expr.ToString() << std::endl - << " given: " << given.ToString() << std::endl; - } - - std::shared_ptr ns = timestamp(TimeUnit::NANO); - std::shared_ptr schema_ = - schema({field("a", int32()), field("b", int32()), field("f", float64()), - field("s", utf8()), field("ts", ns), - field("dict_b", dictionary(int32(), int32()))}); - std::shared_ptr always = scalar(true); - std::shared_ptr never = scalar(false); -}; - -TEST_F(ExpressionsTest, Equality) { - ASSERT_EQ(E{"a"_}, E{"a"_}); - ASSERT_NE(E{"a"_}, E{"b"_}); - - ASSERT_EQ(E{"b"_ == 3}, E{"b"_ == 3}); - ASSERT_NE(E{"b"_ == 3}, E{"b"_ < 3}); - ASSERT_NE(E{"b"_ == 3}, E{"b"_}); - - // ordering matters - ASSERT_EQ(E{"b"_ == 3}, E{"b"_ == 3}); - ASSERT_NE(E{"b"_ == 3}, E{"b"_ < 3}); - ASSERT_NE(E{"b"_ == 3}, E{"b"_}); - - ASSERT_EQ(E("b"_ > 2 and "b"_ < 3), E("b"_ > 2 and "b"_ < 3)); - ASSERT_NE(E("b"_ > 2 and "b"_ < 3), E("b"_ < 3 and "b"_ > 2)); -} - -TEST_F(ExpressionsTest, SimplificationOfCompoundQuery) { - AssertSimplifiesTo("b"_ > 5 and "b"_ < 10, "b"_ == 3, *never); - AssertSimplifiesTo("b"_ > 5 and "b"_ < 10, "b"_ == 6, *always); - - AssertSimplifiesTo("b"_ == 3 or "b"_ == 4, "b"_ > 6, *never); - AssertSimplifiesTo("b"_ == 3 or "b"_ == 4, "b"_ == 3, *always); - AssertSimplifiesTo("b"_ == 3 or "b"_ == 4, "b"_ > 3, "b"_ == 4); - AssertSimplifiesTo("b"_ == 3 or "b"_ == 4, "b"_ >= 3, "b"_ == 3 or "b"_ == 4); - - AssertSimplifiesTo("f"_ > 0.5 and "f"_ < 1.5, not_("f"_ < 0.0 or "f"_ > 1.0), - "f"_ > 0.5); - - AssertSimplifiesTo("b"_ == 4, "a"_ == 0, "b"_ == 4); - - AssertSimplifiesTo("a"_ == 3 or "b"_ == 4, "a"_ == 0, "b"_ == 4); - - auto set_123 = ArrayFromJSON(int32(), R"([1, 2, 3])"); - AssertSimplifiesTo("a"_ == 3 and "b"_.In(set_123), "b"_ == 3, "a"_ == 3); - AssertSimplifiesTo("a"_ == 3 and "b"_.In(set_123), "b"_ == 0, *never); - - AssertSimplifiesTo("a"_ == 0 or not_("b"_.IsValid()), "b"_ == 3, "a"_ == 0); -} - -TEST_F(ExpressionsTest, SimplificationAgainstCompoundCondition) { - AssertSimplifiesTo("b"_ > 5, "b"_ == 3 or "b"_ == 6, "b"_ > 5); - AssertSimplifiesTo("b"_ > 7, "b"_ == 3 or "b"_ == 6, *never); - AssertSimplifiesTo("b"_ > 5 and "b"_ < 10, "b"_ > 6 and "b"_ < 13, "b"_ < 10); - - auto set_123 = ArrayFromJSON(int32(), R"([1, 2, 3])"); - AssertSimplifiesTo("b"_.In(set_123), "a"_ == 3 and "b"_ == 3, *always); - AssertSimplifiesTo("b"_.In(set_123), "a"_ == 3 and "b"_ == 5, *never); - - auto dict_set_123 = - DictArrayFromJSON(dictionary(int32(), int32()), R"([1,2,0])", R"([1,2,3])"); - ASSERT_OK_AND_ASSIGN(auto b_dict, dict_set_123->GetScalar(0)); - AssertSimplifiesTo("b_dict"_.In(dict_set_123), "a"_ == 3 and "b_dict"_ == b_dict, - *always); -} - -class FilterTest : public ::testing::Test { - public: - FilterTest() { evaluator_ = std::make_shared(); } - - Result DoFilter(const Expression& expr, - std::vector> fields, - std::string batch_json, - std::shared_ptr* expected_mask = nullptr) { - // expected filter result is in the "in" field - fields.push_back(field("in", boolean())); - auto batch = RecordBatchFromJSON(schema(fields), batch_json); - if (expected_mask) { - *expected_mask = checked_pointer_cast(batch->GetColumnByName("in")); - } - - ARROW_ASSIGN_OR_RAISE(auto expr_type, expr.Validate(*batch->schema())); - EXPECT_TRUE(expr_type->Equals(boolean())); - - return evaluator_->Evaluate(expr, *batch); - } - - void AssertFilter(const std::shared_ptr& expr, - std::vector> fields, - const std::string& batch_json) { - AssertFilter(*expr, std::move(fields), batch_json); - } - - void AssertFilter(const Expression& expr, std::vector> fields, - const std::string& batch_json) { - std::shared_ptr expected_mask; - - ASSERT_OK_AND_ASSIGN(Datum mask, DoFilter(expr, std::move(fields), - std::move(batch_json), &expected_mask)); - ASSERT_TRUE(mask.type()->Equals(null()) || mask.type()->Equals(boolean())); - - if (mask.is_array()) { - AssertArraysEqual(*expected_mask, *mask.make_array(), /*verbose=*/true); - return; - } - - ASSERT_TRUE(mask.is_scalar()); - auto mask_scalar = mask.scalar(); - if (!mask_scalar->is_valid) { - ASSERT_EQ(expected_mask->null_count(), expected_mask->length()); - return; - } - - TypedBufferBuilder builder; - ASSERT_OK(builder.Append(expected_mask->length(), - checked_cast(*mask_scalar).value)); - - std::shared_ptr values; - ASSERT_OK(builder.Finish(&values)); - - ASSERT_ARRAYS_EQUAL(*expected_mask, BooleanArray(expected_mask->length(), values)); - } - - std::shared_ptr evaluator_; -}; - -TEST_F(FilterTest, Trivial) { - // Note that we should expect these trivial expressions will never be evaluated against - // record batches; since they're trivial, evaluation is not necessary. - AssertFilter(scalar(true), {field("a", int32()), field("b", float64())}, R"([ - {"a": 0, "b": -0.1, "in": 1}, - {"a": 0, "b": 0.3, "in": 1}, - {"a": 1, "b": 0.2, "in": 1}, - {"a": 2, "b": -0.1, "in": 1}, - {"a": 0, "b": 0.1, "in": 1}, - {"a": 0, "b": null, "in": 1}, - {"a": 0, "b": 1.0, "in": 1} - ])"); - - AssertFilter(scalar(false), {field("a", int32()), field("b", float64())}, R"([ - {"a": 0, "b": -0.1, "in": 0}, - {"a": 0, "b": 0.3, "in": 0}, - {"a": 1, "b": 0.2, "in": 0}, - {"a": 2, "b": -0.1, "in": 0}, - {"a": 0, "b": 0.1, "in": 0}, - {"a": 0, "b": null, "in": 0}, - {"a": 0, "b": 1.0, "in": 0} - ])"); - - AssertFilter(*scalar(std::shared_ptr(new BooleanScalar)), - {field("a", int32()), field("b", float64())}, R"([ - {"a": 0, "b": -0.1, "in": null}, - {"a": 0, "b": 0.3, "in": null}, - {"a": 1, "b": 0.2, "in": null}, - {"a": 2, "b": -0.1, "in": null}, - {"a": 0, "b": 0.1, "in": null}, - {"a": 0, "b": null, "in": null}, - {"a": 0, "b": 1.0, "in": null} - ])"); -} - -TEST_F(FilterTest, Basics) { - AssertFilter("a"_ == 0 and "b"_ > 0.0 and "b"_ < 1.0, - {field("a", int32()), field("b", float64())}, R"([ - {"a": 0, "b": -0.1, "in": 0}, - {"a": 0, "b": 0.3, "in": 1}, - {"a": 1, "b": 0.2, "in": 0}, - {"a": 2, "b": -0.1, "in": 0}, - {"a": 0, "b": 0.1, "in": 1}, - {"a": 0, "b": null, "in": null}, - {"a": 0, "b": 1.0, "in": 0} - ])"); - - AssertFilter("a"_ != 0 and "b"_ > 0.1, {field("a", int32()), field("b", float64())}, - R"([ - {"a": 0, "b": -0.1, "in": 0}, - {"a": 0, "b": 0.3, "in": 0}, - {"a": 1, "b": 0.2, "in": 1}, - {"a": 2, "b": -0.1, "in": 0}, - {"a": 0, "b": 0.1, "in": 0}, - {"a": 0, "b": null, "in": 0}, - {"a": 0, "b": 1.0, "in": 0} - ])"); -} - -TEST_F(FilterTest, InExpression) { - auto hello_world = ArrayFromJSON(utf8(), R"(["hello", "world"])"); - - AssertFilter("s"_.In(hello_world), {field("s", utf8())}, R"([ - {"s": "hello", "in": 1}, - {"s": "world", "in": 1}, - {"s": "", "in": 0}, - {"s": null, "in": null}, - {"s": "foo", "in": 0}, - {"s": "hello", "in": 1}, - {"s": "bar", "in": 0} - ])"); -} - -TEST_F(FilterTest, IsValidExpression) { - AssertFilter("s"_.IsValid(), {field("s", utf8())}, R"([ - {"s": "hello", "in": 1}, - {"s": null, "in": 0}, - {"s": "", "in": 1}, - {"s": null, "in": 0}, - {"s": "foo", "in": 1}, - {"s": "hello", "in": 1}, - {"s": null, "in": 0} - ])"); -} - -TEST_F(FilterTest, Cast) { - ASSERT_RAISES(TypeError, ("a"_ == double(1.0)).Validate(Schema({field("a", int32())}))); - - AssertFilter("a"_.CastTo(float64()) == double(1.0), - {field("a", int32()), field("b", float64())}, R"([ - {"a": 0, "b": -0.1, "in": 0}, - {"a": 0, "b": 0.3, "in": 0}, - {"a": 1, "b": 0.2, "in": 1}, - {"a": 2, "b": -0.1, "in": 0}, - {"a": 0, "b": 0.1, "in": 0}, - {"a": 0, "b": null, "in": 0}, - {"a": 1, "b": 1.0, "in": 1} - ])"); - - AssertFilter("a"_ == scalar(0.6)->CastLike("a"_), - {field("a", int32()), field("b", float64())}, R"([ - {"a": 0, "b": -0.1, "in": 1}, - {"a": 0, "b": 0.3, "in": 1}, - {"a": 1, "b": 0.2, "in": 0}, - {"a": 2, "b": -0.1, "in": 0}, - {"a": 0, "b": 0.1, "in": 1}, - {"a": 0, "b": null, "in": 1}, - {"a": 1, "b": 1.0, "in": 0} - ])"); - - AssertFilter("a"_.CastLike("b"_) == "b"_, {field("a", int32()), field("b", float64())}, - R"([ - {"a": 0, "b": -0.1, "in": 0}, - {"a": 0, "b": 0.0, "in": 1}, - {"a": 1, "b": 1.0, "in": 1}, - {"a": 2, "b": -0.1, "in": 0}, - {"a": 0, "b": 0.1, "in": 0}, - {"a": 2, "b": null, "in": null}, - {"a": 1, "b": 1.0, "in": 1} - ])"); -} - -TEST_F(ExpressionsTest, ImplicitCast) { - ASSERT_OK_AND_ASSIGN(auto filter, InsertImplicitCasts("a"_ == 0.0, *schema_)); - ASSERT_EQ(E{filter}, E{"a"_ == 0}); - - auto ns = timestamp(TimeUnit::NANO); - auto date = "1990-10-23 10:23:33"; - ASSERT_OK_AND_ASSIGN(filter, InsertImplicitCasts("ts"_ == date, *schema_)); - ASSERT_EQ(E{filter}, E{"ts"_ == *MakeScalar(date)->CastTo(ns)}); - - ASSERT_OK_AND_ASSIGN(filter, - InsertImplicitCasts("ts"_ == date and "b"_ == "3", *schema_)); - ASSERT_EQ(E{filter}, E{"ts"_ == *MakeScalar(date)->CastTo(ns) and "b"_ == 3}); - AssertSimplifiesTo(*filter, "b"_ == 2, *never); - AssertSimplifiesTo(*filter, "b"_ == 3, "ts"_ == *MakeScalar(date)->CastTo(ns)); - - // set is double but "a"_ is int32 - auto set_double = ArrayFromJSON(float64(), R"([1, 2, 3])"); - ASSERT_OK_AND_ASSIGN(filter, InsertImplicitCasts("a"_.In(set_double), *schema_)); - auto set_int32 = ArrayFromJSON(int32(), R"([1, 2, 3])"); - ASSERT_EQ(E{filter}, E{"a"_.In(set_int32)}); - - EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, - testing::HasSubstr("Field named 'nope' not found"), - InsertImplicitCasts("nope"_ == 0.0, *schema_)); -} - -TEST_F(ExpressionsTest, ImplicitCastToDict) { - auto dict_type = dictionary(int32(), float64()); - ASSERT_OK_AND_ASSIGN(auto filter, - InsertImplicitCasts("a"_ == 1.5, Schema({field("a", dict_type)}))); - - auto encoded_scalar = std::make_shared( - DictionaryScalar::ValueType{MakeScalar(0), - ArrayFromJSON(float64(), "[1.5]")}, - dict_type); - - ASSERT_EQ(E{filter}, E{"a"_ == encoded_scalar}); - - for (int32_t i = 0; i < 5; ++i) { - auto partition_scalar = std::make_shared( - DictionaryScalar::ValueType{ - MakeScalar(i), ArrayFromJSON(float64(), "[0.0, 0.5, 1.0, 1.5, 2.0]")}, - dict_type); - ASSERT_EQ(E{filter->Assume("a"_ == partition_scalar)}, E{scalar(i == 3)}); - } - - auto set_f64 = ArrayFromJSON(float64(), "[0.0, 0.5, 1.0, 1.5, 2.0]"); - ASSERT_OK_AND_ASSIGN( - filter, InsertImplicitCasts("a"_.In(set_f64), Schema({field("a", dict_type)}))); -} - -TEST_F(FilterTest, ImplicitCast) { - ASSERT_OK_AND_ASSIGN(auto filter, - InsertImplicitCasts("a"_ >= "1", Schema({field("a", int32())}))); - - AssertFilter(*filter, {field("a", int32()), field("b", float64())}, - R"([ - {"a": 0, "b": -0.1, "in": 0}, - {"a": 0, "b": 0.0, "in": 0}, - {"a": 1, "b": 1.0, "in": 1}, - {"a": 2, "b": -0.1, "in": 1}, - {"a": 0, "b": 0.1, "in": 0}, - {"a": 2, "b": null, "in": 1}, - {"a": 1, "b": 1.0, "in": 1} - ])"); -} - -TEST_F(FilterTest, ConditionOnAbsentColumn) { - AssertFilter("a"_ == 0 and "b"_ > 0.0 and "b"_ < 1.0 and "absent"_ == 0, - {field("a", int32()), field("b", float64())}, R"([ - {"a": 0, "b": -0.1, "in": false}, - {"a": 0, "b": 0.3, "in": null}, - {"a": 1, "b": 0.2, "in": false}, - {"a": 2, "b": -0.1, "in": false}, - {"a": 0, "b": 0.1, "in": null}, - {"a": 0, "b": null, "in": null}, - {"a": 0, "b": 1.0, "in": false} - ])"); -} - -TEST_F(FilterTest, KleeneTruthTables) { - AssertFilter("a"_ and "b"_, {field("a", boolean()), field("b", boolean())}, R"([ - {"a":null, "b":null, "in":null}, - {"a":null, "b":true, "in":null}, - {"a":null, "b":false, "in":false}, - - {"a":true, "b":true, "in":true}, - {"a":true, "b":false, "in":false}, - - {"a":false, "b":false, "in":false} - ])"); - - AssertFilter("a"_ or "b"_, {field("a", boolean()), field("b", boolean())}, R"([ - {"a":null, "b":null, "in":null}, - {"a":null, "b":true, "in":true}, - {"a":null, "b":false, "in":null}, - - {"a":true, "b":true, "in":true}, - {"a":true, "b":false, "in":true}, - - {"a":false, "b":false, "in":false} - ])"); -} - -void AssertGrouping(const FieldVector& by_fields, const std::string& batch_json, - const std::string& expected_json) { - FieldVector fields_with_ids = by_fields; - fields_with_ids.push_back(field("ids", list(int32()))); - auto expected = ArrayFromJSON(struct_(fields_with_ids), expected_json); - - FieldVector fields_with_id = by_fields; - fields_with_id.push_back(field("id", int32())); - auto batch = RecordBatchFromJSON(schema(fields_with_id), batch_json); - - ASSERT_OK_AND_ASSIGN(auto by, batch->RemoveColumn(batch->num_columns() - 1) - .Map([](std::shared_ptr by) { - return by->ToStructArray(); - })); - - ASSERT_OK_AND_ASSIGN(auto groupings_and_values, MakeGroupings(*by)); - - auto groupings = - checked_pointer_cast(groupings_and_values->GetFieldByName("groupings")); - - ASSERT_OK_AND_ASSIGN(std::shared_ptr grouped_ids, - ApplyGroupings(*groupings, *batch->GetColumnByName("id"))); - - ArrayVector columns = - checked_cast(*groupings_and_values->GetFieldByName("values")) - .fields(); - columns.push_back(grouped_ids); - - ASSERT_OK_AND_ASSIGN(auto actual, StructArray::Make(columns, fields_with_ids)); - - AssertArraysEqual(*expected, *actual, /*verbose=*/true); -} - -TEST(GroupTest, Basics) { - AssertGrouping({field("a", utf8()), field("b", int32())}, R"([ - {"a": "ex", "b": 0, "id": 0}, - {"a": "ex", "b": 0, "id": 1}, - {"a": "why", "b": 0, "id": 2}, - {"a": "ex", "b": 1, "id": 3}, - {"a": "why", "b": 0, "id": 4}, - {"a": "ex", "b": 1, "id": 5}, - {"a": "ex", "b": 0, "id": 6}, - {"a": "why", "b": 1, "id": 7} - ])", - R"([ - {"a": "ex", "b": 0, "ids": [0, 1, 6]}, - {"a": "why", "b": 0, "ids": [2, 4]}, - {"a": "ex", "b": 1, "ids": [3, 5]}, - {"a": "why", "b": 1, "ids": [7]} - ])"); -} - -} // namespace dataset -} // namespace arrow +// FIXME delete this diff --git a/cpp/src/arrow/dataset/partition.cc b/cpp/src/arrow/dataset/partition.cc index 16a49b988a9..6fd61096392 100644 --- a/cpp/src/arrow/dataset/partition.cc +++ b/cpp/src/arrow/dataset/partition.cc @@ -26,10 +26,11 @@ #include "arrow/array/array_nested.h" #include "arrow/array/builder_dict.h" #include "arrow/compute/api_scalar.h" +#include "arrow/compute/api_vector.h" #include "arrow/dataset/dataset_internal.h" -#include "arrow/dataset/filter.h" #include "arrow/filesystem/path_util.h" #include "arrow/scalar.h" +#include "arrow/util/int_util_internal.h" #include "arrow/util/logging.h" #include "arrow/util/make_unique.h" #include "arrow/util/string_view.h" @@ -84,7 +85,7 @@ Status KeyValuePartitioning::SetDefaultValuesFromKeys(const Expression2& expr, return Status::OK(); } -inline std::shared_ptr ConjunctionFromGroupingRow(Scalar* row) { +inline Expression2 ConjunctionFromGroupingRow(Scalar* row) { ScalarVector* values = &checked_cast(row)->value; std::vector equality_expressions(values->size()); for (size_t i = 0; i < values->size(); ++i) { @@ -529,5 +530,192 @@ Result> PartitioningOrFactory::GetOrInferSchema( return factory()->Inspect(paths); } +// Transform an array of counts to offsets which will divide a ListArray +// into an equal number of slices with corresponding lengths. +inline Result> CountsToOffsets( + std::shared_ptr counts) { + Int32Builder offset_builder; + RETURN_NOT_OK(offset_builder.Resize(counts->length() + 1)); + offset_builder.UnsafeAppend(0); + + for (int64_t i = 0; i < counts->length(); ++i) { + DCHECK_NE(counts->Value(i), 0); + auto next_offset = static_cast(offset_builder[i] + counts->Value(i)); + offset_builder.UnsafeAppend(next_offset); + } + + std::shared_ptr offsets; + RETURN_NOT_OK(offset_builder.Finish(&offsets)); + return offsets; +} + +// Helper for simultaneous dictionary encoding of multiple arrays. +// +// The fused dictionary is the Cartesian product of the individual dictionaries. +// For example given two arrays A, B where A has unique values ["ex", "why"] +// and B has unique values [0, 1] the fused dictionary is the set of tuples +// [["ex", 0], ["ex", 1], ["why", 0], ["ex", 1]]. +// +// TODO(bkietz) this capability belongs in an Action of the hash kernels, where +// it can be used to group aggregates without materializing a grouped batch. +// For the purposes of writing we need the materialized grouped batch anyway +// since no Writers accept a selection vector. +class StructDictionary { + public: + struct Encoded { + std::shared_ptr indices; + std::shared_ptr dictionary; + }; + + static Result Encode(const ArrayVector& columns) { + Encoded out{nullptr, std::make_shared()}; + + for (const auto& column : columns) { + if (column->null_count() != 0) { + return Status::NotImplemented("Grouping on a field with nulls"); + } + + RETURN_NOT_OK(out.dictionary->AddOne(column, &out.indices)); + } + + return out; + } + + Result> Decode(std::shared_ptr fused_indices, + FieldVector fields) { + std::vector builders(dictionaries_.size()); + for (Int32Builder& b : builders) { + RETURN_NOT_OK(b.Resize(fused_indices->length())); + } + + std::vector codes(dictionaries_.size()); + for (int64_t i = 0; i < fused_indices->length(); ++i) { + Expand(fused_indices->Value(i), codes.data()); + + auto builder_it = builders.begin(); + for (int32_t index : codes) { + builder_it++->UnsafeAppend(index); + } + } + + ArrayVector columns(dictionaries_.size()); + for (size_t i = 0; i < dictionaries_.size(); ++i) { + std::shared_ptr indices; + RETURN_NOT_OK(builders[i].FinishInternal(&indices)); + + ARROW_ASSIGN_OR_RAISE(Datum column, compute::Take(dictionaries_[i], indices)); + columns[i] = column.make_array(); + } + + return StructArray::Make(std::move(columns), std::move(fields)); + } + + private: + Status AddOne(Datum column, std::shared_ptr* fused_indices) { + ArrayData* encoded; + if (column.type()->id() != Type::DICTIONARY) { + ARROW_ASSIGN_OR_RAISE(column, compute::DictionaryEncode(column)); + } + encoded = column.mutable_array(); + + auto indices = + std::make_shared(encoded->length, std::move(encoded->buffers[1])); + + dictionaries_.push_back(MakeArray(std::move(encoded->dictionary))); + auto dictionary_size = static_cast(dictionaries_.back()->length()); + + if (*fused_indices == nullptr) { + *fused_indices = std::move(indices); + size_ = dictionary_size; + return Status::OK(); + } + + // It's useful to think about the case where each of dictionaries_ has size 10. + // In this case the decimal digit in the ones place is the code in dictionaries_[0], + // the tens place corresponds to dictionaries_[1], etc. + // The incumbent indices must be shifted to the hundreds place so as not to collide. + ARROW_ASSIGN_OR_RAISE(Datum new_fused_indices, + compute::Multiply(indices, MakeScalar(size_))); + + ARROW_ASSIGN_OR_RAISE(new_fused_indices, + compute::Add(new_fused_indices, *fused_indices)); + + *fused_indices = checked_pointer_cast(new_fused_indices.make_array()); + + // XXX should probably cap this at 2**15 or so + ARROW_CHECK(!internal::MultiplyWithOverflow(size_, dictionary_size, &size_)); + return Status::OK(); + } + + // expand a fused code into component dict codes, order is in order of addition + void Expand(int32_t fused_code, int32_t* codes) { + for (size_t i = 0; i < dictionaries_.size(); ++i) { + auto dictionary_size = static_cast(dictionaries_[i]->length()); + codes[i] = fused_code % dictionary_size; + fused_code /= dictionary_size; + } + } + + int32_t size_; + ArrayVector dictionaries_; +}; + +Result> MakeGroupings(const StructArray& by) { + if (by.num_fields() == 0) { + return Status::NotImplemented("Grouping with no criteria"); + } + + ARROW_ASSIGN_OR_RAISE(auto fused, StructDictionary::Encode(by.fields())); + + ARROW_ASSIGN_OR_RAISE(auto sort_indices, compute::SortIndices(*fused.indices)); + ARROW_ASSIGN_OR_RAISE(Datum sorted, compute::Take(fused.indices, *sort_indices)); + fused.indices = checked_pointer_cast(sorted.make_array()); + + ARROW_ASSIGN_OR_RAISE(auto fused_counts_and_values, + compute::ValueCounts(fused.indices)); + fused.indices.reset(); + + auto unique_fused_indices = + checked_pointer_cast(fused_counts_and_values->GetFieldByName("values")); + ARROW_ASSIGN_OR_RAISE( + auto unique_rows, + fused.dictionary->Decode(std::move(unique_fused_indices), by.type()->fields())); + + auto counts = + checked_pointer_cast(fused_counts_and_values->GetFieldByName("counts")); + ARROW_ASSIGN_OR_RAISE(auto offsets, CountsToOffsets(std::move(counts))); + + ARROW_ASSIGN_OR_RAISE(auto grouped_sort_indices, + ListArray::FromArrays(*offsets, *sort_indices)); + + return StructArray::Make( + ArrayVector{std::move(unique_rows), std::move(grouped_sort_indices)}, + std::vector{"values", "groupings"}); +} + +Result> ApplyGroupings(const ListArray& groupings, + const Array& array) { + ARROW_ASSIGN_OR_RAISE(Datum sorted, + compute::Take(array, groupings.data()->child_data[0])); + + return std::make_shared(list(array.type()), groupings.length(), + groupings.value_offsets(), sorted.make_array()); +} + +Result ApplyGroupings(const ListArray& groupings, + const std::shared_ptr& batch) { + ARROW_ASSIGN_OR_RAISE(Datum sorted, + compute::Take(batch, groupings.data()->child_data[0])); + + const auto& sorted_batch = *sorted.record_batch(); + + RecordBatchVector out(static_cast(groupings.length())); + for (size_t i = 0; i < out.size(); ++i) { + out[i] = sorted_batch.Slice(groupings.value_offset(i), groupings.value_length(i)); + } + + return out; +} + } // namespace dataset } // namespace arrow diff --git a/cpp/src/arrow/dataset/partition.h b/cpp/src/arrow/dataset/partition.h index 1bad8956b65..a9f8c87bfc9 100644 --- a/cpp/src/arrow/dataset/partition.h +++ b/cpp/src/arrow/dataset/partition.h @@ -285,5 +285,22 @@ class ARROW_DS_EXPORT PartitioningOrFactory { std::shared_ptr partitioning_; }; +/// \brief Assemble lists of indices of identical rows. +/// +/// \param[in] by A StructArray whose columns will be used as grouping criteria. +/// \return A StructArray mapping unique rows (in field "values", represented as a +/// StructArray with the same fields as `by`) to lists of indices where +/// that row appears (in field "groupings"). +ARROW_DS_EXPORT +Result> MakeGroupings(const StructArray& by); + +/// \brief Produce slices of an Array which correspond to the provided groupings. +ARROW_DS_EXPORT +Result> ApplyGroupings(const ListArray& groupings, + const Array& array); +ARROW_DS_EXPORT +Result ApplyGroupings(const ListArray& groupings, + const std::shared_ptr& batch); + } // namespace dataset } // namespace arrow diff --git a/cpp/src/arrow/dataset/partition_test.cc b/cpp/src/arrow/dataset/partition_test.cc index 7bf7efa3761..0cdb5eb50b6 100644 --- a/cpp/src/arrow/dataset/partition_test.cc +++ b/cpp/src/arrow/dataset/partition_test.cc @@ -105,17 +105,20 @@ TEST_F(TestPartitioning, DirectoryPartitioning) { partitioning_ = std::make_shared( schema({field("alpha", int32()), field("beta", utf8())})); - AssertParse("/0/hello", "alpha"_ == int32_t(0) and "beta"_ == "hello"); - AssertParse("/3", "alpha"_ == int32_t(3)); + AssertParse("/0/hello", and_(equal(field_ref("alpha"), literal(0)), + equal(field_ref("beta"), literal("hello")))); + AssertParse("/3", equal(field_ref("alpha"), literal(3))); AssertParseError("/world/0"); // reversed order AssertParseError("/0.0/foo"); // invalid alpha AssertParseError("/3.25"); // invalid alpha with missing beta AssertParse("", literal(true)); // no segments to parse // gotcha someday: - AssertParse("/0/dat.parquet", "alpha"_ == int32_t(0) and "beta"_ == "dat.parquet"); + AssertParse("/0/dat.parquet", and_(equal(field_ref("alpha"), literal(0)), + equal(field_ref("beta"), literal("dat.parquet")))); - AssertParse("/0/foo/ignored=2341", "alpha"_ == int32_t(0) and "beta"_ == "foo"); + AssertParse("/0/foo/ignored=2341", and_(equal(field_ref("alpha"), literal(0)), + equal(field_ref("beta"), literal("foo")))); } TEST_F(TestPartitioning, DirectoryPartitioningFormat) { @@ -124,20 +127,28 @@ TEST_F(TestPartitioning, DirectoryPartitioningFormat) { written_schema_ = partitioning_->schema(); - AssertFormat("alpha"_ == int32_t(0) and "beta"_ == "hello", "0/hello"); - AssertFormat("beta"_ == "hello" and "alpha"_ == int32_t(0), "0/hello"); - AssertFormat("alpha"_ == int32_t(0), "0"); - AssertFormatError("beta"_ == "hello"); + AssertFormat(and_(equal(field_ref("alpha"), literal(0)), + equal(field_ref("beta"), literal("hello"))), + "0/hello"); + AssertFormat(and_(equal(field_ref("beta"), literal("hello")), + equal(field_ref("alpha"), literal(0))), + "0/hello"); + AssertFormat(equal(field_ref("alpha"), literal(0)), "0"); + AssertFormatError(equal(field_ref("beta"), literal("hello"))); AssertFormat(literal(true), ""); ASSERT_OK_AND_ASSIGN(written_schema_, written_schema_->AddField(0, field("gamma", utf8()))); - AssertFormat("gamma"_ == "yo" and "alpha"_ == int32_t(0) and "beta"_ == "hello", + AssertFormat(and_({equal(field_ref("gamma"), literal("yo")), + equal(field_ref("alpha"), literal(0)), + equal(field_ref("beta"), literal("hello"))}), "0/hello"); // written_schema_ is incompatible with partitioning_'s schema written_schema_ = schema({field("alpha", utf8()), field("beta", utf8())}); - AssertFormatError("alpha"_ == "0.0" and "beta"_ == "hello"); + AssertFormatError( + and_(equal(field_ref("alpha"), literal("0.0")), + equal(field_ref("beta"), literal("hello")))); } TEST_F(TestPartitioning, DirectoryPartitioningWithTemporal) { @@ -147,7 +158,9 @@ TEST_F(TestPartitioning, DirectoryPartitioningWithTemporal) { ASSERT_OK_AND_ASSIGN(auto day, StringScalar("2020-06-08").CastTo(temporal)); AssertParse("/2020/06/2020-06-08", - "year"_ == int32_t(2020) and "month"_ == int8_t(6) and "day"_ == day); + and_({equal(field_ref("year"), literal(2020)), + equal(field_ref("month"), literal(6)), + equal(field_ref("day"), literal(day))})); } } @@ -207,7 +220,7 @@ TEST_F(TestPartitioning, DictionaryHasUniqueValues) { std::make_shared(index_and_dictionary, alpha->type()); auto path = "/" + expected_dictionary->GetString(i); - AssertParse(path, "alpha"_ == dictionary_scalar); + AssertParse(path, equal(field_ref("alpha"), literal(dictionary_scalar))); } AssertParseError("/yosemite"); // not in inspected dictionary @@ -223,17 +236,21 @@ TEST_F(TestPartitioning, HivePartitioning) { partitioning_ = std::make_shared( schema({field("alpha", int32()), field("beta", float32())})); - AssertParse("/alpha=0/beta=3.25", "alpha"_ == int32_t(0) and "beta"_ == 3.25f); - AssertParse("/beta=3.25/alpha=0", "beta"_ == 3.25f and "alpha"_ == int32_t(0)); - AssertParse("/alpha=0", "alpha"_ == int32_t(0)); - AssertParse("/beta=3.25", "beta"_ == 3.25f); + AssertParse("/alpha=0/beta=3.25", and_(equal(field_ref("alpha"), literal(0)), + equal(field_ref("beta"), literal(3.25f)))); + AssertParse("/beta=3.25/alpha=0", and_(equal(field_ref("beta"), literal(3.25f)), + equal(field_ref("alpha"), literal(0)))); + AssertParse("/alpha=0", equal(field_ref("alpha"), literal(0))); + AssertParse("/beta=3.25", equal(field_ref("beta"), literal(3.25f))); AssertParse("", literal(true)); AssertParse("/alpha=0/unexpected/beta=3.25", - "alpha"_ == int32_t(0) and "beta"_ == 3.25f); + and_(equal(field_ref("alpha"), literal(0)), + equal(field_ref("beta"), literal(3.25f)))); AssertParse("/alpha=0/beta=3.25/ignored=2341", - "alpha"_ == int32_t(0) and "beta"_ == 3.25f); + and_(equal(field_ref("alpha"), literal(0)), + equal(field_ref("beta"), literal(3.25f)))); AssertParse("/ignored=2341", literal(true)); @@ -246,20 +263,28 @@ TEST_F(TestPartitioning, HivePartitioningFormat) { written_schema_ = partitioning_->schema(); - AssertFormat("alpha"_ == int32_t(0) and "beta"_ == 3.25f, "alpha=0/beta=3.25"); - AssertFormat("beta"_ == 3.25f and "alpha"_ == int32_t(0), "alpha=0/beta=3.25"); - AssertFormat("alpha"_ == int32_t(0), "alpha=0"); - AssertFormat("beta"_ == 3.25f, "alpha/beta=3.25"); + AssertFormat(and_(equal(field_ref("alpha"), literal(0)), + equal(field_ref("beta"), literal(3.25f))), + "alpha=0/beta=3.25"); + AssertFormat(and_(equal(field_ref("beta"), literal(3.25f)), + equal(field_ref("alpha"), literal(0))), + "alpha=0/beta=3.25"); + AssertFormat(equal(field_ref("alpha"), literal(0)), "alpha=0"); + AssertFormat(equal(field_ref("beta"), literal(3.25f)), "alpha/beta=3.25"); AssertFormat(literal(true), ""); ASSERT_OK_AND_ASSIGN(written_schema_, written_schema_->AddField(0, field("gamma", utf8()))); - AssertFormat("gamma"_ == "yo" and "alpha"_ == int32_t(0) and "beta"_ == 3.25f, + AssertFormat(and_({equal(field_ref("gamma"), literal("yo")), + equal(field_ref("alpha"), literal(0)), + equal(field_ref("beta"), literal(3.25f))}), "alpha=0/beta=3.25"); // written_schema_ is incompatible with partitioning_'s schema written_schema_ = schema({field("alpha", utf8()), field("beta", utf8())}); - AssertFormatError("alpha"_ == "0.0" and "beta"_ == "hello"); + AssertFormatError( + and_(equal(field_ref("alpha"), literal("0.0")), + equal(field_ref("beta"), literal("hello")))); } TEST_F(TestPartitioning, DiscoverHiveSchema) { @@ -326,7 +351,7 @@ TEST_F(TestPartitioning, HiveDictionaryHasUniqueValues) { std::make_shared(index_and_dictionary, alpha->type()); auto path = "/alpha=" + expected_dictionary->GetString(i); - AssertParse(path, "alpha"_ == dictionary_scalar); + AssertParse(path, equal(field_ref("alpha"), literal(dictionary_scalar))); } AssertParseError("/alpha=yosemite"); // not in inspected dictionary @@ -365,9 +390,12 @@ TEST_F(TestPartitioning, EtlThenHive) { }); AssertParse("/1999/12/31/00/alpha=0/beta=3.25", - "year"_ == int16_t(1999) and "month"_ == int8_t(12) and - "day"_ == int8_t(31) and "hour"_ == int8_t(0) and - ("alpha"_ == int32_t(0) and "beta"_ == 3.25f)); + and_({equal(field_ref("year"), literal(1999)), + equal(field_ref("month"), literal(12)), + equal(field_ref("day"), literal(31)), + equal(field_ref("hour"), literal(0)), + and_(equal(field_ref("alpha"), literal(0)), + equal(field_ref("beta"), literal(3.25f)))})); AssertParseError("/20X6/03/21/05/alpha=0/beta=3.25"); } @@ -382,7 +410,7 @@ TEST_F(TestPartitioning, Set) { auto schm = schema({field("x", int32())}); // An adhoc partitioning which parses segments like "/x in [1 4 5]" - // into ("x"_ == 1 or "x"_ == 4 or "x"_ == 5) + // into (field_ref("x") == 1 or field_ref("x") == 4 or field_ref("x") == 5) partitioning_ = std::make_shared( schm, [&](const std::string& path) -> Result { std::vector subexpressions; @@ -402,18 +430,21 @@ TEST_F(TestPartitioning, Set) { } subexpressions.push_back(call("is_in", {field_ref(matches[1])}, - compute::SetLookupOptions{ints(set), true})); + compute::SetLookupOptions{ints(set)})); } return and_(std::move(subexpressions)); }); - AssertParse("/x in [1]", "x"_.In(ints({1}))); - AssertParse("/x in [1 4 5]", "x"_.In(ints({1, 4, 5}))); - AssertParse("/x in []", "x"_.In(ints({}))); + auto x_in = [&](std::vector set) { + return call("is_in", {field_ref("x")}, compute::SetLookupOptions{ints(set)}); + }; + AssertParse("/x in [1]", x_in({1})); + AssertParse("/x in [1 4 5]", x_in({1, 4, 5})); + AssertParse("/x in []", x_in({})); } // An adhoc partitioning which parses segments like "/x=[-3.25, 0.0)" -// into ("x"_ >= -3.25 and "x" < 0.0) +// into (field_ref("x") >= -3.25 and "x" < 0.0) class RangePartitioning : public Partitioning { public: explicit RangePartitioning(std::shared_ptr s) : Partitioning(std::move(s)) {} @@ -477,8 +508,12 @@ TEST_F(TestPartitioning, Range) { schema({field("x", float64()), field("y", float64()), field("z", float64())})); AssertParse("/x=[-1.5 0.0)/y=[0.0 1.5)/z=(1.5 3.0]", - ("x"_ >= -1.5 and "x"_ < 0.0) and ("y"_ >= 0.0 and "y"_ < 1.5) and - ("z"_ > 1.5 and "z"_ <= 3.0)); + and_({and_(greater_equal(field_ref("x"), literal(-1.5)), + less(field_ref("x"), literal(0.0))), + and_(greater_equal(field_ref("y"), literal(0.0)), + less(field_ref("y"), literal(1.5))), + and_(greater(field_ref("z"), literal(1.5)), + less_equal(field_ref("z"), literal(3.0)))})); } TEST(TestStripPrefixAndFilename, Basic) { @@ -497,5 +532,57 @@ TEST(TestStripPrefixAndFilename, Basic) { "year=2019/month=12/day=01")); } +void AssertGrouping(const FieldVector& by_fields, const std::string& batch_json, + const std::string& expected_json) { + FieldVector fields_with_ids = by_fields; + fields_with_ids.push_back(field("ids", list(int32()))); + auto expected = ArrayFromJSON(struct_(fields_with_ids), expected_json); + + FieldVector fields_with_id = by_fields; + fields_with_id.push_back(field("id", int32())); + auto batch = RecordBatchFromJSON(schema(fields_with_id), batch_json); + + ASSERT_OK_AND_ASSIGN(auto by, batch->RemoveColumn(batch->num_columns() - 1) + .Map([](std::shared_ptr by) { + return by->ToStructArray(); + })); + + ASSERT_OK_AND_ASSIGN(auto groupings_and_values, MakeGroupings(*by)); + + auto groupings = + checked_pointer_cast(groupings_and_values->GetFieldByName("groupings")); + + ASSERT_OK_AND_ASSIGN(std::shared_ptr grouped_ids, + ApplyGroupings(*groupings, *batch->GetColumnByName("id"))); + + ArrayVector columns = + checked_cast(*groupings_and_values->GetFieldByName("values")) + .fields(); + columns.push_back(grouped_ids); + + ASSERT_OK_AND_ASSIGN(auto actual, StructArray::Make(columns, fields_with_ids)); + + AssertArraysEqual(*expected, *actual, /*verbose=*/true); +} + +TEST(GroupTest, Basics) { + AssertGrouping({field("a", utf8()), field("b", int32())}, R"([ + {"a": "ex", "b": 0, "id": 0}, + {"a": "ex", "b": 0, "id": 1}, + {"a": "why", "b": 0, "id": 2}, + {"a": "ex", "b": 1, "id": 3}, + {"a": "why", "b": 0, "id": 4}, + {"a": "ex", "b": 1, "id": 5}, + {"a": "ex", "b": 0, "id": 6}, + {"a": "why", "b": 1, "id": 7} + ])", + R"([ + {"a": "ex", "b": 0, "ids": [0, 1, 6]}, + {"a": "why", "b": 0, "ids": [2, 4]}, + {"a": "ex", "b": 1, "ids": [3, 5]}, + {"a": "why", "b": 1, "ids": [7]} + ])"); +} + } // namespace dataset } // namespace arrow diff --git a/cpp/src/arrow/dataset/scanner.cc b/cpp/src/arrow/dataset/scanner.cc index 043f6695222..809d83b7e2b 100644 --- a/cpp/src/arrow/dataset/scanner.cc +++ b/cpp/src/arrow/dataset/scanner.cc @@ -23,7 +23,6 @@ #include "arrow/dataset/dataset.h" #include "arrow/dataset/dataset_internal.h" -#include "arrow/dataset/filter.h" #include "arrow/dataset/scanner_internal.h" #include "arrow/table.h" #include "arrow/util/iterator.h" diff --git a/cpp/src/arrow/dataset/scanner_internal.h b/cpp/src/arrow/dataset/scanner_internal.h index 686be1e6dfb..eb11f326570 100644 --- a/cpp/src/arrow/dataset/scanner_internal.h +++ b/cpp/src/arrow/dataset/scanner_internal.h @@ -24,7 +24,6 @@ #include "arrow/compute/api_vector.h" #include "arrow/compute/exec.h" #include "arrow/dataset/dataset_internal.h" -#include "arrow/dataset/filter.h" #include "arrow/dataset/partition.h" #include "arrow/dataset/scanner.h" diff --git a/cpp/src/arrow/dataset/test_util.h b/cpp/src/arrow/dataset/test_util.h index 44d211260b6..76015fa365c 100644 --- a/cpp/src/arrow/dataset/test_util.h +++ b/cpp/src/arrow/dataset/test_util.h @@ -33,7 +33,6 @@ #include "arrow/dataset/dataset_internal.h" #include "arrow/dataset/discovery.h" #include "arrow/dataset/file_base.h" -#include "arrow/dataset/filter.h" #include "arrow/filesystem/localfs.h" #include "arrow/filesystem/mockfs.h" #include "arrow/filesystem/path_util.h" @@ -59,37 +58,9 @@ const std::shared_ptr kBoringSchema = schema({ field("bool", boolean()), field("str", utf8()), field("dict_str", dictionary(int32(), utf8())), + field("ts_ns", timestamp(TimeUnit::NANO)), }); -inline namespace string_literals { -// clang-format off -inline FieldExpression operator"" _(const char* name, size_t name_length) { - // clang-format on - return FieldExpression({name, name_length}); -} -} // namespace string_literals - -#define COMPARISON_FACTORY(NAME, FACTORY_NAME, OP) \ - template ::type>::value>::type> \ - ComparisonExpression operator OP(const Expression& lhs, T&& rhs) { \ - return ComparisonExpression(CompareOperator::NAME, lhs.Copy(), \ - scalar(std::forward(rhs))); \ - } \ - \ - inline ComparisonExpression operator OP(const Expression& lhs, \ - const Expression& rhs) { \ - return ComparisonExpression(CompareOperator::NAME, lhs.Copy(), rhs.Copy()); \ - } - -COMPARISON_FACTORY(EQUAL, equal, ==) -COMPARISON_FACTORY(NOT_EQUAL, not_equal, !=) -COMPARISON_FACTORY(GREATER, greater, >) -COMPARISON_FACTORY(GREATER_EQUAL, greater_equal, >=) -COMPARISON_FACTORY(LESS, less, <) -COMPARISON_FACTORY(LESS_EQUAL, less_equal, <=) -#undef COMPARISON_FACTORY - using fs::internal::GetAbstractPathExtension; using internal::checked_cast; using internal::checked_pointer_cast; diff --git a/cpp/src/arrow/dataset/type_fwd.h b/cpp/src/arrow/dataset/type_fwd.h index 81441a18166..a80b333f0c8 100644 --- a/cpp/src/arrow/dataset/type_fwd.h +++ b/cpp/src/arrow/dataset/type_fwd.h @@ -68,7 +68,6 @@ class ParquetFileFragment; class ParquetFileWriter; class ParquetFileWriteOptions; -class Expression; class Expression2; class Partitioning; From ec6d9a0b3fe376cb6d1f6d9793fddb7312e21d8f Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Thu, 10 Dec 2020 14:05:32 -0500 Subject: [PATCH 06/38] rename Expression2 -> Expression --- .../arrow/dataset-parquet-scan-example.cc | 4 +- cpp/src/arrow/dataset/dataset.cc | 14 +- cpp/src/arrow/dataset/dataset.h | 24 +-- cpp/src/arrow/dataset/dataset_internal.h | 2 +- cpp/src/arrow/dataset/discovery.h | 6 +- cpp/src/arrow/dataset/expression.cc | 180 +++++++++--------- cpp/src/arrow/dataset/expression.h | 94 ++++----- cpp/src/arrow/dataset/expression_internal.h | 42 ++-- cpp/src/arrow/dataset/expression_test.cc | 92 ++++----- cpp/src/arrow/dataset/file_base.cc | 10 +- cpp/src/arrow/dataset/file_base.h | 12 +- cpp/src/arrow/dataset/file_parquet.cc | 16 +- cpp/src/arrow/dataset/file_parquet.h | 14 +- cpp/src/arrow/dataset/file_parquet_test.cc | 4 +- cpp/src/arrow/dataset/file_test.cc | 8 +- cpp/src/arrow/dataset/partition.cc | 18 +- cpp/src/arrow/dataset/partition.h | 22 +-- cpp/src/arrow/dataset/partition_test.cc | 20 +- cpp/src/arrow/dataset/scanner.cc | 2 +- cpp/src/arrow/dataset/scanner.h | 4 +- cpp/src/arrow/dataset/scanner_internal.h | 10 +- cpp/src/arrow/dataset/test_util.h | 12 +- cpp/src/arrow/dataset/type_fwd.h | 2 +- 23 files changed, 306 insertions(+), 306 deletions(-) diff --git a/cpp/examples/arrow/dataset-parquet-scan-example.cc b/cpp/examples/arrow/dataset-parquet-scan-example.cc index 46778cebaa5..197ca5aa4c6 100644 --- a/cpp/examples/arrow/dataset-parquet-scan-example.cc +++ b/cpp/examples/arrow/dataset-parquet-scan-example.cc @@ -60,7 +60,7 @@ struct Configuration { // Indicates the filter by which rows will be filtered. This optimization can // make use of partition information and/or file metadata if possible. - ds::Expression2 filter = + ds::Expression filter = ds::greater(ds::field_ref("total_amount"), ds::literal(1000.0f)); ds::InspectOptions inspect_options{}; @@ -146,7 +146,7 @@ std::shared_ptr GetDatasetFromPath( std::shared_ptr GetScannerFromDataset(std::shared_ptr dataset, std::vector columns, - ds::Expression2 filter, + ds::Expression filter, bool use_threads) { auto scanner_builder = dataset->NewScan().ValueOrDie(); diff --git a/cpp/src/arrow/dataset/dataset.cc b/cpp/src/arrow/dataset/dataset.cc index bedecf1e3a5..e2386dddec7 100644 --- a/cpp/src/arrow/dataset/dataset.cc +++ b/cpp/src/arrow/dataset/dataset.cc @@ -31,7 +31,7 @@ namespace arrow { namespace dataset { -Fragment::Fragment(Expression2 partition_expression, +Fragment::Fragment(Expression partition_expression, std::shared_ptr physical_schema) : partition_expression_(std::move(partition_expression)), physical_schema_(std::move(physical_schema)) {} @@ -58,14 +58,14 @@ Result> InMemoryFragment::ReadPhysicalSchemaImpl() { InMemoryFragment::InMemoryFragment(std::shared_ptr schema, RecordBatchVector record_batches, - Expression2 partition_expression) + Expression partition_expression) : Fragment(std::move(partition_expression), std::move(schema)), record_batches_(std::move(record_batches)) { DCHECK_NE(physical_schema_, nullptr); } InMemoryFragment::InMemoryFragment(RecordBatchVector record_batches, - Expression2 partition_expression) + Expression partition_expression) : InMemoryFragment(record_batches.empty() ? schema({}) : record_batches[0]->schema(), std::move(record_batches), std::move(partition_expression)) {} @@ -92,7 +92,7 @@ Result InMemoryFragment::Scan(std::shared_ptr opt return MakeMapIterator(fn, std::move(batches_it)); } -Dataset::Dataset(std::shared_ptr schema, Expression2 partition_expression) +Dataset::Dataset(std::shared_ptr schema, Expression partition_expression) : schema_(std::move(schema)), partition_expression_(std::move(partition_expression)) {} @@ -110,7 +110,7 @@ Result Dataset::GetFragments() { return GetFragments(std::move(predicate)); } -Result Dataset::GetFragments(Expression2 predicate) { +Result Dataset::GetFragments(Expression predicate) { ARROW_ASSIGN_OR_RAISE( predicate, SimplifyWithGuarantee(std::move(predicate), partition_expression_)); return predicate.IsSatisfiable() ? GetFragmentsImpl(std::move(predicate)) @@ -154,7 +154,7 @@ Result> InMemoryDataset::ReplaceSchema( return std::make_shared(std::move(schema), get_batches_); } -Result InMemoryDataset::GetFragmentsImpl(Expression2) { +Result InMemoryDataset::GetFragmentsImpl(Expression) { auto schema = this->schema(); auto create_fragment = @@ -195,7 +195,7 @@ Result> UnionDataset::ReplaceSchema( new UnionDataset(std::move(schema), std::move(children))); } -Result UnionDataset::GetFragmentsImpl(Expression2 predicate) { +Result UnionDataset::GetFragmentsImpl(Expression predicate) { return GetFragmentsFromDatasets(children_, predicate); } diff --git a/cpp/src/arrow/dataset/dataset.h b/cpp/src/arrow/dataset/dataset.h index 88782831670..4ad6ecbe74a 100644 --- a/cpp/src/arrow/dataset/dataset.h +++ b/cpp/src/arrow/dataset/dataset.h @@ -72,19 +72,19 @@ class ARROW_DS_EXPORT Fragment { /// \brief An expression which evaluates to true for all data viewed by this /// Fragment. - const Expression2& partition_expression() const { return partition_expression_; } + const Expression& partition_expression() const { return partition_expression_; } virtual ~Fragment() = default; protected: Fragment() = default; - explicit Fragment(Expression2 partition_expression, + explicit Fragment(Expression partition_expression, std::shared_ptr physical_schema); virtual Result> ReadPhysicalSchemaImpl() = 0; util::Mutex physical_schema_mutex_; - Expression2 partition_expression_ = literal(true); + Expression partition_expression_ = literal(true); std::shared_ptr physical_schema_; }; @@ -93,9 +93,9 @@ class ARROW_DS_EXPORT Fragment { class ARROW_DS_EXPORT InMemoryFragment : public Fragment { public: InMemoryFragment(std::shared_ptr schema, RecordBatchVector record_batches, - Expression2 = literal(true)); + Expression = literal(true)); explicit InMemoryFragment(RecordBatchVector record_batches, - Expression2 = literal(true)); + Expression = literal(true)); Result Scan(std::shared_ptr options, std::shared_ptr context) override; @@ -122,14 +122,14 @@ class ARROW_DS_EXPORT Dataset : public std::enable_shared_from_this { Result> NewScan(); /// \brief GetFragments returns an iterator of Fragments given a predicate. - Result GetFragments(Expression2 predicate); + Result GetFragments(Expression predicate); Result GetFragments(); const std::shared_ptr& schema() const { return schema_; } /// \brief An expression which evaluates to true for all data viewed by this Dataset. /// May be null, which indicates no information is available. - const Expression2& partition_expression() const { return partition_expression_; } + const Expression& partition_expression() const { return partition_expression_; } /// \brief The name identifying the kind of Dataset virtual std::string type_name() const = 0; @@ -146,12 +146,12 @@ class ARROW_DS_EXPORT Dataset : public std::enable_shared_from_this { protected: explicit Dataset(std::shared_ptr schema) : schema_(std::move(schema)) {} - Dataset(std::shared_ptr schema, Expression2 partition_expression); + Dataset(std::shared_ptr schema, Expression partition_expression); - virtual Result GetFragmentsImpl(Expression2 predicate) = 0; + virtual Result GetFragmentsImpl(Expression predicate) = 0; std::shared_ptr schema_; - Expression2 partition_expression_ = literal(true); + Expression partition_expression_ = literal(true); }; /// \brief A Source which yields fragments wrapping a stream of record batches. @@ -181,7 +181,7 @@ class ARROW_DS_EXPORT InMemoryDataset : public Dataset { protected: Result GetFragmentsImpl( - Expression2 predicate) override; + Expression predicate) override; std::shared_ptr get_batches_; }; @@ -206,7 +206,7 @@ class ARROW_DS_EXPORT UnionDataset : public Dataset { protected: Result GetFragmentsImpl( - Expression2 predicate) override; + Expression predicate) override; explicit UnionDataset(std::shared_ptr schema, DatasetVector children) : Dataset(std::move(schema)), children_(std::move(children)) {} diff --git a/cpp/src/arrow/dataset/dataset_internal.h b/cpp/src/arrow/dataset/dataset_internal.h index 9e070192cd0..cb6d406fb70 100644 --- a/cpp/src/arrow/dataset/dataset_internal.h +++ b/cpp/src/arrow/dataset/dataset_internal.h @@ -36,7 +36,7 @@ namespace dataset { /// \brief GetFragmentsFromDatasets transforms a vector into a /// flattened FragmentIterator. inline Result GetFragmentsFromDatasets(const DatasetVector& datasets, - Expression2 predicate) { + Expression predicate) { // Iterator auto datasets_it = MakeVectorIterator(datasets); diff --git a/cpp/src/arrow/dataset/discovery.h b/cpp/src/arrow/dataset/discovery.h index 0b1c94f5adc..b7786cd305f 100644 --- a/cpp/src/arrow/dataset/discovery.h +++ b/cpp/src/arrow/dataset/discovery.h @@ -90,8 +90,8 @@ class ARROW_DS_EXPORT DatasetFactory { virtual Result> Finish(FinishOptions options) = 0; /// \brief Optional root partition for the resulting Dataset. - const Expression2& root_partition() const { return root_partition_; } - Status SetRootPartition(Expression2 partition) { + const Expression& root_partition() const { return root_partition_; } + Status SetRootPartition(Expression partition) { root_partition_ = std::move(partition); return Status::OK(); } @@ -101,7 +101,7 @@ class ARROW_DS_EXPORT DatasetFactory { protected: DatasetFactory(); - Expression2 root_partition_; + Expression root_partition_; }; /// \brief DatasetFactory provides a way to inspect/discover a Dataset's diff --git a/cpp/src/arrow/dataset/expression.cc b/cpp/src/arrow/dataset/expression.cc index f1b17752bf2..6ff53d4190c 100644 --- a/cpp/src/arrow/dataset/expression.cc +++ b/cpp/src/arrow/dataset/expression.cc @@ -39,17 +39,17 @@ using internal::checked_pointer_cast; namespace dataset { -const Expression2::Call* Expression2::call() const { +const Expression::Call* Expression::call() const { return util::get_if(impl_.get()); } -const Datum* Expression2::literal() const { return util::get_if(impl_.get()); } +const Datum* Expression::literal() const { return util::get_if(impl_.get()); } -const FieldRef* Expression2::field_ref() const { +const FieldRef* Expression::field_ref() const { return util::get_if(impl_.get()); } -std::string Expression2::ToString() const { +std::string Expression::ToString() const { if (auto lit = literal()) { if (lit->is_scalar()) { return lit->scalar()->ToString(); @@ -118,14 +118,14 @@ std::string Expression2::ToString() const { return out; } -void PrintTo(const Expression2& expr, std::ostream* os) { +void PrintTo(const Expression& expr, std::ostream* os) { *os << expr.ToString(); if (expr.IsBound()) { *os << "[bound]"; } } -bool Expression2::Equals(const Expression2& other) const { +bool Expression::Equals(const Expression& other) const { if (Identical(*this, other)) return true; if (impl_->index() != other.impl_->index()) { @@ -192,7 +192,7 @@ bool Expression2::Equals(const Expression2& other) const { return false; } -size_t Expression2::hash() const { +size_t Expression::hash() const { if (auto lit = literal()) { if (lit->is_scalar()) { return Scalar::Hash::hash(*lit->scalar()); @@ -213,7 +213,7 @@ size_t Expression2::hash() const { return out; } -bool Expression2::IsBound() const { +bool Expression::IsBound() const { if (descr_.type == nullptr) return false; if (auto lit = literal()) return true; @@ -222,14 +222,14 @@ bool Expression2::IsBound() const { auto call = CallNotNull(*this); - for (const Expression2& arg : call->arguments) { + for (const Expression& arg : call->arguments) { if (!arg.IsBound()) return false; } return call->kernel != nullptr; } -bool Expression2::IsScalarExpression() const { +bool Expression::IsScalarExpression() const { if (auto lit = literal()) { return lit->is_scalar(); } @@ -239,7 +239,7 @@ bool Expression2::IsScalarExpression() const { auto call = CallNotNull(*this); - for (const Expression2& arg : call->arguments) { + for (const Expression& arg : call->arguments) { if (!arg.IsScalarExpression()) return false; } @@ -259,7 +259,7 @@ bool Expression2::IsScalarExpression() const { return call->function_kind == compute::Function::SCALAR; } -bool Expression2::IsNullLiteral() const { +bool Expression::IsNullLiteral() const { if (auto lit = literal()) { if (lit->null_count() == lit->length()) { return true; @@ -269,7 +269,7 @@ bool Expression2::IsNullLiteral() const { return false; } -bool Expression2::IsSatisfiable() const { +bool Expression::IsSatisfiable() const { if (descr_.type && descr_.type->id() == Type::NA) { return false; } @@ -303,7 +303,7 @@ inline bool KernelStateIsImmutable(const std::string& function) { } Result> InitKernelState( - const Expression2::Call& call, compute::ExecContext* exec_context) { + const Expression::Call& call, compute::ExecContext* exec_context) { if (!call.kernel->init) return nullptr; compute::KernelContext kernel_context(exec_context); @@ -314,7 +314,7 @@ Result> InitKernelState( return std::move(kernel_state); } -Status MaybeInsertCast(std::shared_ptr to_type, Expression2* expr) { +Status MaybeInsertCast(std::shared_ptr to_type, Expression* expr) { if (expr->descr().type->Equals(to_type)) { return Status::OK(); } @@ -334,15 +334,15 @@ Status MaybeInsertCast(std::shared_ptr to_type, Expression2* expr) { auto call_with_cast = *CallNotNull(with_cast); call_with_cast.arguments[0] = std::move(*expr); - *expr = Expression2(std::make_shared(std::move(call_with_cast)), + *expr = Expression(std::make_shared(std::move(call_with_cast)), ValueDescr{std::move(to_type), expr->descr().shape}); return Status::OK(); } -Status InsertImplicitCasts(Expression2::Call* call) { +Status InsertImplicitCasts(Expression::Call* call) { DCHECK(std::all_of(call->arguments.begin(), call->arguments.end(), - [](const Expression2& argument) { return argument.IsBound(); })); + [](const Expression& argument) { return argument.IsBound(); })); if (Comparison::Get(call->function)) { for (auto&& argument : call->arguments) { @@ -389,7 +389,7 @@ Status InsertImplicitCasts(Expression2::Call* call) { return Status::OK(); } -Result Expression2::Bind(ValueDescr in, +Result Expression::Bind(ValueDescr in, compute::ExecContext* exec_context) const { if (exec_context == nullptr) { compute::ExecContext exec_context; @@ -426,15 +426,15 @@ Result Expression2::Bind(ValueDescr in, ARROW_ASSIGN_OR_RAISE(auto descr, bound_call.kernel->signature->out_type().Resolve( &kernel_context, descrs)); - return Expression2(std::make_shared(std::move(bound_call)), std::move(descr)); + return Expression(std::make_shared(std::move(bound_call)), std::move(descr)); } -Result Expression2::Bind(const Schema& in_schema, +Result Expression::Bind(const Schema& in_schema, compute::ExecContext* exec_context) const { return Bind(ValueDescr::Array(struct_(in_schema.fields())), exec_context); } -Result ExecuteScalarExpression(const Expression2& expr, const Datum& input, +Result ExecuteScalarExpression(const Expression& expr, const Datum& input, compute::ExecContext* exec_context) { if (exec_context == nullptr) { compute::ExecContext exec_context; @@ -489,12 +489,12 @@ Result ExecuteScalarExpression(const Expression2& expr, const Datum& inpu return executor->WrapResults(arguments, listener->values()); } -std::array, 2> -ArgumentsAndFlippedArguments(const Expression2::Call& call) { +std::array, 2> +ArgumentsAndFlippedArguments(const Expression::Call& call) { DCHECK_EQ(call.arguments.size(), 2); - return {std::pair{call.arguments[0], + return {std::pair{call.arguments[0], call.arguments[1]}, - std::pair{call.arguments[1], + std::pair{call.arguments[1], call.arguments[0]}}; } @@ -511,14 +511,14 @@ util::optional FoldLeft(It begin, It end, const BinOp& bin_op) { } util::optional GetNullHandling( - const Expression2::Call& call) { + const Expression::Call& call) { if (call.function_kind == compute::Function::SCALAR) { return static_cast(call.kernel)->null_handling; } return util::nullopt; } -bool DefinitelyNotNull(const Expression2& expr) { +bool DefinitelyNotNull(const Expression& expr) { DCHECK(expr.IsBound()); if (expr.literal()) { @@ -541,7 +541,7 @@ bool DefinitelyNotNull(const Expression2& expr) { return false; } -std::vector FieldsInExpression(const Expression2& expr) { +std::vector FieldsInExpression(const Expression& expr) { if (auto lit = expr.literal()) return {}; if (auto ref = expr.field_ref()) { @@ -549,20 +549,20 @@ std::vector FieldsInExpression(const Expression2& expr) { } std::vector fields; - for (const Expression2& arg : CallNotNull(expr)->arguments) { + for (const Expression& arg : CallNotNull(expr)->arguments) { auto argument_fields = FieldsInExpression(arg); std::move(argument_fields.begin(), argument_fields.end(), std::back_inserter(fields)); } return fields; } -Result FoldConstants(Expression2 expr) { +Result FoldConstants(Expression expr) { return Modify( - std::move(expr), [](Expression2 expr) { return expr; }, - [](Expression2 expr, ...) -> Result { + std::move(expr), [](Expression expr) { return expr; }, + [](Expression expr, ...) -> Result { auto call = CallNotNull(expr); if (std::all_of(call->arguments.begin(), call->arguments.end(), - [](const Expression2& argument) { return argument.literal(); })) { + [](const Expression& argument) { return argument.literal(); })) { // all arguments are literal; we can evaluate this subexpression *now* static const Datum ignored_input; ARROW_ASSIGN_OR_RAISE(Datum constant, @@ -616,8 +616,8 @@ Result FoldConstants(Expression2 expr) { }); } -inline std::vector GuaranteeConjunctionMembers( - const Expression2& guaranteed_true_predicate) { +inline std::vector GuaranteeConjunctionMembers( + const Expression& guaranteed_true_predicate) { auto guarantee = guaranteed_true_predicate.call(); if (!guarantee || guarantee->function != "and_kleene") { return {guaranteed_true_predicate}; @@ -628,11 +628,11 @@ inline std::vector GuaranteeConjunctionMembers( // Conjunction members which are represented in known_values are erased from // conjunction_members Status ExtractKnownFieldValuesImpl( - std::vector* conjunction_members, + std::vector* conjunction_members, std::unordered_map* known_values) { auto unconsumed_end = std::partition(conjunction_members->begin(), conjunction_members->end(), - [](const Expression2& expr) { + [](const Expression& expr) { // search for an equality conditions between a field and a literal auto call = expr.call(); if (!call) return true; @@ -670,24 +670,24 @@ Status ExtractKnownFieldValuesImpl( } Result> ExtractKnownFieldValues( - const Expression2& guaranteed_true_predicate) { + const Expression& guaranteed_true_predicate) { auto conjunction_members = GuaranteeConjunctionMembers(guaranteed_true_predicate); std::unordered_map known_values; RETURN_NOT_OK(ExtractKnownFieldValuesImpl(&conjunction_members, &known_values)); return known_values; } -Result ReplaceFieldsWithKnownValues( +Result ReplaceFieldsWithKnownValues( const std::unordered_map& known_values, - Expression2 expr) { + Expression expr) { if (!expr.IsBound()) { return Status::Invalid( - "ReplaceFieldsWithKnownValues called on an unbound Expression2"); + "ReplaceFieldsWithKnownValues called on an unbound Expression"); } return Modify( std::move(expr), - [&known_values](Expression2 expr) -> Result { + [&known_values](Expression expr) -> Result { if (auto ref = expr.field_ref()) { auto it = known_values.find(*ref); if (it != known_values.end()) { @@ -698,10 +698,10 @@ Result ReplaceFieldsWithKnownValues( } return expr; }, - [](Expression2 expr, ...) { return expr; }); + [](Expression expr, ...) { return expr; }); } -inline bool IsBinaryAssociativeCommutative(const Expression2::Call& call) { +inline bool IsBinaryAssociativeCommutative(const Expression::Call& call) { static std::unordered_set binary_associative_commutative{ "and", "or", "and_kleene", "or_kleene", "xor", "multiply", "add", "multiply_checked", "add_checked"}; @@ -710,7 +710,7 @@ inline bool IsBinaryAssociativeCommutative(const Expression2::Call& call) { return it != binary_associative_commutative.end(); } -Result Canonicalize(Expression2 expr, compute::ExecContext* exec_context) { +Result Canonicalize(Expression expr, compute::ExecContext* exec_context) { if (exec_context == nullptr) { compute::ExecContext exec_context; return Canonicalize(std::move(expr), &exec_context); @@ -720,20 +720,20 @@ Result Canonicalize(Expression2 expr, compute::ExecContext* exec_co // (for example, when reorganizing an associative chain), add expressions to this set to // avoid unnecessary work struct { - std::unordered_set set_; + std::unordered_set set_; - bool operator()(const Expression2& expr) const { + bool operator()(const Expression& expr) const { return set_.find(expr) != set_.end(); } - void Add(std::vector exprs) { + void Add(std::vector exprs) { std::move(exprs.begin(), exprs.end(), std::inserter(set_, set_.end())); } } AlreadyCanonicalized; return Modify( std::move(expr), - [&AlreadyCanonicalized, exec_context](Expression2 expr) -> Result { + [&AlreadyCanonicalized, exec_context](Expression expr) -> Result { auto call = expr.call(); if (!call) return expr; @@ -741,13 +741,13 @@ Result Canonicalize(Expression2 expr, compute::ExecContext* exec_co if (IsBinaryAssociativeCommutative(*call)) { struct { - int Priority(const Expression2& operand) const { + int Priority(const Expression& operand) const { // order literals first, starting with nulls if (operand.IsNullLiteral()) return 0; if (operand.literal()) return 1; return 2; } - bool operator()(const Expression2& l, const Expression2& r) const { + bool operator()(const Expression& l, const Expression& r) const { return Priority(l) < Priority(r); } } CanonicalOrdering; @@ -766,10 +766,10 @@ Result Canonicalize(Expression2 expr, compute::ExecContext* exec_co const auto& descr = expr.descr(); auto folded = FoldLeft( chain.fringe.begin(), chain.fringe.end(), - [call, &descr, &AlreadyCanonicalized](Expression2 l, Expression2 r) { + [call, &descr, &AlreadyCanonicalized](Expression l, Expression r) { auto ret = *call; ret.arguments = {std::move(l), std::move(r)}; - Expression2 expr(std::make_shared(std::move(ret)), + Expression expr(std::make_shared(std::move(ret)), descr); AlreadyCanonicalized.Add({expr}); return expr; @@ -792,22 +792,22 @@ Result Canonicalize(Expression2 expr, compute::ExecContext* exec_co ARROW_ASSIGN_OR_RAISE(flipped_call.kernel, function->DispatchExact(descrs)); std::swap(flipped_call.arguments[0], flipped_call.arguments[1]); - return Expression2( - std::make_shared(std::move(flipped_call)), + return Expression( + std::make_shared(std::move(flipped_call)), expr.descr()); } } return expr; }, - [](Expression2 expr, ...) { return expr; }); + [](Expression expr, ...) { return expr; }); } -Result DirectComparisonSimplification(Expression2 expr, - const Expression2::Call& guarantee) { +Result DirectComparisonSimplification(Expression expr, + const Expression::Call& guarantee) { return Modify( - std::move(expr), [](Expression2 expr) { return expr; }, - [&guarantee](Expression2 expr, ...) -> Result { + std::move(expr), [](Expression expr) { return expr; }, + [&guarantee](Expression expr, ...) -> Result { auto call = expr.call(); if (!call) return expr; @@ -861,8 +861,8 @@ Result DirectComparisonSimplification(Expression2 expr, }); } -Result SimplifyWithGuarantee(Expression2 expr, - const Expression2& guaranteed_true_predicate) { +Result SimplifyWithGuarantee(Expression expr, + const Expression& guaranteed_true_predicate) { auto conjunction_members = GuaranteeConjunctionMembers(guaranteed_true_predicate); std::unordered_map known_values; @@ -896,7 +896,7 @@ Result SimplifyWithGuarantee(Expression2 expr, // Serialization is accomplished by converting expressions to KeyValueMetadata and storing // this in the schema of a RecordBatch. Embedded arrays and scalars are stored in its // columns. Finally, the RecordBatch is written to an IPC file. -Result> Serialize(const Expression2& expr) { +Result> Serialize(const Expression& expr) { struct { std::shared_ptr metadata_ = std::make_shared(); ArrayVector columns_; @@ -908,7 +908,7 @@ Result> Serialize(const Expression2& expr) { return std::to_string(ret); } - Status Visit(const Expression2& expr) { + Status Visit(const Expression& expr) { if (auto lit = expr.literal()) { if (!lit->is_scalar()) { return Status::NotImplemented("Serialization of non-scalar literals"); @@ -943,7 +943,7 @@ Result> Serialize(const Expression2& expr) { return Status::OK(); } - Result> operator()(const Expression2& expr) { + Result> operator()(const Expression& expr) { RETURN_NOT_OK(Visit(expr)); FieldVector fields(columns_.size()); for (size_t i = 0; i < fields.size(); ++i) { @@ -962,16 +962,16 @@ Result> Serialize(const Expression2& expr) { return stream->Finish(); } -Result Deserialize(const Buffer& buffer) { +Result Deserialize(const Buffer& buffer) { io::BufferReader stream(buffer); ARROW_ASSIGN_OR_RAISE(auto reader, ipc::RecordBatchFileReader::Open(&stream)); ARROW_ASSIGN_OR_RAISE(auto batch, reader->ReadRecordBatch(0)); if (batch->schema()->metadata() == nullptr) { - return Status::Invalid("serialized Expression2's batch repr had null metadata"); + return Status::Invalid("serialized Expression's batch repr had null metadata"); } if (batch->num_rows() != 1) { return Status::Invalid( - "serialized Expression2's batch repr was not a single row - had ", + "serialized Expression's batch repr was not a single row - had ", batch->num_rows()); } @@ -992,9 +992,9 @@ Result Deserialize(const Buffer& buffer) { return batch_.column(column_index)->GetScalar(0); } - Result GetOne() { + Result GetOne() { if (index_ >= metadata().size()) { - return Status::Invalid("unterminated serialized Expression2"); + return Status::Invalid("unterminated serialized Expression"); } const std::string& key = metadata().key(index_); @@ -1011,17 +1011,17 @@ Result Deserialize(const Buffer& buffer) { } if (key != "call") { - return Status::Invalid("Unrecognized serialized Expression2 key ", key); + return Status::Invalid("Unrecognized serialized Expression key ", key); } - std::vector arguments; + std::vector arguments; while (metadata().key(index_) != "end") { if (metadata().key(index_) == "options") { ARROW_ASSIGN_OR_RAISE(auto options_scalar, GetScalar(metadata().value(index_))); auto expr = call(value, std::move(arguments)); RETURN_NOT_OK(FunctionOptionsFromStructScalar( checked_cast(options_scalar.get()), - const_cast(expr.call()))); + const_cast(expr.call()))); index_ += 2; return expr; } @@ -1038,40 +1038,40 @@ Result Deserialize(const Buffer& buffer) { return FromRecordBatch{*batch, 0}.GetOne(); } -Expression2 project(std::vector values, std::vector names) { +Expression project(std::vector values, std::vector names) { return call("struct", std::move(values), compute::StructOptions{std::move(names)}); } -Expression2 equal(Expression2 lhs, Expression2 rhs) { +Expression equal(Expression lhs, Expression rhs) { return call("equal", {std::move(lhs), std::move(rhs)}); } -Expression2 not_equal(Expression2 lhs, Expression2 rhs) { +Expression not_equal(Expression lhs, Expression rhs) { return call("not_equal", {std::move(lhs), std::move(rhs)}); } -Expression2 less(Expression2 lhs, Expression2 rhs) { +Expression less(Expression lhs, Expression rhs) { return call("less", {std::move(lhs), std::move(rhs)}); } -Expression2 less_equal(Expression2 lhs, Expression2 rhs) { +Expression less_equal(Expression lhs, Expression rhs) { return call("less_equal", {std::move(lhs), std::move(rhs)}); } -Expression2 greater(Expression2 lhs, Expression2 rhs) { +Expression greater(Expression lhs, Expression rhs) { return call("greater", {std::move(lhs), std::move(rhs)}); } -Expression2 greater_equal(Expression2 lhs, Expression2 rhs) { +Expression greater_equal(Expression lhs, Expression rhs) { return call("greater_equal", {std::move(lhs), std::move(rhs)}); } -Expression2 and_(Expression2 lhs, Expression2 rhs) { +Expression and_(Expression lhs, Expression rhs) { return call("and_kleene", {std::move(lhs), std::move(rhs)}); } -Expression2 and_(const std::vector& operands) { - auto folded = FoldLeft(operands.begin(), +Expression and_(const std::vector& operands) { + auto folded = FoldLeft(operands.begin(), operands.end(), and_); if (folded) { return std::move(*folded); @@ -1079,12 +1079,12 @@ Expression2 and_(const std::vector& operands) { return literal(true); } -Expression2 or_(Expression2 lhs, Expression2 rhs) { +Expression or_(Expression lhs, Expression rhs) { return call("or_kleene", {std::move(lhs), std::move(rhs)}); } -Expression2 or_(const std::vector& operands) { - auto folded = FoldLeft(operands.begin(), +Expression or_(const std::vector& operands) { + auto folded = FoldLeft(operands.begin(), operands.end(), or_); if (folded) { return std::move(*folded); @@ -1092,13 +1092,13 @@ Expression2 or_(const std::vector& operands) { return literal(false); } -Expression2 not_(Expression2 operand) { return call("invert", {std::move(operand)}); } +Expression not_(Expression operand) { return call("invert", {std::move(operand)}); } -Expression2 operator&&(Expression2 lhs, Expression2 rhs) { +Expression operator&&(Expression lhs, Expression rhs) { return and_(std::move(lhs), std::move(rhs)); } -Expression2 operator||(Expression2 lhs, Expression2 rhs) { +Expression operator||(Expression lhs, Expression rhs) { return or_(std::move(lhs), std::move(rhs)); } diff --git a/cpp/src/arrow/dataset/expression.h b/cpp/src/arrow/dataset/expression.h index e6495433d23..a9ec5403068 100644 --- a/cpp/src/arrow/dataset/expression.h +++ b/cpp/src/arrow/dataset/expression.h @@ -45,11 +45,11 @@ namespace dataset { /// - A literal Datum. /// - A reference to a single (potentially nested) field of the input Datum. /// - A call to a compute function, with arguments specified by other Expressions. -class ARROW_DS_EXPORT Expression2 { +class ARROW_DS_EXPORT Expression { public: struct Call { std::string function; - std::vector arguments; + std::vector arguments; std::shared_ptr options; // post-Bind properties: @@ -60,24 +60,24 @@ class ARROW_DS_EXPORT Expression2 { }; std::string ToString() const; - bool Equals(const Expression2& other) const; + bool Equals(const Expression& other) const; size_t hash() const; struct Hash { - size_t operator()(const Expression2& expr) const { return expr.hash(); } + size_t operator()(const Expression& expr) const { return expr.hash(); } }; /// Bind this expression to the given input type, looking up Kernels and field types. /// Some expression simplification may be performed and implicit casts will be inserted. /// Any state necessary for execution will be initialized and returned. - Result Bind(ValueDescr in, compute::ExecContext* = NULLPTR) const; - Result Bind(const Schema& in_schema, + Result Bind(ValueDescr in, compute::ExecContext* = NULLPTR) const; + Result Bind(const Schema& in_schema, compute::ExecContext* = NULLPTR) const; // XXX someday // Clone all KernelState in this bound expression. If any function referenced by this // expression has mutable KernelState, it is not safe to execute or apply simplification // passes to it (or copies of it!) from multiple threads. Cloning state produces new - // KernelStates where necessary to ensure that Expression2s may be manipulated safely + // KernelStates where necessary to ensure that Expressions may be manipulated safely // on multiple threads. // Result CloneState() const; // Status SetState(ExpressionState); @@ -107,10 +107,10 @@ class ARROW_DS_EXPORT Expression2 { using Impl = util::Variant; - explicit Expression2(std::shared_ptr impl, ValueDescr descr = {}) + explicit Expression(std::shared_ptr impl, ValueDescr descr = {}) : impl_(std::move(impl)), descr_(std::move(descr)) {} - Expression2() = default; + Expression() = default; private: std::shared_ptr impl_; @@ -118,86 +118,86 @@ class ARROW_DS_EXPORT Expression2 { // XXX someday // NullGeneralization::type evaluates_to_null_; - ARROW_EXPORT friend bool Identical(const Expression2& l, const Expression2& r); + ARROW_EXPORT friend bool Identical(const Expression& l, const Expression& r); - ARROW_EXPORT friend void PrintTo(const Expression2&, std::ostream*); + ARROW_EXPORT friend void PrintTo(const Expression&, std::ostream*); }; -inline bool operator==(const Expression2& l, const Expression2& r) { return l.Equals(r); } -inline bool operator!=(const Expression2& l, const Expression2& r) { +inline bool operator==(const Expression& l, const Expression& r) { return l.Equals(r); } +inline bool operator!=(const Expression& l, const Expression& r) { return !l.Equals(r); } // Factories -inline Expression2 call(std::string function, std::vector arguments, +inline Expression call(std::string function, std::vector arguments, std::shared_ptr options = NULLPTR) { - Expression2::Call call; + Expression::Call call; call.function = std::move(function); call.arguments = std::move(arguments); call.options = std::move(options); - return Expression2(std::make_shared(std::move(call))); + return Expression(std::make_shared(std::move(call))); } template ::value>::type> -Expression2 call(std::string function, std::vector arguments, +Expression call(std::string function, std::vector arguments, Options options) { return call(std::move(function), std::move(arguments), std::make_shared(std::move(options))); } template -Expression2 field_ref(Args&&... args) { - return Expression2( - std::make_shared(FieldRef(std::forward(args)...))); +Expression field_ref(Args&&... args) { + return Expression( + std::make_shared(FieldRef(std::forward(args)...))); } template -Expression2 literal(Arg&& arg) { +Expression literal(Arg&& arg) { Datum lit(std::forward(arg)); ValueDescr descr = lit.descr(); - return Expression2(std::make_shared(std::move(lit)), + return Expression(std::make_shared(std::move(lit)), std::move(descr)); } ARROW_DS_EXPORT -std::vector FieldsInExpression(const Expression2&); +std::vector FieldsInExpression(const Expression&); ARROW_DS_EXPORT Result> ExtractKnownFieldValues( - const Expression2& guaranteed_true_predicate); + const Expression& guaranteed_true_predicate); -/// \defgroup expression-passes Functions for modification of Expression2s +/// \defgroup expression-passes Functions for modification of Expressions /// /// @{ /// /// These operate on a bound expression and its bound state simultaneously, -/// ensuring that Call Expression2s' KernelState can be utilized or reassociated. +/// ensuring that Call Expressions' KernelState can be utilized or reassociated. /// Weak canonicalization which establishes guarantees for subsequent passes. Even /// equivalent Expressions may result in different canonicalized expressions. /// TODO this could be a strong canonicalization ARROW_DS_EXPORT -Result Canonicalize(Expression2, compute::ExecContext* = NULLPTR); +Result Canonicalize(Expression, compute::ExecContext* = NULLPTR); /// Simplify Expressions based on literal arguments (for example, add(null, x) will always /// be null so replace the call with a null literal). Includes early evaluation of all /// calls whose arguments are entirely literal. ARROW_DS_EXPORT -Result FoldConstants(Expression2); +Result FoldConstants(Expression); ARROW_DS_EXPORT -Result ReplaceFieldsWithKnownValues( - const std::unordered_map& known_values, Expression2); +Result ReplaceFieldsWithKnownValues( + const std::unordered_map& known_values, Expression); /// Simplify an expression by replacing subexpressions based on a guarantee: /// a boolean expression which is guaranteed to evaluate to `true`. For example, this is /// used to remove redundant function calls from a filter expression or to replace a /// reference to a constant-value field with a literal. ARROW_DS_EXPORT -Result SimplifyWithGuarantee(Expression2, - const Expression2& guaranteed_true_predicate); +Result SimplifyWithGuarantee(Expression, + const Expression& guaranteed_true_predicate); /// @} @@ -205,39 +205,39 @@ Result SimplifyWithGuarantee(Expression2, /// Execute a scalar expression against the provided state and input Datum. This /// expression must be bound. -Result ExecuteScalarExpression(const Expression2&, const Datum& input, +Result ExecuteScalarExpression(const Expression&, const Datum& input, compute::ExecContext* = NULLPTR); // Serialization ARROW_DS_EXPORT -Result> Serialize(const Expression2&); +Result> Serialize(const Expression&); ARROW_DS_EXPORT -Result Deserialize(const Buffer&); +Result Deserialize(const Buffer&); // Convenience aliases for factories -ARROW_DS_EXPORT Expression2 project(std::vector values, +ARROW_DS_EXPORT Expression project(std::vector values, std::vector names); -ARROW_DS_EXPORT Expression2 equal(Expression2 lhs, Expression2 rhs); +ARROW_DS_EXPORT Expression equal(Expression lhs, Expression rhs); -ARROW_DS_EXPORT Expression2 not_equal(Expression2 lhs, Expression2 rhs); +ARROW_DS_EXPORT Expression not_equal(Expression lhs, Expression rhs); -ARROW_DS_EXPORT Expression2 less(Expression2 lhs, Expression2 rhs); +ARROW_DS_EXPORT Expression less(Expression lhs, Expression rhs); -ARROW_DS_EXPORT Expression2 less_equal(Expression2 lhs, Expression2 rhs); +ARROW_DS_EXPORT Expression less_equal(Expression lhs, Expression rhs); -ARROW_DS_EXPORT Expression2 greater(Expression2 lhs, Expression2 rhs); +ARROW_DS_EXPORT Expression greater(Expression lhs, Expression rhs); -ARROW_DS_EXPORT Expression2 greater_equal(Expression2 lhs, Expression2 rhs); +ARROW_DS_EXPORT Expression greater_equal(Expression lhs, Expression rhs); -ARROW_DS_EXPORT Expression2 and_(Expression2 lhs, Expression2 rhs); -ARROW_DS_EXPORT Expression2 and_(const std::vector&); -ARROW_DS_EXPORT Expression2 or_(Expression2 lhs, Expression2 rhs); -ARROW_DS_EXPORT Expression2 or_(const std::vector&); -ARROW_DS_EXPORT Expression2 not_(Expression2 operand); +ARROW_DS_EXPORT Expression and_(Expression lhs, Expression rhs); +ARROW_DS_EXPORT Expression and_(const std::vector&); +ARROW_DS_EXPORT Expression or_(Expression lhs, Expression rhs); +ARROW_DS_EXPORT Expression or_(const std::vector&); +ARROW_DS_EXPORT Expression not_(Expression operand); } // namespace dataset } // namespace arrow diff --git a/cpp/src/arrow/dataset/expression_internal.h b/cpp/src/arrow/dataset/expression_internal.h index 0324eaedc77..8c26e037f16 100644 --- a/cpp/src/arrow/dataset/expression_internal.h +++ b/cpp/src/arrow/dataset/expression_internal.h @@ -33,15 +33,15 @@ using internal::checked_cast; namespace dataset { -bool Identical(const Expression2& l, const Expression2& r) { return l.impl_ == r.impl_; } +bool Identical(const Expression& l, const Expression& r) { return l.impl_ == r.impl_; } -const Expression2::Call* CallNotNull(const Expression2& expr) { +const Expression::Call* CallNotNull(const Expression& expr) { auto call = expr.call(); DCHECK_NE(call, nullptr); return call; } -inline void GetAllFieldRefs(const Expression2& expr, +inline void GetAllFieldRefs(const Expression& expr, std::unordered_set* refs) { if (auto lit = expr.literal()) return; @@ -50,12 +50,12 @@ inline void GetAllFieldRefs(const Expression2& expr, return; } - for (const Expression2& arg : CallNotNull(expr)->arguments) { + for (const Expression& arg : CallNotNull(expr)->arguments) { GetAllFieldRefs(arg, refs); } } -inline std::vector GetDescriptors(const std::vector& exprs) { +inline std::vector GetDescriptors(const std::vector& exprs) { std::vector descrs(exprs.size()); for (size_t i = 0; i < exprs.size(); ++i) { DCHECK(exprs[i].IsBound()); @@ -133,7 +133,7 @@ struct Comparison { return it != flipped_comparisons.end() ? &it->second : nullptr; } - static const type* Get(const Expression2& expr) { + static const type* Get(const Expression& expr) { if (auto call = expr.call()) { return Comparison::Get(call->function); } @@ -207,7 +207,7 @@ struct Comparison { } }; -inline const compute::CastOptions* GetCastOptions(const Expression2::Call& call) { +inline const compute::CastOptions* GetCastOptions(const Expression::Call& call) { if (call.function != "cast") return nullptr; return checked_cast(call.options.get()); } @@ -217,17 +217,17 @@ inline bool IsSetLookup(const std::string& function) { } inline const compute::SetLookupOptions* GetSetLookupOptions( - const Expression2::Call& call) { + const Expression::Call& call) { if (!IsSetLookup(call.function)) return nullptr; return checked_cast(call.options.get()); } -inline const compute::StructOptions* GetStructOptions(const Expression2::Call& call) { +inline const compute::StructOptions* GetStructOptions(const Expression::Call& call) { if (call.function != "struct") return nullptr; return checked_cast(call.options.get()); } -inline const compute::StrptimeOptions* GetStrptimeOptions(const Expression2::Call& call) { +inline const compute::StrptimeOptions* GetStrptimeOptions(const Expression::Call& call) { if (call.function != "strptime") return nullptr; return checked_cast(call.options.get()); } @@ -256,7 +256,7 @@ inline Status EnsureNotDictionary(Datum* datum) { return Status::OK(); } -inline Status EnsureNotDictionary(Expression2::Call* call) { +inline Status EnsureNotDictionary(Expression::Call* call) { if (auto options = GetSetLookupOptions(*call)) { auto new_options = *options; RETURN_NOT_OK(EnsureNotDictionary(&new_options.value_set)); @@ -266,7 +266,7 @@ inline Status EnsureNotDictionary(Expression2::Call* call) { } inline Result> FunctionOptionsToStructScalar( - const Expression2::Call& call) { + const Expression::Call& call) { if (call.options == nullptr) { return nullptr; } @@ -318,7 +318,7 @@ inline Result> FunctionOptionsToStructScalar( } inline Status FunctionOptionsFromStructScalar(const StructScalar* repr, - Expression2::Call* call) { + Expression::Call* call) { if (repr == nullptr) { call->options = nullptr; return Status::OK(); @@ -359,9 +359,9 @@ inline Status FunctionOptionsFromStructScalar(const StructScalar* repr, struct FlattenedAssociativeChain { bool was_left_folded = true; - std::vector exprs, fringe; + std::vector exprs, fringe; - explicit FlattenedAssociativeChain(Expression2 expr) : exprs{std::move(expr)} { + explicit FlattenedAssociativeChain(Expression expr) : exprs{std::move(expr)} { auto call = CallNotNull(exprs.back()); fringe = call->arguments; @@ -384,14 +384,14 @@ struct FlattenedAssociativeChain { // NB: no increment so we hit sub_call's first argument next iteration } - DCHECK(std::all_of(exprs.begin(), exprs.end(), [](const Expression2& expr) { + DCHECK(std::all_of(exprs.begin(), exprs.end(), [](const Expression& expr) { return CallNotNull(expr)->options == nullptr; })); } }; inline Result> GetFunction( - const Expression2::Call& call, compute::ExecContext* exec_context) { + const Expression::Call& call, compute::ExecContext* exec_context) { if (call.function != "cast") { return exec_context->func_registry()->GetFunction(call.function); } @@ -401,9 +401,9 @@ inline Result> GetFunction( } template -Result Modify(Expression2 expr, const PreVisit& pre, +Result Modify(Expression expr, const PreVisit& pre, const PostVisitCall& post_call) { - ARROW_ASSIGN_OR_RAISE(expr, Result(pre(std::move(expr)))); + ARROW_ASSIGN_OR_RAISE(expr, Result(pre(std::move(expr)))); auto call = expr.call(); if (!call) return expr; @@ -423,8 +423,8 @@ Result Modify(Expression2 expr, const PreVisit& pre, if (at_least_one_modified) { // reconstruct the call expression with the modified arguments - auto modified_expr = Expression2( - std::make_shared(std::move(modified_call)), expr.descr()); + auto modified_expr = Expression( + std::make_shared(std::move(modified_call)), expr.descr()); return post_call(std::move(modified_expr), &expr); } diff --git a/cpp/src/arrow/dataset/expression_test.cc b/cpp/src/arrow/dataset/expression_test.cc index f2b6951a8a3..4812d06b8c8 100644 --- a/cpp/src/arrow/dataset/expression_test.cc +++ b/cpp/src/arrow/dataset/expression_test.cc @@ -42,12 +42,12 @@ namespace dataset { #define EXPECT_OK ARROW_EXPECT_OK -Expression2 cast(Expression2 argument, std::shared_ptr to_type) { +Expression cast(Expression argument, std::shared_ptr to_type) { return call("cast", {std::move(argument)}, compute::CastOptions::Safe(std::move(to_type))); } -TEST(Expression2, ToString) { +TEST(Expression, ToString) { EXPECT_EQ(field_ref("alpha").ToString(), "FieldRef(alpha)"); EXPECT_EQ(literal(3).ToString(), "3"); @@ -81,7 +81,7 @@ TEST(Expression2, ToString) { ", {a,renamed_a,three,b})"); } -TEST(Expression2, Equality) { +TEST(Expression, Equality) { EXPECT_EQ(literal(1), literal(1)); EXPECT_NE(literal(1), literal(2)); @@ -110,8 +110,8 @@ TEST(Expression2, Equality) { call("cast", {field_ref("a")}, compute::CastOptions::Unsafe(int32()))); } -TEST(Expression2, Hash) { - std::unordered_set set; +TEST(Expression, Hash) { + std::unordered_set set; EXPECT_TRUE(set.emplace(field_ref("alpha")).second); EXPECT_TRUE(set.emplace(field_ref("beta")).second); @@ -131,7 +131,7 @@ TEST(Expression2, Hash) { EXPECT_EQ(set.size(), 6); } -TEST(Expression2, IsScalarExpression) { +TEST(Expression, IsScalarExpression) { EXPECT_TRUE(literal(true).IsScalarExpression()); auto arr = ArrayFromJSON(int8(), "[]"); @@ -150,7 +150,7 @@ TEST(Expression2, IsScalarExpression) { EXPECT_FALSE(call("take", {field_ref("a"), literal(arr)}).IsScalarExpression()); } -TEST(Expression2, IsSatisfiable) { +TEST(Expression, IsSatisfiable) { EXPECT_TRUE(literal(true).IsSatisfiable()); EXPECT_FALSE(literal(false).IsSatisfiable()); @@ -164,7 +164,7 @@ TEST(Expression2, IsSatisfiable) { // NB: no constant folding here EXPECT_TRUE(equal(literal(0), literal(1)).IsSatisfiable()); - // When a top level conjunction contains an Expression2 which is certain to evaluate to + // When a top level conjunction contains an Expression which is certain to evaluate to // null, it can only evaluate to null or false. auto null_or_false = and_(literal(null), field_ref("a")); // This may appear in satisfiable filters if coalesced @@ -176,8 +176,8 @@ TEST(Expression2, IsSatisfiable) { // EXPECT_FALSE(null_or_false).IsSatisfiable()); } -TEST(Expression2, FieldsInExpression) { - auto ExpectFieldsAre = [](Expression2 expr, std::vector expected) { +TEST(Expression, FieldsInExpression) { + auto ExpectFieldsAre = [](Expression expr, std::vector expected) { EXPECT_THAT(FieldsInExpression(expr), testing::ContainerEq(expected)); }; @@ -203,7 +203,7 @@ TEST(Expression2, FieldsInExpression) { {"a", "b", "c"}); } -TEST(Expression2, BindLiteral) { +TEST(Expression, BindLiteral) { for (Datum dat : { Datum(3), Datum(3.5), @@ -216,8 +216,8 @@ TEST(Expression2, BindLiteral) { } } -void ExpectBindsTo(Expression2 expr, Expression2 expected, - Expression2* bound_out = nullptr) { +void ExpectBindsTo(Expression expr, Expression expected, + Expression* bound_out = nullptr) { ASSERT_OK_AND_ASSIGN(auto bound, expr.Bind(*kBoringSchema)); EXPECT_TRUE(bound.IsBound()); @@ -229,7 +229,7 @@ void ExpectBindsTo(Expression2 expr, Expression2 expected, } } -TEST(Expression2, BindFieldRef) { +TEST(Expression, BindFieldRef) { // an unbound field_ref does not have the output ValueDescr set auto expr = field_ref("alpha"); EXPECT_EQ(expr.descr(), ValueDescr{}); @@ -254,7 +254,7 @@ TEST(Expression2, BindFieldRef) { EXPECT_EQ(expr.descr(), ValueDescr::Array(int32())); } -TEST(Expression2, BindCall) { +TEST(Expression, BindCall) { auto expr = call("add", {field_ref("a"), field_ref("b")}); EXPECT_FALSE(expr.IsBound()); @@ -268,7 +268,7 @@ TEST(Expression2, BindCall) { expr.Bind(Schema({field("a", int32()), field("b", int32())}))); } -TEST(Expression2, BindWithImplicitCasts) { +TEST(Expression, BindWithImplicitCasts) { for (auto cmp : {equal, not_equal, less, less_equal, greater, greater_equal}) { // cast arguments to same type ExpectBindsTo(cmp(field_ref("i32"), field_ref("i64")), @@ -299,7 +299,7 @@ TEST(Expression2, BindWithImplicitCasts) { call("is_in", {field_ref("str")}, Opts(utf8()))); } -TEST(Expression2, BindNestedCall) { +TEST(Expression, BindNestedCall) { auto expr = call("add", {field_ref("a"), call("subtract", {call("multiply", {field_ref("b"), field_ref("c")}), @@ -313,7 +313,7 @@ TEST(Expression2, BindNestedCall) { EXPECT_TRUE(expr.IsBound()); } -TEST(Expression2, ExecuteFieldRef) { +TEST(Expression, ExecuteFieldRef) { auto AssertRefIs = [](FieldRef ref, Datum in, Datum expected) { auto expr = field_ref(ref); @@ -348,7 +348,7 @@ TEST(Expression2, ExecuteFieldRef) { MakeNullScalar(null())); } -Result NaiveExecuteScalarExpression(const Expression2& expr, const Datum& input) { +Result NaiveExecuteScalarExpression(const Expression& expr, const Datum& input) { auto call = expr.call(); if (call == nullptr) { // already tested execution of field_ref, execution of literal is trivial @@ -371,7 +371,7 @@ Result NaiveExecuteScalarExpression(const Expression2& expr, const Datum& return function->Execute(arguments, call->options.get(), &exec_context); } -void AssertExecute(Expression2 expr, Datum in, Datum* actual_out = NULLPTR) { +void AssertExecute(Expression expr, Datum in, Datum* actual_out = NULLPTR) { if (in.is_value()) { ASSERT_OK_AND_ASSIGN(expr, expr.Bind(in.descr())); } else { @@ -389,7 +389,7 @@ void AssertExecute(Expression2 expr, Datum in, Datum* actual_out = NULLPTR) { } } -TEST(Expression2, ExecuteCall) { +TEST(Expression, ExecuteCall) { AssertExecute(call("add", {field_ref("a"), literal(3.5)}), ArrayFromJSON(struct_({field("a", float64())}), R"([ {"a": 6.125}, @@ -421,7 +421,7 @@ TEST(Expression2, ExecuteCall) { ])")); } -TEST(Expression2, ExecuteDictionaryTransparent) { +TEST(Expression, ExecuteDictionaryTransparent) { AssertExecute( equal(field_ref("a"), field_ref("b")), ArrayFromJSON( @@ -443,7 +443,7 @@ TEST(Expression2, ExecuteDictionaryTransparent) { } struct { - void operator()(Expression2 expr, Expression2 expected) { + void operator()(Expression expr, Expression expected) { ASSERT_OK_AND_ASSIGN(expr, expr.Bind(*kBoringSchema)); ASSERT_OK_AND_ASSIGN(expected, expected.Bind(*kBoringSchema)); @@ -459,7 +459,7 @@ struct { } ExpectFoldsTo; -TEST(Expression2, FoldConstants) { +TEST(Expression, FoldConstants) { // literals are unchanged ExpectFoldsTo(literal(3), literal(3)); @@ -514,7 +514,7 @@ TEST(Expression2, FoldConstants) { call("is_in", {call("add", {field_ref("i32"), literal(6)})}, in_123)); } -TEST(Expression2, FoldConstantsBoolean) { +TEST(Expression, FoldConstantsBoolean) { // test and_kleene/or_kleene-specific optimizations auto one = literal(1); auto two = literal(2); @@ -532,9 +532,9 @@ TEST(Expression2, FoldConstantsBoolean) { ExpectFoldsTo(or_(whatever, whatever), whatever); } -TEST(Expression2, ExtractKnownFieldValues) { +TEST(Expression, ExtractKnownFieldValues) { struct { - void operator()(Expression2 guarantee, + void operator()(Expression guarantee, std::unordered_map expected) { ASSERT_OK_AND_ASSIGN(auto actual, ExtractKnownFieldValues(guarantee)); EXPECT_THAT(actual, UnorderedElementsAreArray(expected)) @@ -582,11 +582,11 @@ TEST(Expression2, ExtractKnownFieldValues) { {{"i32", Datum(3)}, {"i32_req", Datum(1)}}); } -TEST(Expression2, ReplaceFieldsWithKnownValues) { +TEST(Expression, ReplaceFieldsWithKnownValues) { auto ExpectReplacesTo = - [](Expression2 expr, + [](Expression expr, std::unordered_map known_values, - Expression2 unbound_expected) { + Expression unbound_expected) { ASSERT_OK_AND_ASSIGN(expr, expr.Bind(*kBoringSchema)); ASSERT_OK_AND_ASSIGN(auto expected, unbound_expected.Bind(*kBoringSchema)); ASSERT_OK_AND_ASSIGN(auto replaced, @@ -635,7 +635,7 @@ TEST(Expression2, ReplaceFieldsWithKnownValues) { } struct { - void operator()(Expression2 expr, Expression2 unbound_expected) const { + void operator()(Expression expr, Expression unbound_expected) const { ASSERT_OK_AND_ASSIGN(auto bound, expr.Bind(*kBoringSchema)); ASSERT_OK_AND_ASSIGN(auto expected, unbound_expected.Bind(*kBoringSchema)); ASSERT_OK_AND_ASSIGN(auto actual, Canonicalize(bound)); @@ -649,7 +649,7 @@ struct { } } ExpectCanonicalizesTo; -TEST(Expression2, CanonicalizeTrivial) { +TEST(Expression, CanonicalizeTrivial) { ExpectCanonicalizesTo(literal(1), literal(1)); ExpectCanonicalizesTo(field_ref("b"), field_ref("b")); @@ -658,7 +658,7 @@ TEST(Expression2, CanonicalizeTrivial) { equal(field_ref("i32"), field_ref("i32_req"))); } -TEST(Expression2, CanonicalizeAnd) { +TEST(Expression, CanonicalizeAnd) { // some aliases for brevity: auto true_ = literal(true); auto null_ = literal(std::make_shared()); @@ -686,7 +686,7 @@ TEST(Expression2, CanonicalizeAnd) { call("is_valid", {and_(true_, b)})); } -TEST(Expression2, CanonicalizeComparison) { +TEST(Expression, CanonicalizeComparison) { ExpectCanonicalizesTo(equal(literal(1), field_ref("i32")), equal(field_ref("i32"), literal(1))); @@ -701,12 +701,12 @@ TEST(Expression2, CanonicalizeComparison) { } struct Simplify { - Expression2 expr; + Expression expr; struct Expectable { - Expression2 expr, guarantee; + Expression expr, guarantee; - void Expect(Expression2 unbound_expected) { + void Expect(Expression unbound_expected) { ASSERT_OK_AND_ASSIGN(auto bound, expr.Bind(*kBoringSchema)); ASSERT_OK_AND_ASSIGN(auto simplified, SimplifyWithGuarantee(bound, guarantee)); @@ -724,10 +724,10 @@ struct Simplify { void Expect(bool constant) { Expect(literal(constant)); } }; - Expectable WithGuarantee(Expression2 guarantee) { return {expr, guarantee}; } + Expectable WithGuarantee(Expression guarantee) { return {expr, guarantee}; } }; -TEST(Expression2, SingleComparisonGuarantees) { +TEST(Expression, SingleComparisonGuarantees) { auto i32 = field_ref("i32"); // i32 is guaranteed equal to 3, so the projection can just materialize that constant @@ -840,7 +840,7 @@ TEST(Expression2, SingleComparisonGuarantees) { } } -TEST(Expression2, SimplifyWithGuarantee) { +TEST(Expression, SimplifyWithGuarantee) { // drop both members of a conjunctive filter Simplify{ and_(equal(field_ref("i32"), literal(2)), equal(field_ref("f32"), literal(3.5F)))} @@ -880,7 +880,7 @@ TEST(Expression2, SimplifyWithGuarantee) { compute::SetLookupOptions{ArrayFromJSON(int64(), "[1,2,3]"), true})); } -TEST(Expression2, SimplifyThenExecute) { +TEST(Expression, SimplifyThenExecute) { auto filter = or_({equal(field_ref("f32"), literal("0")), call("is_in", {field_ref("i64")}, @@ -908,8 +908,8 @@ TEST(Expression2, SimplifyThenExecute) { AssertDatumsEqual(evaluated, simplified_evaluated, /*verbose=*/true); } -TEST(Expression2, Filter) { - auto ExpectFilter = [](Expression2 filter, std::string batch_json) { +TEST(Expression, Filter) { + auto ExpectFilter = [](Expression filter, std::string batch_json) { ASSERT_OK_AND_ASSIGN(auto s, kBoringSchema->AddField(0, field("in", boolean()))); auto batch = RecordBatchFromJSON(s, batch_json); auto expected_mask = batch->column(0); @@ -931,10 +931,10 @@ TEST(Expression2, Filter) { ])"); } -TEST(Expression2, SerializationRoundTrips) { - auto ExpectRoundTrips = [](const Expression2& expr) { +TEST(Expression, SerializationRoundTrips) { + auto ExpectRoundTrips = [](const Expression& expr) { ASSERT_OK_AND_ASSIGN(auto serialized, Serialize(expr)); - ASSERT_OK_AND_ASSIGN(Expression2 roundtripped, Deserialize(*serialized)); + ASSERT_OK_AND_ASSIGN(Expression roundtripped, Deserialize(*serialized)); EXPECT_EQ(expr, roundtripped); }; diff --git a/cpp/src/arrow/dataset/file_base.cc b/cpp/src/arrow/dataset/file_base.cc index 86a262a662b..2c437ce8eec 100644 --- a/cpp/src/arrow/dataset/file_base.cc +++ b/cpp/src/arrow/dataset/file_base.cc @@ -60,12 +60,12 @@ Result> FileFormat::MakeFragment( } Result> FileFormat::MakeFragment( - FileSource source, Expression2 partition_expression) { + FileSource source, Expression partition_expression) { return MakeFragment(std::move(source), std::move(partition_expression), nullptr); } Result> FileFormat::MakeFragment( - FileSource source, Expression2 partition_expression, + FileSource source, Expression partition_expression, std::shared_ptr physical_schema) { return std::shared_ptr( new FileFragment(std::move(source), shared_from_this(), @@ -82,7 +82,7 @@ Result FileFragment::Scan(std::shared_ptr options } FileSystemDataset::FileSystemDataset(std::shared_ptr schema, - Expression2 root_partition, + Expression root_partition, std::shared_ptr format, std::shared_ptr filesystem, std::vector> fragments) @@ -92,7 +92,7 @@ FileSystemDataset::FileSystemDataset(std::shared_ptr schema, fragments_(std::move(fragments)) {} Result> FileSystemDataset::Make( - std::shared_ptr schema, Expression2 root_partition, + std::shared_ptr schema, Expression root_partition, std::shared_ptr format, std::shared_ptr filesystem, std::vector> fragments) { return std::shared_ptr(new FileSystemDataset( @@ -136,7 +136,7 @@ std::string FileSystemDataset::ToString() const { return repr; } -Result FileSystemDataset::GetFragmentsImpl(Expression2 predicate) { +Result FileSystemDataset::GetFragmentsImpl(Expression predicate) { FragmentVector fragments; for (const auto& fragment : fragments_) { diff --git a/cpp/src/arrow/dataset/file_base.h b/cpp/src/arrow/dataset/file_base.h index f6def56b7de..7b67e6f9bf5 100644 --- a/cpp/src/arrow/dataset/file_base.h +++ b/cpp/src/arrow/dataset/file_base.h @@ -143,11 +143,11 @@ class ARROW_DS_EXPORT FileFormat : public std::enable_shared_from_this> MakeFragment( - FileSource source, Expression2 partition_expression, + FileSource source, Expression partition_expression, std::shared_ptr physical_schema); Result> MakeFragment(FileSource source, - Expression2 partition_expression); + Expression partition_expression); Result> MakeFragment( FileSource source, std::shared_ptr physical_schema = NULLPTR); @@ -173,7 +173,7 @@ class ARROW_DS_EXPORT FileFragment : public Fragment { protected: FileFragment(FileSource source, std::shared_ptr format, - Expression2 partition_expression, std::shared_ptr physical_schema) + Expression partition_expression, std::shared_ptr physical_schema) : Fragment(std::move(partition_expression), std::move(physical_schema)), source_(std::move(source)), format_(std::move(format)) {} @@ -206,7 +206,7 @@ class ARROW_DS_EXPORT FileSystemDataset : public Dataset { /// /// \return A constructed dataset. static Result> Make( - std::shared_ptr schema, Expression2 root_partition, + std::shared_ptr schema, Expression root_partition, std::shared_ptr format, std::shared_ptr filesystem, std::vector> fragments); @@ -234,9 +234,9 @@ class ARROW_DS_EXPORT FileSystemDataset : public Dataset { protected: Result GetFragmentsImpl( - Expression2 predicate) override; + Expression predicate) override; - FileSystemDataset(std::shared_ptr schema, Expression2 root_partition, + FileSystemDataset(std::shared_ptr schema, Expression root_partition, std::shared_ptr format, std::shared_ptr filesystem, std::vector> fragments); diff --git a/cpp/src/arrow/dataset/file_parquet.cc b/cpp/src/arrow/dataset/file_parquet.cc index 8ba5f67f6fc..94d2115358f 100644 --- a/cpp/src/arrow/dataset/file_parquet.cc +++ b/cpp/src/arrow/dataset/file_parquet.cc @@ -124,7 +124,7 @@ static Result> GetSchemaManifest( return manifest; } -static util::optional ColumnChunkStatisticsAsExpression( +static util::optional ColumnChunkStatisticsAsExpression( const SchemaField& schema_field, const parquet::RowGroupMetaData& metadata) { // For the remaining of this function, failure to extract/parse statistics // are ignored by returning nullptr. The goal is two fold. First @@ -331,7 +331,7 @@ Result ParquetFileFormat::ScanFile(std::shared_ptr> ParquetFileFormat::MakeFragment( - FileSource source, Expression2 partition_expression, + FileSource source, Expression partition_expression, std::shared_ptr physical_schema, std::vector row_groups) { return std::shared_ptr(new ParquetFileFragment( std::move(source), shared_from_this(), std::move(partition_expression), @@ -339,7 +339,7 @@ Result> ParquetFileFormat::MakeFragment( } Result> ParquetFileFormat::MakeFragment( - FileSource source, Expression2 partition_expression, + FileSource source, Expression partition_expression, std::shared_ptr physical_schema) { return std::shared_ptr(new ParquetFileFragment( std::move(source), shared_from_this(), std::move(partition_expression), @@ -394,7 +394,7 @@ Status ParquetFileWriter::Finish() { return parquet_writer_->Close(); } ParquetFileFragment::ParquetFileFragment(FileSource source, std::shared_ptr format, - Expression2 partition_expression, + Expression partition_expression, std::shared_ptr physical_schema, util::optional> row_groups) : FileFragment(std::move(source), std::move(format), std::move(partition_expression), @@ -456,7 +456,7 @@ Status ParquetFileFragment::SetMetadata( return Status::OK(); } -Result ParquetFileFragment::SplitByRowGroup(Expression2 predicate) { +Result ParquetFileFragment::SplitByRowGroup(Expression predicate) { RETURN_NOT_OK(EnsureCompleteMetadata()); ARROW_ASSIGN_OR_RAISE(auto row_groups, FilterRowGroups(predicate)); @@ -474,7 +474,7 @@ Result ParquetFileFragment::SplitByRowGroup(Expression2 predicat return fragments; } -Result> ParquetFileFragment::Subset(Expression2 predicate) { +Result> ParquetFileFragment::Subset(Expression predicate) { RETURN_NOT_OK(EnsureCompleteMetadata()); ARROW_ASSIGN_OR_RAISE(auto row_groups, FilterRowGroups(predicate)); return Subset(std::move(row_groups)); @@ -491,7 +491,7 @@ Result> ParquetFileFragment::Subset( return new_fragment; } -inline void FoldingAnd(Expression2* l, Expression2 r) { +inline void FoldingAnd(Expression* l, Expression r) { if (*l == literal(true)) { *l = std::move(r); } else { @@ -499,7 +499,7 @@ inline void FoldingAnd(Expression2* l, Expression2 r) { } } -Result> ParquetFileFragment::FilterRowGroups(Expression2 predicate) { +Result> ParquetFileFragment::FilterRowGroups(Expression predicate) { auto lock = physical_schema_mutex_.Lock(); DCHECK_NE(metadata_, nullptr); diff --git a/cpp/src/arrow/dataset/file_parquet.h b/cpp/src/arrow/dataset/file_parquet.h index 9eab3fe9449..ae0337994a0 100644 --- a/cpp/src/arrow/dataset/file_parquet.h +++ b/cpp/src/arrow/dataset/file_parquet.h @@ -117,12 +117,12 @@ class ARROW_DS_EXPORT ParquetFileFormat : public FileFormat { /// \brief Create a Fragment targeting all RowGroups. Result> MakeFragment( - FileSource source, Expression2 partition_expression, + FileSource source, Expression partition_expression, std::shared_ptr physical_schema) override; /// \brief Create a Fragment, restricted to the specified row groups. Result> MakeFragment( - FileSource source, Expression2 partition_expression, + FileSource source, Expression partition_expression, std::shared_ptr physical_schema, std::vector row_groups); /// \brief Return a FileReader on the given source. @@ -150,7 +150,7 @@ class ARROW_DS_EXPORT ParquetFileFormat : public FileFormat { /// significant performance boost when scanning high latency file systems. class ARROW_DS_EXPORT ParquetFileFragment : public FileFragment { public: - Result SplitByRowGroup(Expression2 predicate); + Result SplitByRowGroup(Expression predicate); /// \brief Return the RowGroups selected by this fragment. const std::vector& row_groups() const { @@ -166,12 +166,12 @@ class ARROW_DS_EXPORT ParquetFileFragment : public FileFragment { Status EnsureCompleteMetadata(parquet::arrow::FileReader* reader = NULLPTR); /// \brief Return fragment which selects a filtered subset of this fragment's RowGroups. - Result> Subset(Expression2 predicate); + Result> Subset(Expression predicate); Result> Subset(std::vector row_group_ids); private: ParquetFileFragment(FileSource source, std::shared_ptr format, - Expression2 partition_expression, + Expression partition_expression, std::shared_ptr physical_schema, util::optional> row_groups); @@ -185,7 +185,7 @@ class ARROW_DS_EXPORT ParquetFileFragment : public FileFragment { } // Return a filtered subset of row group indices. - Result> FilterRowGroups(Expression2 predicate); + Result> FilterRowGroups(Expression predicate); ParquetFileFormat& parquet_format_; @@ -193,7 +193,7 @@ class ARROW_DS_EXPORT ParquetFileFragment : public FileFragment { // or util::nullopt if all row groups are selected. util::optional> row_groups_; - std::vector statistics_expressions_; + std::vector statistics_expressions_; std::vector statistics_expressions_complete_; std::shared_ptr metadata_; std::shared_ptr manifest_; diff --git a/cpp/src/arrow/dataset/file_parquet_test.cc b/cpp/src/arrow/dataset/file_parquet_test.cc index 47a563b911f..e9182ddf696 100644 --- a/cpp/src/arrow/dataset/file_parquet_test.cc +++ b/cpp/src/arrow/dataset/file_parquet_test.cc @@ -159,7 +159,7 @@ class TestParquetFileFormat : public ArrowParquetWriterMixin { return Batches(std::move(scan_task_it)); } - void SetFilter(Expression2 filter) { + void SetFilter(Expression filter) { ASSERT_OK_AND_ASSIGN(opts_->filter2, filter.Bind(*schema_)); } @@ -191,7 +191,7 @@ class TestParquetFileFormat : public ArrowParquetWriterMixin { void CountRowGroupsInFragment(const std::shared_ptr& fragment, std::vector expected_row_groups, - Expression2 filter) { + Expression filter) { schema_ = opts_->schema(); ASSERT_OK_AND_ASSIGN(auto bound, filter.Bind(*schema_)); auto parquet_fragment = checked_pointer_cast(fragment); diff --git a/cpp/src/arrow/dataset/file_test.cc b/cpp/src/arrow/dataset/file_test.cc index fefd6911ac0..f0799e07a3a 100644 --- a/cpp/src/arrow/dataset/file_test.cc +++ b/cpp/src/arrow/dataset/file_test.cc @@ -122,7 +122,7 @@ TEST_F(TestFileSystemDataset, RootPartitionPruning) { auto root_partition = equal(field_ref("i32"), literal(5)); MakeDataset({fs::File("a"), fs::File("b")}, root_partition); - auto GetFragments = [&](Expression2 filter) { + auto GetFragments = [&](Expression filter) { return *dataset_->GetFragments(*filter.Bind(*dataset_->schema())); }; @@ -156,7 +156,7 @@ TEST_F(TestFileSystemDataset, TreePartitionPruning) { fs::Dir("CA"), fs::File("CA/San Francisco"), fs::File("CA/Franklin"), }; - std::vector partitions = { + std::vector partitions = { equal(field_ref("state"), literal("NY")), and_(equal(field_ref("state"), literal("NY")), @@ -186,7 +186,7 @@ TEST_F(TestFileSystemDataset, TreePartitionPruning) { // Default filter should always return all data. AssertFragmentsAreFromPath(*dataset_->GetFragments(), all_cities); - auto GetFragments = [&](Expression2 filter) { + auto GetFragments = [&](Expression filter) { return *dataset_->GetFragments(*filter.Bind(*dataset_->schema())); }; @@ -212,7 +212,7 @@ TEST_F(TestFileSystemDataset, FragmentPartitions) { fs::Dir("CA"), fs::File("CA/San Francisco"), fs::File("CA/Franklin"), }; - std::vector partitions = { + std::vector partitions = { equal(field_ref("state"), literal("NY")), and_(equal(field_ref("state"), literal("NY")), diff --git a/cpp/src/arrow/dataset/partition.cc b/cpp/src/arrow/dataset/partition.cc index 6fd61096392..ba4f0c46b8f 100644 --- a/cpp/src/arrow/dataset/partition.cc +++ b/cpp/src/arrow/dataset/partition.cc @@ -50,11 +50,11 @@ std::shared_ptr Partitioning::Default() { std::string type_name() const override { return "default"; } - Result Parse(const std::string& path) const override { + Result Parse(const std::string& path) const override { return literal(true); } - Result Format(const Expression2& expr) const override { + Result Format(const Expression& expr) const override { return Status::NotImplemented("formatting paths from ", type_name(), " Partitioning"); } @@ -68,7 +68,7 @@ std::shared_ptr Partitioning::Default() { return std::make_shared(); } -Status KeyValuePartitioning::SetDefaultValuesFromKeys(const Expression2& expr, +Status KeyValuePartitioning::SetDefaultValuesFromKeys(const Expression& expr, RecordBatchProjector* projector) { ARROW_ASSIGN_OR_RAISE(auto known_values, ExtractKnownFieldValues(expr)); for (const auto& ref_value : known_values) { @@ -85,9 +85,9 @@ Status KeyValuePartitioning::SetDefaultValuesFromKeys(const Expression2& expr, return Status::OK(); } -inline Expression2 ConjunctionFromGroupingRow(Scalar* row) { +inline Expression ConjunctionFromGroupingRow(Scalar* row) { ScalarVector* values = &checked_cast(row)->value; - std::vector equality_expressions(values->size()); + std::vector equality_expressions(values->size()); for (size_t i = 0; i < values->size(); ++i) { const std::string& name = row->type->field(static_cast(i))->name(); equality_expressions[i] = equal(field_ref(name), literal(std::move(values->at(i)))); @@ -135,7 +135,7 @@ Result KeyValuePartitioning::Partition( return out; } -Result KeyValuePartitioning::ConvertKey(const Key& key) const { +Result KeyValuePartitioning::ConvertKey(const Key& key) const { ARROW_ASSIGN_OR_RAISE(auto match, FieldRef(key.name).FindOneOrNone(*schema_)); if (!match) { return literal(true); @@ -178,8 +178,8 @@ Result KeyValuePartitioning::ConvertKey(const Key& key) const { return equal(field_ref(field->name()), literal(std::move(converted))); } -Result KeyValuePartitioning::Parse(const std::string& path) const { - std::vector expressions; +Result KeyValuePartitioning::Parse(const std::string& path) const { + std::vector expressions; for (const Key& key : ParseKeys(path)) { ARROW_ASSIGN_OR_RAISE(auto expr, ConvertKey(key)); @@ -190,7 +190,7 @@ Result KeyValuePartitioning::Parse(const std::string& path) const { return and_(std::move(expressions)); } -Result KeyValuePartitioning::Format(const Expression2& expr) const { +Result KeyValuePartitioning::Format(const Expression& expr) const { std::vector values{static_cast(schema_->num_fields()), nullptr}; ARROW_ASSIGN_OR_RAISE(auto known_values, ExtractKnownFieldValues(expr)); diff --git a/cpp/src/arrow/dataset/partition.h b/cpp/src/arrow/dataset/partition.h index a9f8c87bfc9..8975f565b19 100644 --- a/cpp/src/arrow/dataset/partition.h +++ b/cpp/src/arrow/dataset/partition.h @@ -63,15 +63,15 @@ class ARROW_DS_EXPORT Partitioning { /// produce sub-batches which satisfy mutually exclusive Expressions. struct PartitionedBatches { RecordBatchVector batches; - std::vector expressions; + std::vector expressions; }; virtual Result Partition( const std::shared_ptr& batch) const = 0; /// \brief Parse a path into a partition expression - virtual Result Parse(const std::string& path) const = 0; + virtual Result Parse(const std::string& path) const = 0; - virtual Result Format(const Expression2& expr) const = 0; + virtual Result Format(const Expression& expr) const = 0; /// \brief A default Partitioning which always yields scalar(true) static std::shared_ptr Default(); @@ -122,15 +122,15 @@ class ARROW_DS_EXPORT KeyValuePartitioning : public Partitioning { std::string name, value; }; - static Status SetDefaultValuesFromKeys(const Expression2& expr, + static Status SetDefaultValuesFromKeys(const Expression& expr, RecordBatchProjector* projector); Result Partition( const std::shared_ptr& batch) const override; - Result Parse(const std::string& path) const override; + Result Parse(const std::string& path) const override; - Result Format(const Expression2& expr) const override; + Result Format(const Expression& expr) const override; protected: KeyValuePartitioning(std::shared_ptr schema, ArrayVector dictionaries) @@ -145,7 +145,7 @@ class ARROW_DS_EXPORT KeyValuePartitioning : public Partitioning { virtual Result FormatValues(const std::vector& values) const = 0; /// Convert a Key to a full expression. - Result ConvertKey(const Key& key) const; + Result ConvertKey(const Key& key) const; ArrayVector dictionaries_; }; @@ -207,9 +207,9 @@ class ARROW_DS_EXPORT HivePartitioning : public KeyValuePartitioning { /// \brief Implementation provided by lambda or other callable class ARROW_DS_EXPORT FunctionPartitioning : public Partitioning { public: - using ParseImpl = std::function(const std::string&)>; + using ParseImpl = std::function(const std::string&)>; - using FormatImpl = std::function(const Expression2&)>; + using FormatImpl = std::function(const Expression&)>; FunctionPartitioning(std::shared_ptr schema, ParseImpl parse_impl, FormatImpl format_impl = NULLPTR, std::string name = "function") @@ -220,11 +220,11 @@ class ARROW_DS_EXPORT FunctionPartitioning : public Partitioning { std::string type_name() const override { return name_; } - Result Parse(const std::string& path) const override { + Result Parse(const std::string& path) const override { return parse_impl_(path); } - Result Format(const Expression2& expr) const override { + Result Format(const Expression& expr) const override { if (format_impl_) { return format_impl_(expr); } diff --git a/cpp/src/arrow/dataset/partition_test.cc b/cpp/src/arrow/dataset/partition_test.cc index 0cdb5eb50b6..840e2f0303f 100644 --- a/cpp/src/arrow/dataset/partition_test.cc +++ b/cpp/src/arrow/dataset/partition_test.cc @@ -43,17 +43,17 @@ class TestPartitioning : public ::testing::Test { ASSERT_RAISES(Invalid, partitioning_->Parse(path)); } - void AssertParse(const std::string& path, Expression2 expected) { + void AssertParse(const std::string& path, Expression expected) { ASSERT_OK_AND_ASSIGN(auto parsed, partitioning_->Parse(path)); ASSERT_EQ(parsed, expected); } template - void AssertFormatError(Expression2 expr) { + void AssertFormatError(Expression expr) { ASSERT_EQ(partitioning_->Format(expr).status().code(), code); } - void AssertFormat(Expression2 expr, const std::string& expected) { + void AssertFormat(Expression expr, const std::string& expected) { // formatted partition expressions are bound to the schema of the dataset being // written ASSERT_OK_AND_ASSIGN(auto formatted, partitioning_->Format(expr)); @@ -61,7 +61,7 @@ class TestPartitioning : public ::testing::Test { // ensure the formatted path round trips the relevant components of the partition // expression: roundtripped should be a subset of expr - ASSERT_OK_AND_ASSIGN(Expression2 roundtripped, partitioning_->Parse(formatted)); + ASSERT_OK_AND_ASSIGN(Expression roundtripped, partitioning_->Parse(formatted)); ASSERT_OK_AND_ASSIGN(roundtripped, roundtripped.Bind(*written_schema_)); ASSERT_OK_AND_ASSIGN(auto simplified, SimplifyWithGuarantee(roundtripped, expr)); @@ -370,7 +370,7 @@ TEST_F(TestPartitioning, EtlThenHive) { field("hour", int8()), field("alpha", int32()), field("beta", float32())}); partitioning_ = std::make_shared( - schm, [&](const std::string& path) -> Result { + schm, [&](const std::string& path) -> Result { auto segments = fs::internal::SplitAbstractPath(path); if (segments.size() < etl_fields.size() + alphabeta_fields.size()) { return Status::Invalid("path ", path, " can't be parsed"); @@ -412,8 +412,8 @@ TEST_F(TestPartitioning, Set) { // An adhoc partitioning which parses segments like "/x in [1 4 5]" // into (field_ref("x") == 1 or field_ref("x") == 4 or field_ref("x") == 5) partitioning_ = std::make_shared( - schm, [&](const std::string& path) -> Result { - std::vector subexpressions; + schm, [&](const std::string& path) -> Result { + std::vector subexpressions; for (auto segment : fs::internal::SplitAbstractPath(path)) { std::smatch matches; @@ -451,8 +451,8 @@ class RangePartitioning : public Partitioning { std::string type_name() const override { return "range"; } - Result Parse(const std::string& path) const override { - std::vector ranges; + Result Parse(const std::string& path) const override { + std::vector ranges; for (auto segment : fs::internal::SplitAbstractPath(path)) { auto key = HivePartitioning::ParseKey(segment); @@ -496,7 +496,7 @@ class RangePartitioning : public Partitioning { return Status::OK(); } - Result Format(const Expression2&) const override { return ""; } + Result Format(const Expression&) const override { return ""; } Result Partition( const std::shared_ptr&) const override { return Status::OK(); diff --git a/cpp/src/arrow/dataset/scanner.cc b/cpp/src/arrow/dataset/scanner.cc index 809d83b7e2b..0c501c9f5b3 100644 --- a/cpp/src/arrow/dataset/scanner.cc +++ b/cpp/src/arrow/dataset/scanner.cc @@ -121,7 +121,7 @@ Status ScannerBuilder::Project(std::vector columns) { return Status::OK(); } -Status ScannerBuilder::Filter(const Expression2& filter) { +Status ScannerBuilder::Filter(const Expression& filter) { for (const auto& ref : FieldsInExpression(filter)) { RETURN_NOT_OK(ref.FindOne(*schema())); } diff --git a/cpp/src/arrow/dataset/scanner.h b/cpp/src/arrow/dataset/scanner.h index e2f9892af38..5902f759ec3 100644 --- a/cpp/src/arrow/dataset/scanner.h +++ b/cpp/src/arrow/dataset/scanner.h @@ -63,7 +63,7 @@ class ARROW_DS_EXPORT ScanOptions { std::shared_ptr ReplaceSchema(std::shared_ptr schema) const; // Filter - Expression2 filter2 = literal(true); + Expression filter2 = literal(true); // Schema to which record batches will be reconciled const std::shared_ptr& schema() const { return projector.schema(); } @@ -220,7 +220,7 @@ class ARROW_DS_EXPORT ScannerBuilder { /// /// \return Failure if any referenced columns does not exist in the dataset's /// Schema. - Status Filter(const Expression2& filter); + Status Filter(const Expression& filter); /// \brief Indicate if the Scanner should make use of the available /// ThreadPool found in ScanContext; diff --git a/cpp/src/arrow/dataset/scanner_internal.h b/cpp/src/arrow/dataset/scanner_internal.h index eb11f326570..cd8fffd2a71 100644 --- a/cpp/src/arrow/dataset/scanner_internal.h +++ b/cpp/src/arrow/dataset/scanner_internal.h @@ -30,7 +30,7 @@ namespace arrow { namespace dataset { -inline RecordBatchIterator FilterRecordBatch(RecordBatchIterator it, Expression2 filter, +inline RecordBatchIterator FilterRecordBatch(RecordBatchIterator it, Expression filter, MemoryPool* pool) { return MakeMaybeMapIterator( [=](std::shared_ptr in) -> Result> { @@ -70,7 +70,7 @@ inline RecordBatchIterator ProjectRecordBatch(RecordBatchIterator it, class FilterAndProjectScanTask : public ScanTask { public: - explicit FilterAndProjectScanTask(std::shared_ptr task, Expression2 partition) + explicit FilterAndProjectScanTask(std::shared_ptr task, Expression partition) : ScanTask(task->options(), task->context()), task_(std::move(task)), partition_(std::move(partition)), @@ -80,7 +80,7 @@ class FilterAndProjectScanTask : public ScanTask { Result Execute() override { ARROW_ASSIGN_OR_RAISE(auto it, task_->Execute()); - ARROW_ASSIGN_OR_RAISE(Expression2 simplified_filter, + ARROW_ASSIGN_OR_RAISE(Expression simplified_filter, SimplifyWithGuarantee(filter_, partition_)); RecordBatchIterator filter_it = @@ -94,8 +94,8 @@ class FilterAndProjectScanTask : public ScanTask { private: std::shared_ptr task_; - Expression2 partition_; - Expression2 filter_; + Expression partition_; + Expression filter_; RecordBatchProjector projector_; }; diff --git a/cpp/src/arrow/dataset/test_util.h b/cpp/src/arrow/dataset/test_util.h index 76015fa365c..c5904b45e06 100644 --- a/cpp/src/arrow/dataset/test_util.h +++ b/cpp/src/arrow/dataset/test_util.h @@ -199,7 +199,7 @@ class DatasetFixtureMixin : public ::testing::Test { SetFilter(literal(true)); } - void SetFilter(Expression2 filter) { + void SetFilter(Expression filter) { ASSERT_OK_AND_ASSIGN(options_->filter2, filter.Bind(*schema_)); } @@ -336,8 +336,8 @@ struct MakeFileSystemDatasetMixin { } void MakeDataset(const std::vector& infos, - Expression2 root_partition = literal(true), - std::vector partitions = {}, + Expression root_partition = literal(true), + std::vector partitions = {}, std::shared_ptr s = kBoringSchema) { auto n_fragments = infos.size(); if (partitions.empty()) { @@ -397,8 +397,8 @@ void AssertFragmentsAreFromPath(FragmentIterator it, std::vector ex testing::UnorderedElementsAreArray(expected)); } -static std::vector PartitionExpressionsOf(const FragmentVector& fragments) { - std::vector partition_expressions; +static std::vector PartitionExpressionsOf(const FragmentVector& fragments) { + std::vector partition_expressions; std::transform(fragments.begin(), fragments.end(), std::back_inserter(partition_expressions), [](const std::shared_ptr& fragment) { @@ -408,7 +408,7 @@ static std::vector PartitionExpressionsOf(const FragmentVector& fra } void AssertFragmentsHavePartitionExpressions(std::shared_ptr dataset, - std::vector expected) { + std::vector expected) { ASSERT_OK_AND_ASSIGN(auto fragment_it, dataset->GetFragments()); for (auto& expr : expected) { ASSERT_OK_AND_ASSIGN(expr, expr.Bind(*dataset->schema())); diff --git a/cpp/src/arrow/dataset/type_fwd.h b/cpp/src/arrow/dataset/type_fwd.h index a80b333f0c8..66fed352d0f 100644 --- a/cpp/src/arrow/dataset/type_fwd.h +++ b/cpp/src/arrow/dataset/type_fwd.h @@ -68,7 +68,7 @@ class ParquetFileFragment; class ParquetFileWriter; class ParquetFileWriteOptions; -class Expression2; +class Expression; class Partitioning; class PartitioningFactory; From 9f051864b3535f0e0494d94b468fd223afd1ec31 Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Thu, 10 Dec 2020 17:29:35 -0500 Subject: [PATCH 07/38] first pass at repairing bindings --- cpp/src/arrow/dataset/expression.h | 25 +++++++++--------- cpp/src/arrow/dataset/expression_test.cc | 3 +++ cpp/src/arrow/datum.h | 4 +-- cpp/src/arrow/util/variant.h | 32 ++++++++++++++---------- cpp/src/arrow/util/variant_test.cc | 13 ++++++++++ r/src/compute.cpp | 2 +- r/src/dataset.cpp | 5 +--- r/src/expression.cpp | 31 +++++++++++++---------- 8 files changed, 69 insertions(+), 46 deletions(-) diff --git a/cpp/src/arrow/dataset/expression.h b/cpp/src/arrow/dataset/expression.h index a9ec5403068..0ab98327619 100644 --- a/cpp/src/arrow/dataset/expression.h +++ b/cpp/src/arrow/dataset/expression.h @@ -70,8 +70,7 @@ class ARROW_DS_EXPORT Expression { /// Some expression simplification may be performed and implicit casts will be inserted. /// Any state necessary for execution will be initialized and returned. Result Bind(ValueDescr in, compute::ExecContext* = NULLPTR) const; - Result Bind(const Schema& in_schema, - compute::ExecContext* = NULLPTR) const; + Result Bind(const Schema& in_schema, compute::ExecContext* = NULLPTR) const; // XXX someday // Clone all KernelState in this bound expression. If any function referenced by this @@ -124,14 +123,12 @@ class ARROW_DS_EXPORT Expression { }; inline bool operator==(const Expression& l, const Expression& r) { return l.Equals(r); } -inline bool operator!=(const Expression& l, const Expression& r) { - return !l.Equals(r); -} +inline bool operator!=(const Expression& l, const Expression& r) { return !l.Equals(r); } // Factories inline Expression call(std::string function, std::vector arguments, - std::shared_ptr options = NULLPTR) { + std::shared_ptr options = NULLPTR) { Expression::Call call; call.function = std::move(function); call.arguments = std::move(arguments); @@ -142,7 +139,7 @@ inline Expression call(std::string function, std::vector arguments, template ::value>::type> Expression call(std::string function, std::vector arguments, - Options options) { + Options options) { return call(std::move(function), std::move(arguments), std::make_shared(std::move(options))); } @@ -157,8 +154,11 @@ template Expression literal(Arg&& arg) { Datum lit(std::forward(arg)); ValueDescr descr = lit.descr(); - return Expression(std::make_shared(std::move(lit)), - std::move(descr)); + return Expression(std::make_shared(std::move(lit)), std::move(descr)); +} + +inline Expression literal(std::shared_ptr scalar) { + return literal(Datum(std::move(scalar))); } ARROW_DS_EXPORT @@ -172,8 +172,7 @@ Result> ExtractKnownFieldVal /// /// @{ /// -/// These operate on a bound expression and its bound state simultaneously, -/// ensuring that Call Expressions' KernelState can be utilized or reassociated. +/// These operate on bound expressions. /// Weak canonicalization which establishes guarantees for subsequent passes. Even /// equivalent Expressions may result in different canonicalized expressions. @@ -197,7 +196,7 @@ Result ReplaceFieldsWithKnownValues( /// reference to a constant-value field with a literal. ARROW_DS_EXPORT Result SimplifyWithGuarantee(Expression, - const Expression& guaranteed_true_predicate); + const Expression& guaranteed_true_predicate); /// @} @@ -219,7 +218,7 @@ Result Deserialize(const Buffer&); // Convenience aliases for factories ARROW_DS_EXPORT Expression project(std::vector values, - std::vector names); + std::vector names); ARROW_DS_EXPORT Expression equal(Expression lhs, Expression rhs); diff --git a/cpp/src/arrow/dataset/expression_test.cc b/cpp/src/arrow/dataset/expression_test.cc index 4812d06b8c8..d9738798259 100644 --- a/cpp/src/arrow/dataset/expression_test.cc +++ b/cpp/src/arrow/dataset/expression_test.cc @@ -52,6 +52,9 @@ TEST(Expression, ToString) { EXPECT_EQ(literal(3).ToString(), "3"); + auto ts = *MakeScalar("1990-10-23 10:23:33")->CastTo(timestamp(TimeUnit::NANO)); + EXPECT_EQ(literal(ts).ToString(), "656677413000000000"); + EXPECT_EQ(call("add", {literal(3), field_ref("beta")}).ToString(), "add(3,FieldRef(beta))"); diff --git a/cpp/src/arrow/datum.h b/cpp/src/arrow/datum.h index 9c620279c9e..3d1f545e5f7 100644 --- a/cpp/src/arrow/datum.h +++ b/cpp/src/arrow/datum.h @@ -118,8 +118,8 @@ struct ARROW_EXPORT Datum { /// \brief Empty datum, to be populated elsewhere Datum() = default; - Datum(const Datum& other) noexcept = default; - Datum& operator=(const Datum& other) noexcept = default; + Datum(const Datum& other) = default; + Datum& operator=(const Datum& other) = default; Datum(Datum&& other) noexcept = default; Datum& operator=(Datum&& other) noexcept = default; diff --git a/cpp/src/arrow/util/variant.h b/cpp/src/arrow/util/variant.h index 89f39ab8917..04ac5e2bb45 100644 --- a/cpp/src/arrow/util/variant.h +++ b/cpp/src/arrow/util/variant.h @@ -96,7 +96,7 @@ template struct all : conditional_t, std::false_type> {}; struct delete_copy_constructor { - template + template struct type { type() = default; type(const type& other) = delete; @@ -105,11 +105,13 @@ struct delete_copy_constructor { }; struct explicit_copy_constructor { - template + template struct type { type() = default; - type(const type& other) { static_cast(other).copy_to(this); } - type& operator=(const type& other) { + type(const type& other) noexcept(NoexceptCopyable) { + static_cast(other).copy_to(this); + } + type& operator=(const type& other) noexcept(NoexceptCopyable) { static_cast(this)->destroy(); static_cast(other).copy_to(this); return *this; @@ -117,11 +119,18 @@ struct explicit_copy_constructor { }; }; +template +using CopyConstructionImpl = + typename conditional_t::value && + std::is_copy_assignable::value)...>::value, + explicit_copy_constructor, delete_copy_constructor>:: + template type::value...>::value>; + template struct VariantStorage { VariantStorage() = default; - VariantStorage(const VariantStorage&) {} - VariantStorage& operator=(const VariantStorage&) { return *this; } + VariantStorage(const VariantStorage&) noexcept {} + VariantStorage& operator=(const VariantStorage&) noexcept { return *this; } VariantStorage(VariantStorage&&) noexcept {} VariantStorage& operator=(VariantStorage&&) noexcept { return *this; } ~VariantStorage() { @@ -192,7 +201,8 @@ struct VariantImpl, H, T...> : VariantImpl, T...> { // Templated to avoid instantiation in case H is not copy constructible template - void copy_to(Void* generic_target) const { + void copy_to(Void* generic_target) const + noexcept(noexcept(H(std::declval()))) { const auto target = static_cast(generic_target); try { if (this->index_ == kIndex) { @@ -243,11 +253,7 @@ struct VariantImpl, H, T...> : VariantImpl, T...> { template class Variant : detail::VariantImpl, T...>, - detail::conditional_t< - detail::all<(std::is_copy_constructible::value && - std::is_copy_assignable::value)...>::value, - detail::explicit_copy_constructor, - detail::delete_copy_constructor>::template type> { + detail::CopyConstructionImpl, T...> { template static constexpr uint8_t index_of() { return Impl::index_of(detail::type_constant{}); @@ -338,7 +344,7 @@ class Variant : detail::VariantImpl, T...>, this->index_ = 0; } - template + template friend struct detail::explicit_copy_constructor::type; template diff --git a/cpp/src/arrow/util/variant_test.cc b/cpp/src/arrow/util/variant_test.cc index 9e36f2eb9cf..5d83e00185f 100644 --- a/cpp/src/arrow/util/variant_test.cc +++ b/cpp/src/arrow/util/variant_test.cc @@ -125,6 +125,19 @@ TEST(Variant, CopyConstruction) { EXPECT_NO_THROW(AssertCopyConstruction(CopyAssignThrows{})); } +TEST(Variant, Noexcept) { + struct CopyThrows { + CopyThrows() = default; + CopyThrows(const CopyThrows&) { throw 42; } + CopyThrows& operator=(const CopyThrows&) { throw 42; } + }; + static_assert(!std::is_nothrow_copy_constructible::value, ""); + static_assert(std::is_nothrow_copy_constructible::value, ""); + static_assert(std::is_nothrow_copy_constructible>::value, ""); + static_assert( + !std::is_nothrow_copy_constructible>::value, ""); +} + TEST(Variant, Emplace) { using variant_type = Variant, int>; variant_type v; diff --git a/r/src/compute.cpp b/r/src/compute.cpp index a456ec4711b..12b23dad798 100644 --- a/r/src/compute.cpp +++ b/r/src/compute.cpp @@ -191,7 +191,7 @@ SEXP compute__CallFunction(std::string func_name, cpp11::list args, cpp11::list auto datum_args = arrow::r::from_r_list(args); auto out = ValueOrStop( arrow::compute::CallFunction(func_name, datum_args, opts.get(), gc_context())); - return from_datum(out); + return from_datum(std::move(out)); } #endif diff --git a/r/src/dataset.cpp b/r/src/dataset.cpp index 2ad59677eb0..8d8aa9cff6d 100644 --- a/r/src/dataset.cpp +++ b/r/src/dataset.cpp @@ -312,10 +312,7 @@ void dataset___ScannerBuilder__Project(const std::shared_ptr // [[arrow::export]] void dataset___ScannerBuilder__Filter(const std::shared_ptr& sb, const std::shared_ptr& expr) { - // Expressions converted from R's expressions are typed with R's native type, - // i.e. double, int64_t and bool. - auto cast_filter = ValueOrStop(InsertImplicitCasts(*expr, *sb->schema())); - StopIfNotOk(sb->Filter(cast_filter)); + StopIfNotOk(sb->Filter(*expr)); } // [[arrow::export]] diff --git a/r/src/expression.cpp b/r/src/expression.cpp index 575dddc7da7..b1116d27a2f 100644 --- a/r/src/expression.cpp +++ b/r/src/expression.cpp @@ -19,93 +19,98 @@ #if defined(ARROW_R_WITH_ARROW) +#include #include namespace ds = ::arrow::dataset; +std::shared_ptr Share(ds::Expression expr) { + return std::make_shared(std::move(expr)); +} + // [[arrow::export]] std::shared_ptr dataset___expr__field_ref(std::string name) { - return ds::field_ref(std::move(name)); + return Share(ds::field_ref(std::move(name))); } // [[arrow::export]] std::shared_ptr dataset___expr__equal( const std::shared_ptr& lhs, const std::shared_ptr& rhs) { - return ds::equal(lhs, rhs); + return Share(ds::equal(*lhs, *rhs)); } // [[arrow::export]] std::shared_ptr dataset___expr__not_equal( const std::shared_ptr& lhs, const std::shared_ptr& rhs) { - return ds::not_equal(lhs, rhs); + return Share(ds::not_equal(*lhs, *rhs)); } // [[arrow::export]] std::shared_ptr dataset___expr__greater( const std::shared_ptr& lhs, const std::shared_ptr& rhs) { - return ds::greater(lhs, rhs); + return Share(ds::greater(*lhs, *rhs)); } // [[arrow::export]] std::shared_ptr dataset___expr__greater_equal( const std::shared_ptr& lhs, const std::shared_ptr& rhs) { - return ds::greater_equal(lhs, rhs); + return Share(ds::greater_equal(*lhs, *rhs)); } // [[arrow::export]] std::shared_ptr dataset___expr__less( const std::shared_ptr& lhs, const std::shared_ptr& rhs) { - return ds::less(lhs, rhs); + return Share(ds::less(*lhs, *rhs)); } // [[arrow::export]] std::shared_ptr dataset___expr__less_equal( const std::shared_ptr& lhs, const std::shared_ptr& rhs) { - return ds::less_equal(lhs, rhs); + return Share(ds::less_equal(*lhs, *rhs)); } // [[arrow::export]] std::shared_ptr dataset___expr__in( const std::shared_ptr& lhs, const std::shared_ptr& rhs) { - return lhs->In(rhs).Copy(); + return Share(ds::call("is_in", {*lhs}, arrow::compute::SetLookupOptions{rhs})); } // [[arrow::export]] std::shared_ptr dataset___expr__and( const std::shared_ptr& lhs, const std::shared_ptr& rhs) { - return ds::and_(lhs, rhs); + return Share(ds::and_(*lhs, *rhs)); } // [[arrow::export]] std::shared_ptr dataset___expr__or( const std::shared_ptr& lhs, const std::shared_ptr& rhs) { - return ds::or_(lhs, rhs); + return Share(ds::or_(*lhs, *rhs)); } // [[arrow::export]] std::shared_ptr dataset___expr__not( const std::shared_ptr& lhs) { - return ds::not_(lhs); + return Share(ds::not_(*lhs)); } // [[arrow::export]] std::shared_ptr dataset___expr__is_valid( const std::shared_ptr& lhs) { - return lhs->IsValid().Copy(); + return Share(ds::call("is_valid", {*lhs})); } // [[arrow::export]] std::shared_ptr dataset___expr__scalar( const std::shared_ptr& x) { - return ds::scalar(x); + return Share(ds::literal(x)); } // [[arrow::export]] From c691f6b04af36f51161c05528f1245975a2fb04f Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Fri, 11 Dec 2020 13:30:46 -0500 Subject: [PATCH 08/38] add more scalar cast implementations --- .../compute/kernels/scalar_cast_internal.h | 8 ++ .../compute/kernels/scalar_cast_temporal.cc | 75 +++++++++++++------ .../compute/kernels/scalar_set_lookup.cc | 29 ++----- .../arrow/compute/kernels/util_internal.cc | 23 ++++++ cpp/src/arrow/compute/kernels/util_internal.h | 10 +++ cpp/src/arrow/dataset/expression.cc | 47 ++++++------ cpp/src/arrow/dataset/expression_internal.h | 10 ++- cpp/src/arrow/dataset/expression_test.cc | 56 +++++++++----- cpp/src/arrow/dataset/test_util.h | 2 + 9 files changed, 169 insertions(+), 91 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_internal.h b/cpp/src/arrow/compute/kernels/scalar_cast_internal.h index 333601cf216..15769ce5f8f 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_internal.h +++ b/cpp/src/arrow/compute/kernels/scalar_cast_internal.h @@ -21,6 +21,7 @@ #include "arrow/compute/cast.h" // IWYU pragma: export #include "arrow/compute/cast_internal.h" // IWYU pragma: export #include "arrow/compute/kernels/common.h" +#include "arrow/compute/kernels/util_internal.h" namespace arrow { @@ -62,6 +63,13 @@ void AddSimpleCast(InputType in_ty, OutputType out_ty, CastFunction* func) { CastFunctor::Exec)); } +template +void AddSimpleArrayOnlyCast(InputType in_ty, OutputType out_ty, CastFunction* func) { + DCHECK_OK(func->AddKernel( + InType::type_id, {in_ty}, out_ty, + TrivialScalarUnaryAsArraysExec(CastFunctor::Exec))); +} + void ZeroCopyCastExec(KernelContext* ctx, const ExecBatch& batch, Datum* out); void AddZeroCopyCast(Type::type in_type_id, InputType in_type, OutputType out_type, diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_temporal.cc b/cpp/src/arrow/compute/kernels/scalar_cast_temporal.cc index 10b38d75095..7cff41f267e 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_temporal.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_temporal.cc @@ -126,7 +126,6 @@ struct CastFunctor< enable_if_t<(is_timestamp_type::value && is_timestamp_type::value) || (is_duration_type::value && is_duration_type::value)>> { static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - // TODO: Make this work on scalar inputs DCHECK_EQ(batch[0].kind(), Datum::ARRAY); const ArrayData& input = *batch[0].array(); @@ -147,7 +146,6 @@ struct CastFunctor< template <> struct CastFunctor { static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - // TODO: Make this work on scalar inputs DCHECK_EQ(batch[0].kind(), Datum::ARRAY); const ArrayData& input = *batch[0].array(); @@ -170,7 +168,6 @@ struct CastFunctor { template <> struct CastFunctor { static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - // TODO: Make this work on scalar inputs DCHECK_EQ(batch[0].kind(), Datum::ARRAY); const CastOptions& options = checked_cast(*ctx->state()).options; @@ -224,7 +221,6 @@ struct CastFunctor::value && is_time_type:: using out_t = typename O::c_type; static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - // TODO: Make this work on scalar inputs DCHECK_EQ(batch[0].kind(), Datum::ARRAY); const ArrayData& input = *batch[0].array(); @@ -245,7 +241,6 @@ struct CastFunctor::value && is_time_type:: template <> struct CastFunctor { static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - // TODO: Make this work on scalar inputs DCHECK_EQ(batch[0].kind(), Datum::ARRAY); ShiftTime(ctx, util::MULTIPLY, kMillisecondsInDay, @@ -256,7 +251,6 @@ struct CastFunctor { template <> struct CastFunctor { static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - // TODO: Make this work on scalar inputs DCHECK_EQ(batch[0].kind(), Datum::ARRAY); ShiftTime(ctx, util::DIVIDE, kMillisecondsInDay, *batch[0].array(), @@ -264,6 +258,38 @@ struct CastFunctor { } }; +// ---------------------------------------------------------------------- +// date32, date64 to timestamp + +template <> +struct CastFunctor { + static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + DCHECK_EQ(batch[0].kind(), Datum::ARRAY); + + const auto& out_type = checked_cast(*out->type()); + // get conversion SECOND -> unit + auto conversion = util::GetTimestampConversion(TimeUnit::SECOND, out_type.unit()); + DCHECK_EQ(conversion.first, util::MULTIPLY); + + // multiply to achieve days -> unit + conversion.second *= kMillisecondsInDay / 1000; + ShiftTime(ctx, util::MULTIPLY, conversion.second, *batch[0].array(), + out->mutable_array()); + } +}; + +template <> +struct CastFunctor { + static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + DCHECK_EQ(batch[0].kind(), Datum::ARRAY); + + const auto& out_type = checked_cast(*out->type()); + auto conversion = util::GetTimestampConversion(TimeUnit::MILLI, out_type.unit()); + ShiftTime(ctx, util::MULTIPLY, conversion.second, *batch[0].array(), + out->mutable_array()); + } +}; + // ---------------------------------------------------------------------- // String to Timestamp @@ -307,11 +333,11 @@ std::shared_ptr GetDate32Cast() { AddZeroCopyCast(Type::INT32, int32(), date32(), func.get()); // date64 -> date32 - AddSimpleCast(date64(), date32(), func.get()); + AddSimpleArrayOnlyCast(date64(), date32(), func.get()); // timestamp -> date32 - AddSimpleCast(InputType(Type::TIMESTAMP), date32(), - func.get()); + AddSimpleArrayOnlyCast(InputType(Type::TIMESTAMP), date32(), + func.get()); return func; } @@ -324,11 +350,11 @@ std::shared_ptr GetDate64Cast() { AddZeroCopyCast(Type::INT64, int64(), date64(), func.get()); // date32 -> date64 - AddSimpleCast(date32(), date64(), func.get()); + AddSimpleArrayOnlyCast(date32(), date64(), func.get()); // timestamp -> date64 - AddSimpleCast(InputType(Type::TIMESTAMP), date64(), - func.get()); + AddSimpleArrayOnlyCast(InputType(Type::TIMESTAMP), date64(), + func.get()); return func; } @@ -358,8 +384,8 @@ std::shared_ptr GetTime32Cast() { AddZeroCopyCast(Type::INT32, /*in_type=*/int32(), kOutputTargetType, func.get()); // time64 -> time32 - AddSimpleCast(InputType(Type::TIME64), kOutputTargetType, - func.get()); + AddSimpleArrayOnlyCast(InputType(Type::TIME64), + kOutputTargetType, func.get()); // time32 -> time32 AddCrossUnitCast(func.get()); @@ -375,8 +401,8 @@ std::shared_ptr GetTime64Cast() { AddZeroCopyCast(Type::INT64, /*in_type=*/int64(), kOutputTargetType, func.get()); // time32 -> time64 - AddSimpleCast(InputType(Type::TIME32), kOutputTargetType, - func.get()); + AddSimpleArrayOnlyCast(InputType(Type::TIME32), + kOutputTargetType, func.get()); // Between durations AddCrossUnitCast(func.get()); @@ -392,17 +418,18 @@ std::shared_ptr GetTimestampCast() { AddZeroCopyCast(Type::INT64, /*in_type=*/int64(), kOutputTargetType, func.get()); // From date types - // TODO: ARROW-8876, these casts are not implemented - // AddSimpleCast(InputType(Type::DATE32), - // kOutputTargetType, func.get()); - // AddSimpleCast(InputType(Type::DATE64), - // kOutputTargetType, func.get()); + // TODO: ARROW-8876, these casts are not directly tested + AddSimpleArrayOnlyCast(InputType(Type::DATE32), + kOutputTargetType, func.get()); + AddSimpleArrayOnlyCast(InputType(Type::DATE64), + kOutputTargetType, func.get()); // string -> timestamp - AddSimpleCast(utf8(), kOutputTargetType, func.get()); + AddSimpleArrayOnlyCast(utf8(), kOutputTargetType, + func.get()); // large_string -> timestamp - AddSimpleCast(large_utf8(), kOutputTargetType, - func.get()); + AddSimpleArrayOnlyCast(large_utf8(), kOutputTargetType, + func.get()); // From one timestamp to another AddCrossUnitCast(func.get()); diff --git a/cpp/src/arrow/compute/kernels/scalar_set_lookup.cc b/cpp/src/arrow/compute/kernels/scalar_set_lookup.cc index d43083c7538..722f3173a34 100644 --- a/cpp/src/arrow/compute/kernels/scalar_set_lookup.cc +++ b/cpp/src/arrow/compute/kernels/scalar_set_lookup.cc @@ -19,10 +19,10 @@ #include "arrow/array/builder_primitive.h" #include "arrow/compute/api_scalar.h" #include "arrow/compute/kernels/common.h" +#include "arrow/compute/kernels/util_internal.h" #include "arrow/util/bit_util.h" #include "arrow/util/bitmap_writer.h" #include "arrow/util/hashing.h" -#include "arrow/util/optional.h" #include "arrow/visitor_inline.h" namespace arrow { @@ -255,25 +255,8 @@ struct IndexInVisitor { } }; -void ExecArrayOrScalar(KernelContext* ctx, const Datum& in, Datum* out, - std::function array_impl) { - if (in.is_array()) { - KERNEL_RETURN_IF_ERROR(ctx, array_impl(*in.array())); - return; - } - - std::shared_ptr in_array; - std::shared_ptr out_scalar; - KERNEL_RETURN_IF_ERROR(ctx, MakeArrayFromScalar(*in.scalar(), 1).Value(&in_array)); - KERNEL_RETURN_IF_ERROR(ctx, array_impl(*in_array->data())); - KERNEL_RETURN_IF_ERROR(ctx, out->make_array()->GetScalar(0).Value(&out_scalar)); - *out = std::move(out_scalar); -} - void ExecIndexIn(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - ExecArrayOrScalar(ctx, batch[0], out, [&](const ArrayData& in) { - return IndexInVisitor(ctx, in, out).Execute(); - }); + KERNEL_RETURN_IF_ERROR(ctx, IndexInVisitor(ctx, *batch[0].array(), out).Execute()); } // ---------------------------------------------------------------------- @@ -360,9 +343,7 @@ struct IsInVisitor { }; void ExecIsIn(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - ExecArrayOrScalar(ctx, batch[0], out, [&](const ArrayData& in) { - return IsInVisitor(ctx, in, out).Execute(); - }); + KERNEL_RETURN_IF_ERROR(ctx, IsInVisitor(ctx, *batch[0].array(), out).Execute()); } // Unary set lookup kernels available for the following input types @@ -451,7 +432,7 @@ void RegisterScalarSetLookup(FunctionRegistry* registry) { { ScalarKernel isin_base; isin_base.init = InitSetLookup; - isin_base.exec = ExecIsIn; + isin_base.exec = TrivialScalarUnaryAsArraysExec(ExecIsIn); auto is_in = std::make_shared("is_in", Arity::Unary(), &is_in_doc); AddBasicSetLookupKernels(isin_base, /*output_type=*/boolean(), is_in.get()); @@ -468,7 +449,7 @@ void RegisterScalarSetLookup(FunctionRegistry* registry) { { ScalarKernel index_in_base; index_in_base.init = InitSetLookup; - index_in_base.exec = ExecIndexIn; + index_in_base.exec = TrivialScalarUnaryAsArraysExec(ExecIndexIn); index_in_base.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE; index_in_base.mem_allocation = MemAllocation::NO_PREALLOCATE; auto index_in = diff --git a/cpp/src/arrow/compute/kernels/util_internal.cc b/cpp/src/arrow/compute/kernels/util_internal.cc index 32c6317a104..21b6b706a50 100644 --- a/cpp/src/arrow/compute/kernels/util_internal.cc +++ b/cpp/src/arrow/compute/kernels/util_internal.cc @@ -57,6 +57,29 @@ PrimitiveArg GetPrimitiveArg(const ArrayData& arr) { return arg; } +ArrayKernelExec TrivialScalarUnaryAsArraysExec(ArrayKernelExec exec) { + return [exec](KernelContext* ctx, const ExecBatch& batch, Datum* out) { + if (out->is_array()) { + return exec(ctx, batch, out); + } + + if (!batch[0].scalar()->is_valid) { + out->scalar()->is_valid = false; + return; + } + + Datum array_in, array_out; + KERNEL_RETURN_IF_ERROR( + ctx, MakeArrayFromScalar(*batch[0].scalar(), 1).As().Value(&array_in)); + KERNEL_RETURN_IF_ERROR( + ctx, MakeArrayFromScalar(*out->scalar(), 1).As().Value(&array_out)); + + exec(ctx, ExecBatch{{std::move(array_in)}, 1}, &array_out); + KERNEL_RETURN_IF_ERROR(ctx, + array_out.make_array()->GetScalar(0).As().Value(out)); + }; +} + } // namespace internal } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/util_internal.h b/cpp/src/arrow/compute/kernels/util_internal.h index 7ab59965752..4aad3804366 100644 --- a/cpp/src/arrow/compute/kernels/util_internal.h +++ b/cpp/src/arrow/compute/kernels/util_internal.h @@ -18,8 +18,12 @@ #pragma once #include +#include +#include "arrow/array/util.h" #include "arrow/buffer.h" +#include "arrow/compute/kernels/codegen_internal.h" +#include "arrow/compute/type_fwd.h" namespace arrow { namespace compute { @@ -50,6 +54,12 @@ int GetBitWidth(const DataType& type); // rather than duplicating compiled code to do all these in each kernel. PrimitiveArg GetPrimitiveArg(const ArrayData& arr); +// Augment a unary ArrayKernelExec which supports only array-like inputs with support for +// scalar inputs. Scalars will be transformed to 1-long arrays which are passed to the +// original exec. This could be far more efficient, but instead of optimizing this it'd be +// better to support scalar inputs "upstream" in original exec. +ArrayKernelExec TrivialScalarUnaryAsArraysExec(ArrayKernelExec exec); + } // namespace internal } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/dataset/expression.cc b/cpp/src/arrow/dataset/expression.cc index 6ff53d4190c..cc142cbb2d4 100644 --- a/cpp/src/arrow/dataset/expression.cc +++ b/cpp/src/arrow/dataset/expression.cc @@ -335,7 +335,7 @@ Status MaybeInsertCast(std::shared_ptr to_type, Expression* expr) { auto call_with_cast = *CallNotNull(with_cast); call_with_cast.arguments[0] = std::move(*expr); *expr = Expression(std::make_shared(std::move(call_with_cast)), - ValueDescr{std::move(to_type), expr->descr().shape}); + ValueDescr{std::move(to_type), expr->descr().shape}); return Status::OK(); } @@ -344,7 +344,7 @@ Status InsertImplicitCasts(Expression::Call* call) { DCHECK(std::all_of(call->arguments.begin(), call->arguments.end(), [](const Expression& argument) { return argument.IsBound(); })); - if (Comparison::Get(call->function)) { + if (IsSameTypesBinary(call->function)) { for (auto&& argument : call->arguments) { if (auto value_type = GetDictionaryValueType(argument.descr().type)) { RETURN_NOT_OK(MaybeInsertCast(std::move(value_type), &argument)); @@ -390,7 +390,7 @@ Status InsertImplicitCasts(Expression::Call* call) { } Result Expression::Bind(ValueDescr in, - compute::ExecContext* exec_context) const { + compute::ExecContext* exec_context) const { if (exec_context == nullptr) { compute::ExecContext exec_context; return Bind(std::move(in), &exec_context); @@ -430,7 +430,7 @@ Result Expression::Bind(ValueDescr in, } Result Expression::Bind(const Schema& in_schema, - compute::ExecContext* exec_context) const { + compute::ExecContext* exec_context) const { return Bind(ValueDescr::Array(struct_(in_schema.fields())), exec_context); } @@ -493,9 +493,9 @@ std::array, 2> ArgumentsAndFlippedArguments(const Expression::Call& call) { DCHECK_EQ(call.arguments.size(), 2); return {std::pair{call.arguments[0], - call.arguments[1]}, + call.arguments[1]}, std::pair{call.arguments[1], - call.arguments[0]}}; + call.arguments[0]}}; } template Canonicalize(Expression expr, compute::ExecContext* exec_cont // fold the chain back up const auto& descr = expr.descr(); - auto folded = FoldLeft( - chain.fringe.begin(), chain.fringe.end(), - [call, &descr, &AlreadyCanonicalized](Expression l, Expression r) { - auto ret = *call; - ret.arguments = {std::move(l), std::move(r)}; - Expression expr(std::make_shared(std::move(ret)), - descr); - AlreadyCanonicalized.Add({expr}); - return expr; - }); + auto folded = + FoldLeft(chain.fringe.begin(), chain.fringe.end(), + [call, &descr, &AlreadyCanonicalized](Expression l, Expression r) { + auto ret = *call; + ret.arguments = {std::move(l), std::move(r)}; + Expression expr( + std::make_shared(std::move(ret)), descr); + AlreadyCanonicalized.Add({expr}); + return expr; + }); return std::move(*folded); } @@ -792,9 +792,8 @@ Result Canonicalize(Expression expr, compute::ExecContext* exec_cont ARROW_ASSIGN_OR_RAISE(flipped_call.kernel, function->DispatchExact(descrs)); std::swap(flipped_call.arguments[0], flipped_call.arguments[1]); - return Expression( - std::make_shared(std::move(flipped_call)), - expr.descr()); + return Expression(std::make_shared(std::move(flipped_call)), + expr.descr()); } } @@ -804,7 +803,7 @@ Result Canonicalize(Expression expr, compute::ExecContext* exec_cont } Result DirectComparisonSimplification(Expression expr, - const Expression::Call& guarantee) { + const Expression::Call& guarantee) { return Modify( std::move(expr), [](Expression expr) { return expr; }, [&guarantee](Expression expr, ...) -> Result { @@ -862,7 +861,7 @@ Result DirectComparisonSimplification(Expression expr, } Result SimplifyWithGuarantee(Expression expr, - const Expression& guaranteed_true_predicate) { + const Expression& guaranteed_true_predicate) { auto conjunction_members = GuaranteeConjunctionMembers(guaranteed_true_predicate); std::unordered_map known_values; @@ -1072,7 +1071,7 @@ Expression and_(Expression lhs, Expression rhs) { Expression and_(const std::vector& operands) { auto folded = FoldLeft(operands.begin(), - operands.end(), and_); + operands.end(), and_); if (folded) { return std::move(*folded); } @@ -1084,8 +1083,8 @@ Expression or_(Expression lhs, Expression rhs) { } Expression or_(const std::vector& operands) { - auto folded = FoldLeft(operands.begin(), - operands.end(), or_); + auto folded = + FoldLeft(operands.begin(), operands.end(), or_); if (folded) { return std::move(*folded); } diff --git a/cpp/src/arrow/dataset/expression_internal.h b/cpp/src/arrow/dataset/expression_internal.h index 8c26e037f16..90f28cd23ef 100644 --- a/cpp/src/arrow/dataset/expression_internal.h +++ b/cpp/src/arrow/dataset/expression_internal.h @@ -216,6 +216,14 @@ inline bool IsSetLookup(const std::string& function) { return function == "is_in" || function == "index_in"; } +inline bool IsSameTypesBinary(const std::string& function) { + if (Comparison::Get(function)) return true; + + static std::unordered_set set{"add", "subtract", "multiply", "divide"}; + + return set.find(function) != set.end(); +} + inline const compute::SetLookupOptions* GetSetLookupOptions( const Expression::Call& call) { if (!IsSetLookup(call.function)) return nullptr; @@ -402,7 +410,7 @@ inline Result> GetFunction( template Result Modify(Expression expr, const PreVisit& pre, - const PostVisitCall& post_call) { + const PostVisitCall& post_call) { ARROW_ASSIGN_OR_RAISE(expr, Result(pre(std::move(expr)))); auto call = expr.call(); diff --git a/cpp/src/arrow/dataset/expression_test.cc b/cpp/src/arrow/dataset/expression_test.cc index d9738798259..91e9dcdc4f1 100644 --- a/cpp/src/arrow/dataset/expression_test.cc +++ b/cpp/src/arrow/dataset/expression_test.cc @@ -219,30 +219,36 @@ TEST(Expression, BindLiteral) { } } -void ExpectBindsTo(Expression expr, Expression expected, +void ExpectBindsTo(Expression expr, util::optional expected, Expression* bound_out = nullptr) { + if (!expected) { + expected = expr; + } + ASSERT_OK_AND_ASSIGN(auto bound, expr.Bind(*kBoringSchema)); EXPECT_TRUE(bound.IsBound()); - ASSERT_OK_AND_ASSIGN(expected, expected.Bind(*kBoringSchema)); - EXPECT_EQ(bound, expected) << " unbound: " << expr.ToString(); + ASSERT_OK_AND_ASSIGN(expected, expected->Bind(*kBoringSchema)); + EXPECT_EQ(bound, *expected) << " unbound: " << expr.ToString(); if (bound_out) { *bound_out = bound; } } +const auto no_change = util::nullopt; + TEST(Expression, BindFieldRef) { // an unbound field_ref does not have the output ValueDescr set auto expr = field_ref("alpha"); EXPECT_EQ(expr.descr(), ValueDescr{}); EXPECT_FALSE(expr.IsBound()); - ExpectBindsTo(field_ref("i32"), field_ref("i32"), &expr); + ExpectBindsTo(field_ref("i32"), no_change, &expr); EXPECT_EQ(expr.descr(), ValueDescr::Array(int32())); // if the field is not found, a null scalar will be emitted - ExpectBindsTo(field_ref("no such field"), field_ref("no such field"), &expr); + ExpectBindsTo(field_ref("no such field"), no_change, &expr); EXPECT_EQ(expr.descr(), ValueDescr::Scalar(null())); // referencing a field by name is not supported if that name is not unique @@ -258,17 +264,19 @@ TEST(Expression, BindFieldRef) { } TEST(Expression, BindCall) { - auto expr = call("add", {field_ref("a"), field_ref("b")}); + auto expr = call("add", {field_ref("i32"), field_ref("i32_req")}); EXPECT_FALSE(expr.IsBound()); - ASSERT_OK_AND_ASSIGN(expr, - expr.Bind(Schema({field("a", int32()), field("b", int32())}))); + ExpectBindsTo(expr, no_change, &expr); EXPECT_EQ(expr.descr(), ValueDescr::Array(int32())); - EXPECT_TRUE(expr.IsBound()); - expr = call("add", {field_ref("a"), literal(3.5)}); - ASSERT_RAISES(NotImplemented, - expr.Bind(Schema({field("a", int32()), field("b", int32())}))); + // literal(3) may be safely cast to float32, so binding this expr casts that literal: + ExpectBindsTo(call("add", {field_ref("f32"), literal(3)}), + call("add", {field_ref("f32"), literal(3.0F)})); + + // literal(3.5) may not be safely cast to int32, so binding this expr fails: + ASSERT_RAISES(Invalid, + call("add", {field_ref("i32"), literal(3.5)}).Bind(*kBoringSchema)); } TEST(Expression, BindWithImplicitCasts) { @@ -284,11 +292,9 @@ TEST(Expression, BindWithImplicitCasts) { } // scalars are directly cast when possible: - ExpectBindsTo( - equal(field_ref("ts_ns"), literal("1990-10-23 10:23:33")), - equal(field_ref("ts_ns"), - literal( - *MakeScalar("1990-10-23 10:23:33")->CastTo(timestamp(TimeUnit::NANO))))); + auto ts_scalar = MakeScalar("1990-10-23")->CastTo(timestamp(TimeUnit::NANO)); + ExpectBindsTo(equal(field_ref("ts_ns"), literal("1990-10-23")), + equal(field_ref("ts_ns"), literal(*ts_scalar))); // cast value_set to argument type auto Opts = [](std::shared_ptr type) { @@ -509,7 +515,10 @@ TEST(Expression, FoldConstants) { literal(2), })); - compute::SetLookupOptions in_123(ArrayFromJSON(int32(), "[1,2,3]"), true); + compute::SetLookupOptions in_123(ArrayFromJSON(int32(), "[1,2,3]")); + + ExpectFoldsTo(call("is_in", {literal(2)}, in_123), literal(true)); + ExpectFoldsTo( call("is_in", {call("add", {field_ref("i32"), call("multiply", {literal(2), literal(3)})})}, @@ -932,6 +941,17 @@ TEST(Expression, Filter) { {"i32": 0, "f32": null, "in": 1}, {"i32": 0, "f32": 1.0, "in": 1} ])"); + + ExpectFilter( + greater(call("multiply", {field_ref("f32"), field_ref("f64")}), literal(0)), R"([ + {"f64": 0.3, "f32": 0.1, "in": 1}, + {"f64": -0.1, "f32": 0.3, "in": 0}, + {"f64": 0.1, "f32": 0.2, "in": 1}, + {"f64": 0.0, "f32": -0.1, "in": 0}, + {"f64": 1.0, "f32": 0.1, "in": 1}, + {"f64": -2.0, "f32": null, "in": null}, + {"f64": 3.0, "f32": 1.0, "in": 1} + ])"); } TEST(Expression, SerializationRoundTrips) { diff --git a/cpp/src/arrow/dataset/test_util.h b/cpp/src/arrow/dataset/test_util.h index c5904b45e06..1c7c471d3ca 100644 --- a/cpp/src/arrow/dataset/test_util.h +++ b/cpp/src/arrow/dataset/test_util.h @@ -53,8 +53,10 @@ const std::shared_ptr kBoringSchema = schema({ field("i32", int32()), field("i32_req", int32(), /*nullable=*/false), field("i64", int64()), + field("date64", date64()), field("f32", float32()), field("f32_req", float32(), /*nullable=*/false), + field("f64", float64()), field("bool", boolean()), field("str", utf8()), field("dict_str", dictionary(int32(), utf8())), From 1083e08ab2856def094ae28d9501acd11c6fb2df Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Mon, 14 Dec 2020 11:13:26 -0500 Subject: [PATCH 09/38] use dataset___expr__call to create Call expressions --- r/R/arrowExports.R | 48 +--------- r/R/expression.R | 34 ++++--- r/src/arrowExports.cpp | 197 ++++------------------------------------- r/src/compute.cpp | 11 ++- r/src/expression.cpp | 92 ++++--------------- 5 files changed, 63 insertions(+), 319 deletions(-) diff --git a/r/R/arrowExports.R b/r/R/arrowExports.R index 65c4dcdebbd..11fd99b2321 100644 --- a/r/R/arrowExports.R +++ b/r/R/arrowExports.R @@ -736,52 +736,12 @@ FixedSizeListType__list_size <- function(type){ .Call(`_arrow_FixedSizeListType__list_size` , type) } -dataset___expr__field_ref <- function(name){ - .Call(`_arrow_dataset___expr__field_ref` , name) -} - -dataset___expr__equal <- function(lhs, rhs){ - .Call(`_arrow_dataset___expr__equal` , lhs, rhs) -} - -dataset___expr__not_equal <- function(lhs, rhs){ - .Call(`_arrow_dataset___expr__not_equal` , lhs, rhs) -} - -dataset___expr__greater <- function(lhs, rhs){ - .Call(`_arrow_dataset___expr__greater` , lhs, rhs) -} - -dataset___expr__greater_equal <- function(lhs, rhs){ - .Call(`_arrow_dataset___expr__greater_equal` , lhs, rhs) -} - -dataset___expr__less <- function(lhs, rhs){ - .Call(`_arrow_dataset___expr__less` , lhs, rhs) +dataset___expr__call <- function(func_name, argument_list, options){ + .Call(`_arrow_dataset___expr__call` , func_name, argument_list, options) } -dataset___expr__less_equal <- function(lhs, rhs){ - .Call(`_arrow_dataset___expr__less_equal` , lhs, rhs) -} - -dataset___expr__in <- function(lhs, rhs){ - .Call(`_arrow_dataset___expr__in` , lhs, rhs) -} - -dataset___expr__and <- function(lhs, rhs){ - .Call(`_arrow_dataset___expr__and` , lhs, rhs) -} - -dataset___expr__or <- function(lhs, rhs){ - .Call(`_arrow_dataset___expr__or` , lhs, rhs) -} - -dataset___expr__not <- function(lhs){ - .Call(`_arrow_dataset___expr__not` , lhs) -} - -dataset___expr__is_valid <- function(lhs){ - .Call(`_arrow_dataset___expr__is_valid` , lhs) +dataset___expr__field_ref <- function(name){ + .Call(`_arrow_dataset___expr__field_ref` , name) } dataset___expr__scalar <- function(x){ diff --git a/r/R/expression.R b/r/R/expression.R index d5623fb7786..8c7e4bca59f 100644 --- a/r/R/expression.R +++ b/r/R/expression.R @@ -182,36 +182,44 @@ Expression$field_ref <- function(name) { Expression$scalar <- function(x) { dataset___expr__scalar(Scalar$create(x)) } +Expression$call <- function(name, arguments, options = list()) { + dataset___expr__call(name, arguments, options) +} Expression$compare <- function(OP, e1, e2) { comp_func <- comparison_function_map[[OP]] if (is.null(comp_func)) { stop(OP, " is not a supported comparison function", call. = FALSE) } - comp_func(e1, e2) + Expression$call(comp_func, list(e1, e2)) } comparison_function_map <- list( - "==" = dataset___expr__equal, - "!=" = dataset___expr__not_equal, - ">" = dataset___expr__greater, - ">=" = dataset___expr__greater_equal, - "<" = dataset___expr__less, - "<=" = dataset___expr__less_equal + "==" = "equal", + "!=" = "not_equal", + ">" = "greater", + ">=" = "greater_equal", + "<" = "less", + "<=" = "less_equal" ) Expression$in_ <- function(x, set) { - dataset___expr__in(x, Array$create(set)) + Expression$call("is_in", list(x), + options = list(value_set = Array$create(set), + skip_nulls = TRUE)) } Expression$and <- function(e1, e2) { - dataset___expr__and(e1, e2) + Expression$call("and_kleene", list(e1, e2)) } Expression$or <- function(e1, e2) { - dataset___expr__or(e1, e2) + Expression$call("or_kleene", list(e1, e2)) } Expression$not <- function(e1) { - dataset___expr__not(e1) + Expression$call("invert", list(e1)) } Expression$is_valid <- function(e1) { - dataset___expr__is_valid(e1) + Expression$call("is_valid", list(e1)) +} +Expression$is_null <- function(e1) { + Expression$call("is_null", list(e1)) } #' @export @@ -253,4 +261,4 @@ make_expression <- function(operator, e1, e2) { } #' @export -is.na.Expression <- function(x) !Expression$is_valid(x) +is.na.Expression <- function(x) Expression$is_null(x) diff --git a/r/src/arrowExports.cpp b/r/src/arrowExports.cpp index 79de10091c9..670bee20665 100644 --- a/r/src/arrowExports.cpp +++ b/r/src/arrowExports.cpp @@ -2846,190 +2846,33 @@ extern "C" SEXP _arrow_FixedSizeListType__list_size(SEXP type_sexp){ // expression.cpp #if defined(ARROW_R_WITH_ARROW) -std::shared_ptr dataset___expr__field_ref(std::string name); -extern "C" SEXP _arrow_dataset___expr__field_ref(SEXP name_sexp){ +std::shared_ptr dataset___expr__call(std::string func_name, cpp11::list argument_list, cpp11::list options); +extern "C" SEXP _arrow_dataset___expr__call(SEXP func_name_sexp, SEXP argument_list_sexp, SEXP options_sexp){ BEGIN_CPP11 - arrow::r::Input::type name(name_sexp); - return cpp11::as_sexp(dataset___expr__field_ref(name)); -END_CPP11 -} -#else -extern "C" SEXP _arrow_dataset___expr__field_ref(SEXP name_sexp){ - Rf_error("Cannot call dataset___expr__field_ref(). See https://arrow.apache.org/docs/r/articles/install.html for help installing Arrow C++ libraries. "); -} -#endif - -// expression.cpp -#if defined(ARROW_R_WITH_ARROW) -std::shared_ptr dataset___expr__equal(const std::shared_ptr& lhs, const std::shared_ptr& rhs); -extern "C" SEXP _arrow_dataset___expr__equal(SEXP lhs_sexp, SEXP rhs_sexp){ -BEGIN_CPP11 - arrow::r::Input&>::type lhs(lhs_sexp); - arrow::r::Input&>::type rhs(rhs_sexp); - return cpp11::as_sexp(dataset___expr__equal(lhs, rhs)); -END_CPP11 -} -#else -extern "C" SEXP _arrow_dataset___expr__equal(SEXP lhs_sexp, SEXP rhs_sexp){ - Rf_error("Cannot call dataset___expr__equal(). See https://arrow.apache.org/docs/r/articles/install.html for help installing Arrow C++ libraries. "); -} -#endif - -// expression.cpp -#if defined(ARROW_R_WITH_ARROW) -std::shared_ptr dataset___expr__not_equal(const std::shared_ptr& lhs, const std::shared_ptr& rhs); -extern "C" SEXP _arrow_dataset___expr__not_equal(SEXP lhs_sexp, SEXP rhs_sexp){ -BEGIN_CPP11 - arrow::r::Input&>::type lhs(lhs_sexp); - arrow::r::Input&>::type rhs(rhs_sexp); - return cpp11::as_sexp(dataset___expr__not_equal(lhs, rhs)); -END_CPP11 -} -#else -extern "C" SEXP _arrow_dataset___expr__not_equal(SEXP lhs_sexp, SEXP rhs_sexp){ - Rf_error("Cannot call dataset___expr__not_equal(). See https://arrow.apache.org/docs/r/articles/install.html for help installing Arrow C++ libraries. "); -} -#endif - -// expression.cpp -#if defined(ARROW_R_WITH_ARROW) -std::shared_ptr dataset___expr__greater(const std::shared_ptr& lhs, const std::shared_ptr& rhs); -extern "C" SEXP _arrow_dataset___expr__greater(SEXP lhs_sexp, SEXP rhs_sexp){ -BEGIN_CPP11 - arrow::r::Input&>::type lhs(lhs_sexp); - arrow::r::Input&>::type rhs(rhs_sexp); - return cpp11::as_sexp(dataset___expr__greater(lhs, rhs)); -END_CPP11 -} -#else -extern "C" SEXP _arrow_dataset___expr__greater(SEXP lhs_sexp, SEXP rhs_sexp){ - Rf_error("Cannot call dataset___expr__greater(). See https://arrow.apache.org/docs/r/articles/install.html for help installing Arrow C++ libraries. "); -} -#endif - -// expression.cpp -#if defined(ARROW_R_WITH_ARROW) -std::shared_ptr dataset___expr__greater_equal(const std::shared_ptr& lhs, const std::shared_ptr& rhs); -extern "C" SEXP _arrow_dataset___expr__greater_equal(SEXP lhs_sexp, SEXP rhs_sexp){ -BEGIN_CPP11 - arrow::r::Input&>::type lhs(lhs_sexp); - arrow::r::Input&>::type rhs(rhs_sexp); - return cpp11::as_sexp(dataset___expr__greater_equal(lhs, rhs)); -END_CPP11 -} -#else -extern "C" SEXP _arrow_dataset___expr__greater_equal(SEXP lhs_sexp, SEXP rhs_sexp){ - Rf_error("Cannot call dataset___expr__greater_equal(). See https://arrow.apache.org/docs/r/articles/install.html for help installing Arrow C++ libraries. "); -} -#endif - -// expression.cpp -#if defined(ARROW_R_WITH_ARROW) -std::shared_ptr dataset___expr__less(const std::shared_ptr& lhs, const std::shared_ptr& rhs); -extern "C" SEXP _arrow_dataset___expr__less(SEXP lhs_sexp, SEXP rhs_sexp){ -BEGIN_CPP11 - arrow::r::Input&>::type lhs(lhs_sexp); - arrow::r::Input&>::type rhs(rhs_sexp); - return cpp11::as_sexp(dataset___expr__less(lhs, rhs)); -END_CPP11 -} -#else -extern "C" SEXP _arrow_dataset___expr__less(SEXP lhs_sexp, SEXP rhs_sexp){ - Rf_error("Cannot call dataset___expr__less(). See https://arrow.apache.org/docs/r/articles/install.html for help installing Arrow C++ libraries. "); -} -#endif - -// expression.cpp -#if defined(ARROW_R_WITH_ARROW) -std::shared_ptr dataset___expr__less_equal(const std::shared_ptr& lhs, const std::shared_ptr& rhs); -extern "C" SEXP _arrow_dataset___expr__less_equal(SEXP lhs_sexp, SEXP rhs_sexp){ -BEGIN_CPP11 - arrow::r::Input&>::type lhs(lhs_sexp); - arrow::r::Input&>::type rhs(rhs_sexp); - return cpp11::as_sexp(dataset___expr__less_equal(lhs, rhs)); -END_CPP11 -} -#else -extern "C" SEXP _arrow_dataset___expr__less_equal(SEXP lhs_sexp, SEXP rhs_sexp){ - Rf_error("Cannot call dataset___expr__less_equal(). See https://arrow.apache.org/docs/r/articles/install.html for help installing Arrow C++ libraries. "); -} -#endif - -// expression.cpp -#if defined(ARROW_R_WITH_ARROW) -std::shared_ptr dataset___expr__in(const std::shared_ptr& lhs, const std::shared_ptr& rhs); -extern "C" SEXP _arrow_dataset___expr__in(SEXP lhs_sexp, SEXP rhs_sexp){ -BEGIN_CPP11 - arrow::r::Input&>::type lhs(lhs_sexp); - arrow::r::Input&>::type rhs(rhs_sexp); - return cpp11::as_sexp(dataset___expr__in(lhs, rhs)); -END_CPP11 -} -#else -extern "C" SEXP _arrow_dataset___expr__in(SEXP lhs_sexp, SEXP rhs_sexp){ - Rf_error("Cannot call dataset___expr__in(). See https://arrow.apache.org/docs/r/articles/install.html for help installing Arrow C++ libraries. "); -} -#endif - -// expression.cpp -#if defined(ARROW_R_WITH_ARROW) -std::shared_ptr dataset___expr__and(const std::shared_ptr& lhs, const std::shared_ptr& rhs); -extern "C" SEXP _arrow_dataset___expr__and(SEXP lhs_sexp, SEXP rhs_sexp){ -BEGIN_CPP11 - arrow::r::Input&>::type lhs(lhs_sexp); - arrow::r::Input&>::type rhs(rhs_sexp); - return cpp11::as_sexp(dataset___expr__and(lhs, rhs)); -END_CPP11 -} -#else -extern "C" SEXP _arrow_dataset___expr__and(SEXP lhs_sexp, SEXP rhs_sexp){ - Rf_error("Cannot call dataset___expr__and(). See https://arrow.apache.org/docs/r/articles/install.html for help installing Arrow C++ libraries. "); -} -#endif - -// expression.cpp -#if defined(ARROW_R_WITH_ARROW) -std::shared_ptr dataset___expr__or(const std::shared_ptr& lhs, const std::shared_ptr& rhs); -extern "C" SEXP _arrow_dataset___expr__or(SEXP lhs_sexp, SEXP rhs_sexp){ -BEGIN_CPP11 - arrow::r::Input&>::type lhs(lhs_sexp); - arrow::r::Input&>::type rhs(rhs_sexp); - return cpp11::as_sexp(dataset___expr__or(lhs, rhs)); -END_CPP11 -} -#else -extern "C" SEXP _arrow_dataset___expr__or(SEXP lhs_sexp, SEXP rhs_sexp){ - Rf_error("Cannot call dataset___expr__or(). See https://arrow.apache.org/docs/r/articles/install.html for help installing Arrow C++ libraries. "); -} -#endif - -// expression.cpp -#if defined(ARROW_R_WITH_ARROW) -std::shared_ptr dataset___expr__not(const std::shared_ptr& lhs); -extern "C" SEXP _arrow_dataset___expr__not(SEXP lhs_sexp){ -BEGIN_CPP11 - arrow::r::Input&>::type lhs(lhs_sexp); - return cpp11::as_sexp(dataset___expr__not(lhs)); + arrow::r::Input::type func_name(func_name_sexp); + arrow::r::Input::type argument_list(argument_list_sexp); + arrow::r::Input::type options(options_sexp); + return cpp11::as_sexp(dataset___expr__call(func_name, argument_list, options)); END_CPP11 } #else -extern "C" SEXP _arrow_dataset___expr__not(SEXP lhs_sexp){ - Rf_error("Cannot call dataset___expr__not(). See https://arrow.apache.org/docs/r/articles/install.html for help installing Arrow C++ libraries. "); +extern "C" SEXP _arrow_dataset___expr__call(SEXP func_name_sexp, SEXP argument_list_sexp, SEXP options_sexp){ + Rf_error("Cannot call dataset___expr__call(). See https://arrow.apache.org/docs/r/articles/install.html for help installing Arrow C++ libraries. "); } #endif // expression.cpp #if defined(ARROW_R_WITH_ARROW) -std::shared_ptr dataset___expr__is_valid(const std::shared_ptr& lhs); -extern "C" SEXP _arrow_dataset___expr__is_valid(SEXP lhs_sexp){ +std::shared_ptr dataset___expr__field_ref(std::string name); +extern "C" SEXP _arrow_dataset___expr__field_ref(SEXP name_sexp){ BEGIN_CPP11 - arrow::r::Input&>::type lhs(lhs_sexp); - return cpp11::as_sexp(dataset___expr__is_valid(lhs)); + arrow::r::Input::type name(name_sexp); + return cpp11::as_sexp(dataset___expr__field_ref(name)); END_CPP11 } #else -extern "C" SEXP _arrow_dataset___expr__is_valid(SEXP lhs_sexp){ - Rf_error("Cannot call dataset___expr__is_valid(). See https://arrow.apache.org/docs/r/articles/install.html for help installing Arrow C++ libraries. "); +extern "C" SEXP _arrow_dataset___expr__field_ref(SEXP name_sexp){ + Rf_error("Cannot call dataset___expr__field_ref(). See https://arrow.apache.org/docs/r/articles/install.html for help installing Arrow C++ libraries. "); } #endif @@ -6605,18 +6448,8 @@ static const R_CallMethodDef CallEntries[] = { { "_arrow_FixedSizeListType__value_field", (DL_FUNC) &_arrow_FixedSizeListType__value_field, 1}, { "_arrow_FixedSizeListType__value_type", (DL_FUNC) &_arrow_FixedSizeListType__value_type, 1}, { "_arrow_FixedSizeListType__list_size", (DL_FUNC) &_arrow_FixedSizeListType__list_size, 1}, + { "_arrow_dataset___expr__call", (DL_FUNC) &_arrow_dataset___expr__call, 3}, { "_arrow_dataset___expr__field_ref", (DL_FUNC) &_arrow_dataset___expr__field_ref, 1}, - { "_arrow_dataset___expr__equal", (DL_FUNC) &_arrow_dataset___expr__equal, 2}, - { "_arrow_dataset___expr__not_equal", (DL_FUNC) &_arrow_dataset___expr__not_equal, 2}, - { "_arrow_dataset___expr__greater", (DL_FUNC) &_arrow_dataset___expr__greater, 2}, - { "_arrow_dataset___expr__greater_equal", (DL_FUNC) &_arrow_dataset___expr__greater_equal, 2}, - { "_arrow_dataset___expr__less", (DL_FUNC) &_arrow_dataset___expr__less, 2}, - { "_arrow_dataset___expr__less_equal", (DL_FUNC) &_arrow_dataset___expr__less_equal, 2}, - { "_arrow_dataset___expr__in", (DL_FUNC) &_arrow_dataset___expr__in, 2}, - { "_arrow_dataset___expr__and", (DL_FUNC) &_arrow_dataset___expr__and, 2}, - { "_arrow_dataset___expr__or", (DL_FUNC) &_arrow_dataset___expr__or, 2}, - { "_arrow_dataset___expr__not", (DL_FUNC) &_arrow_dataset___expr__not, 1}, - { "_arrow_dataset___expr__is_valid", (DL_FUNC) &_arrow_dataset___expr__is_valid, 1}, { "_arrow_dataset___expr__scalar", (DL_FUNC) &_arrow_dataset___expr__scalar, 1}, { "_arrow_dataset___expr__ToString", (DL_FUNC) &_arrow_dataset___expr__ToString, 1}, { "_arrow_ipc___WriteFeather__Table", (DL_FUNC) &_arrow_ipc___WriteFeather__Table, 6}, diff --git a/r/src/compute.cpp b/r/src/compute.cpp index 12b23dad798..8a75a251d8e 100644 --- a/r/src/compute.cpp +++ b/r/src/compute.cpp @@ -126,7 +126,6 @@ arrow::Datum as_cpp(SEXP x) { // This assumes that R objects have already been converted to Arrow objects; // that seems right but should we do the wrapping here too/instead? cpp11::stop("to_datum: Not implemented for type %s", Rf_type2char(TYPEOF(x))); - return arrow::Datum(); } } // namespace cpp11 @@ -151,9 +150,7 @@ SEXP from_datum(arrow::Datum datum) { break; } - auto str = datum.ToString(); - cpp11::stop("from_datum: Not implemented for Datum %s", str.c_str()); - return R_NilValue; + cpp11::stop("from_datum: Not implemented for Datum %s", datum.ToString().c_str()); } std::shared_ptr make_compute_options( @@ -182,6 +179,12 @@ std::shared_ptr make_compute_options( return out; } + if (func_name == "is_in" || func_name == "index_in") { + using Options = arrow::compute::SetLookupOptions; + return std::make_shared(cpp11::as_cpp(options["value_set"]), + cpp11::as_cpp(options["skip_nulls"])); + } + return nullptr; } diff --git a/r/src/expression.cpp b/r/src/expression.cpp index b1116d27a2f..ddb1e72c309 100644 --- a/r/src/expression.cpp +++ b/r/src/expression.cpp @@ -23,94 +23,34 @@ #include namespace ds = ::arrow::dataset; -std::shared_ptr Share(ds::Expression expr) { - return std::make_shared(std::move(expr)); -} - -// [[arrow::export]] -std::shared_ptr dataset___expr__field_ref(std::string name) { - return Share(ds::field_ref(std::move(name))); -} - -// [[arrow::export]] -std::shared_ptr dataset___expr__equal( - const std::shared_ptr& lhs, - const std::shared_ptr& rhs) { - return Share(ds::equal(*lhs, *rhs)); -} +std::shared_ptr make_compute_options( + std::string func_name, cpp11::list options); // [[arrow::export]] -std::shared_ptr dataset___expr__not_equal( - const std::shared_ptr& lhs, - const std::shared_ptr& rhs) { - return Share(ds::not_equal(*lhs, *rhs)); -} - -// [[arrow::export]] -std::shared_ptr dataset___expr__greater( - const std::shared_ptr& lhs, - const std::shared_ptr& rhs) { - return Share(ds::greater(*lhs, *rhs)); -} +std::shared_ptr dataset___expr__call(std::string func_name, + cpp11::list argument_list, + cpp11::list options) { + std::vector arguments; + for (SEXP argument : argument_list) { + auto argument_ptr = cpp11::as_cpp>(argument); + arguments.push_back(*argument_ptr); + } -// [[arrow::export]] -std::shared_ptr dataset___expr__greater_equal( - const std::shared_ptr& lhs, - const std::shared_ptr& rhs) { - return Share(ds::greater_equal(*lhs, *rhs)); -} + auto options_ptr = make_compute_options(func_name, options); -// [[arrow::export]] -std::shared_ptr dataset___expr__less( - const std::shared_ptr& lhs, - const std::shared_ptr& rhs) { - return Share(ds::less(*lhs, *rhs)); + return std::make_shared( + ds::call(std::move(func_name), std::move(arguments), std::move(options_ptr))); } // [[arrow::export]] -std::shared_ptr dataset___expr__less_equal( - const std::shared_ptr& lhs, - const std::shared_ptr& rhs) { - return Share(ds::less_equal(*lhs, *rhs)); -} - -// [[arrow::export]] -std::shared_ptr dataset___expr__in( - const std::shared_ptr& lhs, - const std::shared_ptr& rhs) { - return Share(ds::call("is_in", {*lhs}, arrow::compute::SetLookupOptions{rhs})); -} - -// [[arrow::export]] -std::shared_ptr dataset___expr__and( - const std::shared_ptr& lhs, - const std::shared_ptr& rhs) { - return Share(ds::and_(*lhs, *rhs)); -} - -// [[arrow::export]] -std::shared_ptr dataset___expr__or( - const std::shared_ptr& lhs, - const std::shared_ptr& rhs) { - return Share(ds::or_(*lhs, *rhs)); -} - -// [[arrow::export]] -std::shared_ptr dataset___expr__not( - const std::shared_ptr& lhs) { - return Share(ds::not_(*lhs)); -} - -// [[arrow::export]] -std::shared_ptr dataset___expr__is_valid( - const std::shared_ptr& lhs) { - return Share(ds::call("is_valid", {*lhs})); +std::shared_ptr dataset___expr__field_ref(std::string name) { + return std::make_shared(ds::field_ref(std::move(name))); } // [[arrow::export]] std::shared_ptr dataset___expr__scalar( const std::shared_ptr& x) { - return Share(ds::literal(x)); + return std::make_shared(ds::literal(std::move(x))); } // [[arrow::export]] From 557ca31597e41d5dec803fcd300145bc3723b01f Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Mon, 14 Dec 2020 11:16:04 -0500 Subject: [PATCH 10/38] lint fixes --- cpp/src/arrow/dataset/dataset.h | 9 +- cpp/src/arrow/dataset/expression_internal.h | 4 +- cpp/src/arrow/dataset/expression_test.cc | 1 - cpp/src/arrow/dataset/file_base.h | 3 +- cpp/src/arrow/dataset/file_parquet_test.cc | 3 +- cpp/src/arrow/dataset/filter.cc | 2 + cpp/src/arrow/dataset/filter.h | 1 - python/pyarrow/_dataset.pyx | 89 +++++++++----------- python/pyarrow/includes/libarrow_dataset.pxd | 89 +++++--------------- 9 files changed, 69 insertions(+), 132 deletions(-) diff --git a/cpp/src/arrow/dataset/dataset.h b/cpp/src/arrow/dataset/dataset.h index 4ad6ecbe74a..c92381d78c5 100644 --- a/cpp/src/arrow/dataset/dataset.h +++ b/cpp/src/arrow/dataset/dataset.h @@ -94,8 +94,7 @@ class ARROW_DS_EXPORT InMemoryFragment : public Fragment { public: InMemoryFragment(std::shared_ptr schema, RecordBatchVector record_batches, Expression = literal(true)); - explicit InMemoryFragment(RecordBatchVector record_batches, - Expression = literal(true)); + explicit InMemoryFragment(RecordBatchVector record_batches, Expression = literal(true)); Result Scan(std::shared_ptr options, std::shared_ptr context) override; @@ -180,8 +179,7 @@ class ARROW_DS_EXPORT InMemoryDataset : public Dataset { std::shared_ptr schema) const override; protected: - Result GetFragmentsImpl( - Expression predicate) override; + Result GetFragmentsImpl(Expression predicate) override; std::shared_ptr get_batches_; }; @@ -205,8 +203,7 @@ class ARROW_DS_EXPORT UnionDataset : public Dataset { std::shared_ptr schema) const override; protected: - Result GetFragmentsImpl( - Expression predicate) override; + Result GetFragmentsImpl(Expression predicate) override; explicit UnionDataset(std::shared_ptr schema, DatasetVector children) : Dataset(std::move(schema)), children_(std::move(children)) {} diff --git a/cpp/src/arrow/dataset/expression_internal.h b/cpp/src/arrow/dataset/expression_internal.h index 90f28cd23ef..306f219859a 100644 --- a/cpp/src/arrow/dataset/expression_internal.h +++ b/cpp/src/arrow/dataset/expression_internal.h @@ -179,7 +179,7 @@ struct Comparison { return GREATER_EQUAL; case GREATER_EQUAL: return LESS_EQUAL; - }; + } DCHECK(false); return NA; } @@ -201,7 +201,7 @@ struct Comparison { return "less_equal"; case GREATER_EQUAL: return "greater_equal"; - }; + } DCHECK(false); return "na"; } diff --git a/cpp/src/arrow/dataset/expression_test.cc b/cpp/src/arrow/dataset/expression_test.cc index 91e9dcdc4f1..e5301677ed7 100644 --- a/cpp/src/arrow/dataset/expression_test.cc +++ b/cpp/src/arrow/dataset/expression_test.cc @@ -465,7 +465,6 @@ struct { EXPECT_TRUE(Identical(folded, expr)); } } - } ExpectFoldsTo; TEST(Expression, FoldConstants) { diff --git a/cpp/src/arrow/dataset/file_base.h b/cpp/src/arrow/dataset/file_base.h index 7b67e6f9bf5..bb2aa86ba9b 100644 --- a/cpp/src/arrow/dataset/file_base.h +++ b/cpp/src/arrow/dataset/file_base.h @@ -233,8 +233,7 @@ class ARROW_DS_EXPORT FileSystemDataset : public Dataset { std::string ToString() const; protected: - Result GetFragmentsImpl( - Expression predicate) override; + Result GetFragmentsImpl(Expression predicate) override; FileSystemDataset(std::shared_ptr schema, Expression root_partition, std::shared_ptr format, diff --git a/cpp/src/arrow/dataset/file_parquet_test.cc b/cpp/src/arrow/dataset/file_parquet_test.cc index e9182ddf696..67d1fb17120 100644 --- a/cpp/src/arrow/dataset/file_parquet_test.cc +++ b/cpp/src/arrow/dataset/file_parquet_test.cc @@ -190,8 +190,7 @@ class TestParquetFileFormat : public ArrowParquetWriterMixin { } void CountRowGroupsInFragment(const std::shared_ptr& fragment, - std::vector expected_row_groups, - Expression filter) { + std::vector expected_row_groups, Expression filter) { schema_ = opts_->schema(); ASSERT_OK_AND_ASSIGN(auto bound, filter.Bind(*schema_)); auto parquet_fragment = checked_pointer_cast(fragment); diff --git a/cpp/src/arrow/dataset/filter.cc b/cpp/src/arrow/dataset/filter.cc index 9852f8c0808..2357896dd7a 100644 --- a/cpp/src/arrow/dataset/filter.cc +++ b/cpp/src/arrow/dataset/filter.cc @@ -15,4 +15,6 @@ // specific language governing permissions and limitations // under the License. +#include "arrow/dataset/filter.h" + // FIXME remove diff --git a/cpp/src/arrow/dataset/filter.h b/cpp/src/arrow/dataset/filter.h index 6dd39e25e1e..9852f8c0808 100644 --- a/cpp/src/arrow/dataset/filter.h +++ b/cpp/src/arrow/dataset/filter.h @@ -16,4 +16,3 @@ // under the License. // FIXME remove - diff --git a/python/pyarrow/_dataset.pyx b/python/pyarrow/_dataset.pyx index bbe0485fa69..e235d50c74a 100644 --- a/python/pyarrow/_dataset.pyx +++ b/python/pyarrow/_dataset.pyx @@ -86,24 +86,22 @@ cdef CFileSource _make_file_source(object file, FileSystem filesystem=None): cdef class Expression(_Weakrefable): cdef: - shared_ptr[CExpression] wrapped - CExpression* expr + CExpression expr def __init__(self): _forbid_instantiation(self.__class__) - cdef void init(self, const shared_ptr[CExpression]& sp): - self.wrapped = sp - self.expr = sp.get() + cdef void init(self, const CExpression& sp): + self.expr = sp @staticmethod - cdef wrap(const shared_ptr[CExpression]& sp): + cdef wrap(const CExpression& sp): cdef Expression self = Expression.__new__(Expression) self.init(sp) return self - cdef inline shared_ptr[CExpression] unwrap(self): - return self.wrapped + cdef inline CExpression unwrap(self): + return self.expr def equals(self, Expression other): return self.expr.Equals(other.unwrap()) @@ -119,7 +117,7 @@ cdef class Expression(_Weakrefable): @staticmethod def _deserialize(Buffer buffer not None): c_buffer = pyarrow_unwrap_buffer(buffer) - c_expr = GetResultValue(CExpression.Deserialize(deref(c_buffer))) + #c_expr = GetResultValue(CExpression.Deserialize(deref(c_buffer))) return Expression.wrap(move(c_expr)) def __reduce__(self): @@ -157,16 +155,16 @@ cdef class Expression(_Weakrefable): return Expression.wrap(CMakeNotExpression(self.unwrap())) @staticmethod - cdef shared_ptr[CExpression] _expr_or_scalar(object expr) except *: + cdef CExpression _expr_or_scalar(object expr) except *: if isinstance(expr, Expression): return ( expr).unwrap() return ( Expression._scalar(expr)).unwrap() def __richcmp__(self, other, int op): cdef: - shared_ptr[CExpression] c_expr - shared_ptr[CExpression] c_left - shared_ptr[CExpression] c_right + CExpression c_expr + CExpression c_left + CExpression c_right c_left = self.unwrap() c_right = Expression._expr_or_scalar(other) @@ -212,10 +210,17 @@ cdef class Expression(_Weakrefable): def isin(self, values): """Checks whether the expression is contained in values""" + cdef: + vector[CExpression] arguments + arguments.push_back(self.expr) + if not isinstance(values, pa.Array): values = pa.array(values) c_values = pyarrow_unwrap_array(values) - return Expression.wrap(self.expr.In(c_values).Copy()) + return Expression.wrap(CMakeCallExpression( + tobytes("is_in"), + move(arguments), + move(SetLookupOptions(values).set_lookup_options))) @staticmethod def _field(str name not None): @@ -231,11 +236,7 @@ cdef class Expression(_Weakrefable): else: scalar = pa.scalar(value) - return Expression.wrap( - shared_ptr[CExpression]( - new CScalarExpression(move(scalar.unwrap())) - ) - ) + return Expression.wrap(CMakeScalarExpression(scalar.unwrap())) _deserialize = Expression._deserialize @@ -288,12 +289,7 @@ cdef class Dataset(_Weakrefable): An Expression which evaluates to true for all data viewed by this Dataset. """ - cdef shared_ptr[CExpression] expr - expr = self.dataset.partition_expression() - if expr.get() == nullptr: - return None - else: - return Expression.wrap(expr) + return Expression.wrap(self.dataset.partition_expression()) def replace_schema(self, Schema schema not None): """ @@ -322,13 +318,13 @@ cdef class Dataset(_Weakrefable): fragments : iterator of Fragment """ cdef: - shared_ptr[CExpression] c_filter + CExpression c_filter CFragmentIterator c_iterator if filter is None: c_fragments = self.dataset.GetFragments() else: - c_filter = _insert_implicit_casts(filter, self.schema) + c_filter = _bind(filter, self.schema) c_fragments = self.dataset.GetFragments(c_filter) for maybe_fragment in c_fragments: @@ -593,19 +589,14 @@ cdef class FileSystemDataset(Dataset): return FileFormat.wrap(self.filesystem_dataset.format()) -cdef shared_ptr[CExpression] _insert_implicit_casts(Expression filter, - Schema schema) except *: +cdef CExpression _bind(Expression filter, Schema schema) except *: assert schema is not None if filter is None: return _true.unwrap() - return GetResultValue( - CInsertImplicitCasts( - deref(filter.unwrap().get()), - deref(pyarrow_unwrap_schema(schema).get()) - ) - ) + return GetResultValue(filter.unwrap().Bind( + deref(pyarrow_unwrap_schema(schema).get()))) cdef class FileWriteOptions(_Weakrefable): @@ -1035,11 +1026,11 @@ cdef class ParquetFileFragment(FileFragment): """ cdef: vector[shared_ptr[CFragment]] c_fragments - shared_ptr[CExpression] c_filter + CExpression c_filter shared_ptr[CFragment] c_fragment schema = schema or self.physical_schema - c_filter = _insert_implicit_casts(filter, schema) + c_filter = _bind(filter, schema) with nogil: c_fragments = move(GetResultValue( self.parquet_file_fragment.SplitByRowGroup(move(c_filter)))) @@ -1072,7 +1063,7 @@ cdef class ParquetFileFragment(FileFragment): ParquetFileFragment """ cdef: - shared_ptr[CExpression] c_filter + CExpression c_filter vector[int] c_row_group_ids shared_ptr[CFragment] c_fragment @@ -1083,7 +1074,7 @@ cdef class ParquetFileFragment(FileFragment): if filter is not None: schema = schema or self.physical_schema - c_filter = _insert_implicit_casts(filter, schema) + c_filter = _bind(filter, schema) with nogil: c_fragment = move(GetResultValue( self.parquet_file_fragment.SubsetWithFilter( @@ -1408,7 +1399,7 @@ cdef class Partitioning(_Weakrefable): return self.wrapped def parse(self, path): - cdef CResult[shared_ptr[CExpression]] result + cdef CResult[CExpression] result result = self.partitioning.Parse(tobytes(path)) return Expression.wrap(GetResultValue(result)) @@ -1643,11 +1634,7 @@ cdef class DatasetFactory(_Weakrefable): @property def root_partition(self): - cdef shared_ptr[CExpression] expr = self.factory.root_partition() - if expr.get() == nullptr: - return None - else: - return Expression.wrap(expr) + return Expression.wrap(self.factory.root_partition()) @root_partition.setter def root_partition(self, Expression expr): @@ -2121,14 +2108,14 @@ cdef void _populate_builder(const shared_ptr[CScannerBuilder]& ptr, int batch_size=_DEFAULT_BATCH_SIZE) except *: cdef: CScannerBuilder *builder - builder = ptr.get() - if columns is not None: - check_status(builder.Project([tobytes(c) for c in columns])) - check_status(builder.Filter(_insert_implicit_casts( + check_status(builder.Filter(_bind( filter, pyarrow_wrap_schema(builder.schema())))) + if columns is not None: + check_status(builder.Project([tobytes(c) for c in columns])) + check_status(builder.BatchSize(batch_size)) @@ -2296,12 +2283,12 @@ def _get_partition_keys(Expression partition_expression): is converted to {'part': 'a', 'year': 2016} """ cdef: - shared_ptr[CExpression] expr = partition_expression.unwrap() + CExpression expr = partition_expression.unwrap() pair[c_string, shared_ptr[CScalar]] name_val return { frombytes(name_val.first): pyarrow_wrap_scalar(name_val.second).as_py() - for name_val in GetResultValue(CGetPartitionKeys(deref(expr.get()))) + for name_val in GetResultValue(CGetPartitionKeys(expr)) } diff --git a/python/pyarrow/includes/libarrow_dataset.pxd b/python/pyarrow/includes/libarrow_dataset.pxd index a81042920d5..53e48da1703 100644 --- a/python/pyarrow/includes/libarrow_dataset.pxd +++ b/python/pyarrow/includes/libarrow_dataset.pxd @@ -36,63 +36,19 @@ cdef extern from "arrow/dataset/api.h" namespace "arrow::dataset" nogil: cdef cppclass CExpression "arrow::dataset::Expression": c_bool Equals(const CExpression& other) const - c_bool Equals(const shared_ptr[CExpression]& other) const - CResult[shared_ptr[CDataType]] Validate(const CSchema& schema) const - shared_ptr[CExpression] Assume(const CExpression& given) const - shared_ptr[CExpression] Assume( - const shared_ptr[CExpression]& given) const c_string ToString() const - shared_ptr[CExpression] Copy() const + CResult[CExpression] Bind(const CSchema&) - const CExpression& In(shared_ptr[CArray]) const - const CExpression& IsValid() const - const CExpression& CastTo(shared_ptr[CDataType], CCastOptions) const - const CExpression& CastLike(shared_ptr[CExpression], - CCastOptions) const + cdef CExpression CMakeScalarExpression \ + "arrow::dataset::literal"(shared_ptr[CScalar] value) - @staticmethod - CResult[shared_ptr[CExpression]] Deserialize(const CBuffer& buffer) - CResult[shared_ptr[CBuffer]] Serialize() const - - ctypedef vector[shared_ptr[CExpression]] CExpressionVector \ - "arrow::dataset::ExpressionVector" - - cdef cppclass CScalarExpression \ - "arrow::dataset::ScalarExpression"(CExpression): - CScalarExpression(const shared_ptr[CScalar]& value) - - cdef shared_ptr[CExpression] CMakeFieldExpression \ + cdef CExpression CMakeFieldExpression \ "arrow::dataset::field_ref"(c_string name) - cdef shared_ptr[CExpression] CMakeNotExpression \ - "arrow::dataset::not_"(shared_ptr[CExpression] operand) - cdef shared_ptr[CExpression] CMakeAndExpression \ - "arrow::dataset::and_"(shared_ptr[CExpression], - shared_ptr[CExpression]) - cdef shared_ptr[CExpression] CMakeOrExpression \ - "arrow::dataset::or_"(shared_ptr[CExpression], - shared_ptr[CExpression]) - cdef shared_ptr[CExpression] CMakeEqualExpression \ - "arrow::dataset::equal"(shared_ptr[CExpression], - shared_ptr[CExpression]) - cdef shared_ptr[CExpression] CMakeNotEqualExpression \ - "arrow::dataset::not_equal"(shared_ptr[CExpression], - shared_ptr[CExpression]) - cdef shared_ptr[CExpression] CMakeGreaterExpression \ - "arrow::dataset::greater"(shared_ptr[CExpression], - shared_ptr[CExpression]) - cdef shared_ptr[CExpression] CMakeGreaterEqualExpression \ - "arrow::dataset::greater_equal"(shared_ptr[CExpression], - shared_ptr[CExpression]) - cdef shared_ptr[CExpression] CMakeLessExpression \ - "arrow::dataset::less"(shared_ptr[CExpression], - shared_ptr[CExpression]) - cdef shared_ptr[CExpression] CMakeLessEqualExpression \ - "arrow::dataset::less_equal"(shared_ptr[CExpression], - shared_ptr[CExpression]) - - cdef CResult[shared_ptr[CExpression]] CInsertImplicitCasts \ - "arrow::dataset::InsertImplicitCasts"( - const CExpression &, const CSchema&) + + cdef CExpression CMakeCallExpression \ + "arrow::dataset::call"(c_string function, + vector[CExpression] arguments, + shared_ptr[CFunctionOptions] options) cdef cppclass CRecordBatchProjector "arrow::dataset::RecordBatchProjector": pass @@ -119,7 +75,7 @@ cdef extern from "arrow/dataset/api.h" namespace "arrow::dataset" nogil: shared_ptr[CScanOptions] options, shared_ptr[CScanContext] context) c_bool splittable() const c_string type_name() const - const shared_ptr[CExpression]& partition_expression() const + const CExpression& partition_expression() const ctypedef vector[shared_ptr[CFragment]] CFragmentVector \ "arrow::dataset::FragmentVector" @@ -130,7 +86,7 @@ cdef extern from "arrow/dataset/api.h" namespace "arrow::dataset" nogil: cdef cppclass CInMemoryFragment "arrow::dataset::InMemoryFragment"( CFragment): CInMemoryFragment(vector[shared_ptr[CRecordBatch]] record_batches, - shared_ptr[CExpression] partition_expression) + CExpression partition_expression) cdef cppclass CScanner "arrow::dataset::Scanner": CScanner(shared_ptr[CDataset], shared_ptr[CScanOptions], @@ -148,8 +104,7 @@ cdef extern from "arrow/dataset/api.h" namespace "arrow::dataset" nogil: CScannerBuilder(shared_ptr[CSchema], shared_ptr[CFragment], shared_ptr[CScanContext] scan_context) CStatus Project(const vector[c_string]& columns) - CStatus Filter(const CExpression& filter) - CStatus Filter(shared_ptr[CExpression] filter) + CStatus Filter(CExpression filter) CStatus UseThreads(c_bool use_threads) CStatus BatchSize(int64_t batch_size) CResult[shared_ptr[CScanner]] Finish() @@ -161,8 +116,8 @@ cdef extern from "arrow/dataset/api.h" namespace "arrow::dataset" nogil: cdef cppclass CDataset "arrow::dataset::Dataset": const shared_ptr[CSchema] & schema() CFragmentIterator GetFragments() - CFragmentIterator GetFragments(shared_ptr[CExpression] predicate) - const shared_ptr[CExpression] & partition_expression() + CFragmentIterator GetFragments(CExpression predicate) + const CExpression & partition_expression() c_string type_name() CResult[shared_ptr[CDataset]] ReplaceSchema(shared_ptr[CSchema]) @@ -193,8 +148,8 @@ cdef extern from "arrow/dataset/api.h" namespace "arrow::dataset" nogil: CResult[shared_ptr[CDataset]] FinishWithSchema "Finish"( const shared_ptr[CSchema]& schema) CResult[shared_ptr[CDataset]] Finish() - const shared_ptr[CExpression]& root_partition() - CStatus SetRootPartition(shared_ptr[CExpression] partition) + const CExpression& root_partition() + CStatus SetRootPartition(CExpression partition) cdef cppclass CUnionDatasetFactory "arrow::dataset::UnionDatasetFactory": @staticmethod @@ -221,7 +176,7 @@ cdef extern from "arrow/dataset/api.h" namespace "arrow::dataset" nogil: CResult[shared_ptr[CSchema]] Inspect(const CFileSource&) const CResult[shared_ptr[CFileFragment]] MakeFragment( CFileSource source, - shared_ptr[CExpression] partition_expression, + CExpression partition_expression, shared_ptr[CSchema] physical_schema) shared_ptr[CFileWriteOptions] DefaultWriteOptions() @@ -240,9 +195,9 @@ cdef extern from "arrow/dataset/api.h" namespace "arrow::dataset" nogil: const vector[int]& row_groups() const shared_ptr[CFileMetaData] metadata() const CResult[vector[shared_ptr[CFragment]]] SplitByRowGroup( - shared_ptr[CExpression] predicate) + CExpression predicate) CResult[shared_ptr[CFragment]] SubsetWithFilter "Subset"( - shared_ptr[CExpression] predicate) + CExpression predicate) CResult[shared_ptr[CFragment]] SubsetWithIds "Subset"( vector[int] row_group_ids) CStatus EnsureCompleteMetadata() @@ -260,7 +215,7 @@ cdef extern from "arrow/dataset/api.h" namespace "arrow::dataset" nogil: @staticmethod CResult[shared_ptr[CDataset]] Make( shared_ptr[CSchema] schema, - shared_ptr[CExpression] source_partition, + CExpression source_partition, shared_ptr[CFileFormat] format, shared_ptr[CFileSystem] filesystem, vector[shared_ptr[CFileFragment]] fragments) @@ -287,7 +242,7 @@ cdef extern from "arrow/dataset/api.h" namespace "arrow::dataset" nogil: CParquetFileFormatReaderOptions reader_options CResult[shared_ptr[CFileFragment]] MakeFragment( CFileSource source, - shared_ptr[CExpression] partition_expression, + CExpression partition_expression, shared_ptr[CSchema] physical_schema, vector[int] row_groups) @@ -305,7 +260,7 @@ cdef extern from "arrow/dataset/api.h" namespace "arrow::dataset" nogil: cdef cppclass CPartitioning "arrow::dataset::Partitioning": c_string type_name() const - CResult[shared_ptr[CExpression]] Parse(const c_string & path) const + CResult[CExpression] Parse(const c_string & path) const const shared_ptr[CSchema] & schema() cdef cppclass CPartitioningFactoryOptions \ From 9d5a57dcd2a7af5693416698e232b001a724a56f Mon Sep 17 00:00:00 2001 From: Neal Richardson Date: Tue, 15 Dec 2020 13:25:24 -0800 Subject: [PATCH 11/38] Refactor dataset expression code to follow array expression pattern --- r/R/dplyr.R | 2 +- r/R/expression.R | 99 ++++++++++++++---------------------------------- 2 files changed, 29 insertions(+), 72 deletions(-) diff --git a/r/R/dplyr.R b/r/R/dplyr.R index f8de1a525c9..875823032ca 100644 --- a/r/R/dplyr.R +++ b/r/R/dplyr.R @@ -248,7 +248,7 @@ filter_mask <- function(.data) { if (query_on_dataset(.data)) { comp_func <- function(operator) { force(operator) - function(e1, e2) make_expression(operator, e1, e2) + function(e1, e2) build_dataset_expression(operator, e1, e2) } var_binder <- function(x) Expression$field_ref(x) } else { diff --git a/r/R/expression.R b/r/R/expression.R index 8c7e4bca59f..f601730f905 100644 --- a/r/R/expression.R +++ b/r/R/expression.R @@ -173,6 +173,31 @@ Expression <- R6Class("Expression", inherit = ArrowObject, ToString = function() dataset___expr__ToString(self) ) ) +Expression$create <- function(name, arguments, options = list()) { + dataset___expr__call(name, arguments, options) +} + +build_dataset_expression <- function(.Generic, e1, e2, ...) { + if (.Generic %in% names(.unary_function_map)) { + expr <- Expression$create(.unary_function_map[[.Generic]], list(e1)) + } else if (.Generic == "%in%") { + # Special-case %in%, which is different from the Array function name + expr <- Expression$create("is_in", list(e1), + options = list( + value_set = Array$create(e2), + skip_nulls = TRUE) + ) + } else { + if (!inherits(e1, "Expression")) { + e1 <- Expression$scalar(e1) + } + if (!inherits(e2, "Expression")) { + e2 <- Expression$scalar(e2) + } + expr <- Expression$create(.binary_function_map[[.Generic]], list(e1, e2), ...) + } + expr +} Expression$field_ref <- function(name) { assert_is(name, "character") @@ -182,83 +207,15 @@ Expression$field_ref <- function(name) { Expression$scalar <- function(x) { dataset___expr__scalar(Scalar$create(x)) } -Expression$call <- function(name, arguments, options = list()) { - dataset___expr__call(name, arguments, options) -} -Expression$compare <- function(OP, e1, e2) { - comp_func <- comparison_function_map[[OP]] - if (is.null(comp_func)) { - stop(OP, " is not a supported comparison function", call. = FALSE) - } - Expression$call(comp_func, list(e1, e2)) -} - -comparison_function_map <- list( - "==" = "equal", - "!=" = "not_equal", - ">" = "greater", - ">=" = "greater_equal", - "<" = "less", - "<=" = "less_equal" -) -Expression$in_ <- function(x, set) { - Expression$call("is_in", list(x), - options = list(value_set = Array$create(set), - skip_nulls = TRUE)) -} -Expression$and <- function(e1, e2) { - Expression$call("and_kleene", list(e1, e2)) -} -Expression$or <- function(e1, e2) { - Expression$call("or_kleene", list(e1, e2)) -} -Expression$not <- function(e1) { - Expression$call("invert", list(e1)) -} -Expression$is_valid <- function(e1) { - Expression$call("is_valid", list(e1)) -} -Expression$is_null <- function(e1) { - Expression$call("is_null", list(e1)) -} #' @export Ops.Expression <- function(e1, e2) { if (.Generic == "!") { - return(Expression$not(e1)) - } - make_expression(.Generic, e1, e2) -} - -make_expression <- function(operator, e1, e2) { - if (operator == "%in%") { - # In doesn't take Scalar, it takes Array - return(Expression$in_(e1, e2)) - } - - # Handle unary functions before touching e2 - if (operator == "is.na") { - return(is.na(e1)) - } - if (operator == "!") { - return(Expression$not(e1)) - } - - # Check for non-expressions and convert to Expressions - if (!inherits(e1, "Expression")) { - e1 <- Expression$scalar(e1) - } - if (!inherits(e2, "Expression")) { - e2 <- Expression$scalar(e2) - } - if (operator == "&") { - Expression$and(e1, e2) - } else if (operator == "|") { - Expression$or(e1, e2) + build_dataset_expression(.Generic, e1) } else { - Expression$compare(operator, e1, e2) + build_dataset_expression(.Generic, e1, e2) } } #' @export -is.na.Expression <- function(x) Expression$is_null(x) +is.na.Expression <- function(x) Expression$create("is_null", list(x)) From a04c7d96a3ccc4a1f5a8a52a728bccebb5c11cf4 Mon Sep 17 00:00:00 2001 From: Neal Richardson Date: Tue, 15 Dec 2020 13:54:33 -0800 Subject: [PATCH 12/38] A little more --- r/R/expression.R | 49 +++++++++++++++++++-------------------------- r/man/Expression.Rd | 14 ++----------- 2 files changed, 23 insertions(+), 40 deletions(-) diff --git a/r/R/expression.R b/r/R/expression.R index f601730f905..f9e09c2fadd 100644 --- a/r/R/expression.R +++ b/r/R/expression.R @@ -153,18 +153,8 @@ print.array_expression <- function(x, ...) { #' `Expression$field_ref(name)` is used to construct an `Expression` which #' evaluates to the named column in the `Dataset` against which it is evaluated. #' -#' `Expression$compare(OP, e1, e2)` takes two `Expression` operands, constructing -#' an `Expression` which will evaluate these operands then compare them with the -#' relation specified by OP (e.g. "==", "!=", ">", etc.) For example, to filter -#' down to rows where the column named "alpha" is less than 5: -#' `Expression$compare("<", Expression$field_ref("alpha"), Expression$scalar(5))` -#' -#' `Expression$and(e1, e2)`, `Expression$or(e1, e2)`, and `Expression$not(e1)` -#' construct an `Expression` combining their arguments with Boolean operators. -#' -#' `Expression$is_valid(x)` is essentially (an inversion of) `is.na()` for `Expression`s. -#' -#' `Expression$in_(x, set)` evaluates x and returns whether or not it is a member of the set. +#' `Expression$create(function_name, ..., options)` builds a function-call +#' `Expression` containing one or more `Expression`s. #' @name Expression #' @rdname Expression #' @export @@ -173,20 +163,32 @@ Expression <- R6Class("Expression", inherit = ArrowObject, ToString = function() dataset___expr__ToString(self) ) ) -Expression$create <- function(name, arguments, options = list()) { - dataset___expr__call(name, arguments, options) +Expression$create <- function(function_name, + ..., + args = list(...), + options = empty_named_list()) { + assert_that(is.string(function_name)) + dataset___expr__call(function_name, args, options) +} +Expression$field_ref <- function(name) { + assert_that(is.string(name)) + dataset___expr__field_ref(name) +} +Expression$scalar <- function(x) { + dataset___expr__scalar(Scalar$create(x)) } build_dataset_expression <- function(.Generic, e1, e2, ...) { if (.Generic %in% names(.unary_function_map)) { - expr <- Expression$create(.unary_function_map[[.Generic]], list(e1)) + expr <- Expression$create(.unary_function_map[[.Generic]], e1) } else if (.Generic == "%in%") { # Special-case %in%, which is different from the Array function name - expr <- Expression$create("is_in", list(e1), + expr <- Expression$create("is_in", e1, options = list( value_set = Array$create(e2), - skip_nulls = TRUE) + skip_nulls = TRUE ) + ) } else { if (!inherits(e1, "Expression")) { e1 <- Expression$scalar(e1) @@ -194,20 +196,11 @@ build_dataset_expression <- function(.Generic, e1, e2, ...) { if (!inherits(e2, "Expression")) { e2 <- Expression$scalar(e2) } - expr <- Expression$create(.binary_function_map[[.Generic]], list(e1, e2), ...) + expr <- Expression$create(.binary_function_map[[.Generic]], e1, e2, ...) } expr } -Expression$field_ref <- function(name) { - assert_is(name, "character") - assert_that(length(name) == 1) - dataset___expr__field_ref(name) -} -Expression$scalar <- function(x) { - dataset___expr__scalar(Scalar$create(x)) -} - #' @export Ops.Expression <- function(e1, e2) { if (.Generic == "!") { @@ -218,4 +211,4 @@ Ops.Expression <- function(e1, e2) { } #' @export -is.na.Expression <- function(x) Expression$create("is_null", list(x)) +is.na.Expression <- function(x) Expression$create("is_null", x) diff --git a/r/man/Expression.Rd b/r/man/Expression.Rd index 14570eba892..58a6a44c0c0 100644 --- a/r/man/Expression.Rd +++ b/r/man/Expression.Rd @@ -13,16 +13,6 @@ the provided scalar (length-1) R value. \code{Expression$field_ref(name)} is used to construct an \code{Expression} which evaluates to the named column in the \code{Dataset} against which it is evaluated. -\code{Expression$compare(OP, e1, e2)} takes two \code{Expression} operands, constructing -an \code{Expression} which will evaluate these operands then compare them with the -relation specified by OP (e.g. "==", "!=", ">", etc.) For example, to filter -down to rows where the column named "alpha" is less than 5: -\code{Expression$compare("<", Expression$field_ref("alpha"), Expression$scalar(5))} - -\code{Expression$and(e1, e2)}, \code{Expression$or(e1, e2)}, and \code{Expression$not(e1)} -construct an \code{Expression} combining their arguments with Boolean operators. - -\code{Expression$is_valid(x)} is essentially (an inversion of) \code{is.na()} for \code{Expression}s. - -\code{Expression$in_(x, set)} evaluates x and returns whether or not it is a member of the set. +\code{Expression$create(function_name, ..., options)} builds a function-call +\code{Expression} containing one or more \code{Expression}s. } From 2470187e2e19b645ccfcb8134aaf3ef718b15c8e Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Tue, 15 Dec 2020 16:57:45 -0500 Subject: [PATCH 13/38] revamp Expression::ToString, make cast-from-string error more informative --- .../compute/kernels/scalar_cast_numeric.cc | 4 +- .../compute/kernels/scalar_cast_temporal.cc | 3 +- .../arrow/compute/kernels/scalar_string.cc | 4 +- cpp/src/arrow/dataset/expression.cc | 93 ++++++++++++------- cpp/src/arrow/dataset/expression_internal.h | 22 +++++ cpp/src/arrow/dataset/expression_test.cc | 37 ++++++-- r/tests/testthat/test-dataset.R | 4 +- r/tests/testthat/test-expression.R | 2 +- 8 files changed, 120 insertions(+), 49 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc b/cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc index 62665d4ea44..6e550fb12c0 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc @@ -282,7 +282,9 @@ struct ParseString { OutValue Call(KernelContext* ctx, Arg0Value val) const { OutValue result = OutValue(0); if (ARROW_PREDICT_FALSE(!ParseValue(val.data(), val.size(), &result))) { - ctx->SetStatus(Status::Invalid("Failed to parse string: ", val)); + ctx->SetStatus(Status::Invalid("Failed to parse string: '", val, + "' as a scalar of type ", + TypeTraits::type_singleton()->ToString())); } return result; } diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_temporal.cc b/cpp/src/arrow/compute/kernels/scalar_cast_temporal.cc index 7cff41f267e..f1702b7fd51 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_temporal.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_temporal.cc @@ -298,7 +298,8 @@ struct ParseTimestamp { OutValue Call(KernelContext* ctx, Arg0Value val) const { OutValue result = 0; if (ARROW_PREDICT_FALSE(!ParseValue(type, val.data(), val.size(), &result))) { - ctx->SetStatus(Status::Invalid("Failed to parse string: ", val)); + ctx->SetStatus(Status::Invalid("Failed to parse string: '", val, + "' as a scalar of type ", type.ToString())); } return result; } diff --git a/cpp/src/arrow/compute/kernels/scalar_string.cc b/cpp/src/arrow/compute/kernels/scalar_string.cc index eff60b80481..2f4a0d45ed2 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string.cc @@ -1207,7 +1207,9 @@ struct ParseStrptime { int64_t Call(KernelContext* ctx, util::string_view val) const { int64_t result = 0; if (!(*parser)(val.data(), val.size(), unit, &result)) { - ctx->SetStatus(Status::Invalid("Failed to parse string ", val)); + ctx->SetStatus(Status::Invalid("Failed to parse string: '", val, + "' as a scalar of type ", + TimestampType(unit).ToString())); } return result; } diff --git a/cpp/src/arrow/dataset/expression.cc b/cpp/src/arrow/dataset/expression.cc index cc142cbb2d4..f4600196e78 100644 --- a/cpp/src/arrow/dataset/expression.cc +++ b/cpp/src/arrow/dataset/expression.cc @@ -52,6 +52,21 @@ const FieldRef* Expression::field_ref() const { std::string Expression::ToString() const { if (auto lit = literal()) { if (lit->is_scalar()) { + switch (lit->type()->id()) { + case Type::STRING: + case Type::LARGE_STRING: + return '"' + + Escape(util::string_view(*lit->scalar_as().value)) + + '"'; + + case Type::BINARY: + case Type::FIXED_SIZE_BINARY: + case Type::LARGE_BINARY: + return '"' + lit->scalar_as().value->ToHexString() + '"'; + + default: + break; + } return lit->scalar()->ToString(); } return lit->ToString(); @@ -59,7 +74,7 @@ std::string Expression::ToString() const { if (auto ref = field_ref()) { if (auto name = ref->name()) { - return "FieldRef(" + *name + ")"; + return *name; } if (auto path = ref->field_path()) { return path->ToString(); @@ -68,54 +83,62 @@ std::string Expression::ToString() const { } auto call = CallNotNull(*this); + if (auto cmp = Comparison::Get(call->function)) { + return "(" + call->arguments[0].ToString() + " " + Comparison::GetOp(*cmp) + " " + + call->arguments[1].ToString() + ")"; + } + + if (auto options = GetStructOptions(*call)) { + std::string out = "{"; + auto argument = call->arguments.begin(); + for (const auto& field_name : options->field_names) { + out += field_name + "=" + argument++->ToString() + ", "; + } + out.resize(out.size() - 1); + out.back() = '}'; + return out; + } - // FIXME represent FunctionOptions std::string out = call->function + "("; for (const auto& arg : call->arguments) { - out += arg.ToString() + ","; + out += arg.ToString() + ", "; } - if (!call->options) { + if (call->options == nullptr) { + out.resize(out.size() - 1); out.back() = ')'; return out; } - if (call->options) { - out += " {"; - - if (auto options = GetSetLookupOptions(*call)) { - DCHECK_EQ(options->value_set.kind(), Datum::ARRAY); - out += "value_set:" + options->value_set.make_array()->ToString(); - if (options->skip_nulls) { - out += ",skip_nulls"; - } - } - - if (auto options = GetCastOptions(*call)) { - out += "to_type:" + options->to_type->ToString(); - if (options->allow_int_overflow) out += ",allow_int_overflow"; - if (options->allow_time_truncate) out += ",allow_time_truncate"; - if (options->allow_time_overflow) out += ",allow_time_overflow"; - if (options->allow_decimal_truncate) out += ",allow_decimal_truncate"; - if (options->allow_float_truncate) out += ",allow_float_truncate"; - if (options->allow_invalid_utf8) out += ",allow_invalid_utf8"; - } - - if (auto options = GetStructOptions(*call)) { - for (const auto& field_name : options->field_names) { - out += field_name + ","; - } - out.resize(out.size() - 1); + if (auto options = GetSetLookupOptions(*call)) { + DCHECK_EQ(options->value_set.kind(), Datum::ARRAY); + out += "value_set=" + options->value_set.make_array()->ToString(); + if (options->skip_nulls) { + out += ", skip_nulls"; } + return out + ")"; + } - if (auto options = GetStrptimeOptions(*call)) { - out += "format:" + options->format; - out += ",unit:" + internal::ToString(options->unit); + if (auto options = GetCastOptions(*call)) { + if (options->to_type == nullptr) { + return out + "to_type=)"; } + out += "to_type=" + options->to_type->ToString(); + if (options->allow_int_overflow) out += ", allow_int_overflow"; + if (options->allow_time_truncate) out += ", allow_time_truncate"; + if (options->allow_time_overflow) out += ", allow_time_overflow"; + if (options->allow_decimal_truncate) out += ", allow_decimal_truncate"; + if (options->allow_float_truncate) out += ", allow_float_truncate"; + if (options->allow_invalid_utf8) out += ", allow_invalid_utf8"; + return out + ")"; + } - out += "})"; + if (auto options = GetStrptimeOptions(*call)) { + return out + "format=" + options->format + + ", unit=" + internal::ToString(options->unit) + ")"; } - return out; + + return out + "{NON-REPRESENTABLE OPTIONS})"; } void PrintTo(const Expression& expr, std::ostream* os) { diff --git a/cpp/src/arrow/dataset/expression_internal.h b/cpp/src/arrow/dataset/expression_internal.h index 306f219859a..c91bbb4b62f 100644 --- a/cpp/src/arrow/dataset/expression_internal.h +++ b/cpp/src/arrow/dataset/expression_internal.h @@ -205,6 +205,28 @@ struct Comparison { DCHECK(false); return "na"; } + + static std::string GetOp(type op) { + switch (op) { + case NA: + DCHECK(false) << "unreachable"; + break; + case EQUAL: + return "=="; + case LESS: + return "<"; + case GREATER: + return ">"; + case NOT_EQUAL: + return "!="; + case LESS_EQUAL: + return "<="; + case GREATER_EQUAL: + return ">="; + } + DCHECK(false); + return ""; + } }; inline const compute::CastOptions* GetCastOptions(const Expression::Call& call) { diff --git a/cpp/src/arrow/dataset/expression_test.cc b/cpp/src/arrow/dataset/expression_test.cc index e5301677ed7..738fab72237 100644 --- a/cpp/src/arrow/dataset/expression_test.cc +++ b/cpp/src/arrow/dataset/expression_test.cc @@ -48,23 +48,45 @@ Expression cast(Expression argument, std::shared_ptr to_type) { } TEST(Expression, ToString) { - EXPECT_EQ(field_ref("alpha").ToString(), "FieldRef(alpha)"); + EXPECT_EQ(field_ref("alpha").ToString(), "alpha"); EXPECT_EQ(literal(3).ToString(), "3"); + EXPECT_EQ(literal("a").ToString(), "\"a\""); + EXPECT_EQ(literal("a\nb").ToString(), "\"a\\nb\""); + EXPECT_EQ(literal(std::make_shared()).ToString(), "null"); + EXPECT_EQ(literal(std::make_shared(Buffer::FromString("az"))).ToString(), + "\"617A\""); auto ts = *MakeScalar("1990-10-23 10:23:33")->CastTo(timestamp(TimeUnit::NANO)); EXPECT_EQ(literal(ts).ToString(), "656677413000000000"); - EXPECT_EQ(call("add", {literal(3), field_ref("beta")}).ToString(), - "add(3,FieldRef(beta))"); + EXPECT_EQ(call("add", {literal(3), field_ref("beta")}).ToString(), "add(3, beta)"); auto in_12 = call("index_in", {field_ref("beta")}, compute::SetLookupOptions{ArrayFromJSON(int32(), "[1,2]")}); - EXPECT_EQ(in_12.ToString(), "index_in(FieldRef(beta), {value_set:[\n 1,\n 2\n]})"); + EXPECT_EQ(in_12.ToString(), "index_in(beta, value_set=[\n 1,\n 2\n])"); - EXPECT_EQ(cast(field_ref("a"), int32()).ToString(), - "cast(FieldRef(a), {to_type:int32})"); + EXPECT_EQ(cast(field_ref("a"), int32()).ToString(), "cast(a, to_type=int32)"); + EXPECT_EQ(cast(field_ref("a"), nullptr).ToString(), + "cast(a, to_type=)"); + + struct WidgetifyOptions : compute::FunctionOptions { + bool really; + }; + + // NB: corrupted for nullary functions but we don't have any of those + EXPECT_EQ(call("widgetify", {}).ToString(), "widgetif)"); + EXPECT_EQ( + call("widgetify", {literal(1)}, std::make_shared()).ToString(), + "widgetify(1, {NON-REPRESENTABLE OPTIONS})"); + + EXPECT_EQ(equal(field_ref("a"), literal(1)).ToString(), "(a == 1)"); + EXPECT_EQ(less(field_ref("a"), literal(2)).ToString(), "(a < 2)"); + EXPECT_EQ(greater(field_ref("a"), literal(3)).ToString(), "(a > 3)"); + EXPECT_EQ(not_equal(field_ref("a"), literal("a")).ToString(), "(a != \"a\")"); + EXPECT_EQ(less_equal(field_ref("a"), literal("b")).ToString(), "(a <= \"b\")"); + EXPECT_EQ(greater_equal(field_ref("a"), literal("c")).ToString(), "(a >= \"c\")"); EXPECT_EQ(project( { @@ -80,8 +102,7 @@ TEST(Expression, ToString) { "b", }) .ToString(), - "struct(FieldRef(a),FieldRef(a),3," + in_12.ToString() + - ", {a,renamed_a,three,b})"); + "{a=a, renamed_a=a, three=3, b=" + in_12.ToString() + "}"); } TEST(Expression, Equality) { diff --git a/r/tests/testthat/test-dataset.R b/r/tests/testthat/test-dataset.R index c976a4807ad..4c8db531411 100644 --- a/r/tests/testthat/test-dataset.R +++ b/r/tests/testthat/test-dataset.R @@ -528,7 +528,7 @@ test_that("filter scalar validation doesn't crash (ARROW-7772)", { ds %>% filter(int == "fff", part == 1) %>% collect(), - "error parsing 'fff' as scalar of type int32" + "Failed to parse string: 'fff' as a scalar of type int32" ) }) @@ -683,7 +683,7 @@ test_that("Dataset and query print methods", { "lgl: bool", "integer: int32", "", - "* Filter: (int == 6:double)", + "* Filter: (int == 6)", "* Grouped by lgl", "See $.data for the source Arrow object", sep = "\n" diff --git a/r/tests/testthat/test-expression.R b/r/tests/testthat/test-expression.R index 1bf08595758..0c5ef4c12da 100644 --- a/r/tests/testthat/test-expression.R +++ b/r/tests/testthat/test-expression.R @@ -60,7 +60,7 @@ test_that("C++ expressions", { expect_is(!(f < 4), "Expression") expect_output( print(f > 4), - 'Expression\n(f > 4:double)', + 'Expression\n(f > 4)', fixed = TRUE ) # Interprets that as a list type From d2ccb27a26f101508167690e054faf6f7b73b83c Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Tue, 15 Dec 2020 17:13:17 -0500 Subject: [PATCH 14/38] print and_kleene, or_kleene as binary ops --- cpp/src/arrow/dataset/expression.cc | 16 ++++++++++++++-- cpp/src/arrow/dataset/expression_test.cc | 4 ++++ 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/cpp/src/arrow/dataset/expression.cc b/cpp/src/arrow/dataset/expression.cc index f4600196e78..693f4b67d58 100644 --- a/cpp/src/arrow/dataset/expression.cc +++ b/cpp/src/arrow/dataset/expression.cc @@ -83,9 +83,21 @@ std::string Expression::ToString() const { } auto call = CallNotNull(*this); - if (auto cmp = Comparison::Get(call->function)) { - return "(" + call->arguments[0].ToString() + " " + Comparison::GetOp(*cmp) + " " + + auto binary = [&](std::string op) { + return "(" + call->arguments[0].ToString() + " " + op + " " + call->arguments[1].ToString() + ")"; + }; + + if (auto cmp = Comparison::Get(call->function)) { + return binary(Comparison::GetOp(*cmp)); + } + + if (call->function == "and_kleene") { + return binary("and"); + } + + if (call->function == "or_kleene") { + return binary("or"); } if (auto options = GetStructOptions(*call)) { diff --git a/cpp/src/arrow/dataset/expression_test.cc b/cpp/src/arrow/dataset/expression_test.cc index 738fab72237..709e19791a8 100644 --- a/cpp/src/arrow/dataset/expression_test.cc +++ b/cpp/src/arrow/dataset/expression_test.cc @@ -67,6 +67,10 @@ TEST(Expression, ToString) { EXPECT_EQ(in_12.ToString(), "index_in(beta, value_set=[\n 1,\n 2\n])"); + EXPECT_EQ(and_(field_ref("a"), field_ref("b")).ToString(), "(a and b)"); + EXPECT_EQ(or_(field_ref("a"), field_ref("b")).ToString(), "(a or b)"); + EXPECT_EQ(not_(field_ref("a")).ToString(), "invert(a)"); + EXPECT_EQ(cast(field_ref("a"), int32()).ToString(), "cast(a, to_type=int32)"); EXPECT_EQ(cast(field_ref("a"), nullptr).ToString(), "cast(a, to_type=)"); From c149640f6365b417222519d122bf67cbf8b7b44b Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Wed, 16 Dec 2020 14:22:42 -0500 Subject: [PATCH 15/38] clean up Expression class, extract Parameter --- cpp/src/arrow/dataset/expression.cc | 159 ++++++++++++-------- cpp/src/arrow/dataset/expression.h | 62 ++++---- cpp/src/arrow/dataset/expression_internal.h | 29 ++-- cpp/src/arrow/dataset/expression_test.cc | 36 ++--- cpp/src/arrow/dataset/partition_test.cc | 2 +- 5 files changed, 155 insertions(+), 133 deletions(-) diff --git a/cpp/src/arrow/dataset/expression.cc b/cpp/src/arrow/dataset/expression.cc index 693f4b67d58..4c5b693ffb9 100644 --- a/cpp/src/arrow/dataset/expression.cc +++ b/cpp/src/arrow/dataset/expression.cc @@ -39,14 +39,54 @@ using internal::checked_pointer_cast; namespace dataset { -const Expression::Call* Expression::call() const { - return util::get_if(impl_.get()); +Expression::Expression(Call call) : impl_(std::make_shared(std::move(call))) {} + +Expression::Expression(Datum literal) + : impl_(std::make_shared(std::move(literal))) {} + +Expression::Expression(Parameter parameter) + : impl_(std::make_shared(std::move(parameter))) {} + +Expression literal(Datum lit) { return Expression(std::move(lit)); } + +Expression field_ref(FieldRef ref) { + return Expression(Expression::Parameter{std::move(ref), {}}); +} + +Expression call(std::string function, std::vector arguments, + std::shared_ptr options) { + Expression::Call call; + call.function_name = std::move(function); + call.arguments = std::move(arguments); + call.options = std::move(options); + return Expression(std::move(call)); } const Datum* Expression::literal() const { return util::get_if(impl_.get()); } const FieldRef* Expression::field_ref() const { - return util::get_if(impl_.get()); + if (auto parameter = util::get_if(impl_.get())) { + return ¶meter->ref; + } + return nullptr; +} + +const Expression::Call* Expression::call() const { + return util::get_if(impl_.get()); +} + +ValueDescr Expression::descr() const { + if (impl_ == nullptr) return {}; + + if (auto lit = literal()) { + return lit->descr(); + } + + if (auto parameter = util::get_if(impl_.get())) { + return parameter->descr; + } + + return CallNotNull(*this)->descr; } std::string Expression::ToString() const { @@ -88,16 +128,14 @@ std::string Expression::ToString() const { call->arguments[1].ToString() + ")"; }; - if (auto cmp = Comparison::Get(call->function)) { + if (auto cmp = Comparison::Get(call->function_name)) { return binary(Comparison::GetOp(*cmp)); } - if (call->function == "and_kleene") { - return binary("and"); - } - - if (call->function == "or_kleene") { - return binary("or"); + constexpr util::string_view kleene = "_kleene"; + if (util::string_view{call->function_name}.ends_with(kleene)) { + auto op = call->function_name.substr(0, call->function_name.size() - kleene.size()); + return binary(std::move(op)); } if (auto options = GetStructOptions(*call)) { @@ -111,7 +149,7 @@ std::string Expression::ToString() const { return out; } - std::string out = call->function + "("; + std::string out = call->function_name + "("; for (const auto& arg : call->arguments) { out += arg.ToString() + ", "; } @@ -178,7 +216,8 @@ bool Expression::Equals(const Expression& other) const { auto call = CallNotNull(*this); auto other_call = CallNotNull(other); - if (call->function != other_call->function || call->kernel != other_call->kernel) { + if (call->function_name != other_call->function_name || + call->kernel != other_call->kernel) { return false; } @@ -223,7 +262,7 @@ bool Expression::Equals(const Expression& other) const { } ARROW_LOG(WARNING) << "comparing unknown FunctionOptions for function " - << call->function; + << call->function_name; return false; } @@ -241,7 +280,7 @@ size_t Expression::hash() const { auto call = CallNotNull(*this); - size_t out = std::hash{}(call->function); + size_t out = std::hash{}(call->function_name); for (const auto& arg : call->arguments) { out ^= arg.hash(); } @@ -249,7 +288,7 @@ size_t Expression::hash() const { } bool Expression::IsBound() const { - if (descr_.type == nullptr) return false; + if (descr().type == nullptr) return false; if (auto lit = literal()) return true; @@ -278,20 +317,20 @@ bool Expression::IsScalarExpression() const { if (!arg.IsScalarExpression()) return false; } - if (call->kernel == nullptr) { - // this expression is not bound; make a best guess based on - // the default function registry - if (auto function = compute::GetFunctionRegistry() - ->GetFunction(call->function) - .ValueOr(nullptr)) { - return function->kind() == compute::Function::SCALAR; - } + if (call->function) { + return call->function->kind() == compute::Function::SCALAR; + } - // unknown function or other error; conservatively return false - return false; + // this expression is not bound; make a best guess based on + // the default function registry + if (auto function = compute::GetFunctionRegistry() + ->GetFunction(call->function_name) + .ValueOr(nullptr)) { + return function->kind() == compute::Function::SCALAR; } - return call->function_kind == compute::Function::SCALAR; + // unknown function or other error; conservatively return false + return false; } bool Expression::IsNullLiteral() const { @@ -305,7 +344,7 @@ bool Expression::IsNullLiteral() const { } bool Expression::IsSatisfiable() const { - if (descr_.type && descr_.type->id() == Type::NA) { + if (descr().type && descr().type->id() == Type::NA) { return false; } @@ -369,9 +408,9 @@ Status MaybeInsertCast(std::shared_ptr to_type, Expression* expr) { auto call_with_cast = *CallNotNull(with_cast); call_with_cast.arguments[0] = std::move(*expr); - *expr = Expression(std::make_shared(std::move(call_with_cast)), - ValueDescr{std::move(to_type), expr->descr().shape}); + call_with_cast.descr = ValueDescr{std::move(to_type), expr->descr().shape}; + *expr = Expression(std::move(call_with_cast)); return Status::OK(); } @@ -379,7 +418,7 @@ Status InsertImplicitCasts(Expression::Call* call) { DCHECK(std::all_of(call->arguments.begin(), call->arguments.end(), [](const Expression& argument) { return argument.IsBound(); })); - if (IsSameTypesBinary(call->function)) { + if (IsSameTypesBinary(call->function_name)) { for (auto&& argument : call->arguments) { if (auto value_type = GetDictionaryValueType(argument.descr().type)) { RETURN_NOT_OK(MaybeInsertCast(std::move(value_type), &argument)); @@ -389,9 +428,10 @@ Status InsertImplicitCasts(Expression::Call* call) { if (call->arguments[0].descr().shape == ValueDescr::SCALAR) { // argument 0 is scalar so casting is cheap return MaybeInsertCast(call->arguments[1].descr().type, &call->arguments[0]); - } else { - return MaybeInsertCast(call->arguments[0].descr().type, &call->arguments[1]); } + + // cast argument 1 unconditionally + return MaybeInsertCast(call->arguments[0].descr().type, &call->arguments[1]); } if (auto options = GetSetLookupOptions(*call)) { @@ -435,15 +475,13 @@ Result Expression::Bind(ValueDescr in, if (auto ref = field_ref()) { ARROW_ASSIGN_OR_RAISE(auto field, ref->GetOneOrNone(*in.type)); - auto out = *this; - out.descr_ = field ? ValueDescr{field->type(), in.shape} : ValueDescr::Scalar(null()); - return out; + auto descr = field ? ValueDescr{field->type(), in.shape} : ValueDescr::Scalar(null()); + return Expression{Parameter{*ref, std::move(descr)}}; } auto bound_call = *CallNotNull(*this); - ARROW_ASSIGN_OR_RAISE(auto function, GetFunction(bound_call, exec_context)); - bound_call.function_kind = function->kind(); + ARROW_ASSIGN_OR_RAISE(bound_call.function, GetFunction(bound_call, exec_context)); for (auto&& argument : bound_call.arguments) { ARROW_ASSIGN_OR_RAISE(argument, argument.Bind(in, exec_context)); @@ -451,17 +489,18 @@ Result Expression::Bind(ValueDescr in, RETURN_NOT_OK(InsertImplicitCasts(&bound_call)); auto descrs = GetDescriptors(bound_call.arguments); - ARROW_ASSIGN_OR_RAISE(bound_call.kernel, function->DispatchExact(descrs)); + ARROW_ASSIGN_OR_RAISE(bound_call.kernel, bound_call.function->DispatchExact(descrs)); compute::KernelContext kernel_context(exec_context); ARROW_ASSIGN_OR_RAISE(bound_call.kernel_state, InitKernelState(bound_call, exec_context)); kernel_context.SetState(bound_call.kernel_state.get()); - ARROW_ASSIGN_OR_RAISE(auto descr, bound_call.kernel->signature->out_type().Resolve( - &kernel_context, descrs)); + ARROW_ASSIGN_OR_RAISE( + bound_call.descr, + bound_call.kernel->signature->out_type().Resolve(&kernel_context, descrs)); - return Expression(std::make_shared(std::move(bound_call)), std::move(descr)); + return Expression(std::move(bound_call)); } Result Expression::Bind(const Schema& in_schema, @@ -547,7 +586,7 @@ util::optional FoldLeft(It begin, It end, const BinOp& bin_op) { util::optional GetNullHandling( const Expression::Call& call) { - if (call.function_kind == compute::Function::SCALAR) { + if (call.function && call.function->kind() == compute::Function::SCALAR) { return static_cast(call.kernel)->null_handling; } return util::nullopt; @@ -619,7 +658,7 @@ Result FoldConstants(Expression expr) { } } - if (call->function == "and_kleene") { + if (call->function_name == "and_kleene") { for (auto args : ArgumentsAndFlippedArguments(*call)) { // true and x == x if (args.first == literal(true)) return args.second; @@ -633,7 +672,7 @@ Result FoldConstants(Expression expr) { return expr; } - if (call->function == "or_kleene") { + if (call->function_name == "or_kleene") { for (auto args : ArgumentsAndFlippedArguments(*call)) { // false or x == x if (args.first == literal(false)) return args.second; @@ -654,7 +693,7 @@ Result FoldConstants(Expression expr) { inline std::vector GuaranteeConjunctionMembers( const Expression& guaranteed_true_predicate) { auto guarantee = guaranteed_true_predicate.call(); - if (!guarantee || guarantee->function != "and_kleene") { + if (!guarantee || guarantee->function_name != "and_kleene") { return {guaranteed_true_predicate}; } return FlattenedAssociativeChain(guaranteed_true_predicate).fringe; @@ -672,7 +711,7 @@ Status ExtractKnownFieldValuesImpl( auto call = expr.call(); if (!call) return true; - if (call->function == "equal") { + if (call->function_name == "equal") { auto ref = call->arguments[0].field_ref(); auto lit = call->arguments[1].literal(); return !(ref && lit); @@ -741,7 +780,7 @@ inline bool IsBinaryAssociativeCommutative(const Expression::Call& call) { "and", "or", "and_kleene", "or_kleene", "xor", "multiply", "add", "multiply_checked", "add_checked"}; - auto it = binary_associative_commutative.find(call.function); + auto it = binary_associative_commutative.find(call.function_name); return it != binary_associative_commutative.end(); } @@ -798,37 +837,35 @@ Result Canonicalize(Expression expr, compute::ExecContext* exec_cont std::stable_sort(chain.fringe.begin(), chain.fringe.end(), CanonicalOrdering); // fold the chain back up - const auto& descr = expr.descr(); auto folded = FoldLeft(chain.fringe.begin(), chain.fringe.end(), - [call, &descr, &AlreadyCanonicalized](Expression l, Expression r) { - auto ret = *call; - ret.arguments = {std::move(l), std::move(r)}; - Expression expr( - std::make_shared(std::move(ret)), descr); + [call, &AlreadyCanonicalized](Expression l, Expression r) { + auto canonicalized_call = *call; + canonicalized_call.arguments = {std::move(l), std::move(r)}; + Expression expr(std::move(canonicalized_call)); AlreadyCanonicalized.Add({expr}); return expr; }); return std::move(*folded); } - if (auto cmp = Comparison::Get(call->function)) { + if (auto cmp = Comparison::Get(call->function_name)) { if (call->arguments[0].literal() && !call->arguments[1].literal()) { // ensure that literals are on comparisons' RHS auto flipped_call = *call; - flipped_call.function = Comparison::GetName(Comparison::GetFlipped(*cmp)); + flipped_call.function_name = + Comparison::GetName(Comparison::GetFlipped(*cmp)); // look up the flipped kernel // TODO extract a helper for use here and in Bind ARROW_ASSIGN_OR_RAISE( auto function, - exec_context->func_registry()->GetFunction(flipped_call.function)); + exec_context->func_registry()->GetFunction(flipped_call.function_name)); auto descrs = GetDescriptors(flipped_call.arguments); ARROW_ASSIGN_OR_RAISE(flipped_call.kernel, function->DispatchExact(descrs)); std::swap(flipped_call.arguments[0], flipped_call.arguments[1]); - return Expression(std::make_shared(std::move(flipped_call)), - expr.descr()); + return Expression(std::move(flipped_call)); } } @@ -847,7 +884,7 @@ Result DirectComparisonSimplification(Expression expr, // Ensure both calls are comparisons with equal LHS and scalar RHS auto cmp = Comparison::Get(expr); - auto cmp_guarantee = Comparison::Get(guarantee.function); + auto cmp_guarantee = Comparison::Get(guarantee.function_name); if (!cmp || !cmp_guarantee) return expr; if (call->arguments[0] != guarantee.arguments[0]) return expr; @@ -961,7 +998,7 @@ Result> Serialize(const Expression& expr) { } auto call = CallNotNull(expr); - metadata_->Append("call", call->function); + metadata_->Append("call", call->function_name); for (const auto& argument : call->arguments) { RETURN_NOT_OK(Visit(argument)); @@ -973,7 +1010,7 @@ Result> Serialize(const Expression& expr) { metadata_->Append("options", std::move(value)); } - metadata_->Append("end", call->function); + metadata_->Append("end", call->function_name); return Status::OK(); } diff --git a/cpp/src/arrow/dataset/expression.h b/cpp/src/arrow/dataset/expression.h index 0ab98327619..6137666abff 100644 --- a/cpp/src/arrow/dataset/expression.h +++ b/cpp/src/arrow/dataset/expression.h @@ -48,15 +48,15 @@ namespace dataset { class ARROW_DS_EXPORT Expression { public: struct Call { - std::string function; + std::string function_name; std::vector arguments; std::shared_ptr options; // post-Bind properties: const compute::Kernel* kernel = NULLPTR; - compute::Function::Kind - function_kind; // XXX give Kernel a non-owning pointer to its Function + std::shared_ptr function; std::shared_ptr kernel_state; + ValueDescr descr; }; std::string ToString() const; @@ -102,20 +102,23 @@ class ARROW_DS_EXPORT Expression { const Datum* literal() const; const FieldRef* field_ref() const; - const ValueDescr& descr() const { return descr_; } - - using Impl = util::Variant; + ValueDescr descr() const; + // XXX someday + // NullGeneralization::type nullable() const; - explicit Expression(std::shared_ptr impl, ValueDescr descr = {}) - : impl_(std::move(impl)), descr_(std::move(descr)) {} + struct Parameter { + FieldRef ref; + ValueDescr descr; + }; Expression() = default; + explicit Expression(Call call); + explicit Expression(Datum literal); + explicit Expression(Parameter parameter); private: + using Impl = util::Variant; std::shared_ptr impl_; - ValueDescr descr_; - // XXX someday - // NullGeneralization::type evaluates_to_null_; ARROW_EXPORT friend bool Identical(const Expression& l, const Expression& r); @@ -127,15 +130,21 @@ inline bool operator!=(const Expression& l, const Expression& r) { return !l.Equ // Factories -inline Expression call(std::string function, std::vector arguments, - std::shared_ptr options = NULLPTR) { - Expression::Call call; - call.function = std::move(function); - call.arguments = std::move(arguments); - call.options = std::move(options); - return Expression(std::make_shared(std::move(call))); +ARROW_DS_EXPORT +Expression literal(Datum lit); + +template +Expression literal(Arg&& arg) { + return literal(Datum(std::forward(arg))); } +ARROW_DS_EXPORT +Expression field_ref(FieldRef ref); + +ARROW_DS_EXPORT +Expression call(std::string function, std::vector arguments, + std::shared_ptr options = NULLPTR); + template ::value>::type> Expression call(std::string function, std::vector arguments, @@ -144,23 +153,6 @@ Expression call(std::string function, std::vector arguments, std::make_shared(std::move(options))); } -template -Expression field_ref(Args&&... args) { - return Expression( - std::make_shared(FieldRef(std::forward(args)...))); -} - -template -Expression literal(Arg&& arg) { - Datum lit(std::forward(arg)); - ValueDescr descr = lit.descr(); - return Expression(std::make_shared(std::move(lit)), std::move(descr)); -} - -inline Expression literal(std::shared_ptr scalar) { - return literal(Datum(std::move(scalar))); -} - ARROW_DS_EXPORT std::vector FieldsInExpression(const Expression&); diff --git a/cpp/src/arrow/dataset/expression_internal.h b/cpp/src/arrow/dataset/expression_internal.h index c91bbb4b62f..eba972dc18b 100644 --- a/cpp/src/arrow/dataset/expression_internal.h +++ b/cpp/src/arrow/dataset/expression_internal.h @@ -135,7 +135,7 @@ struct Comparison { static const type* Get(const Expression& expr) { if (auto call = expr.call()) { - return Comparison::Get(call->function); + return Comparison::Get(call->function_name); } return nullptr; } @@ -230,7 +230,7 @@ struct Comparison { }; inline const compute::CastOptions* GetCastOptions(const Expression::Call& call) { - if (call.function != "cast") return nullptr; + if (call.function_name != "cast") return nullptr; return checked_cast(call.options.get()); } @@ -248,17 +248,17 @@ inline bool IsSameTypesBinary(const std::string& function) { inline const compute::SetLookupOptions* GetSetLookupOptions( const Expression::Call& call) { - if (!IsSetLookup(call.function)) return nullptr; + if (!IsSetLookup(call.function_name)) return nullptr; return checked_cast(call.options.get()); } inline const compute::StructOptions* GetStructOptions(const Expression::Call& call) { - if (call.function != "struct") return nullptr; + if (call.function_name != "struct") return nullptr; return checked_cast(call.options.get()); } inline const compute::StrptimeOptions* GetStrptimeOptions(const Expression::Call& call) { - if (call.function != "strptime") return nullptr; + if (call.function_name != "strptime") return nullptr; return checked_cast(call.options.get()); } @@ -321,7 +321,7 @@ inline Result> FunctionOptionsToStructScalar( {"value_set", "skip_nulls"}); } - if (call.function == "cast") { + if (call.function_name == "cast") { auto options = checked_cast(call.options.get()); return Finish( { @@ -344,7 +344,7 @@ inline Result> FunctionOptionsToStructScalar( }); } - return Status::NotImplemented("conversion of options for ", call.function); + return Status::NotImplemented("conversion of options for ", call.function_name); } inline Status FunctionOptionsFromStructScalar(const StructScalar* repr, @@ -354,7 +354,7 @@ inline Status FunctionOptionsFromStructScalar(const StructScalar* repr, return Status::OK(); } - if (IsSetLookup(call->function)) { + if (IsSetLookup(call->function_name)) { ARROW_ASSIGN_OR_RAISE(auto value_set, repr->field("value_set")); ARROW_ASSIGN_OR_RAISE(auto skip_nulls, repr->field("skip_nulls")); call->options = std::make_shared( @@ -363,7 +363,7 @@ inline Status FunctionOptionsFromStructScalar(const StructScalar* repr, return Status::OK(); } - if (call->function == "cast") { + if (call->function_name == "cast") { auto options = std::make_shared(); ARROW_ASSIGN_OR_RAISE(auto to_type_holder, repr->field("to_type_holder")); options->to_type = to_type_holder->type; @@ -384,7 +384,7 @@ inline Status FunctionOptionsFromStructScalar(const StructScalar* repr, return Status::OK(); } - return Status::NotImplemented("conversion of options for ", call->function); + return Status::NotImplemented("conversion of options for ", call->function_name); } struct FlattenedAssociativeChain { @@ -399,7 +399,7 @@ struct FlattenedAssociativeChain { while (it != fringe.end()) { auto sub_call = it->call(); - if (!sub_call || sub_call->function != call->function) { + if (!sub_call || sub_call->function_name != call->function_name) { ++it; continue; } @@ -422,8 +422,8 @@ struct FlattenedAssociativeChain { inline Result> GetFunction( const Expression::Call& call, compute::ExecContext* exec_context) { - if (call.function != "cast") { - return exec_context->func_registry()->GetFunction(call.function); + if (call.function_name != "cast") { + return exec_context->func_registry()->GetFunction(call.function_name); } // XXX this special case is strange; why not make "cast" a ScalarFunction? const auto& to_type = checked_cast(*call.options).to_type; @@ -453,8 +453,7 @@ Result Modify(Expression expr, const PreVisit& pre, if (at_least_one_modified) { // reconstruct the call expression with the modified arguments - auto modified_expr = Expression( - std::make_shared(std::move(modified_call)), expr.descr()); + auto modified_expr = Expression(std::move(modified_call)); return post_call(std::move(modified_expr), &expr); } diff --git a/cpp/src/arrow/dataset/expression_test.cc b/cpp/src/arrow/dataset/expression_test.cc index 709e19791a8..3b286e378d0 100644 --- a/cpp/src/arrow/dataset/expression_test.cc +++ b/cpp/src/arrow/dataset/expression_test.cc @@ -282,8 +282,9 @@ TEST(Expression, BindFieldRef) { {field("alpha", int32()), field("alpha", float32())}))); // referencing nested fields is supported - ASSERT_OK_AND_ASSIGN(expr, field_ref("a", "b").Bind( - Schema({field("a", struct_({field("b", int32())}))}))); + ASSERT_OK_AND_ASSIGN(expr, + field_ref(FieldRef("a", "b")) + .Bind(Schema({field("a", struct_({field("b", int32())}))}))); EXPECT_TRUE(expr.IsBound()); EXPECT_EQ(expr.descr(), ValueDescr::Array(int32())); } @@ -476,6 +477,13 @@ TEST(Expression, ExecuteDictionaryTransparent) { ])")); } +void ExpectIdenticalIfUnchanged(Expression modified, Expression original) { + if (modified == original) { + // no change -> must be identical + EXPECT_TRUE(Identical(modified, original)) << " " << original.ToString(); + } +} + struct { void operator()(Expression expr, Expression expected) { ASSERT_OK_AND_ASSIGN(expr, expr.Bind(*kBoringSchema)); @@ -484,11 +492,7 @@ struct { ASSERT_OK_AND_ASSIGN(auto folded, FoldConstants(expr)); EXPECT_EQ(folded, expected); - - if (folded == expr) { - // no change -> must be identical - EXPECT_TRUE(Identical(folded, expr)); - } + ExpectIdenticalIfUnchanged(folded, expr); } } ExpectFoldsTo; @@ -629,11 +633,7 @@ TEST(Expression, ReplaceFieldsWithKnownValues) { ReplaceFieldsWithKnownValues(known_values, expr)); EXPECT_EQ(replaced, expected); - - if (replaced == expr) { - // no change -> must be identical - EXPECT_TRUE(Identical(replaced, expr)); - } + ExpectIdenticalIfUnchanged(replaced, expr); }; std::unordered_map i32_is_3{{"i32", Datum(3)}}; @@ -677,18 +677,14 @@ struct { ASSERT_OK_AND_ASSIGN(auto actual, Canonicalize(bound)); EXPECT_EQ(actual, expected); - - if (actual == expr) { - // no change -> must be identical - EXPECT_TRUE(Identical(actual, expr)); - } + ExpectIdenticalIfUnchanged(actual, bound); } } ExpectCanonicalizesTo; TEST(Expression, CanonicalizeTrivial) { ExpectCanonicalizesTo(literal(1), literal(1)); - ExpectCanonicalizesTo(field_ref("b"), field_ref("b")); + ExpectCanonicalizesTo(field_ref("i32"), field_ref("i32")); ExpectCanonicalizesTo(equal(field_ref("i32"), field_ref("i32_req")), equal(field_ref("i32"), field_ref("i32_req"))); @@ -752,9 +748,7 @@ struct Simplify { << " guarantee: " << guarantee.ToString() << "\n" << (simplified == bound ? " (no change)\n" : ""); - if (simplified == bound) { - EXPECT_TRUE(Identical(simplified, bound)); - } + ExpectIdenticalIfUnchanged(simplified, bound); } void ExpectUnchanged() { Expect(expr); } void Expect(bool constant) { Expect(literal(constant)); } diff --git a/cpp/src/arrow/dataset/partition_test.cc b/cpp/src/arrow/dataset/partition_test.cc index 840e2f0303f..8610ff1e891 100644 --- a/cpp/src/arrow/dataset/partition_test.cc +++ b/cpp/src/arrow/dataset/partition_test.cc @@ -429,7 +429,7 @@ TEST_F(TestPartitioning, Set) { set.push_back(checked_cast(*s).value); } - subexpressions.push_back(call("is_in", {field_ref(matches[1])}, + subexpressions.push_back(call("is_in", {field_ref(std::string(matches[1]))}, compute::SetLookupOptions{ints(set)})); } return and_(std::move(subexpressions)); From cc1fd25a099598ae798f0c5812a7e7a083e6cbbb Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Wed, 16 Dec 2020 14:33:02 -0500 Subject: [PATCH 16/38] revert variant noexcept changes --- cpp/src/arrow/datum.h | 4 ++-- cpp/src/arrow/util/variant.h | 32 ++++++++++++------------------ cpp/src/arrow/util/variant_test.cc | 13 ------------ 3 files changed, 15 insertions(+), 34 deletions(-) diff --git a/cpp/src/arrow/datum.h b/cpp/src/arrow/datum.h index 3d1f545e5f7..8479f7ad366 100644 --- a/cpp/src/arrow/datum.h +++ b/cpp/src/arrow/datum.h @@ -120,8 +120,8 @@ struct ARROW_EXPORT Datum { Datum(const Datum& other) = default; Datum& operator=(const Datum& other) = default; - Datum(Datum&& other) noexcept = default; - Datum& operator=(Datum&& other) noexcept = default; + Datum(Datum&& other) = default; + Datum& operator=(Datum&& other) = default; Datum(std::shared_ptr value) // NOLINT implicit conversion : value(std::move(value)) {} diff --git a/cpp/src/arrow/util/variant.h b/cpp/src/arrow/util/variant.h index 04ac5e2bb45..89f39ab8917 100644 --- a/cpp/src/arrow/util/variant.h +++ b/cpp/src/arrow/util/variant.h @@ -96,7 +96,7 @@ template struct all : conditional_t, std::false_type> {}; struct delete_copy_constructor { - template + template struct type { type() = default; type(const type& other) = delete; @@ -105,13 +105,11 @@ struct delete_copy_constructor { }; struct explicit_copy_constructor { - template + template struct type { type() = default; - type(const type& other) noexcept(NoexceptCopyable) { - static_cast(other).copy_to(this); - } - type& operator=(const type& other) noexcept(NoexceptCopyable) { + type(const type& other) { static_cast(other).copy_to(this); } + type& operator=(const type& other) { static_cast(this)->destroy(); static_cast(other).copy_to(this); return *this; @@ -119,18 +117,11 @@ struct explicit_copy_constructor { }; }; -template -using CopyConstructionImpl = - typename conditional_t::value && - std::is_copy_assignable::value)...>::value, - explicit_copy_constructor, delete_copy_constructor>:: - template type::value...>::value>; - template struct VariantStorage { VariantStorage() = default; - VariantStorage(const VariantStorage&) noexcept {} - VariantStorage& operator=(const VariantStorage&) noexcept { return *this; } + VariantStorage(const VariantStorage&) {} + VariantStorage& operator=(const VariantStorage&) { return *this; } VariantStorage(VariantStorage&&) noexcept {} VariantStorage& operator=(VariantStorage&&) noexcept { return *this; } ~VariantStorage() { @@ -201,8 +192,7 @@ struct VariantImpl, H, T...> : VariantImpl, T...> { // Templated to avoid instantiation in case H is not copy constructible template - void copy_to(Void* generic_target) const - noexcept(noexcept(H(std::declval()))) { + void copy_to(Void* generic_target) const { const auto target = static_cast(generic_target); try { if (this->index_ == kIndex) { @@ -253,7 +243,11 @@ struct VariantImpl, H, T...> : VariantImpl, T...> { template class Variant : detail::VariantImpl, T...>, - detail::CopyConstructionImpl, T...> { + detail::conditional_t< + detail::all<(std::is_copy_constructible::value && + std::is_copy_assignable::value)...>::value, + detail::explicit_copy_constructor, + detail::delete_copy_constructor>::template type> { template static constexpr uint8_t index_of() { return Impl::index_of(detail::type_constant{}); @@ -344,7 +338,7 @@ class Variant : detail::VariantImpl, T...>, this->index_ = 0; } - template + template friend struct detail::explicit_copy_constructor::type; template diff --git a/cpp/src/arrow/util/variant_test.cc b/cpp/src/arrow/util/variant_test.cc index 5d83e00185f..9e36f2eb9cf 100644 --- a/cpp/src/arrow/util/variant_test.cc +++ b/cpp/src/arrow/util/variant_test.cc @@ -125,19 +125,6 @@ TEST(Variant, CopyConstruction) { EXPECT_NO_THROW(AssertCopyConstruction(CopyAssignThrows{})); } -TEST(Variant, Noexcept) { - struct CopyThrows { - CopyThrows() = default; - CopyThrows(const CopyThrows&) { throw 42; } - CopyThrows& operator=(const CopyThrows&) { throw 42; } - }; - static_assert(!std::is_nothrow_copy_constructible::value, ""); - static_assert(std::is_nothrow_copy_constructible::value, ""); - static_assert(std::is_nothrow_copy_constructible>::value, ""); - static_assert( - !std::is_nothrow_copy_constructible>::value, ""); -} - TEST(Variant, Emplace) { using variant_type = Variant, int>; variant_type v; From 6f2cc2cf91e1ca312711aad7a08ca914882655b6 Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Wed, 16 Dec 2020 21:26:11 -0500 Subject: [PATCH 17/38] get python binding building --- python/pyarrow/_dataset.pyx | 149 ++++++++----------- python/pyarrow/compute.py | 2 +- python/pyarrow/includes/libarrow.pxd | 4 + python/pyarrow/includes/libarrow_dataset.pxd | 17 ++- python/pyarrow/tests/test_dataset.py | 13 -- 5 files changed, 77 insertions(+), 108 deletions(-) diff --git a/python/pyarrow/_dataset.pyx b/python/pyarrow/_dataset.pyx index e235d50c74a..7b96e87633b 100644 --- a/python/pyarrow/_dataset.pyx +++ b/python/pyarrow/_dataset.pyx @@ -116,111 +116,79 @@ cdef class Expression(_Weakrefable): @staticmethod def _deserialize(Buffer buffer not None): - c_buffer = pyarrow_unwrap_buffer(buffer) - #c_expr = GetResultValue(CExpression.Deserialize(deref(c_buffer))) - return Expression.wrap(move(c_expr)) + return Expression.wrap(GetResultValue(CDeserializeExpression( + deref(pyarrow_unwrap_buffer(buffer))))) def __reduce__(self): - buffer = pyarrow_wrap_buffer(GetResultValue(self.expr.Serialize())) + buffer = pyarrow_wrap_buffer(GetResultValue( + CSerializeExpression(self.expr))) return Expression._deserialize, (buffer,) - def validate(self, Schema schema not None): - """Validate this expression for execution against a schema. - - This will check that all reference fields are present (fields not in - the schema will be replaced with null) and all subexpressions are - executable. Returns the type to which this expression will evaluate. - - Parameters - ---------- - schema : Schema - Schema to execute the expression on. + @staticmethod + cdef Expression _expr_or_scalar(object expr): + if isinstance(expr, Expression): + return ( expr) + return ( Expression._scalar(expr)) - Returns - ------- - type : DataType - """ + @staticmethod + cdef Expression _call(str function_name, list arguments, + shared_ptr[CFunctionOptions] options=( + nullptr)): cdef: - shared_ptr[CSchema] sp_schema - CResult[shared_ptr[CDataType]] result - sp_schema = pyarrow_unwrap_schema(schema) - result = self.expr.Validate(deref(sp_schema)) - return pyarrow_wrap_data_type(GetResultValue(result)) + vector[CExpression] c_arguments - def assume(self, Expression given): - """Simplify to an equivalent Expression given assumed constraints.""" - return Expression.wrap(self.expr.Assume(given.unwrap())) + for argument in arguments: + c_arguments.push_back(( argument).expr) - def __invert__(self): - return Expression.wrap(CMakeNotExpression(self.unwrap())) - - @staticmethod - cdef CExpression _expr_or_scalar(object expr) except *: - if isinstance(expr, Expression): - return ( expr).unwrap() - return ( Expression._scalar(expr)).unwrap() + return Expression.wrap(CMakeCallExpression(tobytes(function_name), + move(c_arguments), options)) def __richcmp__(self, other, int op): - cdef: - CExpression c_expr - CExpression c_left - CExpression c_right - - c_left = self.unwrap() - c_right = Expression._expr_or_scalar(other) - - if op == Py_EQ: - c_expr = CMakeEqualExpression(move(c_left), move(c_right)) - elif op == Py_NE: - c_expr = CMakeNotEqualExpression(move(c_left), move(c_right)) - elif op == Py_GT: - c_expr = CMakeGreaterExpression(move(c_left), move(c_right)) - elif op == Py_GE: - c_expr = CMakeGreaterEqualExpression(move(c_left), move(c_right)) - elif op == Py_LT: - c_expr = CMakeLessExpression(move(c_left), move(c_right)) - elif op == Py_LE: - c_expr = CMakeLessEqualExpression(move(c_left), move(c_right)) - - return Expression.wrap(c_expr) + other = Expression._expr_or_scalar(other) + return Expression._call({ + Py_EQ: "equal", + Py_NE: "not_equal", + Py_GT: "greater", + Py_GE: "greater_equal", + Py_LT: "less", + Py_LE: "less_equal", + }[op], [self, other]) + + def __invert__(self): + return Expression._call("invert", [self]) def __and__(Expression self, other): - c_other = Expression._expr_or_scalar(other) - return Expression.wrap(CMakeAndExpression(self.wrapped, - move(c_other))) + other = Expression._expr_or_scalar(other) + return Expression._call("and_kleene", [self, other]) def __or__(Expression self, other): - c_other = Expression._expr_or_scalar(other) - return Expression.wrap(CMakeOrExpression(self.wrapped, - move(c_other))) + other = Expression._expr_or_scalar(other) + return Expression._call("or_kleene", [self, other]) def is_valid(self): """Checks whether the expression is not-null (valid)""" - return Expression.wrap(self.expr.IsValid().Copy()) + return Expression._call("is_valid", [self]) def cast(self, type, bint safe=True): """Explicitly change the expression's data type""" - cdef CCastOptions options - if safe: - options = CCastOptions.Safe() - else: - options = CCastOptions.Unsafe() - c_type = pyarrow_unwrap_data_type(ensure_type(type)) - return Expression.wrap(self.expr.CastTo(c_type, options).Copy()) + cdef shared_ptr[CCastOptions] c_options + c_options.reset(new CCastOptions(safe)) + c_options.get().to_type = pyarrow_unwrap_data_type(ensure_type(type)) + return Expression._call("cast", [self], + c_options) def isin(self, values): """Checks whether the expression is contained in values""" cdef: - vector[CExpression] arguments - arguments.push_back(self.expr) + shared_ptr[CFunctionOptions] c_options + CDatum c_values if not isinstance(values, pa.Array): values = pa.array(values) - c_values = pyarrow_unwrap_array(values) - return Expression.wrap(CMakeCallExpression( - tobytes("is_in"), - move(arguments), - move(SetLookupOptions(values).set_lookup_options))) + + c_values = CDatum(pyarrow_unwrap_array(values)) + c_options.reset(new CSetLookupOptions(c_values, True)) + return Expression._call("is_in", [self], c_options) @staticmethod def _field(str name not None): @@ -322,10 +290,11 @@ cdef class Dataset(_Weakrefable): CFragmentIterator c_iterator if filter is None: - c_fragments = self.dataset.GetFragments() + c_fragments = move(GetResultValue(self.dataset.GetFragments())) else: c_filter = _bind(filter, self.schema) - c_fragments = self.dataset.GetFragments(c_filter) + c_fragments = move(GetResultValue( + self.dataset.GetFragments(c_filter))) for maybe_fragment in c_fragments: yield Fragment.wrap(GetResultValue(move(maybe_fragment))) @@ -596,7 +565,7 @@ cdef CExpression _bind(Expression filter, Schema schema) except *: return _true.unwrap() return GetResultValue(filter.unwrap().Bind( - deref(pyarrow_unwrap_schema(schema).get()))) + deref(pyarrow_unwrap_schema(schema).get()))) cdef class FileWriteOptions(_Weakrefable): @@ -2264,7 +2233,8 @@ cdef class Scanner(_Weakrefable): def get_fragments(self): """Returns an iterator over the fragments in this scan. """ - cdef CFragmentIterator c_fragments = self.scanner.GetFragments() + cdef CFragmentIterator c_fragments = move(GetResultValue( + self.scanner.GetFragments())) for maybe_fragment in c_fragments: yield Fragment.wrap(GetResultValue(move(maybe_fragment))) @@ -2284,12 +2254,15 @@ def _get_partition_keys(Expression partition_expression): """ cdef: CExpression expr = partition_expression.unwrap() - pair[c_string, shared_ptr[CScalar]] name_val - - return { - frombytes(name_val.first): pyarrow_wrap_scalar(name_val.second).as_py() - for name_val in GetResultValue(CGetPartitionKeys(expr)) - } + pair[CFieldRef, CDatum] name_val + + out = {} + for ref_val in GetResultValue(CExtractKnownFieldValues(expr)): + assert ref_val.first.name() != nullptr + assert ref_val.second.kind() == DatumType_SCALAR + val = pyarrow_wrap_scalar(name_val.second.scalar()).as_py() + out[deref(ref_val.first.name())] = val + return out def _filesystemdataset_write( diff --git a/python/pyarrow/compute.py b/python/pyarrow/compute.py index ddfd8057db2..ddd77174d13 100644 --- a/python/pyarrow/compute.py +++ b/python/pyarrow/compute.py @@ -62,7 +62,7 @@ def _get_arg_names(func): arg_names = ["left", "right"] else: raise NotImplementedError( - "unsupported arity: {}".format(func.arity)) + f"unsupported arity: {func.arity} (function: {func.name})") return arg_names diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index 45a39061919..aaee9fb5351 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -405,6 +405,10 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil: CFieldRef() CFieldRef(c_string name) CFieldRef(int index) + const c_string* name() const + + cdef cppclass CFieldRefHash" arrow::FieldRef::Hash": + pass cdef cppclass CStructType" arrow::StructType"(CDataType): CStructType(const vector[shared_ptr[CField]]& fields) diff --git a/python/pyarrow/includes/libarrow_dataset.pxd b/python/pyarrow/includes/libarrow_dataset.pxd index 53e48da1703..98e6c20bf23 100644 --- a/python/pyarrow/includes/libarrow_dataset.pxd +++ b/python/pyarrow/includes/libarrow_dataset.pxd @@ -50,6 +50,11 @@ cdef extern from "arrow/dataset/api.h" namespace "arrow::dataset" nogil: vector[CExpression] arguments, shared_ptr[CFunctionOptions] options) + cdef CResult[shared_ptr[CBuffer]] CSerializeExpression \ + "arrow::dataset::Serialize"(const CExpression&) + cdef CResult[CExpression] CDeserializeExpression \ + "arrow::dataset::Deserialize"(const CBuffer&) + cdef cppclass CRecordBatchProjector "arrow::dataset::RecordBatchProjector": pass @@ -95,7 +100,7 @@ cdef extern from "arrow/dataset/api.h" namespace "arrow::dataset" nogil: shared_ptr[CScanContext]) CResult[CScanTaskIterator] Scan() CResult[shared_ptr[CTable]] ToTable() - CFragmentIterator GetFragments() + CResult[CFragmentIterator] GetFragments() const shared_ptr[CScanOptions]& options() cdef cppclass CScannerBuilder "arrow::dataset::ScannerBuilder": @@ -115,8 +120,8 @@ cdef extern from "arrow/dataset/api.h" namespace "arrow::dataset" nogil: cdef cppclass CDataset "arrow::dataset::Dataset": const shared_ptr[CSchema] & schema() - CFragmentIterator GetFragments() - CFragmentIterator GetFragments(CExpression predicate) + CResult[CFragmentIterator] GetFragments() + CResult[CFragmentIterator] GetFragments(CExpression predicate) const CExpression & partition_expression() c_string type_name() @@ -301,9 +306,9 @@ cdef extern from "arrow/dataset/api.h" namespace "arrow::dataset" nogil: const CExpression& partition_expression, CRecordBatchProjector* projector) - cdef CResult[unordered_map[c_string, shared_ptr[CScalar]]] \ - CGetPartitionKeys "arrow::dataset::KeyValuePartitioning::GetKeys"( - const CExpression& partition_expression) + cdef CResult[unordered_map[CFieldRef, CDatum, CFieldRefHash]] \ + CExtractKnownFieldValues "arrow::dataset::ExtractKnownFieldValues"( + const CExpression& partition_expression) cdef cppclass CFileSystemFactoryOptions \ "arrow::dataset::FileSystemFactoryOptions": diff --git a/python/pyarrow/tests/test_dataset.py b/python/pyarrow/tests/test_dataset.py index 3bda5692128..a783c3b4273 100644 --- a/python/pyarrow/tests/test_dataset.py +++ b/python/pyarrow/tests/test_dataset.py @@ -408,19 +408,6 @@ def test_expression_serialization(): f = ds.scalar({'a': 1}) g = ds.scalar(pa.scalar(1)) - condition = ds.field('i64') > 5 - schema = pa.schema([ - pa.field('i64', pa.int64()), - pa.field('f64', pa.float64()) - ]) - assert condition.validate(schema) == pa.bool_() - - assert condition.assume(ds.field('i64') == 5).equals( - ds.scalar(False)) - - assert condition.assume(ds.field('i64') == 7).equals( - ds.scalar(True)) - all_exprs = [a, b, c, d, e, f, g, a == b, a > b, a & b, a | b, ~c, d.is_valid(), a.cast(pa.int32(), safe=False), a.cast(pa.int32(), safe=False), a.isin([1, 2, 3]), From 44017b9a489b1c6776c6a586a5825b7b79c08324 Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Thu, 17 Dec 2020 13:00:57 -0500 Subject: [PATCH 18/38] fix doc generation for struct function --- cpp/src/arrow/compute/cast.cc | 2 +- cpp/src/arrow/compute/function.cc | 11 ++++++++--- python/pyarrow/compute.py | 10 ++++++---- python/pyarrow/tests/test_compute.py | 11 ++++++++--- 4 files changed, 23 insertions(+), 11 deletions(-) diff --git a/cpp/src/arrow/compute/cast.cc b/cpp/src/arrow/compute/cast.cc index b55dfb38c15..5c332aedf73 100644 --- a/cpp/src/arrow/compute/cast.cc +++ b/cpp/src/arrow/compute/cast.cc @@ -121,7 +121,7 @@ class CastMetaFunction : public MetaFunction { const FunctionDoc struct_doc{"Wrap Arrays into a StructArray", ("Names of the StructArray's fields are\n" "specified through StructOptions."), - {}, + {"*args"}, "StructOptions"}; Result StructResolve(KernelContext* ctx, diff --git a/cpp/src/arrow/compute/function.cc b/cpp/src/arrow/compute/function.cc index ab6766b1cdf..7a868853db7 100644 --- a/cpp/src/arrow/compute/function.cc +++ b/cpp/src/arrow/compute/function.cc @@ -152,10 +152,15 @@ Result Function::Execute(const std::vector& args, Status Function::Validate() const { if (!doc_->summary.empty()) { // Documentation given, check its contents - if (static_cast(doc_->arg_names.size()) != arity_.num_args) { - return Status::Invalid("In function '", name_, - "': ", "number of argument names != function arity"); + int arg_count = static_cast(doc_->arg_names.size()); + if (arg_count == arity_.num_args) { + return Status::OK(); } + if (arity_.is_varargs && arg_count == arity_.num_args + 1) { + return Status::OK(); + } + return Status::Invalid("In function '", name_, + "': ", "number of argument names != function arity"); } return Status::OK(); } diff --git a/python/pyarrow/compute.py b/python/pyarrow/compute.py index ddd77174d13..8650a38345b 100644 --- a/python/pyarrow/compute.py +++ b/python/pyarrow/compute.py @@ -116,7 +116,7 @@ def _decorate_compute_function(wrapper, exposed_name, func, option_class): doc_pieces.append("""\ options : pyarrow.compute.{0}, optional Parameters altering compute function semantics - **kwargs: optional + **kwargs : optional Parameters for {0} constructor. Either `options` or `**kwargs` can be passed, but not both at the same time. """.format(option_class.__name__)) @@ -160,14 +160,14 @@ def _handle_options(name, option_class, options, kwargs): _wrapper_template = dedent("""\ def make_wrapper(func, option_class): - def {func_name}({args_sig}, *, memory_pool=None): + def {func_name}({args_sig}{kwonly}, memory_pool=None): return func.call([{args_sig}], None, memory_pool) return {func_name} """) _wrapper_options_template = dedent("""\ def make_wrapper(func, option_class): - def {func_name}({args_sig}, *, options=None, memory_pool=None, + def {func_name}({args_sig}{kwonly}, options=None, memory_pool=None, **kwargs): options = _handle_options({func_name!r}, option_class, options, kwargs) @@ -180,6 +180,7 @@ def _wrap_function(name, func): option_class = _get_options_class(func) arg_names = _get_arg_names(func) args_sig = ', '.join(arg_names) + kwonly = '' if arg_names[-1].startswith('*') else ', *' # Generate templated wrapper, so that the signature matches # the documented argument names. @@ -188,7 +189,8 @@ def _wrap_function(name, func): template = _wrapper_options_template else: template = _wrapper_template - exec(template.format(func_name=name, args_sig=args_sig), globals(), ns) + exec(template.format(func_name=name, args_sig=args_sig, kwonly=kwonly), + globals(), ns) wrapper = ns['make_wrapper'](func, option_class) return _decorate_compute_function(wrapper, name, func, option_class) diff --git a/python/pyarrow/tests/test_compute.py b/python/pyarrow/tests/test_compute.py index 30555a6ebb5..57cfe2d25ea 100644 --- a/python/pyarrow/tests/test_compute.py +++ b/python/pyarrow/tests/test_compute.py @@ -83,7 +83,11 @@ def test_exported_functions(): functions = exported_functions assert len(functions) >= 10 for func in functions: - args = [object()] * func.__arrow_compute_function__['arity'] + arity = func.__arrow_compute_function__['arity'] + if arity is Ellipsis: + args = [object()] * 3 + else: + args = [object()] * arity with pytest.raises(TypeError, match="Got unexpected argument type " " for compute function"): @@ -172,7 +176,8 @@ def test_function_attributes(): kernels = func.kernels assert func.num_kernels == len(kernels) assert all(isinstance(ker, pc.Kernel) for ker in kernels) - assert func.arity >= 1 # no varargs functions for now + if func.arity is not Ellipsis: + assert func.arity >= 1 repr(func) for ker in kernels: repr(ker) @@ -402,7 +407,7 @@ def test_generated_docstrings(): If not passed, will allocate memory from the default memory pool. options : pyarrow.compute.MinMaxOptions, optional Parameters altering compute function semantics - **kwargs: optional + **kwargs : optional Parameters for MinMaxOptions constructor. Either `options` or `**kwargs` can be passed, but not both at the same time. """) From 13aca8170bf4767680e9e7a4288dc0a32615a66e Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Thu, 17 Dec 2020 21:37:55 -0500 Subject: [PATCH 19/38] review comments --- .../compute/kernels/scalar_cast_internal.cc | 17 +++---- .../compute/kernels/scalar_cast_temporal.cc | 6 ++- .../arrow/compute/kernels/util_internal.cc | 13 +++-- cpp/src/arrow/compute/type_fwd.h | 6 +++ cpp/src/arrow/dataset/expression.cc | 8 +++ cpp/src/arrow/dataset/expression.h | 9 ++-- cpp/src/arrow/dataset/expression_internal.h | 29 ++++++----- cpp/src/arrow/dataset/expression_test.cc | 4 +- cpp/src/arrow/dataset/partition.cc | 1 + cpp/src/arrow/dataset/partition_test.cc | 1 + cpp/src/arrow/dataset/scanner_test.cc | 36 +++++-------- cpp/src/arrow/datum.cc | 18 +++++-- cpp/src/arrow/datum.h | 5 ++ cpp/src/arrow/type.cc | 51 ------------------- cpp/src/arrow/type.h | 5 -- 15 files changed, 87 insertions(+), 122 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_internal.cc b/cpp/src/arrow/compute/kernels/scalar_cast_internal.cc index 7abf1c618da..f8dde20e3aa 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_internal.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_internal.cc @@ -149,16 +149,10 @@ void CastNumberToNumberUnsafe(Type::type in_type, Type::type out_type, const Dat // ---------------------------------------------------------------------- void UnpackDictionary(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - auto Finish = [&](Result result) { - if (!result.ok()) { - ctx->SetStatus(result.status()); - return; - } - *out = *result; - }; - if (out->is_scalar()) { - return Finish(batch[0].scalar_as().GetEncodedValue()); + KERNEL_ASSIGN_OR_RAISE(*out, ctx, + batch[0].scalar_as().GetEncodedValue()); + return; } DictionaryArray dict_arr(batch[0].array()); @@ -172,8 +166,9 @@ void UnpackDictionary(KernelContext* ctx, const ExecBatch& batch, Datum* out) { return; } - return Finish(Take(Datum(dict_arr.dictionary()), Datum(dict_arr.indices()), - /*options=*/TakeOptions::Defaults(), ctx->exec_context())); + KERNEL_ASSIGN_OR_RAISE(*out, ctx, + Take(Datum(dict_arr.dictionary()), Datum(dict_arr.indices()), + TakeOptions::Defaults(), ctx->exec_context())); } void OutputAllNull(KernelContext* ctx, const ExecBatch& batch, Datum* out) { diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_temporal.cc b/cpp/src/arrow/compute/kernels/scalar_cast_temporal.cc index f1702b7fd51..99c1401d1b8 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_temporal.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_temporal.cc @@ -284,9 +284,11 @@ struct CastFunctor { DCHECK_EQ(batch[0].kind(), Datum::ARRAY); const auto& out_type = checked_cast(*out->type()); + + // date64 is ms since epoch auto conversion = util::GetTimestampConversion(TimeUnit::MILLI, out_type.unit()); - ShiftTime(ctx, util::MULTIPLY, conversion.second, *batch[0].array(), - out->mutable_array()); + ShiftTime(ctx, conversion.first, conversion.second, + *batch[0].array(), out->mutable_array()); } }; diff --git a/cpp/src/arrow/compute/kernels/util_internal.cc b/cpp/src/arrow/compute/kernels/util_internal.cc index 21b6b706a50..3d21f5b1494 100644 --- a/cpp/src/arrow/compute/kernels/util_internal.cc +++ b/cpp/src/arrow/compute/kernels/util_internal.cc @@ -68,15 +68,14 @@ ArrayKernelExec TrivialScalarUnaryAsArraysExec(ArrayKernelExec exec) { return; } - Datum array_in, array_out; - KERNEL_RETURN_IF_ERROR( - ctx, MakeArrayFromScalar(*batch[0].scalar(), 1).As().Value(&array_in)); - KERNEL_RETURN_IF_ERROR( - ctx, MakeArrayFromScalar(*out->scalar(), 1).As().Value(&array_out)); + KERNEL_ASSIGN_OR_RAISE(Datum array_in, ctx, + MakeArrayFromScalar(*batch[0].scalar(), 1)); + + KERNEL_ASSIGN_OR_RAISE(Datum array_out, ctx, MakeArrayFromScalar(*out->scalar(), 1)); exec(ctx, ExecBatch{{std::move(array_in)}, 1}, &array_out); - KERNEL_RETURN_IF_ERROR(ctx, - array_out.make_array()->GetScalar(0).As().Value(out)); + + KERNEL_ASSIGN_OR_RAISE(*out, ctx, array_out.make_array()->GetScalar(0)); }; } diff --git a/cpp/src/arrow/compute/type_fwd.h b/cpp/src/arrow/compute/type_fwd.h index 4b05ceccb5a..9888e610aa7 100644 --- a/cpp/src/arrow/compute/type_fwd.h +++ b/cpp/src/arrow/compute/type_fwd.h @@ -20,9 +20,13 @@ namespace arrow { struct Datum; +struct ValueDescr; namespace compute { +class Function; +struct FunctionOptions; + struct CastOptions; class ExecContext; @@ -33,5 +37,7 @@ struct ScalarKernel; struct ScalarAggregateKernel; struct VectorKernel; +struct KernelState; + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/dataset/expression.cc b/cpp/src/arrow/dataset/expression.cc index 4c5b693ffb9..637010cfed2 100644 --- a/cpp/src/arrow/dataset/expression.cc +++ b/cpp/src/arrow/dataset/expression.cc @@ -26,6 +26,7 @@ #include "arrow/io/memory.h" #include "arrow/ipc/reader.h" #include "arrow/ipc/writer.h" +#include "arrow/util/atomic_shared_ptr.h" #include "arrow/util/key_value_metadata.h" #include "arrow/util/logging.h" #include "arrow/util/optional.h" @@ -279,11 +280,18 @@ size_t Expression::hash() const { } auto call = CallNotNull(*this); + if (call->hash != nullptr) { + return call->hash->load(); + } size_t out = std::hash{}(call->function_name); for (const auto& arg : call->arguments) { out ^= arg.hash(); } + + std::shared_ptr> expected = nullptr; + internal::atomic_compare_exchange_strong(&const_cast(call)->hash, &expected, + std::make_shared>(out)); return out; } diff --git a/cpp/src/arrow/dataset/expression.h b/cpp/src/arrow/dataset/expression.h index 6137666abff..f625ca694e1 100644 --- a/cpp/src/arrow/dataset/expression.h +++ b/cpp/src/arrow/dataset/expression.h @@ -19,21 +19,17 @@ #pragma once -#include +#include #include #include #include #include #include -#include "arrow/chunked_array.h" -#include "arrow/compute/api_scalar.h" -#include "arrow/compute/cast.h" +#include "arrow/compute/type_fwd.h" #include "arrow/dataset/type_fwd.h" #include "arrow/dataset/visibility.h" #include "arrow/datum.h" -#include "arrow/result.h" -#include "arrow/scalar.h" #include "arrow/type_fwd.h" #include "arrow/util/variant.h" @@ -51,6 +47,7 @@ class ARROW_DS_EXPORT Expression { std::string function_name; std::vector arguments; std::shared_ptr options; + std::shared_ptr> hash; // post-Bind properties: const compute::Kernel* kernel = NULLPTR; diff --git a/cpp/src/arrow/dataset/expression_internal.h b/cpp/src/arrow/dataset/expression_internal.h index eba972dc18b..77e555aa7a6 100644 --- a/cpp/src/arrow/dataset/expression_internal.h +++ b/cpp/src/arrow/dataset/expression_internal.h @@ -21,7 +21,8 @@ #include #include -#include "arrow/compute/api_vector.h" +#include "arrow/compute/api_scalar.h" +#include "arrow/compute/cast.h" #include "arrow/compute/registry.h" #include "arrow/record_batch.h" #include "arrow/table.h" @@ -92,11 +93,11 @@ inline Result GetDatumField(const FieldRef& ref, const Datum& input) { FieldPath path; if (auto type = input.type()) { - ARROW_ASSIGN_OR_RAISE(path, ref.FindOneOrNone(*input.type())); - } else if (input.kind() == Datum::RECORD_BATCH) { - ARROW_ASSIGN_OR_RAISE(path, ref.FindOneOrNone(*input.record_batch()->schema())); - } else if (input.kind() == Datum::TABLE) { - ARROW_ASSIGN_OR_RAISE(path, ref.FindOneOrNone(*input.table()->schema())); + ARROW_ASSIGN_OR_RAISE(path, ref.FindOneOrNone(*type)); + } else if (auto schema = input.schema()) { + ARROW_ASSIGN_OR_RAISE(path, ref.FindOneOrNone(*schema)); + } else { + return Status::NotImplemented("retrieving fields from datum ", input.ToString()); } if (path) { @@ -123,14 +124,14 @@ struct Comparison { }; static const type* Get(const std::string& function) { - static std::unordered_map flipped_comparisons{ + static std::unordered_map map{ {"equal", EQUAL}, {"not_equal", NOT_EQUAL}, {"less", LESS}, {"less_equal", LESS_EQUAL}, {"greater", GREATER}, {"greater_equal", GREATER_EQUAL}, }; - auto it = flipped_comparisons.find(function); - return it != flipped_comparisons.end() ? &it->second : nullptr; + auto it = map.find(function); + return it != map.end() ? &it->second : nullptr; } static const type* Get(const Expression& expr) { @@ -262,13 +263,12 @@ inline const compute::StrptimeOptions* GetStrptimeOptions(const Expression::Call return checked_cast(call.options.get()); } -inline const std::shared_ptr& GetDictionaryValueType( +inline std::shared_ptr GetDictionaryValueType( const std::shared_ptr& type) { if (type && type->id() == Type::DICTIONARY) { return checked_cast(*type).value_type(); } - static std::shared_ptr null; - return null; + return nullptr; } inline Status EnsureNotDictionary(ValueDescr* descr) { @@ -410,7 +410,10 @@ struct FlattenedAssociativeChain { exprs.push_back(std::move(*it)); it = fringe.erase(it); - it = fringe.insert(it, sub_call->arguments.begin(), sub_call->arguments.end()); + + auto index = it - fringe.begin(); + fringe.insert(it, sub_call->arguments.begin(), sub_call->arguments.end()); + it = fringe.begin() + index; // NB: no increment so we hit sub_call's first argument next iteration } diff --git a/cpp/src/arrow/dataset/expression_test.cc b/cpp/src/arrow/dataset/expression_test.cc index 3b286e378d0..175e045918d 100644 --- a/cpp/src/arrow/dataset/expression_test.cc +++ b/cpp/src/arrow/dataset/expression_test.cc @@ -558,7 +558,7 @@ TEST(Expression, FoldConstantsBoolean) { // test and_kleene/or_kleene-specific optimizations auto one = literal(1); auto two = literal(2); - auto whatever = call("equal", {call("add", {one, field_ref("i32")}), two}); + auto whatever = equal(call("add", {one, field_ref("i32")}), two); auto true_ = literal(true); auto false_ = literal(false); @@ -696,7 +696,7 @@ TEST(Expression, CanonicalizeAnd) { auto null_ = literal(std::make_shared()); auto b = field_ref("bool"); - auto c = call("equal", {literal(1), literal(2)}); + auto c = equal(literal(1), literal(2)); // no change possible: ExpectCanonicalizesTo(and_(b, c), and_(b, c)); diff --git a/cpp/src/arrow/dataset/partition.cc b/cpp/src/arrow/dataset/partition.cc index ba4f0c46b8f..2822c4a15b6 100644 --- a/cpp/src/arrow/dataset/partition.cc +++ b/cpp/src/arrow/dataset/partition.cc @@ -27,6 +27,7 @@ #include "arrow/array/builder_dict.h" #include "arrow/compute/api_scalar.h" #include "arrow/compute/api_vector.h" +#include "arrow/compute/cast.h" #include "arrow/dataset/dataset_internal.h" #include "arrow/filesystem/path_util.h" #include "arrow/scalar.h" diff --git a/cpp/src/arrow/dataset/partition_test.cc b/cpp/src/arrow/dataset/partition_test.cc index 8610ff1e891..2260eb219da 100644 --- a/cpp/src/arrow/dataset/partition_test.cc +++ b/cpp/src/arrow/dataset/partition_test.cc @@ -26,6 +26,7 @@ #include #include +#include "arrow/compute/api_scalar.h" #include "arrow/dataset/scanner_internal.h" #include "arrow/dataset/test_util.h" #include "arrow/filesystem/path_util.h" diff --git a/cpp/src/arrow/dataset/scanner_test.cc b/cpp/src/arrow/dataset/scanner_test.cc index 7261bdb8037..f8f959e3a28 100644 --- a/cpp/src/arrow/dataset/scanner_test.cc +++ b/cpp/src/arrow/dataset/scanner_test.cc @@ -184,25 +184,17 @@ TEST_F(TestScannerBuilder, TestFilter) { ScannerBuilder builder(dataset_, ctx_); ASSERT_OK(builder.Filter(literal(true))); - ASSERT_OK(builder.Filter(call("equal", {field_ref("i64"), literal(10)}))); - ASSERT_OK(builder.Filter( - call("or_kleene", { - call("equal", {field_ref("i64"), literal(10)}), - call("equal", {field_ref("b"), literal(true)}), - }))); - - ASSERT_OK(builder.Filter(call("equal", {field_ref("i64"), literal(10)}))); - - ASSERT_RAISES( - Invalid, builder.Filter(call("equal", {field_ref("not_a_column"), literal(true)}))); - - ASSERT_RAISES( - Invalid, - builder.Filter( - call("or_kleene", { - call("equal", {field_ref("i64"), literal(10)}), - call("equal", {field_ref("not_a_column"), literal(true)}), - }))); + ASSERT_OK(builder.Filter(equal(field_ref("i64"), literal(10)))); + ASSERT_OK(builder.Filter(or_(equal(field_ref("i64"), literal(10)), + equal(field_ref("b"), literal(true))))); + + ASSERT_OK(builder.Filter(equal(field_ref("i64"), literal(10)))); + + ASSERT_RAISES(Invalid, builder.Filter(equal(field_ref("not_a_column"), literal(true)))); + + ASSERT_RAISES(Invalid, + builder.Filter(or_(equal(field_ref("i64"), literal(10)), + equal(field_ref("not_a_column"), literal(true))))); } using testing::ElementsAre; @@ -215,7 +207,7 @@ TEST(ScanOptions, TestMaterializedFields) { auto opts = ScanOptions::Make(schema({})); EXPECT_THAT(opts->MaterializedFields(), IsEmpty()); - opts->filter2 = call("equal", {field_ref("i32"), literal(10)}); + opts->filter2 = equal(field_ref("i32"), literal(10)); EXPECT_THAT(opts->MaterializedFields(), ElementsAre("i32")); opts = ScanOptions::Make(schema({i32, i64})); @@ -224,10 +216,10 @@ TEST(ScanOptions, TestMaterializedFields) { opts = opts->ReplaceSchema(schema({i32})); EXPECT_THAT(opts->MaterializedFields(), ElementsAre("i32")); - opts->filter2 = call("equal", {field_ref("i32"), literal(10)}); + opts->filter2 = equal(field_ref("i32"), literal(10)); EXPECT_THAT(opts->MaterializedFields(), ElementsAre("i32", "i32")); - opts->filter2 = call("equal", {field_ref("i64"), literal(10)}); + opts->filter2 = equal(field_ref("i64"), literal(10)); EXPECT_THAT(opts->MaterializedFields(), ElementsAre("i32", "i64")); } diff --git a/cpp/src/arrow/datum.cc b/cpp/src/arrow/datum.cc index 0f2fa8ebe90..786110996dc 100644 --- a/cpp/src/arrow/datum.cc +++ b/cpp/src/arrow/datum.cc @@ -89,12 +89,24 @@ std::shared_ptr Datum::make_array() const { std::shared_ptr Datum::type() const { if (this->kind() == Datum::ARRAY) { return util::get>(this->value)->type; - } else if (this->kind() == Datum::CHUNKED_ARRAY) { + } + if (this->kind() == Datum::CHUNKED_ARRAY) { return util::get>(this->value)->type(); - } else if (this->kind() == Datum::SCALAR) { + } + if (this->kind() == Datum::SCALAR) { return util::get>(this->value)->type; } - return NULLPTR; + return nullptr; +} + +std::shared_ptr Datum::schema() const { + if (this->kind() == Datum::RECORD_BATCH) { + return util::get>(this->value)->schema(); + } + if (this->kind() == Datum::TABLE) { + return util::get>(this->value)->schema(); + } + return nullptr; } int64_t Datum::length() const { diff --git a/cpp/src/arrow/datum.h b/cpp/src/arrow/datum.h index 8479f7ad366..fb783ea5261 100644 --- a/cpp/src/arrow/datum.h +++ b/cpp/src/arrow/datum.h @@ -252,6 +252,11 @@ struct ARROW_EXPORT Datum { /// \return nullptr if no type std::shared_ptr type() const; + /// \brief The schema of the variant, if any + /// + /// \return nullptr if no schema + std::shared_ptr schema() const; + /// \brief The value length of the variant, if any /// /// \return kUnknownLength if no type diff --git a/cpp/src/arrow/type.cc b/cpp/src/arrow/type.cc index 20dc0f14ca0..ad889b3eb24 100644 --- a/cpp/src/arrow/type.cc +++ b/cpp/src/arrow/type.cc @@ -30,12 +30,10 @@ #include #include "arrow/array.h" -#include "arrow/chunked_array.h" #include "arrow/compare.h" #include "arrow/record_batch.h" #include "arrow/result.h" #include "arrow/status.h" -#include "arrow/table.h" #include "arrow/util/checked_cast.h" #include "arrow/util/hash_util.h" #include "arrow/util/hashing.h" @@ -900,8 +898,6 @@ std::string FieldPath::ToString() const { struct FieldPathGetImpl { static const DataType& GetType(const ArrayData& data) { return *data.type; } - static const DataType& GetType(const ChunkedArray& array) { return *array.type(); } - static void Summarize(const FieldVector& fields, std::stringstream* ss) { *ss << "{ "; for (const auto& field : fields) { @@ -996,32 +992,6 @@ struct FieldPathGetImpl { path, &child_data, [](const std::shared_ptr& data) { return &data->child_data; }); } - - static Result> Get( - const FieldPath* path, const ChunkedArrayVector& columns_arg) { - ChunkedArrayVector columns = columns_arg; - - return FieldPathGetImpl::Get( - path, &columns, [&](const std::shared_ptr& a) { - columns.clear(); - - for (int i = 0; i < a->type()->num_fields(); ++i) { - ArrayVector child_chunks; - - for (const auto& chunk : a->chunks()) { - auto child_chunk = MakeArray(chunk->data()->child_data[i]); - child_chunks.push_back(std::move(child_chunk)); - } - - auto child_column = std::make_shared( - std::move(child_chunks), a->type()->field(i)->type()); - - columns.emplace_back(std::move(child_column)); - } - - return &columns; - }); - } }; Result> FieldPath::Get(const Schema& schema) const { @@ -1045,10 +1015,6 @@ Result> FieldPath::Get(const RecordBatch& batch) const { return MakeArray(std::move(data)); } -Result> FieldPath::Get(const Table& table) const { - return FieldPathGetImpl::Get(this, table.columns()); -} - Result> FieldPath::Get(const Array& array) const { ARROW_ASSIGN_OR_RAISE(auto data, Get(*array.data())); return MakeArray(std::move(data)); @@ -1058,15 +1024,6 @@ Result> FieldPath::Get(const ArrayData& data) const { return FieldPathGetImpl::Get(this, data.child_data); } -Result> FieldPath::Get(const ChunkedArray& array) const { - FieldPath prefixed_with_0 = *this; - prefixed_with_0.indices_.insert(prefixed_with_0.indices_.begin(), 0); - - ChunkedArrayVector vec; - vec.emplace_back(const_cast(&array), [](...) {}); - return FieldPathGetImpl::Get(&prefixed_with_0, vec); -} - FieldRef::FieldRef(FieldPath indices) : impl_(std::move(indices)) { DCHECK_GT(util::get(impl_).indices().size(), 0); } @@ -1320,18 +1277,10 @@ std::vector FieldRef::FindAll(const Array& array) const { return FindAll(*array.type()); } -std::vector FieldRef::FindAll(const ChunkedArray& array) const { - return FindAll(*array.type()); -} - std::vector FieldRef::FindAll(const RecordBatch& batch) const { return FindAll(*batch.schema()); } -std::vector FieldRef::FindAll(const Table& table) const { - return FindAll(*table.schema()); -} - void PrintTo(const FieldRef& ref, std::ostream* os) { *os << ref.ToString(); } // ---------------------------------------------------------------------- diff --git a/cpp/src/arrow/type.h b/cpp/src/arrow/type.h index cadb819b9f2..f0fa04f40be 100644 --- a/cpp/src/arrow/type.h +++ b/cpp/src/arrow/type.h @@ -1424,15 +1424,10 @@ class ARROW_EXPORT FieldPath { /// \brief Retrieve the referenced column from a RecordBatch or Table Result> Get(const RecordBatch& batch) const; - Result> Get(const Table& table) const; /// \brief Retrieve the referenced child from an Array, ArrayData, or ChunkedArray Result> Get(const Array& array) const; Result> Get(const ArrayData& data) const; - Result> Get(const ChunkedArray& array) const; - - /// \brief Retrieve the reference child from a Datum - Result Get(const Datum& datum) const; private: std::vector indices_; From a88685e541819fb5510c600a44b0e19aba80c1a0 Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Fri, 18 Dec 2020 10:33:49 -0500 Subject: [PATCH 20/38] repair linkage on Expression friends --- cpp/src/arrow/dataset/expression.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/src/arrow/dataset/expression.h b/cpp/src/arrow/dataset/expression.h index f625ca694e1..4b414d4d142 100644 --- a/cpp/src/arrow/dataset/expression.h +++ b/cpp/src/arrow/dataset/expression.h @@ -117,9 +117,9 @@ class ARROW_DS_EXPORT Expression { using Impl = util::Variant; std::shared_ptr impl_; - ARROW_EXPORT friend bool Identical(const Expression& l, const Expression& r); + ARROW_DS_EXPORT friend bool Identical(const Expression& l, const Expression& r); - ARROW_EXPORT friend void PrintTo(const Expression&, std::ostream*); + ARROW_DS_EXPORT friend void PrintTo(const Expression&, std::ostream*); }; inline bool operator==(const Expression& l, const Expression& r) { return l.Equals(r); } From 9be418b57153831a0be97de99e6743efc768e80d Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Fri, 18 Dec 2020 11:32:45 -0500 Subject: [PATCH 21/38] centos-7 doesn't recognize this initializer list --- cpp/src/arrow/dataset/expression.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/cpp/src/arrow/dataset/expression.cc b/cpp/src/arrow/dataset/expression.cc index 637010cfed2..848055ac5aa 100644 --- a/cpp/src/arrow/dataset/expression.cc +++ b/cpp/src/arrow/dataset/expression.cc @@ -389,9 +389,10 @@ Result> InitKernelState( if (!call.kernel->init) return nullptr; compute::KernelContext kernel_context(exec_context); - auto kernel_state = call.kernel->init( - &kernel_context, {call.kernel, GetDescriptors(call.arguments), call.options.get()}); + compute::KernelInitArgs kernel_init_args{call.kernel, GetDescriptors(call.arguments), + call.options.get()}; + auto kernel_state = call.kernel->init(&kernel_context, kernel_init_args); RETURN_NOT_OK(kernel_context.status()); return std::move(kernel_state); } From 82c32defb2a9f0188c7c68198daac326f119b1fe Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Fri, 18 Dec 2020 11:33:15 -0500 Subject: [PATCH 22/38] get test_{dataset,parquet}.py passing --- .../compute/kernels/scalar_cast_string.cc | 12 +- python/pyarrow/_dataset.pyx | 6 +- python/pyarrow/includes/libarrow.pxd | 13 +- python/pyarrow/tests/test_dataset.py | 4 +- python/pyarrow/tests/test_parquet.py | 4528 +++++++++++++++++ 5 files changed, 4549 insertions(+), 14 deletions(-) create mode 100644 python/pyarrow/tests/test_parquet.py diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_string.cc b/cpp/src/arrow/compute/kernels/scalar_cast_string.cc index 9d3812e5967..ca9c33a8f1b 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_string.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_string.cc @@ -49,6 +49,7 @@ struct NumericToStringCastFunctor { using FormatterType = StringFormatter; static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + DCHECK(out->is_array()); const ArrayData& input = *batch[0].array(); ArrayData* output = out->mutable_array(); ctx->SetStatus(Convert(ctx, input, output)); @@ -160,6 +161,7 @@ struct BinaryToBinaryCastFunctor { using output_offset_type = typename O::offset_type; static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + DCHECK(out->is_array()); const CastOptions& options = checked_cast(*ctx->state()).options; const ArrayData& input = *batch[0].array(); @@ -194,7 +196,8 @@ void AddNumberToStringCasts(CastFunction* func) { auto out_ty = TypeTraits::type_singleton(); DCHECK_OK(func->AddKernel(Type::BOOL, {boolean()}, out_ty, - NumericToStringCastFunctor::Exec, + TrivialScalarUnaryAsArraysExec( + NumericToStringCastFunctor::Exec), NullHandling::COMPUTED_NO_PREALLOCATE)); for (const std::shared_ptr& in_ty : NumericTypes()) { @@ -210,9 +213,10 @@ void AddBinaryToBinaryCast(CastFunction* func) { auto in_ty = TypeTraits::type_singleton(); auto out_ty = TypeTraits::type_singleton(); - DCHECK_OK(func->AddKernel(OutType::type_id, {in_ty}, out_ty, - BinaryToBinaryCastFunctor::Exec, - NullHandling::COMPUTED_NO_PREALLOCATE)); + DCHECK_OK(func->AddKernel( + OutType::type_id, {in_ty}, out_ty, + TrivialScalarUnaryAsArraysExec(BinaryToBinaryCastFunctor::Exec), + NullHandling::COMPUTED_NO_PREALLOCATE)); } } // namespace diff --git a/python/pyarrow/_dataset.pyx b/python/pyarrow/_dataset.pyx index 7b96e87633b..710e3c9c2c4 100644 --- a/python/pyarrow/_dataset.pyx +++ b/python/pyarrow/_dataset.pyx @@ -2254,14 +2254,14 @@ def _get_partition_keys(Expression partition_expression): """ cdef: CExpression expr = partition_expression.unwrap() - pair[CFieldRef, CDatum] name_val + pair[CFieldRef, CDatum] ref_val out = {} for ref_val in GetResultValue(CExtractKnownFieldValues(expr)): assert ref_val.first.name() != nullptr assert ref_val.second.kind() == DatumType_SCALAR - val = pyarrow_wrap_scalar(name_val.second.scalar()).as_py() - out[deref(ref_val.first.name())] = val + val = pyarrow_wrap_scalar(ref_val.second.scalar()) + out[frombytes(deref(ref_val.first.name()))] = val.as_py() return out diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index aaee9fb5351..fcdf8ed4179 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -1830,13 +1830,14 @@ cdef extern from "arrow/compute/api.h" namespace "arrow::compute" nogil: CDatum(const shared_ptr[CRecordBatch]& value) CDatum(const shared_ptr[CTable]& value) - DatumType kind() + DatumType kind() const + c_string ToString() const - shared_ptr[CArrayData] array() - shared_ptr[CChunkedArray] chunked_array() - shared_ptr[CRecordBatch] record_batch() - shared_ptr[CTable] table() - shared_ptr[CScalar] scalar() + const shared_ptr[CArrayData]& array() const + const shared_ptr[CChunkedArray]& chunked_array() const + const shared_ptr[CRecordBatch]& record_batch() const + const shared_ptr[CTable]& table() const + const shared_ptr[CScalar]& scalar() const cdef cppclass CSetLookupOptions \ "arrow::compute::SetLookupOptions"(CFunctionOptions): diff --git a/python/pyarrow/tests/test_dataset.py b/python/pyarrow/tests/test_dataset.py index a783c3b4273..0ab9d95398d 100644 --- a/python/pyarrow/tests/test_dataset.py +++ b/python/pyarrow/tests/test_dataset.py @@ -768,7 +768,9 @@ def assert_yields_projected(fragment, row_slice, # Fragments don't contain the partition's columns if not provided to the # `to_table(schema=...)` method. - with pytest.raises(ValueError, match="Field named 'part' not found"): + pattern = (r'No match for FieldRef.Name\(part\) in ' + + fragment.physical_schema.to_string(False, False, False)) + with pytest.raises(ValueError, match=pattern): new_fragment = parquet_format.make_fragment( fragment.path, fragment.filesystem, partition_expression=fragment.partition_expression) diff --git a/python/pyarrow/tests/test_parquet.py b/python/pyarrow/tests/test_parquet.py new file mode 100644 index 00000000000..6fa001cc758 --- /dev/null +++ b/python/pyarrow/tests/test_parquet.py @@ -0,0 +1,4528 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from collections import OrderedDict +import datetime +import decimal +from distutils.version import LooseVersion +import io +import json +import os +import pytest + +import numpy as np + +import pyarrow as pa +from pyarrow.pandas_compat import _pandas_api +from pyarrow.tests import util +from pyarrow.util import guid +from pyarrow.filesystem import LocalFileSystem, FileSystem +from pyarrow import fs + + +try: + import pyarrow.parquet as pq +except ImportError: + pq = None + + +try: + import pandas as pd + import pandas.testing as tm + from .pandas_examples import dataframe_with_arrays, dataframe_with_lists +except ImportError: + pd = tm = None + + +# Marks all of the tests in this module +# Ignore these with pytest ... -m 'not parquet' +pytestmark = pytest.mark.parquet + + +@pytest.fixture(scope='module') +def datadir(datadir): + return datadir / 'parquet' + + +parametrize_legacy_dataset = pytest.mark.parametrize( + "use_legacy_dataset", + [True, pytest.param(False, marks=pytest.mark.dataset)]) +parametrize_legacy_dataset_not_supported = pytest.mark.parametrize( + "use_legacy_dataset", [True, pytest.param(False, marks=pytest.mark.skip)]) +parametrize_legacy_dataset_fixed = pytest.mark.parametrize( + "use_legacy_dataset", [pytest.param(True, marks=pytest.mark.xfail), + pytest.param(False, marks=pytest.mark.dataset)]) + + +def _write_table(table, path, **kwargs): + # So we see the ImportError somewhere + import pyarrow.parquet as pq + + if _pandas_api.is_data_frame(table): + table = pa.Table.from_pandas(table) + + pq.write_table(table, path, **kwargs) + return table + + +def _read_table(*args, **kwargs): + table = pq.read_table(*args, **kwargs) + table.validate(full=True) + return table + + +def _roundtrip_table(table, read_table_kwargs=None, + write_table_kwargs=None, use_legacy_dataset=True): + read_table_kwargs = read_table_kwargs or {} + write_table_kwargs = write_table_kwargs or {} + + writer = pa.BufferOutputStream() + _write_table(table, writer, **write_table_kwargs) + reader = pa.BufferReader(writer.getvalue()) + return _read_table(reader, use_legacy_dataset=use_legacy_dataset, + **read_table_kwargs) + + +def _check_roundtrip(table, expected=None, read_table_kwargs=None, + use_legacy_dataset=True, **write_table_kwargs): + if expected is None: + expected = table + + read_table_kwargs = read_table_kwargs or {} + + # intentionally check twice + result = _roundtrip_table(table, read_table_kwargs=read_table_kwargs, + write_table_kwargs=write_table_kwargs, + use_legacy_dataset=use_legacy_dataset) + assert result.equals(expected) + result = _roundtrip_table(result, read_table_kwargs=read_table_kwargs, + write_table_kwargs=write_table_kwargs, + use_legacy_dataset=use_legacy_dataset) + assert result.equals(expected) + + +def _roundtrip_pandas_dataframe(df, write_kwargs, use_legacy_dataset=True): + table = pa.Table.from_pandas(df) + result = _roundtrip_table( + table, write_table_kwargs=write_kwargs, + use_legacy_dataset=use_legacy_dataset) + return result.to_pandas() + + +def test_large_binary(): + data = [b'foo', b'bar'] * 50 + for type in [pa.large_binary(), pa.large_string()]: + arr = pa.array(data, type=type) + table = pa.Table.from_arrays([arr], names=['strs']) + for use_dictionary in [False, True]: + _check_roundtrip(table, use_dictionary=use_dictionary) + + +@pytest.mark.large_memory +def test_large_binary_huge(): + s = b'xy' * 997 + data = [s] * ((1 << 33) // len(s)) + for type in [pa.large_binary(), pa.large_string()]: + arr = pa.array(data, type=type) + table = pa.Table.from_arrays([arr], names=['strs']) + for use_dictionary in [False, True]: + _check_roundtrip(table, use_dictionary=use_dictionary) + del arr, table + + +@pytest.mark.large_memory +def test_large_binary_overflow(): + s = b'x' * (1 << 31) + arr = pa.array([s], type=pa.large_binary()) + table = pa.Table.from_arrays([arr], names=['strs']) + for use_dictionary in [False, True]: + writer = pa.BufferOutputStream() + with pytest.raises( + pa.ArrowInvalid, + match="Parquet cannot store strings with size 2GB or more"): + _write_table(table, writer, use_dictionary=use_dictionary) + + +@parametrize_legacy_dataset +@pytest.mark.parametrize('dtype', [int, float]) +def test_single_pylist_column_roundtrip(tempdir, dtype, use_legacy_dataset): + filename = tempdir / 'single_{}_column.parquet'.format(dtype.__name__) + data = [pa.array(list(map(dtype, range(5))))] + table = pa.Table.from_arrays(data, names=['a']) + _write_table(table, filename) + table_read = _read_table(filename, use_legacy_dataset=use_legacy_dataset) + for i in range(table.num_columns): + col_written = table[i] + col_read = table_read[i] + assert table.field(i).name == table_read.field(i).name + assert col_read.num_chunks == 1 + data_written = col_written.chunk(0) + data_read = col_read.chunk(0) + assert data_written.equals(data_read) + + +def alltypes_sample(size=10000, seed=0, categorical=False): + np.random.seed(seed) + arrays = { + 'uint8': np.arange(size, dtype=np.uint8), + 'uint16': np.arange(size, dtype=np.uint16), + 'uint32': np.arange(size, dtype=np.uint32), + 'uint64': np.arange(size, dtype=np.uint64), + 'int8': np.arange(size, dtype=np.int16), + 'int16': np.arange(size, dtype=np.int16), + 'int32': np.arange(size, dtype=np.int32), + 'int64': np.arange(size, dtype=np.int64), + 'float32': np.arange(size, dtype=np.float32), + 'float64': np.arange(size, dtype=np.float64), + 'bool': np.random.randn(size) > 0, + # TODO(wesm): Test other timestamp resolutions now that arrow supports + # them + 'datetime': np.arange("2016-01-01T00:00:00.001", size, + dtype='datetime64[ms]'), + 'str': pd.Series([str(x) for x in range(size)]), + 'empty_str': [''] * size, + 'str_with_nulls': [None] + [str(x) for x in range(size - 2)] + [None], + 'null': [None] * size, + 'null_list': [None] * 2 + [[None] * (x % 4) for x in range(size - 2)], + } + if categorical: + arrays['str_category'] = arrays['str'].astype('category') + return pd.DataFrame(arrays) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +@pytest.mark.parametrize('chunk_size', [None, 1000]) +def test_pandas_parquet_2_0_roundtrip(tempdir, chunk_size, use_legacy_dataset): + df = alltypes_sample(size=10000, categorical=True) + + filename = tempdir / 'pandas_roundtrip.parquet' + arrow_table = pa.Table.from_pandas(df) + assert arrow_table.schema.pandas_metadata is not None + + _write_table(arrow_table, filename, version="2.0", + coerce_timestamps='ms', chunk_size=chunk_size) + table_read = pq.read_pandas( + filename, use_legacy_dataset=use_legacy_dataset) + assert table_read.schema.pandas_metadata is not None + + read_metadata = table_read.schema.metadata + assert arrow_table.schema.metadata == read_metadata + + df_read = table_read.to_pandas() + tm.assert_frame_equal(df, df_read) + + +def test_parquet_invalid_version(tempdir): + table = pa.table({'a': [1, 2, 3]}) + with pytest.raises(ValueError, match="Unsupported Parquet format version"): + _write_table(table, tempdir / 'test_version.parquet', version="2.2") + with pytest.raises(ValueError, match="Unsupported Parquet data page " + + "version"): + _write_table(table, tempdir / 'test_version.parquet', + data_page_version="2.2") + + +@parametrize_legacy_dataset +def test_set_data_page_size(use_legacy_dataset): + arr = pa.array([1, 2, 3] * 100000) + t = pa.Table.from_arrays([arr], names=['f0']) + + # 128K, 512K + page_sizes = [2 << 16, 2 << 18] + for target_page_size in page_sizes: + _check_roundtrip(t, data_page_size=target_page_size, + use_legacy_dataset=use_legacy_dataset) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_chunked_table_write(use_legacy_dataset): + # ARROW-232 + tables = [] + batch = pa.RecordBatch.from_pandas(alltypes_sample(size=10)) + tables.append(pa.Table.from_batches([batch] * 3)) + df, _ = dataframe_with_lists() + batch = pa.RecordBatch.from_pandas(df) + tables.append(pa.Table.from_batches([batch] * 3)) + + for data_page_version in ['1.0', '2.0']: + for use_dictionary in [True, False]: + for table in tables: + _check_roundtrip( + table, version='2.0', + use_legacy_dataset=use_legacy_dataset, + data_page_version=data_page_version, + use_dictionary=use_dictionary) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_memory_map(tempdir, use_legacy_dataset): + df = alltypes_sample(size=10) + + table = pa.Table.from_pandas(df) + _check_roundtrip(table, read_table_kwargs={'memory_map': True}, + version='2.0', use_legacy_dataset=use_legacy_dataset) + + filename = str(tempdir / 'tmp_file') + with open(filename, 'wb') as f: + _write_table(table, f, version='2.0') + table_read = pq.read_pandas(filename, memory_map=True, + use_legacy_dataset=use_legacy_dataset) + assert table_read.equals(table) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_enable_buffered_stream(tempdir, use_legacy_dataset): + df = alltypes_sample(size=10) + + table = pa.Table.from_pandas(df) + _check_roundtrip(table, read_table_kwargs={'buffer_size': 1025}, + version='2.0', use_legacy_dataset=use_legacy_dataset) + + filename = str(tempdir / 'tmp_file') + with open(filename, 'wb') as f: + _write_table(table, f, version='2.0') + table_read = pq.read_pandas(filename, buffer_size=4096, + use_legacy_dataset=use_legacy_dataset) + assert table_read.equals(table) + + +@parametrize_legacy_dataset +def test_special_chars_filename(tempdir, use_legacy_dataset): + table = pa.Table.from_arrays([pa.array([42])], ["ints"]) + filename = "foo # bar" + path = tempdir / filename + assert not path.exists() + _write_table(table, str(path)) + assert path.exists() + table_read = _read_table(str(path), use_legacy_dataset=use_legacy_dataset) + assert table_read.equals(table) + + +@pytest.mark.slow +def test_file_with_over_int16_max_row_groups(): + # PARQUET-1857: Parquet encryption support introduced a INT16_MAX upper + # limit on the number of row groups, but this limit only impacts files with + # encrypted row group metadata because of the int16 row group ordinal used + # in the Parquet Thrift metadata. Unencrypted files are not impacted, so + # this test checks that it works (even if it isn't a good idea) + t = pa.table([list(range(40000))], names=['f0']) + _check_roundtrip(t, row_group_size=1) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_empty_table_roundtrip(use_legacy_dataset): + df = alltypes_sample(size=10) + + # Create a non-empty table to infer the types correctly, then slice to 0 + table = pa.Table.from_pandas(df) + table = pa.Table.from_arrays( + [col.chunk(0)[:0] for col in table.itercolumns()], + names=table.schema.names) + + assert table.schema.field('null').type == pa.null() + assert table.schema.field('null_list').type == pa.list_(pa.null()) + _check_roundtrip( + table, version='2.0', use_legacy_dataset=use_legacy_dataset) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_empty_table_no_columns(use_legacy_dataset): + df = pd.DataFrame() + empty = pa.Table.from_pandas(df, preserve_index=False) + _check_roundtrip(empty, use_legacy_dataset=use_legacy_dataset) + + +@parametrize_legacy_dataset +def test_empty_lists_table_roundtrip(use_legacy_dataset): + # ARROW-2744: Shouldn't crash when writing an array of empty lists + arr = pa.array([[], []], type=pa.list_(pa.int32())) + table = pa.Table.from_arrays([arr], ["A"]) + _check_roundtrip(table, use_legacy_dataset=use_legacy_dataset) + + +@parametrize_legacy_dataset +def test_nested_list_nonnullable_roundtrip_bug(use_legacy_dataset): + # Reproduce failure in ARROW-5630 + typ = pa.list_(pa.field("item", pa.float32(), False)) + num_rows = 10000 + t = pa.table([ + pa.array(([[0] * ((i + 5) % 10) for i in range(0, 10)] * + (num_rows // 10)), type=typ) + ], ['a']) + _check_roundtrip( + t, data_page_size=4096, use_legacy_dataset=use_legacy_dataset) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_pandas_parquet_datetime_tz(use_legacy_dataset): + s = pd.Series([datetime.datetime(2017, 9, 6)]) + s = s.dt.tz_localize('utc') + + s.index = s + + # Both a column and an index to hit both use cases + df = pd.DataFrame({'tz_aware': s, + 'tz_eastern': s.dt.tz_convert('US/Eastern')}, + index=s) + + f = io.BytesIO() + + arrow_table = pa.Table.from_pandas(df) + + _write_table(arrow_table, f, coerce_timestamps='ms') + f.seek(0) + + table_read = pq.read_pandas(f, use_legacy_dataset=use_legacy_dataset) + + df_read = table_read.to_pandas() + tm.assert_frame_equal(df, df_read) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_datetime_timezone_tzinfo(use_legacy_dataset): + value = datetime.datetime(2018, 1, 1, 1, 23, 45, + tzinfo=datetime.timezone.utc) + df = pd.DataFrame({'foo': [value]}) + + _roundtrip_pandas_dataframe( + df, write_kwargs={}, use_legacy_dataset=use_legacy_dataset) + + +@pytest.mark.pandas +def test_pandas_parquet_custom_metadata(tempdir): + df = alltypes_sample(size=10000) + + filename = tempdir / 'pandas_roundtrip.parquet' + arrow_table = pa.Table.from_pandas(df) + assert b'pandas' in arrow_table.schema.metadata + + _write_table(arrow_table, filename, version='2.0', coerce_timestamps='ms') + + metadata = pq.read_metadata(filename).metadata + assert b'pandas' in metadata + + js = json.loads(metadata[b'pandas'].decode('utf8')) + assert js['index_columns'] == [{'kind': 'range', + 'name': None, + 'start': 0, 'stop': 10000, + 'step': 1}] + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_pandas_parquet_column_multiindex(tempdir, use_legacy_dataset): + df = alltypes_sample(size=10) + df.columns = pd.MultiIndex.from_tuples( + list(zip(df.columns, df.columns[::-1])), + names=['level_1', 'level_2'] + ) + + filename = tempdir / 'pandas_roundtrip.parquet' + arrow_table = pa.Table.from_pandas(df) + assert arrow_table.schema.pandas_metadata is not None + + _write_table(arrow_table, filename, version='2.0', coerce_timestamps='ms') + + table_read = pq.read_pandas( + filename, use_legacy_dataset=use_legacy_dataset) + df_read = table_read.to_pandas() + tm.assert_frame_equal(df, df_read) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_pandas_parquet_2_0_roundtrip_read_pandas_no_index_written( + tempdir, use_legacy_dataset +): + df = alltypes_sample(size=10000) + + filename = tempdir / 'pandas_roundtrip.parquet' + arrow_table = pa.Table.from_pandas(df, preserve_index=False) + js = arrow_table.schema.pandas_metadata + assert not js['index_columns'] + # ARROW-2170 + # While index_columns should be empty, columns needs to be filled still. + assert js['columns'] + + _write_table(arrow_table, filename, version='2.0', coerce_timestamps='ms') + table_read = pq.read_pandas( + filename, use_legacy_dataset=use_legacy_dataset) + + js = table_read.schema.pandas_metadata + assert not js['index_columns'] + + read_metadata = table_read.schema.metadata + assert arrow_table.schema.metadata == read_metadata + + df_read = table_read.to_pandas() + tm.assert_frame_equal(df, df_read) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_pandas_parquet_1_0_roundtrip(tempdir, use_legacy_dataset): + size = 10000 + np.random.seed(0) + df = pd.DataFrame({ + 'uint8': np.arange(size, dtype=np.uint8), + 'uint16': np.arange(size, dtype=np.uint16), + 'uint32': np.arange(size, dtype=np.uint32), + 'uint64': np.arange(size, dtype=np.uint64), + 'int8': np.arange(size, dtype=np.int16), + 'int16': np.arange(size, dtype=np.int16), + 'int32': np.arange(size, dtype=np.int32), + 'int64': np.arange(size, dtype=np.int64), + 'float32': np.arange(size, dtype=np.float32), + 'float64': np.arange(size, dtype=np.float64), + 'bool': np.random.randn(size) > 0, + 'str': [str(x) for x in range(size)], + 'str_with_nulls': [None] + [str(x) for x in range(size - 2)] + [None], + 'empty_str': [''] * size + }) + filename = tempdir / 'pandas_roundtrip.parquet' + arrow_table = pa.Table.from_pandas(df) + _write_table(arrow_table, filename, version='1.0') + table_read = _read_table(filename, use_legacy_dataset=use_legacy_dataset) + df_read = table_read.to_pandas() + + # We pass uint32_t as int64_t if we write Parquet version 1.0 + df['uint32'] = df['uint32'].values.astype(np.int64) + + tm.assert_frame_equal(df, df_read) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_multiple_path_types(tempdir, use_legacy_dataset): + # Test compatibility with PEP 519 path-like objects + path = tempdir / 'zzz.parquet' + df = pd.DataFrame({'x': np.arange(10, dtype=np.int64)}) + _write_table(df, path) + table_read = _read_table(path, use_legacy_dataset=use_legacy_dataset) + df_read = table_read.to_pandas() + tm.assert_frame_equal(df, df_read) + + # Test compatibility with plain string paths + path = str(tempdir) + 'zzz.parquet' + df = pd.DataFrame({'x': np.arange(10, dtype=np.int64)}) + _write_table(df, path) + table_read = _read_table(path, use_legacy_dataset=use_legacy_dataset) + df_read = table_read.to_pandas() + tm.assert_frame_equal(df, df_read) + + +@pytest.mark.dataset +@parametrize_legacy_dataset +@pytest.mark.parametrize("filesystem", [ + None, fs.LocalFileSystem(), LocalFileSystem._get_instance() +]) +def test_relative_paths(tempdir, use_legacy_dataset, filesystem): + # reading and writing from relative paths + table = pa.table({"a": [1, 2, 3]}) + + # reading + pq.write_table(table, str(tempdir / "data.parquet")) + with util.change_cwd(tempdir): + result = pq.read_table("data.parquet", filesystem=filesystem, + use_legacy_dataset=use_legacy_dataset) + assert result.equals(table) + + # writing + with util.change_cwd(tempdir): + pq.write_table(table, "data2.parquet", filesystem=filesystem) + result = pq.read_table(tempdir / "data2.parquet") + assert result.equals(table) + + +@parametrize_legacy_dataset_fixed +def test_filesystem_uri(tempdir, use_legacy_dataset): + table = pa.table({"a": [1, 2, 3]}) + + directory = tempdir / "data_dir" + directory.mkdir() + path = directory / "data.parquet" + pq.write_table(table, str(path)) + + # filesystem object + result = pq.read_table( + path, filesystem=fs.LocalFileSystem(), + use_legacy_dataset=use_legacy_dataset) + assert result.equals(table) + + # filesystem URI + result = pq.read_table( + "data_dir/data.parquet", filesystem=util._filesystem_uri(tempdir), + use_legacy_dataset=use_legacy_dataset) + assert result.equals(table) + + +@parametrize_legacy_dataset +def test_read_non_existing_file(use_legacy_dataset): + # ensure we have a proper error message + with pytest.raises(FileNotFoundError): + pq.read_table('i-am-not-existing.parquet') + + +# TODO(dataset) duplicate column selection actually gives duplicate columns now +@pytest.mark.pandas +@parametrize_legacy_dataset_not_supported +def test_pandas_column_selection(tempdir, use_legacy_dataset): + size = 10000 + np.random.seed(0) + df = pd.DataFrame({ + 'uint8': np.arange(size, dtype=np.uint8), + 'uint16': np.arange(size, dtype=np.uint16) + }) + filename = tempdir / 'pandas_roundtrip.parquet' + arrow_table = pa.Table.from_pandas(df) + _write_table(arrow_table, filename) + table_read = _read_table( + filename, columns=['uint8'], use_legacy_dataset=use_legacy_dataset) + df_read = table_read.to_pandas() + + tm.assert_frame_equal(df[['uint8']], df_read) + + # ARROW-4267: Selection of duplicate columns still leads to these columns + # being read uniquely. + table_read = _read_table( + filename, columns=['uint8', 'uint8'], + use_legacy_dataset=use_legacy_dataset) + df_read = table_read.to_pandas() + + tm.assert_frame_equal(df[['uint8']], df_read) + + +def _random_integers(size, dtype): + # We do not generate integers outside the int64 range + platform_int_info = np.iinfo('int_') + iinfo = np.iinfo(dtype) + return np.random.randint(max(iinfo.min, platform_int_info.min), + min(iinfo.max, platform_int_info.max), + size=size).astype(dtype) + + +def _test_dataframe(size=10000, seed=0): + np.random.seed(seed) + df = pd.DataFrame({ + 'uint8': _random_integers(size, np.uint8), + 'uint16': _random_integers(size, np.uint16), + 'uint32': _random_integers(size, np.uint32), + 'uint64': _random_integers(size, np.uint64), + 'int8': _random_integers(size, np.int8), + 'int16': _random_integers(size, np.int16), + 'int32': _random_integers(size, np.int32), + 'int64': _random_integers(size, np.int64), + 'float32': np.random.randn(size).astype(np.float32), + 'float64': np.arange(size, dtype=np.float64), + 'bool': np.random.randn(size) > 0, + 'strings': [util.rands(10) for i in range(size)], + 'all_none': [None] * size, + 'all_none_category': [None] * size + }) + # TODO(PARQUET-1015) + # df['all_none_category'] = df['all_none_category'].astype('category') + return df + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_pandas_parquet_native_file_roundtrip(tempdir, use_legacy_dataset): + df = _test_dataframe(10000) + arrow_table = pa.Table.from_pandas(df) + imos = pa.BufferOutputStream() + _write_table(arrow_table, imos, version="2.0") + buf = imos.getvalue() + reader = pa.BufferReader(buf) + df_read = _read_table( + reader, use_legacy_dataset=use_legacy_dataset).to_pandas() + tm.assert_frame_equal(df, df_read) + + +@parametrize_legacy_dataset +def test_parquet_read_from_buffer(tempdir, use_legacy_dataset): + # reading from a buffer from python's open() + table = pa.table({"a": [1, 2, 3]}) + pq.write_table(table, str(tempdir / "data.parquet")) + + with open(str(tempdir / "data.parquet"), "rb") as f: + result = pq.read_table(f, use_legacy_dataset=use_legacy_dataset) + assert result.equals(table) + + with open(str(tempdir / "data.parquet"), "rb") as f: + result = pq.read_table(pa.PythonFile(f), + use_legacy_dataset=use_legacy_dataset) + assert result.equals(table) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_parquet_incremental_file_build(tempdir, use_legacy_dataset): + df = _test_dataframe(100) + df['unique_id'] = 0 + + arrow_table = pa.Table.from_pandas(df, preserve_index=False) + out = pa.BufferOutputStream() + + writer = pq.ParquetWriter(out, arrow_table.schema, version='2.0') + + frames = [] + for i in range(10): + df['unique_id'] = i + arrow_table = pa.Table.from_pandas(df, preserve_index=False) + writer.write_table(arrow_table) + + frames.append(df.copy()) + + writer.close() + + buf = out.getvalue() + result = _read_table( + pa.BufferReader(buf), use_legacy_dataset=use_legacy_dataset) + + expected = pd.concat(frames, ignore_index=True) + tm.assert_frame_equal(result.to_pandas(), expected) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_read_pandas_column_subset(tempdir, use_legacy_dataset): + df = _test_dataframe(10000) + arrow_table = pa.Table.from_pandas(df) + imos = pa.BufferOutputStream() + _write_table(arrow_table, imos, version="2.0") + buf = imos.getvalue() + reader = pa.BufferReader(buf) + df_read = pq.read_pandas( + reader, columns=['strings', 'uint8'], + use_legacy_dataset=use_legacy_dataset + ).to_pandas() + tm.assert_frame_equal(df[['strings', 'uint8']], df_read) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_pandas_parquet_empty_roundtrip(tempdir, use_legacy_dataset): + df = _test_dataframe(0) + arrow_table = pa.Table.from_pandas(df) + imos = pa.BufferOutputStream() + _write_table(arrow_table, imos, version="2.0") + buf = imos.getvalue() + reader = pa.BufferReader(buf) + df_read = _read_table( + reader, use_legacy_dataset=use_legacy_dataset).to_pandas() + tm.assert_frame_equal(df, df_read) + + +@pytest.mark.pandas +def test_pandas_can_write_nested_data(tempdir): + data = { + "agg_col": [ + {"page_type": 1}, + {"record_type": 1}, + {"non_consecutive_home": 0}, + ], + "uid_first": "1001" + } + df = pd.DataFrame(data=data) + arrow_table = pa.Table.from_pandas(df) + imos = pa.BufferOutputStream() + # This succeeds under V2 + _write_table(arrow_table, imos) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_pandas_parquet_pyfile_roundtrip(tempdir, use_legacy_dataset): + filename = tempdir / 'pandas_pyfile_roundtrip.parquet' + size = 5 + df = pd.DataFrame({ + 'int64': np.arange(size, dtype=np.int64), + 'float32': np.arange(size, dtype=np.float32), + 'float64': np.arange(size, dtype=np.float64), + 'bool': np.random.randn(size) > 0, + 'strings': ['foo', 'bar', None, 'baz', 'qux'] + }) + + arrow_table = pa.Table.from_pandas(df) + + with filename.open('wb') as f: + _write_table(arrow_table, f, version="1.0") + + data = io.BytesIO(filename.read_bytes()) + + table_read = _read_table(data, use_legacy_dataset=use_legacy_dataset) + df_read = table_read.to_pandas() + tm.assert_frame_equal(df, df_read) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_pandas_parquet_configuration_options(tempdir, use_legacy_dataset): + size = 10000 + np.random.seed(0) + df = pd.DataFrame({ + 'uint8': np.arange(size, dtype=np.uint8), + 'uint16': np.arange(size, dtype=np.uint16), + 'uint32': np.arange(size, dtype=np.uint32), + 'uint64': np.arange(size, dtype=np.uint64), + 'int8': np.arange(size, dtype=np.int16), + 'int16': np.arange(size, dtype=np.int16), + 'int32': np.arange(size, dtype=np.int32), + 'int64': np.arange(size, dtype=np.int64), + 'float32': np.arange(size, dtype=np.float32), + 'float64': np.arange(size, dtype=np.float64), + 'bool': np.random.randn(size) > 0 + }) + filename = tempdir / 'pandas_roundtrip.parquet' + arrow_table = pa.Table.from_pandas(df) + + for use_dictionary in [True, False]: + _write_table(arrow_table, filename, version='2.0', + use_dictionary=use_dictionary) + table_read = _read_table( + filename, use_legacy_dataset=use_legacy_dataset) + df_read = table_read.to_pandas() + tm.assert_frame_equal(df, df_read) + + for write_statistics in [True, False]: + _write_table(arrow_table, filename, version='2.0', + write_statistics=write_statistics) + table_read = _read_table(filename, + use_legacy_dataset=use_legacy_dataset) + df_read = table_read.to_pandas() + tm.assert_frame_equal(df, df_read) + + for compression in ['NONE', 'SNAPPY', 'GZIP', 'LZ4', 'ZSTD']: + if (compression != 'NONE' and + not pa.lib.Codec.is_available(compression)): + continue + _write_table(arrow_table, filename, version='2.0', + compression=compression) + table_read = _read_table( + filename, use_legacy_dataset=use_legacy_dataset) + df_read = table_read.to_pandas() + tm.assert_frame_equal(df, df_read) + + +def make_sample_file(table_or_df): + if isinstance(table_or_df, pa.Table): + a_table = table_or_df + else: + a_table = pa.Table.from_pandas(table_or_df) + + buf = io.BytesIO() + _write_table(a_table, buf, compression='SNAPPY', version='2.0', + coerce_timestamps='ms') + + buf.seek(0) + return pq.ParquetFile(buf) + + +@parametrize_legacy_dataset +def test_byte_stream_split(use_legacy_dataset): + # This is only a smoke test. + arr_float = pa.array(list(map(float, range(100)))) + arr_int = pa.array(list(map(int, range(100)))) + data_float = [arr_float, arr_float] + table = pa.Table.from_arrays(data_float, names=['a', 'b']) + + # Check with byte_stream_split for both columns. + _check_roundtrip(table, expected=table, compression="gzip", + use_dictionary=False, use_byte_stream_split=True) + + # Check with byte_stream_split for column 'b' and dictionary + # for column 'a'. + _check_roundtrip(table, expected=table, compression="gzip", + use_dictionary=['a'], + use_byte_stream_split=['b']) + + # Check with a collision for both columns. + _check_roundtrip(table, expected=table, compression="gzip", + use_dictionary=['a', 'b'], + use_byte_stream_split=['a', 'b']) + + # Check with mixed column types. + mixed_table = pa.Table.from_arrays([arr_float, arr_int], + names=['a', 'b']) + _check_roundtrip(mixed_table, expected=mixed_table, + use_dictionary=['b'], + use_byte_stream_split=['a']) + + # Try to use the wrong data type with the byte_stream_split encoding. + # This should throw an exception. + table = pa.Table.from_arrays([arr_int], names=['tmp']) + with pytest.raises(IOError): + _check_roundtrip(table, expected=table, use_byte_stream_split=True, + use_dictionary=False, + use_legacy_dataset=use_legacy_dataset) + + +@parametrize_legacy_dataset +def test_compression_level(use_legacy_dataset): + arr = pa.array(list(map(int, range(1000)))) + data = [arr, arr] + table = pa.Table.from_arrays(data, names=['a', 'b']) + + # Check one compression level. + _check_roundtrip(table, expected=table, compression="gzip", + compression_level=1, + use_legacy_dataset=use_legacy_dataset) + + # Check another one to make sure that compression_level=1 does not + # coincide with the default one in Arrow. + _check_roundtrip(table, expected=table, compression="gzip", + compression_level=5, + use_legacy_dataset=use_legacy_dataset) + + # Check that the user can provide a compression per column + _check_roundtrip(table, expected=table, + compression={'a': "gzip", 'b': "snappy"}, + use_legacy_dataset=use_legacy_dataset) + + # Check that the user can provide a compression level per column + _check_roundtrip(table, expected=table, compression="gzip", + compression_level={'a': 2, 'b': 3}, + use_legacy_dataset=use_legacy_dataset) + + # Check that specifying a compression level for a codec which does allow + # specifying one, results into an error. + # Uncompressed, snappy, lz4 and lzo do not support specifying a compression + # level. + # GZIP (zlib) allows for specifying a compression level but as of up + # to version 1.2.11 the valid range is [-1, 9]. + invalid_combinations = [("snappy", 4), ("lz4", 5), ("gzip", -1337), + ("None", 444), ("lzo", 14)] + buf = io.BytesIO() + for (codec, level) in invalid_combinations: + with pytest.raises((ValueError, OSError)): + _write_table(table, buf, compression=codec, + compression_level=level) + + +@pytest.mark.pandas +def test_parquet_metadata_api(): + df = alltypes_sample(size=10000) + df = df.reindex(columns=sorted(df.columns)) + df.index = np.random.randint(0, 1000000, size=len(df)) + + fileh = make_sample_file(df) + ncols = len(df.columns) + + # Series of sniff tests + meta = fileh.metadata + repr(meta) + assert meta.num_rows == len(df) + assert meta.num_columns == ncols + 1 # +1 for index + assert meta.num_row_groups == 1 + assert meta.format_version == '2.0' + assert 'parquet-cpp' in meta.created_by + assert isinstance(meta.serialized_size, int) + assert isinstance(meta.metadata, dict) + + # Schema + schema = fileh.schema + assert meta.schema is schema + assert len(schema) == ncols + 1 # +1 for index + repr(schema) + + col = schema[0] + repr(col) + assert col.name == df.columns[0] + assert col.max_definition_level == 1 + assert col.max_repetition_level == 0 + assert col.max_repetition_level == 0 + + assert col.physical_type == 'BOOLEAN' + assert col.converted_type == 'NONE' + + with pytest.raises(IndexError): + schema[ncols + 1] # +1 for index + + with pytest.raises(IndexError): + schema[-1] + + # Row group + for rg in range(meta.num_row_groups): + rg_meta = meta.row_group(rg) + assert isinstance(rg_meta, pq.RowGroupMetaData) + repr(rg_meta) + + for col in range(rg_meta.num_columns): + col_meta = rg_meta.column(col) + assert isinstance(col_meta, pq.ColumnChunkMetaData) + repr(col_meta) + + with pytest.raises(IndexError): + meta.row_group(-1) + + with pytest.raises(IndexError): + meta.row_group(meta.num_row_groups + 1) + + rg_meta = meta.row_group(0) + assert rg_meta.num_rows == len(df) + assert rg_meta.num_columns == ncols + 1 # +1 for index + assert rg_meta.total_byte_size > 0 + + with pytest.raises(IndexError): + col_meta = rg_meta.column(-1) + + with pytest.raises(IndexError): + col_meta = rg_meta.column(ncols + 2) + + col_meta = rg_meta.column(0) + assert col_meta.file_offset > 0 + assert col_meta.file_path == '' # created from BytesIO + assert col_meta.physical_type == 'BOOLEAN' + assert col_meta.num_values == 10000 + assert col_meta.path_in_schema == 'bool' + assert col_meta.is_stats_set is True + assert isinstance(col_meta.statistics, pq.Statistics) + assert col_meta.compression == 'SNAPPY' + assert col_meta.encodings == ('PLAIN', 'RLE') + assert col_meta.has_dictionary_page is False + assert col_meta.dictionary_page_offset is None + assert col_meta.data_page_offset > 0 + assert col_meta.total_compressed_size > 0 + assert col_meta.total_uncompressed_size > 0 + with pytest.raises(NotImplementedError): + col_meta.has_index_page + with pytest.raises(NotImplementedError): + col_meta.index_page_offset + + +def test_parquet_metadata_lifetime(tempdir): + # ARROW-6642 - ensure that chained access keeps parent objects alive + table = pa.table({'a': [1, 2, 3]}) + pq.write_table(table, tempdir / 'test_metadata_segfault.parquet') + dataset = pq.ParquetDataset(tempdir / 'test_metadata_segfault.parquet') + dataset.pieces[0].get_metadata().row_group(0).column(0).statistics + + +@pytest.mark.pandas +@pytest.mark.parametrize( + ( + 'data', + 'type', + 'physical_type', + 'min_value', + 'max_value', + 'null_count', + 'num_values', + 'distinct_count' + ), + [ + ([1, 2, 2, None, 4], pa.uint8(), 'INT32', 1, 4, 1, 4, 0), + ([1, 2, 2, None, 4], pa.uint16(), 'INT32', 1, 4, 1, 4, 0), + ([1, 2, 2, None, 4], pa.uint32(), 'INT32', 1, 4, 1, 4, 0), + ([1, 2, 2, None, 4], pa.uint64(), 'INT64', 1, 4, 1, 4, 0), + ([-1, 2, 2, None, 4], pa.int8(), 'INT32', -1, 4, 1, 4, 0), + ([-1, 2, 2, None, 4], pa.int16(), 'INT32', -1, 4, 1, 4, 0), + ([-1, 2, 2, None, 4], pa.int32(), 'INT32', -1, 4, 1, 4, 0), + ([-1, 2, 2, None, 4], pa.int64(), 'INT64', -1, 4, 1, 4, 0), + ( + [-1.1, 2.2, 2.3, None, 4.4], pa.float32(), + 'FLOAT', -1.1, 4.4, 1, 4, 0 + ), + ( + [-1.1, 2.2, 2.3, None, 4.4], pa.float64(), + 'DOUBLE', -1.1, 4.4, 1, 4, 0 + ), + ( + ['', 'b', chr(1000), None, 'aaa'], pa.binary(), + 'BYTE_ARRAY', b'', chr(1000).encode('utf-8'), 1, 4, 0 + ), + ( + [True, False, False, True, True], pa.bool_(), + 'BOOLEAN', False, True, 0, 5, 0 + ), + ( + [b'\x00', b'b', b'12', None, b'aaa'], pa.binary(), + 'BYTE_ARRAY', b'\x00', b'b', 1, 4, 0 + ), + ] +) +def test_parquet_column_statistics_api(data, type, physical_type, min_value, + max_value, null_count, num_values, + distinct_count): + df = pd.DataFrame({'data': data}) + schema = pa.schema([pa.field('data', type)]) + table = pa.Table.from_pandas(df, schema=schema, safe=False) + fileh = make_sample_file(table) + + meta = fileh.metadata + + rg_meta = meta.row_group(0) + col_meta = rg_meta.column(0) + + stat = col_meta.statistics + assert stat.has_min_max + assert _close(type, stat.min, min_value) + assert _close(type, stat.max, max_value) + assert stat.null_count == null_count + assert stat.num_values == num_values + # TODO(kszucs) until parquet-cpp API doesn't expose HasDistinctCount + # method, missing distinct_count is represented as zero instead of None + assert stat.distinct_count == distinct_count + assert stat.physical_type == physical_type + + +# ARROW-6339 +@pytest.mark.pandas +def test_parquet_raise_on_unset_statistics(): + df = pd.DataFrame({"t": pd.Series([pd.NaT], dtype="datetime64[ns]")}) + meta = make_sample_file(pa.Table.from_pandas(df)).metadata + + assert not meta.row_group(0).column(0).statistics.has_min_max + assert meta.row_group(0).column(0).statistics.max is None + + +def _close(type, left, right): + if type == pa.float32(): + return abs(left - right) < 1E-7 + elif type == pa.float64(): + return abs(left - right) < 1E-13 + else: + return left == right + + +def test_statistics_convert_logical_types(tempdir): + # ARROW-5166, ARROW-4139 + + # (min, max, type) + cases = [(10, 11164359321221007157, pa.uint64()), + (10, 4294967295, pa.uint32()), + ("ähnlich", "öffentlich", pa.utf8()), + (datetime.time(10, 30, 0, 1000), datetime.time(15, 30, 0, 1000), + pa.time32('ms')), + (datetime.time(10, 30, 0, 1000), datetime.time(15, 30, 0, 1000), + pa.time64('us')), + (datetime.datetime(2019, 6, 24, 0, 0, 0, 1000), + datetime.datetime(2019, 6, 25, 0, 0, 0, 1000), + pa.timestamp('ms')), + (datetime.datetime(2019, 6, 24, 0, 0, 0, 1000), + datetime.datetime(2019, 6, 25, 0, 0, 0, 1000), + pa.timestamp('us'))] + + for i, (min_val, max_val, typ) in enumerate(cases): + t = pa.Table.from_arrays([pa.array([min_val, max_val], type=typ)], + ['col']) + path = str(tempdir / ('example{}.parquet'.format(i))) + pq.write_table(t, path, version='2.0') + pf = pq.ParquetFile(path) + stats = pf.metadata.row_group(0).column(0).statistics + assert stats.min == min_val + assert stats.max == max_val + + +def test_parquet_metadata_empty_to_dict(tempdir): + # https://issues.apache.org/jira/browse/ARROW-10146 + table = pa.table({"a": pa.array([], type="int64")}) + pq.write_table(table, tempdir / "data.parquet") + metadata = pq.read_metadata(tempdir / "data.parquet") + # ensure this doesn't error / statistics set to None + metadata_dict = metadata.to_dict() + assert len(metadata_dict["row_groups"]) == 1 + assert len(metadata_dict["row_groups"][0]["columns"]) == 1 + assert metadata_dict["row_groups"][0]["columns"][0]["statistics"] is None + + +def test_parquet_write_disable_statistics(tempdir): + table = pa.Table.from_pydict( + OrderedDict([ + ('a', pa.array([1, 2, 3])), + ('b', pa.array(['a', 'b', 'c'])) + ]) + ) + _write_table(table, tempdir / 'data.parquet') + meta = pq.read_metadata(tempdir / 'data.parquet') + for col in [0, 1]: + cc = meta.row_group(0).column(col) + assert cc.is_stats_set is True + assert cc.statistics is not None + + _write_table(table, tempdir / 'data2.parquet', write_statistics=False) + meta = pq.read_metadata(tempdir / 'data2.parquet') + for col in [0, 1]: + cc = meta.row_group(0).column(col) + assert cc.is_stats_set is False + assert cc.statistics is None + + _write_table(table, tempdir / 'data3.parquet', write_statistics=['a']) + meta = pq.read_metadata(tempdir / 'data3.parquet') + cc_a = meta.row_group(0).column(0) + cc_b = meta.row_group(0).column(1) + assert cc_a.is_stats_set is True + assert cc_b.is_stats_set is False + assert cc_a.statistics is not None + assert cc_b.statistics is None + + +@pytest.mark.pandas +def test_compare_schemas(): + df = alltypes_sample(size=10000) + + fileh = make_sample_file(df) + fileh2 = make_sample_file(df) + fileh3 = make_sample_file(df[df.columns[::2]]) + + # ParquetSchema + assert isinstance(fileh.schema, pq.ParquetSchema) + assert fileh.schema.equals(fileh.schema) + assert fileh.schema == fileh.schema + assert fileh.schema.equals(fileh2.schema) + assert fileh.schema == fileh2.schema + assert fileh.schema != 'arbitrary object' + assert not fileh.schema.equals(fileh3.schema) + assert fileh.schema != fileh3.schema + + # ColumnSchema + assert isinstance(fileh.schema[0], pq.ColumnSchema) + assert fileh.schema[0].equals(fileh.schema[0]) + assert fileh.schema[0] == fileh.schema[0] + assert not fileh.schema[0].equals(fileh.schema[1]) + assert fileh.schema[0] != fileh.schema[1] + assert fileh.schema[0] != 'arbitrary object' + + +def test_validate_schema_write_table(tempdir): + # ARROW-2926 + simple_fields = [ + pa.field('POS', pa.uint32()), + pa.field('desc', pa.string()) + ] + + simple_schema = pa.schema(simple_fields) + + # simple_table schema does not match simple_schema + simple_from_array = [pa.array([1]), pa.array(['bla'])] + simple_table = pa.Table.from_arrays(simple_from_array, ['POS', 'desc']) + + path = tempdir / 'simple_validate_schema.parquet' + + with pq.ParquetWriter(path, simple_schema, + version='2.0', + compression='snappy', flavor='spark') as w: + with pytest.raises(ValueError): + w.write_table(simple_table) + + +@pytest.mark.pandas +def test_column_of_arrays(tempdir): + df, schema = dataframe_with_arrays() + + filename = tempdir / 'pandas_roundtrip.parquet' + arrow_table = pa.Table.from_pandas(df, schema=schema) + _write_table(arrow_table, filename, version="2.0", coerce_timestamps='ms') + table_read = _read_table(filename) + df_read = table_read.to_pandas() + tm.assert_frame_equal(df, df_read) + + +@pytest.mark.pandas +def test_coerce_timestamps(tempdir): + from collections import OrderedDict + # ARROW-622 + arrays = OrderedDict() + fields = [pa.field('datetime64', + pa.list_(pa.timestamp('ms')))] + arrays['datetime64'] = [ + np.array(['2007-07-13T01:23:34.123456789', + None, + '2010-08-13T05:46:57.437699912'], + dtype='datetime64[ms]'), + None, + None, + np.array(['2007-07-13T02', + None, + '2010-08-13T05:46:57.437699912'], + dtype='datetime64[ms]'), + ] + + df = pd.DataFrame(arrays) + schema = pa.schema(fields) + + filename = tempdir / 'pandas_roundtrip.parquet' + arrow_table = pa.Table.from_pandas(df, schema=schema) + + _write_table(arrow_table, filename, version="2.0", coerce_timestamps='us') + table_read = _read_table(filename) + df_read = table_read.to_pandas() + + df_expected = df.copy() + for i, x in enumerate(df_expected['datetime64']): + if isinstance(x, np.ndarray): + df_expected['datetime64'][i] = x.astype('M8[us]') + + tm.assert_frame_equal(df_expected, df_read) + + with pytest.raises(ValueError): + _write_table(arrow_table, filename, version='2.0', + coerce_timestamps='unknown') + + +@pytest.mark.pandas +def test_coerce_timestamps_truncated(tempdir): + """ + ARROW-2555: Test that we can truncate timestamps when coercing if + explicitly allowed. + """ + dt_us = datetime.datetime(year=2017, month=1, day=1, hour=1, minute=1, + second=1, microsecond=1) + dt_ms = datetime.datetime(year=2017, month=1, day=1, hour=1, minute=1, + second=1) + + fields_us = [pa.field('datetime64', pa.timestamp('us'))] + arrays_us = {'datetime64': [dt_us, dt_ms]} + + df_us = pd.DataFrame(arrays_us) + schema_us = pa.schema(fields_us) + + filename = tempdir / 'pandas_truncated.parquet' + table_us = pa.Table.from_pandas(df_us, schema=schema_us) + + _write_table(table_us, filename, version="2.0", coerce_timestamps='ms', + allow_truncated_timestamps=True) + table_ms = _read_table(filename) + df_ms = table_ms.to_pandas() + + arrays_expected = {'datetime64': [dt_ms, dt_ms]} + df_expected = pd.DataFrame(arrays_expected) + tm.assert_frame_equal(df_expected, df_ms) + + +@pytest.mark.pandas +def test_column_of_lists(tempdir): + df, schema = dataframe_with_lists(parquet_compatible=True) + + filename = tempdir / 'pandas_roundtrip.parquet' + arrow_table = pa.Table.from_pandas(df, schema=schema) + _write_table(arrow_table, filename, version='2.0') + table_read = _read_table(filename) + df_read = table_read.to_pandas() + + tm.assert_frame_equal(df, df_read) + + +@pytest.mark.pandas +def test_date_time_types(tempdir): + t1 = pa.date32() + data1 = np.array([17259, 17260, 17261], dtype='int32') + a1 = pa.array(data1, type=t1) + + t2 = pa.date64() + data2 = data1.astype('int64') * 86400000 + a2 = pa.array(data2, type=t2) + + t3 = pa.timestamp('us') + start = pd.Timestamp('2001-01-01').value / 1000 + data3 = np.array([start, start + 1, start + 2], dtype='int64') + a3 = pa.array(data3, type=t3) + + t4 = pa.time32('ms') + data4 = np.arange(3, dtype='i4') + a4 = pa.array(data4, type=t4) + + t5 = pa.time64('us') + a5 = pa.array(data4.astype('int64'), type=t5) + + t6 = pa.time32('s') + a6 = pa.array(data4, type=t6) + + ex_t6 = pa.time32('ms') + ex_a6 = pa.array(data4 * 1000, type=ex_t6) + + t7 = pa.timestamp('ns') + start = pd.Timestamp('2001-01-01').value + data7 = np.array([start, start + 1000, start + 2000], + dtype='int64') + a7 = pa.array(data7, type=t7) + + table = pa.Table.from_arrays([a1, a2, a3, a4, a5, a6, a7], + ['date32', 'date64', 'timestamp[us]', + 'time32[s]', 'time64[us]', + 'time32_from64[s]', + 'timestamp[ns]']) + + # date64 as date32 + # time32[s] to time32[ms] + expected = pa.Table.from_arrays([a1, a1, a3, a4, a5, ex_a6, a7], + ['date32', 'date64', 'timestamp[us]', + 'time32[s]', 'time64[us]', + 'time32_from64[s]', + 'timestamp[ns]']) + + _check_roundtrip(table, expected=expected, version='2.0') + + t0 = pa.timestamp('ms') + data0 = np.arange(4, dtype='int64') + a0 = pa.array(data0, type=t0) + + t1 = pa.timestamp('us') + data1 = np.arange(4, dtype='int64') + a1 = pa.array(data1, type=t1) + + t2 = pa.timestamp('ns') + data2 = np.arange(4, dtype='int64') + a2 = pa.array(data2, type=t2) + + table = pa.Table.from_arrays([a0, a1, a2], + ['ts[ms]', 'ts[us]', 'ts[ns]']) + expected = pa.Table.from_arrays([a0, a1, a2], + ['ts[ms]', 'ts[us]', 'ts[ns]']) + + # int64 for all timestamps supported by default + filename = tempdir / 'int64_timestamps.parquet' + _write_table(table, filename, version='2.0') + parquet_schema = pq.ParquetFile(filename).schema + for i in range(3): + assert parquet_schema.column(i).physical_type == 'INT64' + read_table = _read_table(filename) + assert read_table.equals(expected) + + t0_ns = pa.timestamp('ns') + data0_ns = np.array(data0 * 1000000, dtype='int64') + a0_ns = pa.array(data0_ns, type=t0_ns) + + t1_ns = pa.timestamp('ns') + data1_ns = np.array(data1 * 1000, dtype='int64') + a1_ns = pa.array(data1_ns, type=t1_ns) + + expected = pa.Table.from_arrays([a0_ns, a1_ns, a2], + ['ts[ms]', 'ts[us]', 'ts[ns]']) + + # int96 nanosecond timestamps produced upon request + filename = tempdir / 'explicit_int96_timestamps.parquet' + _write_table(table, filename, version='2.0', + use_deprecated_int96_timestamps=True) + parquet_schema = pq.ParquetFile(filename).schema + for i in range(3): + assert parquet_schema.column(i).physical_type == 'INT96' + read_table = _read_table(filename) + assert read_table.equals(expected) + + # int96 nanosecond timestamps implied by flavor 'spark' + filename = tempdir / 'spark_int96_timestamps.parquet' + _write_table(table, filename, version='2.0', + flavor='spark') + parquet_schema = pq.ParquetFile(filename).schema + for i in range(3): + assert parquet_schema.column(i).physical_type == 'INT96' + read_table = _read_table(filename) + assert read_table.equals(expected) + + +def test_timestamp_restore_timezone(): + # ARROW-5888, restore timezone from serialized metadata + ty = pa.timestamp('ms', tz='America/New_York') + arr = pa.array([1, 2, 3], type=ty) + t = pa.table([arr], names=['f0']) + _check_roundtrip(t) + + +@pytest.mark.pandas +def test_list_of_datetime_time_roundtrip(): + # ARROW-4135 + times = pd.to_datetime(['09:00', '09:30', '10:00', '10:30', '11:00', + '11:30', '12:00']) + df = pd.DataFrame({'time': [times.time]}) + _roundtrip_pandas_dataframe(df, write_kwargs={}) + + +@pytest.mark.pandas +def test_parquet_version_timestamp_differences(): + i_s = pd.Timestamp('2010-01-01').value / 1000000000 # := 1262304000 + + d_s = np.arange(i_s, i_s + 10, 1, dtype='int64') + d_ms = d_s * 1000 + d_us = d_ms * 1000 + d_ns = d_us * 1000 + + a_s = pa.array(d_s, type=pa.timestamp('s')) + a_ms = pa.array(d_ms, type=pa.timestamp('ms')) + a_us = pa.array(d_us, type=pa.timestamp('us')) + a_ns = pa.array(d_ns, type=pa.timestamp('ns')) + + names = ['ts:s', 'ts:ms', 'ts:us', 'ts:ns'] + table = pa.Table.from_arrays([a_s, a_ms, a_us, a_ns], names) + + # Using Parquet version 1.0, seconds should be coerced to milliseconds + # and nanoseconds should be coerced to microseconds by default + expected = pa.Table.from_arrays([a_ms, a_ms, a_us, a_us], names) + _check_roundtrip(table, expected) + + # Using Parquet version 2.0, seconds should be coerced to milliseconds + # and nanoseconds should be retained by default + expected = pa.Table.from_arrays([a_ms, a_ms, a_us, a_ns], names) + _check_roundtrip(table, expected, version='2.0') + + # Using Parquet version 1.0, coercing to milliseconds or microseconds + # is allowed + expected = pa.Table.from_arrays([a_ms, a_ms, a_ms, a_ms], names) + _check_roundtrip(table, expected, coerce_timestamps='ms') + + # Using Parquet version 2.0, coercing to milliseconds or microseconds + # is allowed + expected = pa.Table.from_arrays([a_us, a_us, a_us, a_us], names) + _check_roundtrip(table, expected, version='2.0', coerce_timestamps='us') + + # TODO: after pyarrow allows coerce_timestamps='ns', tests like the + # following should pass ... + + # Using Parquet version 1.0, coercing to nanoseconds is not allowed + # expected = None + # with pytest.raises(NotImplementedError): + # _roundtrip_table(table, coerce_timestamps='ns') + + # Using Parquet version 2.0, coercing to nanoseconds is allowed + # expected = pa.Table.from_arrays([a_ns, a_ns, a_ns, a_ns], names) + # _check_roundtrip(table, expected, version='2.0', coerce_timestamps='ns') + + # For either Parquet version, coercing to nanoseconds is allowed + # if Int96 storage is used + expected = pa.Table.from_arrays([a_ns, a_ns, a_ns, a_ns], names) + _check_roundtrip(table, expected, + use_deprecated_int96_timestamps=True) + _check_roundtrip(table, expected, version='2.0', + use_deprecated_int96_timestamps=True) + + +def test_large_list_records(): + # This was fixed in PARQUET-1100 + + list_lengths = np.random.randint(0, 500, size=50) + list_lengths[::10] = 0 + + list_values = [list(map(int, np.random.randint(0, 100, size=x))) + if i % 8 else None + for i, x in enumerate(list_lengths)] + + a1 = pa.array(list_values) + + table = pa.Table.from_arrays([a1], ['int_lists']) + _check_roundtrip(table) + + +def test_sanitized_spark_field_names(): + a0 = pa.array([0, 1, 2, 3, 4]) + name = 'prohib; ,\t{}' + table = pa.Table.from_arrays([a0], [name]) + + result = _roundtrip_table(table, write_table_kwargs={'flavor': 'spark'}) + + expected_name = 'prohib______' + assert result.schema[0].name == expected_name + + +@pytest.mark.pandas +def test_spark_flavor_preserves_pandas_metadata(): + df = _test_dataframe(size=100) + df.index = np.arange(0, 10 * len(df), 10) + df.index.name = 'foo' + + result = _roundtrip_pandas_dataframe(df, {'version': '2.0', + 'flavor': 'spark'}) + tm.assert_frame_equal(result, df) + + +def test_fixed_size_binary(): + t0 = pa.binary(10) + data = [b'fooooooooo', None, b'barooooooo', b'quxooooooo'] + a0 = pa.array(data, type=t0) + + table = pa.Table.from_arrays([a0], + ['binary[10]']) + _check_roundtrip(table) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_multithreaded_read(use_legacy_dataset): + df = alltypes_sample(size=10000) + + table = pa.Table.from_pandas(df) + + buf = io.BytesIO() + _write_table(table, buf, compression='SNAPPY', version='2.0') + + buf.seek(0) + table1 = _read_table( + buf, use_threads=True, use_legacy_dataset=use_legacy_dataset) + + buf.seek(0) + table2 = _read_table( + buf, use_threads=False, use_legacy_dataset=use_legacy_dataset) + + assert table1.equals(table2) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_min_chunksize(use_legacy_dataset): + data = pd.DataFrame([np.arange(4)], columns=['A', 'B', 'C', 'D']) + table = pa.Table.from_pandas(data.reset_index()) + + buf = io.BytesIO() + _write_table(table, buf, chunk_size=-1) + + buf.seek(0) + result = _read_table(buf, use_legacy_dataset=use_legacy_dataset) + + assert result.equals(table) + + with pytest.raises(ValueError): + _write_table(table, buf, chunk_size=0) + + +@pytest.mark.pandas +def test_pass_separate_metadata(): + # ARROW-471 + df = alltypes_sample(size=10000) + + a_table = pa.Table.from_pandas(df) + + buf = io.BytesIO() + _write_table(a_table, buf, compression='snappy', version='2.0') + + buf.seek(0) + metadata = pq.read_metadata(buf) + + buf.seek(0) + + fileh = pq.ParquetFile(buf, metadata=metadata) + + tm.assert_frame_equal(df, fileh.read().to_pandas()) + + +@pytest.mark.pandas +def test_read_single_row_group(): + # ARROW-471 + N, K = 10000, 4 + df = alltypes_sample(size=N) + + a_table = pa.Table.from_pandas(df) + + buf = io.BytesIO() + _write_table(a_table, buf, row_group_size=N / K, + compression='snappy', version='2.0') + + buf.seek(0) + + pf = pq.ParquetFile(buf) + + assert pf.num_row_groups == K + + row_groups = [pf.read_row_group(i) for i in range(K)] + result = pa.concat_tables(row_groups) + tm.assert_frame_equal(df, result.to_pandas()) + + +@pytest.mark.pandas +def test_read_single_row_group_with_column_subset(): + N, K = 10000, 4 + df = alltypes_sample(size=N) + a_table = pa.Table.from_pandas(df) + + buf = io.BytesIO() + _write_table(a_table, buf, row_group_size=N / K, + compression='snappy', version='2.0') + + buf.seek(0) + pf = pq.ParquetFile(buf) + + cols = list(df.columns[:2]) + row_groups = [pf.read_row_group(i, columns=cols) for i in range(K)] + result = pa.concat_tables(row_groups) + tm.assert_frame_equal(df[cols], result.to_pandas()) + + # ARROW-4267: Selection of duplicate columns still leads to these columns + # being read uniquely. + row_groups = [pf.read_row_group(i, columns=cols + cols) for i in range(K)] + result = pa.concat_tables(row_groups) + tm.assert_frame_equal(df[cols], result.to_pandas()) + + +@pytest.mark.pandas +def test_read_multiple_row_groups(): + N, K = 10000, 4 + df = alltypes_sample(size=N) + + a_table = pa.Table.from_pandas(df) + + buf = io.BytesIO() + _write_table(a_table, buf, row_group_size=N / K, + compression='snappy', version='2.0') + + buf.seek(0) + + pf = pq.ParquetFile(buf) + + assert pf.num_row_groups == K + + result = pf.read_row_groups(range(K)) + tm.assert_frame_equal(df, result.to_pandas()) + + +@pytest.mark.pandas +def test_read_multiple_row_groups_with_column_subset(): + N, K = 10000, 4 + df = alltypes_sample(size=N) + a_table = pa.Table.from_pandas(df) + + buf = io.BytesIO() + _write_table(a_table, buf, row_group_size=N / K, + compression='snappy', version='2.0') + + buf.seek(0) + pf = pq.ParquetFile(buf) + + cols = list(df.columns[:2]) + result = pf.read_row_groups(range(K), columns=cols) + tm.assert_frame_equal(df[cols], result.to_pandas()) + + # ARROW-4267: Selection of duplicate columns still leads to these columns + # being read uniquely. + result = pf.read_row_groups(range(K), columns=cols + cols) + tm.assert_frame_equal(df[cols], result.to_pandas()) + + +@pytest.mark.pandas +def test_scan_contents(): + N, K = 10000, 4 + df = alltypes_sample(size=N) + a_table = pa.Table.from_pandas(df) + + buf = io.BytesIO() + _write_table(a_table, buf, row_group_size=N / K, + compression='snappy', version='2.0') + + buf.seek(0) + pf = pq.ParquetFile(buf) + + assert pf.scan_contents() == 10000 + assert pf.scan_contents(df.columns[:4]) == 10000 + + +@pytest.mark.pandas +def test_parquet_piece_read(tempdir): + df = _test_dataframe(1000) + table = pa.Table.from_pandas(df) + + path = tempdir / 'parquet_piece_read.parquet' + _write_table(table, path, version='2.0') + + piece1 = pq.ParquetDatasetPiece(path) + + result = piece1.read() + assert result.equals(table) + + +@pytest.mark.pandas +def test_parquet_piece_open_and_get_metadata(tempdir): + df = _test_dataframe(100) + table = pa.Table.from_pandas(df) + + path = tempdir / 'parquet_piece_read.parquet' + _write_table(table, path, version='2.0') + + piece = pq.ParquetDatasetPiece(path) + table1 = piece.read() + assert isinstance(table1, pa.Table) + meta1 = piece.get_metadata() + assert isinstance(meta1, pq.FileMetaData) + + assert table.equals(table1) + + +def test_parquet_piece_basics(): + path = '/baz.parq' + + piece1 = pq.ParquetDatasetPiece(path) + piece2 = pq.ParquetDatasetPiece(path, row_group=1) + piece3 = pq.ParquetDatasetPiece( + path, row_group=1, partition_keys=[('foo', 0), ('bar', 1)]) + + assert str(piece1) == path + assert str(piece2) == '/baz.parq | row_group=1' + assert str(piece3) == 'partition[foo=0, bar=1] /baz.parq | row_group=1' + + assert piece1 == piece1 + assert piece2 == piece2 + assert piece3 == piece3 + assert piece1 != piece3 + + +def test_partition_set_dictionary_type(): + set1 = pq.PartitionSet('key1', ['foo', 'bar', 'baz']) + set2 = pq.PartitionSet('key2', [2007, 2008, 2009]) + + assert isinstance(set1.dictionary, pa.StringArray) + assert isinstance(set2.dictionary, pa.IntegerArray) + + set3 = pq.PartitionSet('key2', [datetime.datetime(2007, 1, 1)]) + with pytest.raises(TypeError): + set3.dictionary + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_read_partitioned_directory(tempdir, use_legacy_dataset): + fs = LocalFileSystem._get_instance() + _partition_test_for_filesystem(fs, tempdir, use_legacy_dataset) + + +@pytest.mark.pandas +def test_create_parquet_dataset_multi_threaded(tempdir): + fs = LocalFileSystem._get_instance() + base_path = tempdir + + _partition_test_for_filesystem(fs, base_path) + + manifest = pq.ParquetManifest(base_path, filesystem=fs, + metadata_nthreads=1) + dataset = pq.ParquetDataset(base_path, filesystem=fs, metadata_nthreads=16) + assert len(dataset.pieces) > 0 + partitions = dataset.partitions + assert len(partitions.partition_names) > 0 + assert partitions.partition_names == manifest.partitions.partition_names + assert len(partitions.levels) == len(manifest.partitions.levels) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_read_partitioned_columns_selection(tempdir, use_legacy_dataset): + # ARROW-3861 - do not include partition columns in resulting table when + # `columns` keyword was passed without those columns + fs = LocalFileSystem._get_instance() + base_path = tempdir + _partition_test_for_filesystem(fs, base_path) + + dataset = pq.ParquetDataset( + base_path, use_legacy_dataset=use_legacy_dataset) + result = dataset.read(columns=["values"]) + if use_legacy_dataset: + # ParquetDataset implementation always includes the partition columns + # automatically, and we can't easily "fix" this since dask relies on + # this behaviour (ARROW-8644) + assert result.column_names == ["values", "foo", "bar"] + else: + assert result.column_names == ["values"] + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_filters_equivalency(tempdir, use_legacy_dataset): + fs = LocalFileSystem._get_instance() + base_path = tempdir + + integer_keys = [0, 1] + string_keys = ['a', 'b', 'c'] + boolean_keys = [True, False] + partition_spec = [ + ['integer', integer_keys], + ['string', string_keys], + ['boolean', boolean_keys] + ] + + df = pd.DataFrame({ + 'integer': np.array(integer_keys, dtype='i4').repeat(15), + 'string': np.tile(np.tile(np.array(string_keys, dtype=object), 5), 2), + 'boolean': np.tile(np.tile(np.array(boolean_keys, dtype='bool'), 5), + 3), + }, columns=['integer', 'string', 'boolean']) + + _generate_partition_directories(fs, base_path, partition_spec, df) + + # Old filters syntax: + # integer == 1 AND string != b AND boolean == True + dataset = pq.ParquetDataset( + base_path, filesystem=fs, + filters=[('integer', '=', 1), ('string', '!=', 'b'), + ('boolean', '==', True)], + use_legacy_dataset=use_legacy_dataset, + ) + table = dataset.read() + result_df = (table.to_pandas().reset_index(drop=True)) + + assert 0 not in result_df['integer'].values + assert 'b' not in result_df['string'].values + assert False not in result_df['boolean'].values + + # filters in disjunctive normal form: + # (integer == 1 AND string != b AND boolean == True) OR + # (integer == 2 AND boolean == False) + # TODO(ARROW-3388): boolean columns are reconstructed as string + filters = [ + [ + ('integer', '=', 1), + ('string', '!=', 'b'), + ('boolean', '==', 'True') + ], + [('integer', '=', 0), ('boolean', '==', 'False')] + ] + dataset = pq.ParquetDataset( + base_path, filesystem=fs, filters=filters, + use_legacy_dataset=use_legacy_dataset) + table = dataset.read() + result_df = table.to_pandas().reset_index(drop=True) + + # Check that all rows in the DF fulfill the filter + # Pandas 0.23.x has problems with indexing constant memoryviews in + # categoricals. Thus we need to make an explicit copy here with np.array. + df_filter_1 = (np.array(result_df['integer']) == 1) \ + & (np.array(result_df['string']) != 'b') \ + & (np.array(result_df['boolean']) == 'True') + df_filter_2 = (np.array(result_df['integer']) == 0) \ + & (np.array(result_df['boolean']) == 'False') + assert df_filter_1.sum() > 0 + assert df_filter_2.sum() > 0 + assert result_df.shape[0] == (df_filter_1.sum() + df_filter_2.sum()) + + if use_legacy_dataset: + # Check for \0 in predicate values. Until they are correctly + # implemented in ARROW-3391, they would otherwise lead to weird + # results with the current code. + with pytest.raises(NotImplementedError): + filters = [[('string', '==', b'1\0a')]] + pq.ParquetDataset(base_path, filesystem=fs, filters=filters) + with pytest.raises(NotImplementedError): + filters = [[('string', '==', '1\0a')]] + pq.ParquetDataset(base_path, filesystem=fs, filters=filters) + else: + for filters in [[[('string', '==', b'1\0a')]], + [[('string', '==', '1\0a')]]]: + dataset = pq.ParquetDataset( + base_path, filesystem=fs, filters=filters, + use_legacy_dataset=False) + assert dataset.read().num_rows == 0 + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_filters_cutoff_exclusive_integer(tempdir, use_legacy_dataset): + fs = LocalFileSystem._get_instance() + base_path = tempdir + + integer_keys = [0, 1, 2, 3, 4] + partition_spec = [ + ['integers', integer_keys], + ] + N = 5 + + df = pd.DataFrame({ + 'index': np.arange(N), + 'integers': np.array(integer_keys, dtype='i4'), + }, columns=['index', 'integers']) + + _generate_partition_directories(fs, base_path, partition_spec, df) + + dataset = pq.ParquetDataset( + base_path, filesystem=fs, + filters=[ + ('integers', '<', 4), + ('integers', '>', 1), + ], + use_legacy_dataset=use_legacy_dataset + ) + table = dataset.read() + result_df = (table.to_pandas() + .sort_values(by='index') + .reset_index(drop=True)) + + result_list = [x for x in map(int, result_df['integers'].values)] + assert result_list == [2, 3] + + +@pytest.mark.pandas +@parametrize_legacy_dataset +@pytest.mark.xfail( + # different error with use_legacy_datasets because result_df is no longer + # categorical + raises=(TypeError, AssertionError), + reason='Loss of type information in creation of categoricals.' +) +def test_filters_cutoff_exclusive_datetime(tempdir, use_legacy_dataset): + fs = LocalFileSystem._get_instance() + base_path = tempdir + + date_keys = [ + datetime.date(2018, 4, 9), + datetime.date(2018, 4, 10), + datetime.date(2018, 4, 11), + datetime.date(2018, 4, 12), + datetime.date(2018, 4, 13) + ] + partition_spec = [ + ['dates', date_keys] + ] + N = 5 + + df = pd.DataFrame({ + 'index': np.arange(N), + 'dates': np.array(date_keys, dtype='datetime64'), + }, columns=['index', 'dates']) + + _generate_partition_directories(fs, base_path, partition_spec, df) + + dataset = pq.ParquetDataset( + base_path, filesystem=fs, + filters=[ + ('dates', '<', "2018-04-12"), + ('dates', '>', "2018-04-10") + ], + use_legacy_dataset=use_legacy_dataset + ) + table = dataset.read() + result_df = (table.to_pandas() + .sort_values(by='index') + .reset_index(drop=True)) + + expected = pd.Categorical( + np.array([datetime.date(2018, 4, 11)], dtype='datetime64'), + categories=np.array(date_keys, dtype='datetime64')) + + assert result_df['dates'].values == expected + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_filters_inclusive_integer(tempdir, use_legacy_dataset): + fs = LocalFileSystem._get_instance() + base_path = tempdir + + integer_keys = [0, 1, 2, 3, 4] + partition_spec = [ + ['integers', integer_keys], + ] + N = 5 + + df = pd.DataFrame({ + 'index': np.arange(N), + 'integers': np.array(integer_keys, dtype='i4'), + }, columns=['index', 'integers']) + + _generate_partition_directories(fs, base_path, partition_spec, df) + + dataset = pq.ParquetDataset( + base_path, filesystem=fs, + filters=[ + ('integers', '<=', 3), + ('integers', '>=', 2), + ], + use_legacy_dataset=use_legacy_dataset + ) + table = dataset.read() + result_df = (table.to_pandas() + .sort_values(by='index') + .reset_index(drop=True)) + + result_list = [int(x) for x in map(int, result_df['integers'].values)] + assert result_list == [2, 3] + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_filters_inclusive_set(tempdir, use_legacy_dataset): + fs = LocalFileSystem._get_instance() + base_path = tempdir + + integer_keys = [0, 1] + string_keys = ['a', 'b', 'c'] + boolean_keys = [True, False] + partition_spec = [ + ['integer', integer_keys], + ['string', string_keys], + ['boolean', boolean_keys] + ] + + df = pd.DataFrame({ + 'integer': np.array(integer_keys, dtype='i4').repeat(15), + 'string': np.tile(np.tile(np.array(string_keys, dtype=object), 5), 2), + 'boolean': np.tile(np.tile(np.array(boolean_keys, dtype='bool'), 5), + 3), + }, columns=['integer', 'string', 'boolean']) + + _generate_partition_directories(fs, base_path, partition_spec, df) + + dataset = pq.ParquetDataset( + base_path, filesystem=fs, + filters=[('integer', 'in', {1}), ('string', 'in', {'a', 'b'}), + ('boolean', 'in', {True})], + use_legacy_dataset=use_legacy_dataset + ) + table = dataset.read() + result_df = (table.to_pandas().reset_index(drop=True)) + + assert 0 not in result_df['integer'].values + assert 'c' not in result_df['string'].values + assert False not in result_df['boolean'].values + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_filters_invalid_pred_op(tempdir, use_legacy_dataset): + fs = LocalFileSystem._get_instance() + base_path = tempdir + + integer_keys = [0, 1, 2, 3, 4] + partition_spec = [ + ['integers', integer_keys], + ] + N = 5 + + df = pd.DataFrame({ + 'index': np.arange(N), + 'integers': np.array(integer_keys, dtype='i4'), + }, columns=['index', 'integers']) + + _generate_partition_directories(fs, base_path, partition_spec, df) + + with pytest.raises(ValueError): + pq.ParquetDataset(base_path, + filesystem=fs, + filters=[('integers', '=<', 3), ], + use_legacy_dataset=use_legacy_dataset) + + if use_legacy_dataset: + with pytest.raises(ValueError): + pq.ParquetDataset(base_path, + filesystem=fs, + filters=[('integers', 'in', set()), ], + use_legacy_dataset=use_legacy_dataset) + else: + # Dataset API returns empty table instead + dataset = pq.ParquetDataset(base_path, + filesystem=fs, + filters=[('integers', 'in', set()), ], + use_legacy_dataset=use_legacy_dataset) + assert dataset.read().num_rows == 0 + + with pytest.raises(ValueError): + pq.ParquetDataset(base_path, + filesystem=fs, + filters=[('integers', '!=', {3})], + use_legacy_dataset=use_legacy_dataset) + + +@pytest.mark.pandas +@parametrize_legacy_dataset_fixed +def test_filters_invalid_column(tempdir, use_legacy_dataset): + # ARROW-5572 - raise error on invalid name in filter specification + # works with new dataset / xfail with legacy implementation + fs = LocalFileSystem._get_instance() + base_path = tempdir + + integer_keys = [0, 1, 2, 3, 4] + partition_spec = [['integers', integer_keys]] + N = 5 + + df = pd.DataFrame({ + 'index': np.arange(N), + 'integers': np.array(integer_keys, dtype='i4'), + }, columns=['index', 'integers']) + + _generate_partition_directories(fs, base_path, partition_spec, df) + + msg = r"No match for FieldRef.Name\(non_existent_column\)" + with pytest.raises(ValueError, match=msg): + pq.ParquetDataset(base_path, filesystem=fs, + filters=[('non_existent_column', '<', 3), ], + use_legacy_dataset=use_legacy_dataset).read() + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_filters_read_table(tempdir, use_legacy_dataset): + # test that filters keyword is passed through in read_table + fs = LocalFileSystem._get_instance() + base_path = tempdir + + integer_keys = [0, 1, 2, 3, 4] + partition_spec = [ + ['integers', integer_keys], + ] + N = 5 + + df = pd.DataFrame({ + 'index': np.arange(N), + 'integers': np.array(integer_keys, dtype='i4'), + }, columns=['index', 'integers']) + + _generate_partition_directories(fs, base_path, partition_spec, df) + + table = pq.read_table( + base_path, filesystem=fs, filters=[('integers', '<', 3)], + use_legacy_dataset=use_legacy_dataset) + assert table.num_rows == 3 + + table = pq.read_table( + base_path, filesystem=fs, filters=[[('integers', '<', 3)]], + use_legacy_dataset=use_legacy_dataset) + assert table.num_rows == 3 + + table = pq.read_pandas( + base_path, filters=[('integers', '<', 3)], + use_legacy_dataset=use_legacy_dataset) + assert table.num_rows == 3 + + +@pytest.mark.pandas +@parametrize_legacy_dataset_fixed +def test_partition_keys_with_underscores(tempdir, use_legacy_dataset): + # ARROW-5666 - partition field values with underscores preserve underscores + # xfail with legacy dataset -> they get interpreted as integers + fs = LocalFileSystem._get_instance() + base_path = tempdir + + string_keys = ["2019_2", "2019_3"] + partition_spec = [ + ['year_week', string_keys], + ] + N = 2 + + df = pd.DataFrame({ + 'index': np.arange(N), + 'year_week': np.array(string_keys, dtype='object'), + }, columns=['index', 'year_week']) + + _generate_partition_directories(fs, base_path, partition_spec, df) + + dataset = pq.ParquetDataset( + base_path, use_legacy_dataset=use_legacy_dataset) + result = dataset.read() + assert result.column("year_week").to_pylist() == string_keys + + +@pytest.fixture +def s3_bucket(request, s3_connection, s3_server): + boto3 = pytest.importorskip('boto3') + botocore = pytest.importorskip('botocore') + + host, port, access_key, secret_key = s3_connection + s3 = boto3.resource( + 's3', + endpoint_url='http://{}:{}'.format(host, port), + aws_access_key_id=access_key, + aws_secret_access_key=secret_key, + config=botocore.client.Config(signature_version='s3v4'), + region_name='us-east-1' + ) + bucket = s3.Bucket('test-s3fs') + try: + bucket.create() + except Exception: + # we get BucketAlreadyOwnedByYou error with fsspec handler + pass + return 'test-s3fs' + + +@pytest.fixture +def s3_example_s3fs(s3_connection, s3_server, s3_bucket): + s3fs = pytest.importorskip('s3fs') + + host, port, access_key, secret_key = s3_connection + fs = s3fs.S3FileSystem( + key=access_key, + secret=secret_key, + client_kwargs={ + 'endpoint_url': 'http://{}:{}'.format(host, port) + } + ) + + test_path = '{}/{}'.format(s3_bucket, guid()) + + fs.mkdir(test_path) + yield fs, test_path + try: + fs.rm(test_path, recursive=True) + except FileNotFoundError: + pass + + +@parametrize_legacy_dataset +def test_read_s3fs(s3_example_s3fs, use_legacy_dataset): + fs, path = s3_example_s3fs + path = path + "/test.parquet" + table = pa.table({"a": [1, 2, 3]}) + _write_table(table, path, filesystem=fs) + + result = _read_table( + path, filesystem=fs, use_legacy_dataset=use_legacy_dataset + ) + assert result.equals(table) + + +@parametrize_legacy_dataset +def test_read_directory_s3fs(s3_example_s3fs, use_legacy_dataset): + fs, directory = s3_example_s3fs + path = directory + "/test.parquet" + table = pa.table({"a": [1, 2, 3]}) + _write_table(table, path, filesystem=fs) + + result = _read_table( + directory, filesystem=fs, use_legacy_dataset=use_legacy_dataset + ) + assert result.equals(table) + + +@pytest.mark.pandas +@pytest.mark.s3 +@parametrize_legacy_dataset +def test_read_partitioned_directory_s3fs_wrapper( + s3_example_s3fs, use_legacy_dataset +): + from pyarrow.filesystem import S3FSWrapper + import s3fs + + if s3fs.__version__ >= LooseVersion("0.5"): + pytest.skip("S3FSWrapper no longer working for s3fs 0.5+") + + fs, path = s3_example_s3fs + with pytest.warns(DeprecationWarning): + wrapper = S3FSWrapper(fs) + _partition_test_for_filesystem(wrapper, path) + + # Check that we can auto-wrap + dataset = pq.ParquetDataset( + path, filesystem=fs, use_legacy_dataset=use_legacy_dataset + ) + dataset.read() + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_read_partitioned_directory_s3fs(s3_example_s3fs, use_legacy_dataset): + fs, path = s3_example_s3fs + _partition_test_for_filesystem( + fs, path, use_legacy_dataset=use_legacy_dataset + ) + + +def _partition_test_for_filesystem(fs, base_path, use_legacy_dataset=True): + foo_keys = [0, 1] + bar_keys = ['a', 'b', 'c'] + partition_spec = [ + ['foo', foo_keys], + ['bar', bar_keys] + ] + N = 30 + + df = pd.DataFrame({ + 'index': np.arange(N), + 'foo': np.array(foo_keys, dtype='i4').repeat(15), + 'bar': np.tile(np.tile(np.array(bar_keys, dtype=object), 5), 2), + 'values': np.random.randn(N) + }, columns=['index', 'foo', 'bar', 'values']) + + _generate_partition_directories(fs, base_path, partition_spec, df) + + dataset = pq.ParquetDataset( + base_path, filesystem=fs, use_legacy_dataset=use_legacy_dataset) + table = dataset.read() + result_df = (table.to_pandas() + .sort_values(by='index') + .reset_index(drop=True)) + + expected_df = (df.sort_values(by='index') + .reset_index(drop=True) + .reindex(columns=result_df.columns)) + + expected_df['foo'] = pd.Categorical(df['foo'], categories=foo_keys) + expected_df['bar'] = pd.Categorical(df['bar'], categories=bar_keys) + + assert (result_df.columns == ['index', 'values', 'foo', 'bar']).all() + + tm.assert_frame_equal(result_df, expected_df) + + +def _generate_partition_directories(fs, base_dir, partition_spec, df): + # partition_spec : list of lists, e.g. [['foo', [0, 1, 2], + # ['bar', ['a', 'b', 'c']] + # part_table : a pyarrow.Table to write to each partition + DEPTH = len(partition_spec) + + pathsep = getattr(fs, "pathsep", getattr(fs, "sep", "/")) + + def _visit_level(base_dir, level, part_keys): + name, values = partition_spec[level] + for value in values: + this_part_keys = part_keys + [(name, value)] + + level_dir = pathsep.join([ + str(base_dir), + '{}={}'.format(name, value) + ]) + fs.mkdir(level_dir) + + if level == DEPTH - 1: + # Generate example data + file_path = pathsep.join([level_dir, guid()]) + filtered_df = _filter_partition(df, this_part_keys) + part_table = pa.Table.from_pandas(filtered_df) + with fs.open(file_path, 'wb') as f: + _write_table(part_table, f) + assert fs.exists(file_path) + + file_success = pathsep.join([level_dir, '_SUCCESS']) + with fs.open(file_success, 'wb') as f: + pass + else: + _visit_level(level_dir, level + 1, this_part_keys) + file_success = pathsep.join([level_dir, '_SUCCESS']) + with fs.open(file_success, 'wb') as f: + pass + + _visit_level(base_dir, 0, []) + + +def _test_read_common_metadata_files(fs, base_path): + N = 100 + df = pd.DataFrame({ + 'index': np.arange(N), + 'values': np.random.randn(N) + }, columns=['index', 'values']) + + base_path = str(base_path) + data_path = os.path.join(base_path, 'data.parquet') + + table = pa.Table.from_pandas(df) + + with fs.open(data_path, 'wb') as f: + _write_table(table, f) + + metadata_path = os.path.join(base_path, '_common_metadata') + with fs.open(metadata_path, 'wb') as f: + pq.write_metadata(table.schema, f) + + dataset = pq.ParquetDataset(base_path, filesystem=fs) + assert dataset.common_metadata_path == str(metadata_path) + + with fs.open(data_path) as f: + common_schema = pq.read_metadata(f).schema + assert dataset.schema.equals(common_schema) + + # handle list of one directory + dataset2 = pq.ParquetDataset([base_path], filesystem=fs) + assert dataset2.schema.equals(dataset.schema) + + +@pytest.mark.pandas +def test_read_common_metadata_files(tempdir): + fs = LocalFileSystem._get_instance() + _test_read_common_metadata_files(fs, tempdir) + + +@pytest.mark.pandas +def test_read_metadata_files(tempdir): + fs = LocalFileSystem._get_instance() + + N = 100 + df = pd.DataFrame({ + 'index': np.arange(N), + 'values': np.random.randn(N) + }, columns=['index', 'values']) + + data_path = tempdir / 'data.parquet' + + table = pa.Table.from_pandas(df) + + with fs.open(data_path, 'wb') as f: + _write_table(table, f) + + metadata_path = tempdir / '_metadata' + with fs.open(metadata_path, 'wb') as f: + pq.write_metadata(table.schema, f) + + dataset = pq.ParquetDataset(tempdir, filesystem=fs) + assert dataset.metadata_path == str(metadata_path) + + with fs.open(data_path) as f: + metadata_schema = pq.read_metadata(f).schema + assert dataset.schema.equals(metadata_schema) + + +@pytest.mark.pandas +def test_read_schema(tempdir): + N = 100 + df = pd.DataFrame({ + 'index': np.arange(N), + 'values': np.random.randn(N) + }, columns=['index', 'values']) + + data_path = tempdir / 'test.parquet' + + table = pa.Table.from_pandas(df) + _write_table(table, data_path) + + read1 = pq.read_schema(data_path) + read2 = pq.read_schema(data_path, memory_map=True) + assert table.schema.equals(read1) + assert table.schema.equals(read2) + + assert table.schema.metadata[b'pandas'] == read1.metadata[b'pandas'] + + +def _filter_partition(df, part_keys): + predicate = np.ones(len(df), dtype=bool) + + to_drop = [] + for name, value in part_keys: + to_drop.append(name) + + # to avoid pandas warning + if isinstance(value, (datetime.date, datetime.datetime)): + value = pd.Timestamp(value) + + predicate &= df[name] == value + + return df[predicate].drop(to_drop, axis=1) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_read_multiple_files(tempdir, use_legacy_dataset): + nfiles = 10 + size = 5 + + dirpath = tempdir / guid() + dirpath.mkdir() + + test_data = [] + paths = [] + for i in range(nfiles): + df = _test_dataframe(size, seed=i) + + # Hack so that we don't have a dtype cast in v1 files + df['uint32'] = df['uint32'].astype(np.int64) + + path = dirpath / '{}.parquet'.format(i) + + table = pa.Table.from_pandas(df) + _write_table(table, path) + + test_data.append(table) + paths.append(path) + + # Write a _SUCCESS.crc file + (dirpath / '_SUCCESS.crc').touch() + + def read_multiple_files(paths, columns=None, use_threads=True, **kwargs): + dataset = pq.ParquetDataset( + paths, use_legacy_dataset=use_legacy_dataset, **kwargs) + return dataset.read(columns=columns, use_threads=use_threads) + + result = read_multiple_files(paths) + expected = pa.concat_tables(test_data) + + assert result.equals(expected) + + # Read with provided metadata + # TODO(dataset) specifying metadata not yet supported + metadata = pq.read_metadata(paths[0]) + if use_legacy_dataset: + result2 = read_multiple_files(paths, metadata=metadata) + assert result2.equals(expected) + + result3 = pq.ParquetDataset(dirpath, schema=metadata.schema).read() + assert result3.equals(expected) + else: + with pytest.raises(ValueError, match="no longer supported"): + pq.read_table(paths, metadata=metadata, use_legacy_dataset=False) + + # Read column subset + to_read = [0, 2, 6, result.num_columns - 1] + + col_names = [result.field(i).name for i in to_read] + out = pq.read_table( + dirpath, columns=col_names, use_legacy_dataset=use_legacy_dataset + ) + expected = pa.Table.from_arrays([result.column(i) for i in to_read], + names=col_names, + metadata=result.schema.metadata) + assert out.equals(expected) + + # Read with multiple threads + pq.read_table( + dirpath, use_threads=True, use_legacy_dataset=use_legacy_dataset + ) + + # Test failure modes with non-uniform metadata + bad_apple = _test_dataframe(size, seed=i).iloc[:, :4] + bad_apple_path = tempdir / '{}.parquet'.format(guid()) + + t = pa.Table.from_pandas(bad_apple) + _write_table(t, bad_apple_path) + + if not use_legacy_dataset: + # TODO(dataset) Dataset API skips bad files + return + + bad_meta = pq.read_metadata(bad_apple_path) + + with pytest.raises(ValueError): + read_multiple_files(paths + [bad_apple_path]) + + with pytest.raises(ValueError): + read_multiple_files(paths, metadata=bad_meta) + + mixed_paths = [bad_apple_path, paths[0]] + + with pytest.raises(ValueError): + read_multiple_files(mixed_paths, schema=bad_meta.schema) + + with pytest.raises(ValueError): + read_multiple_files(mixed_paths) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_dataset_read_pandas(tempdir, use_legacy_dataset): + nfiles = 5 + size = 5 + + dirpath = tempdir / guid() + dirpath.mkdir() + + test_data = [] + frames = [] + paths = [] + for i in range(nfiles): + df = _test_dataframe(size, seed=i) + df.index = np.arange(i * size, (i + 1) * size) + df.index.name = 'index' + + path = dirpath / '{}.parquet'.format(i) + + table = pa.Table.from_pandas(df) + _write_table(table, path) + test_data.append(table) + frames.append(df) + paths.append(path) + + dataset = pq.ParquetDataset(dirpath, use_legacy_dataset=use_legacy_dataset) + columns = ['uint8', 'strings'] + result = dataset.read_pandas(columns=columns).to_pandas() + expected = pd.concat([x[columns] for x in frames]) + + tm.assert_frame_equal(result, expected) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_dataset_memory_map(tempdir, use_legacy_dataset): + # ARROW-2627: Check that we can use ParquetDataset with memory-mapping + dirpath = tempdir / guid() + dirpath.mkdir() + + df = _test_dataframe(10, seed=0) + path = dirpath / '{}.parquet'.format(0) + table = pa.Table.from_pandas(df) + _write_table(table, path, version='2.0') + + dataset = pq.ParquetDataset( + dirpath, memory_map=True, use_legacy_dataset=use_legacy_dataset) + assert dataset.read().equals(table) + if use_legacy_dataset: + assert dataset.pieces[0].read().equals(table) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_dataset_enable_buffered_stream(tempdir, use_legacy_dataset): + dirpath = tempdir / guid() + dirpath.mkdir() + + df = _test_dataframe(10, seed=0) + path = dirpath / '{}.parquet'.format(0) + table = pa.Table.from_pandas(df) + _write_table(table, path, version='2.0') + + with pytest.raises(ValueError): + pq.ParquetDataset( + dirpath, buffer_size=-64, + use_legacy_dataset=use_legacy_dataset) + + for buffer_size in [128, 1024]: + dataset = pq.ParquetDataset( + dirpath, buffer_size=buffer_size, + use_legacy_dataset=use_legacy_dataset) + assert dataset.read().equals(table) + + +@pytest.mark.pandas +@pytest.mark.parametrize('preserve_index', [True, False, None]) +def test_dataset_read_pandas_common_metadata(tempdir, preserve_index): + # ARROW-1103 + nfiles = 5 + size = 5 + + dirpath = tempdir / guid() + dirpath.mkdir() + + test_data = [] + frames = [] + paths = [] + for i in range(nfiles): + df = _test_dataframe(size, seed=i) + df.index = pd.Index(np.arange(i * size, (i + 1) * size), name='index') + + path = dirpath / '{}.parquet'.format(i) + + table = pa.Table.from_pandas(df, preserve_index=preserve_index) + + # Obliterate metadata + table = table.replace_schema_metadata(None) + assert table.schema.metadata is None + + _write_table(table, path) + test_data.append(table) + frames.append(df) + paths.append(path) + + # Write _metadata common file + table_for_metadata = pa.Table.from_pandas( + df, preserve_index=preserve_index + ) + pq.write_metadata(table_for_metadata.schema, dirpath / '_metadata') + + dataset = pq.ParquetDataset(dirpath) + columns = ['uint8', 'strings'] + result = dataset.read_pandas(columns=columns).to_pandas() + expected = pd.concat([x[columns] for x in frames]) + expected.index.name = ( + df.index.name if preserve_index is not False else None) + tm.assert_frame_equal(result, expected) + + +def _make_example_multifile_dataset(base_path, nfiles=10, file_nrows=5): + test_data = [] + paths = [] + for i in range(nfiles): + df = _test_dataframe(file_nrows, seed=i) + path = base_path / '{}.parquet'.format(i) + + test_data.append(_write_table(df, path)) + paths.append(path) + return paths + + +def _assert_dataset_paths(dataset, paths, use_legacy_dataset): + if use_legacy_dataset: + assert set(map(str, paths)) == {x.path for x in dataset.pieces} + else: + paths = [str(path.as_posix()) for path in paths] + assert set(paths) == set(dataset._dataset.files) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +@pytest.mark.parametrize('dir_prefix', ['_', '.']) +def test_ignore_private_directories(tempdir, dir_prefix, use_legacy_dataset): + dirpath = tempdir / guid() + dirpath.mkdir() + + paths = _make_example_multifile_dataset(dirpath, nfiles=10, + file_nrows=5) + + # private directory + (dirpath / '{}staging'.format(dir_prefix)).mkdir() + + dataset = pq.ParquetDataset(dirpath, use_legacy_dataset=use_legacy_dataset) + + _assert_dataset_paths(dataset, paths, use_legacy_dataset) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_ignore_hidden_files_dot(tempdir, use_legacy_dataset): + dirpath = tempdir / guid() + dirpath.mkdir() + + paths = _make_example_multifile_dataset(dirpath, nfiles=10, + file_nrows=5) + + with (dirpath / '.DS_Store').open('wb') as f: + f.write(b'gibberish') + + with (dirpath / '.private').open('wb') as f: + f.write(b'gibberish') + + dataset = pq.ParquetDataset(dirpath, use_legacy_dataset=use_legacy_dataset) + + _assert_dataset_paths(dataset, paths, use_legacy_dataset) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_ignore_hidden_files_underscore(tempdir, use_legacy_dataset): + dirpath = tempdir / guid() + dirpath.mkdir() + + paths = _make_example_multifile_dataset(dirpath, nfiles=10, + file_nrows=5) + + with (dirpath / '_committed_123').open('wb') as f: + f.write(b'abcd') + + with (dirpath / '_started_321').open('wb') as f: + f.write(b'abcd') + + dataset = pq.ParquetDataset(dirpath, use_legacy_dataset=use_legacy_dataset) + + _assert_dataset_paths(dataset, paths, use_legacy_dataset) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +@pytest.mark.parametrize('dir_prefix', ['_', '.']) +def test_ignore_no_private_directories_in_base_path( + tempdir, dir_prefix, use_legacy_dataset +): + # ARROW-8427 - don't ignore explicitly listed files if parent directory + # is a private directory + dirpath = tempdir / "{0}data".format(dir_prefix) / guid() + dirpath.mkdir(parents=True) + + paths = _make_example_multifile_dataset(dirpath, nfiles=10, + file_nrows=5) + + dataset = pq.ParquetDataset(paths, use_legacy_dataset=use_legacy_dataset) + _assert_dataset_paths(dataset, paths, use_legacy_dataset) + + # ARROW-9644 - don't ignore full directory with underscore in base path + dataset = pq.ParquetDataset(dirpath, use_legacy_dataset=use_legacy_dataset) + _assert_dataset_paths(dataset, paths, use_legacy_dataset) + + +@pytest.mark.pandas +@parametrize_legacy_dataset_fixed +def test_ignore_custom_prefixes(tempdir, use_legacy_dataset): + # ARROW-9573 - allow override of default ignore_prefixes + part = ["xxx"] * 3 + ["yyy"] * 3 + table = pa.table([ + pa.array(range(len(part))), + pa.array(part).dictionary_encode(), + ], names=['index', '_part']) + + # TODO use_legacy_dataset ARROW-10247 + pq.write_to_dataset(table, str(tempdir), partition_cols=['_part']) + + private_duplicate = tempdir / '_private_duplicate' + private_duplicate.mkdir() + pq.write_to_dataset(table, str(private_duplicate), + partition_cols=['_part']) + + read = pq.read_table( + tempdir, use_legacy_dataset=use_legacy_dataset, + ignore_prefixes=['_private']) + + assert read.equals(table) + + +@parametrize_legacy_dataset_fixed +def test_empty_directory(tempdir, use_legacy_dataset): + # ARROW-5310 - reading empty directory + # fails with legacy implementation + empty_dir = tempdir / 'dataset' + empty_dir.mkdir() + + dataset = pq.ParquetDataset( + empty_dir, use_legacy_dataset=use_legacy_dataset) + result = dataset.read() + assert result.num_rows == 0 + assert result.num_columns == 0 + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_multiindex_duplicate_values(tempdir, use_legacy_dataset): + num_rows = 3 + numbers = list(range(num_rows)) + index = pd.MultiIndex.from_arrays( + [['foo', 'foo', 'bar'], numbers], + names=['foobar', 'some_numbers'], + ) + + df = pd.DataFrame({'numbers': numbers}, index=index) + table = pa.Table.from_pandas(df) + + filename = tempdir / 'dup_multi_index_levels.parquet' + + _write_table(table, filename) + result_table = _read_table(filename, use_legacy_dataset=use_legacy_dataset) + assert table.equals(result_table) + + result_df = result_table.to_pandas() + tm.assert_frame_equal(result_df, df) + + +@pytest.mark.pandas +def test_write_error_deletes_incomplete_file(tempdir): + # ARROW-1285 + df = pd.DataFrame({'a': list('abc'), + 'b': list(range(1, 4)), + 'c': np.arange(3, 6).astype('u1'), + 'd': np.arange(4.0, 7.0, dtype='float64'), + 'e': [True, False, True], + 'f': pd.Categorical(list('abc')), + 'g': pd.date_range('20130101', periods=3), + 'h': pd.date_range('20130101', periods=3, + tz='US/Eastern'), + 'i': pd.date_range('20130101', periods=3, freq='ns')}) + + pdf = pa.Table.from_pandas(df) + + filename = tempdir / 'tmp_file' + try: + _write_table(pdf, filename) + except pa.ArrowException: + pass + + assert not filename.exists() + + +@pytest.mark.pandas +def test_noncoerced_nanoseconds_written_without_exception(tempdir): + # ARROW-1957: the Parquet version 2.0 writer preserves Arrow + # nanosecond timestamps by default + n = 9 + df = pd.DataFrame({'x': range(n)}, + index=pd.date_range('2017-01-01', freq='1n', periods=n)) + tb = pa.Table.from_pandas(df) + + filename = tempdir / 'written.parquet' + try: + pq.write_table(tb, filename, version='2.0') + except Exception: + pass + assert filename.exists() + + recovered_table = pq.read_table(filename) + assert tb.equals(recovered_table) + + # Loss of data through coercion (without explicit override) still an error + filename = tempdir / 'not_written.parquet' + with pytest.raises(ValueError): + pq.write_table(tb, filename, coerce_timestamps='ms', version='2.0') + + +@parametrize_legacy_dataset +def test_read_non_existent_file(tempdir, use_legacy_dataset): + path = 'non-existent-file.parquet' + try: + pq.read_table(path, use_legacy_dataset=use_legacy_dataset) + except Exception as e: + assert path in e.args[0] + + +@parametrize_legacy_dataset +def test_read_table_doesnt_warn(datadir, use_legacy_dataset): + with pytest.warns(None) as record: + pq.read_table(datadir / 'v0.7.1.parquet', + use_legacy_dataset=use_legacy_dataset) + + assert len(record) == 0 + + +def _test_write_to_dataset_with_partitions(base_path, + use_legacy_dataset=True, + filesystem=None, + schema=None, + index_name=None): + # ARROW-1400 + output_df = pd.DataFrame({'group1': list('aaabbbbccc'), + 'group2': list('eefeffgeee'), + 'num': list(range(10)), + 'nan': [np.nan] * 10, + 'date': np.arange('2017-01-01', '2017-01-11', + dtype='datetime64[D]')}) + cols = output_df.columns.tolist() + partition_by = ['group1', 'group2'] + output_table = pa.Table.from_pandas(output_df, schema=schema, safe=False, + preserve_index=False) + pq.write_to_dataset(output_table, base_path, partition_by, + filesystem=filesystem, + use_legacy_dataset=use_legacy_dataset) + + metadata_path = os.path.join(str(base_path), '_common_metadata') + + if filesystem is not None: + with filesystem.open(metadata_path, 'wb') as f: + pq.write_metadata(output_table.schema, f) + else: + pq.write_metadata(output_table.schema, metadata_path) + + # ARROW-2891: Ensure the output_schema is preserved when writing a + # partitioned dataset + dataset = pq.ParquetDataset(base_path, + filesystem=filesystem, + validate_schema=True, + use_legacy_dataset=use_legacy_dataset) + # ARROW-2209: Ensure the dataset schema also includes the partition columns + if use_legacy_dataset: + dataset_cols = set(dataset.schema.to_arrow_schema().names) + else: + # NB schema property is an arrow and not parquet schema + dataset_cols = set(dataset.schema.names) + + assert dataset_cols == set(output_table.schema.names) + + input_table = dataset.read() + input_df = input_table.to_pandas() + + # Read data back in and compare with original DataFrame + # Partitioned columns added to the end of the DataFrame when read + input_df_cols = input_df.columns.tolist() + assert partition_by == input_df_cols[-1 * len(partition_by):] + + input_df = input_df[cols] + # Partitioned columns become 'categorical' dtypes + for col in partition_by: + output_df[col] = output_df[col].astype('category') + tm.assert_frame_equal(output_df, input_df) + + +def _test_write_to_dataset_no_partitions(base_path, + use_legacy_dataset=True, + filesystem=None): + # ARROW-1400 + output_df = pd.DataFrame({'group1': list('aaabbbbccc'), + 'group2': list('eefeffgeee'), + 'num': list(range(10)), + 'date': np.arange('2017-01-01', '2017-01-11', + dtype='datetime64[D]')}) + cols = output_df.columns.tolist() + output_table = pa.Table.from_pandas(output_df) + + if filesystem is None: + filesystem = LocalFileSystem._get_instance() + + # Without partitions, append files to root_path + n = 5 + for i in range(n): + pq.write_to_dataset(output_table, base_path, + filesystem=filesystem) + output_files = [file for file in filesystem.ls(str(base_path)) + if file.endswith(".parquet")] + assert len(output_files) == n + + # Deduplicated incoming DataFrame should match + # original outgoing Dataframe + input_table = pq.ParquetDataset( + base_path, filesystem=filesystem, + use_legacy_dataset=use_legacy_dataset + ).read() + input_df = input_table.to_pandas() + input_df = input_df.drop_duplicates() + input_df = input_df[cols] + assert output_df.equals(input_df) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_write_to_dataset_with_partitions(tempdir, use_legacy_dataset): + _test_write_to_dataset_with_partitions(str(tempdir), use_legacy_dataset) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_write_to_dataset_with_partitions_and_schema( + tempdir, use_legacy_dataset +): + schema = pa.schema([pa.field('group1', type=pa.string()), + pa.field('group2', type=pa.string()), + pa.field('num', type=pa.int64()), + pa.field('nan', type=pa.int32()), + pa.field('date', type=pa.timestamp(unit='us'))]) + _test_write_to_dataset_with_partitions( + str(tempdir), use_legacy_dataset, schema=schema) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_write_to_dataset_with_partitions_and_index_name( + tempdir, use_legacy_dataset +): + _test_write_to_dataset_with_partitions( + str(tempdir), use_legacy_dataset, index_name='index_name') + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_write_to_dataset_no_partitions(tempdir, use_legacy_dataset): + _test_write_to_dataset_no_partitions(str(tempdir), use_legacy_dataset) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_write_to_dataset_pathlib(tempdir, use_legacy_dataset): + _test_write_to_dataset_with_partitions( + tempdir / "test1", use_legacy_dataset) + _test_write_to_dataset_no_partitions( + tempdir / "test2", use_legacy_dataset) + + +# Those tests are failing - see ARROW-10370 +# @pytest.mark.pandas +# @pytest.mark.s3 +# @parametrize_legacy_dataset +# def test_write_to_dataset_pathlib_nonlocal( +# tempdir, s3_example_s3fs, use_legacy_dataset +# ): +# # pathlib paths are only accepted for local files +# fs, _ = s3_example_s3fs + +# with pytest.raises(TypeError, match="path-like objects are only allowed"): +# _test_write_to_dataset_with_partitions( +# tempdir / "test1", use_legacy_dataset, filesystem=fs) + +# with pytest.raises(TypeError, match="path-like objects are only allowed"): +# _test_write_to_dataset_no_partitions( +# tempdir / "test2", use_legacy_dataset, filesystem=fs) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_write_to_dataset_with_partitions_s3fs( + s3_example_s3fs, use_legacy_dataset +): + fs, path = s3_example_s3fs + + _test_write_to_dataset_with_partitions( + path, use_legacy_dataset, filesystem=fs) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_write_to_dataset_no_partitions_s3fs( + s3_example_s3fs, use_legacy_dataset +): + fs, path = s3_example_s3fs + + _test_write_to_dataset_no_partitions( + path, use_legacy_dataset, filesystem=fs) + + +@pytest.mark.pandas +@parametrize_legacy_dataset_not_supported +def test_write_to_dataset_with_partitions_and_custom_filenames( + tempdir, use_legacy_dataset +): + output_df = pd.DataFrame({'group1': list('aaabbbbccc'), + 'group2': list('eefeffgeee'), + 'num': list(range(10)), + 'nan': [np.nan] * 10, + 'date': np.arange('2017-01-01', '2017-01-11', + dtype='datetime64[D]')}) + partition_by = ['group1', 'group2'] + output_table = pa.Table.from_pandas(output_df) + path = str(tempdir) + + def partition_filename_callback(keys): + return "{}-{}.parquet".format(*keys) + + pq.write_to_dataset(output_table, path, + partition_by, partition_filename_callback, + use_legacy_dataset=use_legacy_dataset) + + dataset = pq.ParquetDataset(path) + + # ARROW-3538: Ensure partition filenames match the given pattern + # defined in the local function partition_filename_callback + expected_basenames = [ + 'a-e.parquet', 'a-f.parquet', + 'b-e.parquet', 'b-f.parquet', + 'b-g.parquet', 'c-e.parquet' + ] + output_basenames = [os.path.basename(p.path) for p in dataset.pieces] + + assert sorted(expected_basenames) == sorted(output_basenames) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_write_to_dataset_pandas_preserve_extensiondtypes( + tempdir, use_legacy_dataset +): + # ARROW-8251 - preserve pandas extension dtypes in roundtrip + if LooseVersion(pd.__version__) < "1.0.0": + pytest.skip("__arrow_array__ added to pandas in 1.0.0") + + df = pd.DataFrame({'part': 'a', "col": [1, 2, 3]}) + df['col'] = df['col'].astype("Int64") + table = pa.table(df) + + pq.write_to_dataset( + table, str(tempdir / "case1"), partition_cols=['part'], + use_legacy_dataset=use_legacy_dataset + ) + result = pq.read_table( + str(tempdir / "case1"), use_legacy_dataset=use_legacy_dataset + ).to_pandas() + tm.assert_frame_equal(result[["col"]], df[["col"]]) + + pq.write_to_dataset( + table, str(tempdir / "case2"), use_legacy_dataset=use_legacy_dataset + ) + result = pq.read_table( + str(tempdir / "case2"), use_legacy_dataset=use_legacy_dataset + ).to_pandas() + tm.assert_frame_equal(result[["col"]], df[["col"]]) + + pq.write_table(table, str(tempdir / "data.parquet")) + result = pq.read_table( + str(tempdir / "data.parquet"), use_legacy_dataset=use_legacy_dataset + ).to_pandas() + tm.assert_frame_equal(result[["col"]], df[["col"]]) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_write_to_dataset_pandas_preserve_index(tempdir, use_legacy_dataset): + # ARROW-8251 - preserve pandas index in roundtrip + + df = pd.DataFrame({'part': ['a', 'a', 'b'], "col": [1, 2, 3]}) + df.index = pd.Index(['a', 'b', 'c'], name="idx") + table = pa.table(df) + df_cat = df[["col", "part"]].copy() + df_cat["part"] = df_cat["part"].astype("category") + + pq.write_to_dataset( + table, str(tempdir / "case1"), partition_cols=['part'], + use_legacy_dataset=use_legacy_dataset + ) + result = pq.read_table( + str(tempdir / "case1"), use_legacy_dataset=use_legacy_dataset + ).to_pandas() + tm.assert_frame_equal(result, df_cat) + + pq.write_to_dataset( + table, str(tempdir / "case2"), use_legacy_dataset=use_legacy_dataset + ) + result = pq.read_table( + str(tempdir / "case2"), use_legacy_dataset=use_legacy_dataset + ).to_pandas() + tm.assert_frame_equal(result, df) + + pq.write_table(table, str(tempdir / "data.parquet")) + result = pq.read_table( + str(tempdir / "data.parquet"), use_legacy_dataset=use_legacy_dataset + ).to_pandas() + tm.assert_frame_equal(result, df) + + +@pytest.mark.large_memory +def test_large_table_int32_overflow(): + size = np.iinfo('int32').max + 1 + + arr = np.ones(size, dtype='uint8') + + parr = pa.array(arr, type=pa.uint8()) + + table = pa.Table.from_arrays([parr], names=['one']) + f = io.BytesIO() + _write_table(table, f) + + +def _simple_table_roundtrip(table, use_legacy_dataset=False, **write_kwargs): + stream = pa.BufferOutputStream() + _write_table(table, stream, **write_kwargs) + buf = stream.getvalue() + return _read_table(buf, use_legacy_dataset=use_legacy_dataset) + + +@pytest.mark.large_memory +@parametrize_legacy_dataset +def test_byte_array_exactly_2gb(use_legacy_dataset): + # Test edge case reported in ARROW-3762 + val = b'x' * (1 << 10) + + base = pa.array([val] * ((1 << 21) - 1)) + cases = [ + [b'x' * 1023], # 2^31 - 1 + [b'x' * 1024], # 2^31 + [b'x' * 1025] # 2^31 + 1 + ] + for case in cases: + values = pa.chunked_array([base, pa.array(case)]) + t = pa.table([values], names=['f0']) + result = _simple_table_roundtrip( + t, use_legacy_dataset=use_legacy_dataset, use_dictionary=False) + assert t.equals(result) + + +@pytest.mark.pandas +@pytest.mark.large_memory +@parametrize_legacy_dataset +def test_binary_array_overflow_to_chunked(use_legacy_dataset): + # ARROW-3762 + + # 2^31 + 1 bytes + values = [b'x'] + [ + b'x' * (1 << 20) + ] * 2 * (1 << 10) + df = pd.DataFrame({'byte_col': values}) + + tbl = pa.Table.from_pandas(df, preserve_index=False) + read_tbl = _simple_table_roundtrip( + tbl, use_legacy_dataset=use_legacy_dataset) + + col0_data = read_tbl[0] + assert isinstance(col0_data, pa.ChunkedArray) + + # Split up into 2GB chunks + assert col0_data.num_chunks == 2 + + assert tbl.equals(read_tbl) + + +@pytest.mark.pandas +@pytest.mark.large_memory +@parametrize_legacy_dataset +def test_list_of_binary_large_cell(use_legacy_dataset): + # ARROW-4688 + data = [] + + # TODO(wesm): handle chunked children + # 2^31 - 1 bytes in a single cell + # data.append([b'x' * (1 << 20)] * 2047 + [b'x' * ((1 << 20) - 1)]) + + # A little under 2GB in cell each containing approximately 10MB each + data.extend([[b'x' * 1000000] * 10] * 214) + + arr = pa.array(data) + table = pa.Table.from_arrays([arr], ['chunky_cells']) + read_table = _simple_table_roundtrip( + table, use_legacy_dataset=use_legacy_dataset) + assert table.equals(read_table) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_index_column_name_duplicate(tempdir, use_legacy_dataset): + data = { + 'close': { + pd.Timestamp('2017-06-30 01:31:00'): 154.99958999999998, + pd.Timestamp('2017-06-30 01:32:00'): 154.99958999999998, + }, + 'time': { + pd.Timestamp('2017-06-30 01:31:00'): pd.Timestamp( + '2017-06-30 01:31:00' + ), + pd.Timestamp('2017-06-30 01:32:00'): pd.Timestamp( + '2017-06-30 01:32:00' + ), + } + } + path = str(tempdir / 'data.parquet') + dfx = pd.DataFrame(data).set_index('time', drop=False) + tdfx = pa.Table.from_pandas(dfx) + _write_table(tdfx, path) + arrow_table = _read_table(path, use_legacy_dataset=use_legacy_dataset) + result_df = arrow_table.to_pandas() + tm.assert_frame_equal(result_df, dfx) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_parquet_nested_convenience(tempdir, use_legacy_dataset): + # ARROW-1684 + df = pd.DataFrame({ + 'a': [[1, 2, 3], None, [4, 5], []], + 'b': [[1.], None, None, [6., 7.]], + }) + + path = str(tempdir / 'nested_convenience.parquet') + + table = pa.Table.from_pandas(df, preserve_index=False) + _write_table(table, path) + + read = pq.read_table( + path, columns=['a'], use_legacy_dataset=use_legacy_dataset) + tm.assert_frame_equal(read.to_pandas(), df[['a']]) + + read = pq.read_table( + path, columns=['a', 'b'], use_legacy_dataset=use_legacy_dataset) + tm.assert_frame_equal(read.to_pandas(), df) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_backwards_compatible_index_naming(datadir, use_legacy_dataset): + expected_string = b"""\ +carat cut color clarity depth table price x y z + 0.23 Ideal E SI2 61.5 55.0 326 3.95 3.98 2.43 + 0.21 Premium E SI1 59.8 61.0 326 3.89 3.84 2.31 + 0.23 Good E VS1 56.9 65.0 327 4.05 4.07 2.31 + 0.29 Premium I VS2 62.4 58.0 334 4.20 4.23 2.63 + 0.31 Good J SI2 63.3 58.0 335 4.34 4.35 2.75 + 0.24 Very Good J VVS2 62.8 57.0 336 3.94 3.96 2.48 + 0.24 Very Good I VVS1 62.3 57.0 336 3.95 3.98 2.47 + 0.26 Very Good H SI1 61.9 55.0 337 4.07 4.11 2.53 + 0.22 Fair E VS2 65.1 61.0 337 3.87 3.78 2.49 + 0.23 Very Good H VS1 59.4 61.0 338 4.00 4.05 2.39""" + expected = pd.read_csv(io.BytesIO(expected_string), sep=r'\s{2,}', + index_col=None, header=0, engine='python') + table = _read_table( + datadir / 'v0.7.1.parquet', use_legacy_dataset=use_legacy_dataset) + result = table.to_pandas() + tm.assert_frame_equal(result, expected) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_backwards_compatible_index_multi_level_named( + datadir, use_legacy_dataset +): + expected_string = b"""\ +carat cut color clarity depth table price x y z + 0.23 Ideal E SI2 61.5 55.0 326 3.95 3.98 2.43 + 0.21 Premium E SI1 59.8 61.0 326 3.89 3.84 2.31 + 0.23 Good E VS1 56.9 65.0 327 4.05 4.07 2.31 + 0.29 Premium I VS2 62.4 58.0 334 4.20 4.23 2.63 + 0.31 Good J SI2 63.3 58.0 335 4.34 4.35 2.75 + 0.24 Very Good J VVS2 62.8 57.0 336 3.94 3.96 2.48 + 0.24 Very Good I VVS1 62.3 57.0 336 3.95 3.98 2.47 + 0.26 Very Good H SI1 61.9 55.0 337 4.07 4.11 2.53 + 0.22 Fair E VS2 65.1 61.0 337 3.87 3.78 2.49 + 0.23 Very Good H VS1 59.4 61.0 338 4.00 4.05 2.39""" + expected = pd.read_csv( + io.BytesIO(expected_string), sep=r'\s{2,}', + index_col=['cut', 'color', 'clarity'], + header=0, engine='python' + ).sort_index() + + table = _read_table(datadir / 'v0.7.1.all-named-index.parquet', + use_legacy_dataset=use_legacy_dataset) + result = table.to_pandas() + tm.assert_frame_equal(result, expected) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_backwards_compatible_index_multi_level_some_named( + datadir, use_legacy_dataset +): + expected_string = b"""\ +carat cut color clarity depth table price x y z + 0.23 Ideal E SI2 61.5 55.0 326 3.95 3.98 2.43 + 0.21 Premium E SI1 59.8 61.0 326 3.89 3.84 2.31 + 0.23 Good E VS1 56.9 65.0 327 4.05 4.07 2.31 + 0.29 Premium I VS2 62.4 58.0 334 4.20 4.23 2.63 + 0.31 Good J SI2 63.3 58.0 335 4.34 4.35 2.75 + 0.24 Very Good J VVS2 62.8 57.0 336 3.94 3.96 2.48 + 0.24 Very Good I VVS1 62.3 57.0 336 3.95 3.98 2.47 + 0.26 Very Good H SI1 61.9 55.0 337 4.07 4.11 2.53 + 0.22 Fair E VS2 65.1 61.0 337 3.87 3.78 2.49 + 0.23 Very Good H VS1 59.4 61.0 338 4.00 4.05 2.39""" + expected = pd.read_csv( + io.BytesIO(expected_string), + sep=r'\s{2,}', index_col=['cut', 'color', 'clarity'], + header=0, engine='python' + ).sort_index() + expected.index = expected.index.set_names(['cut', None, 'clarity']) + + table = _read_table(datadir / 'v0.7.1.some-named-index.parquet', + use_legacy_dataset=use_legacy_dataset) + result = table.to_pandas() + tm.assert_frame_equal(result, expected) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_backwards_compatible_column_metadata_handling( + datadir, use_legacy_dataset +): + expected = pd.DataFrame( + {'a': [1, 2, 3], 'b': [.1, .2, .3], + 'c': pd.date_range("2017-01-01", periods=3, tz='Europe/Brussels')}) + expected.index = pd.MultiIndex.from_arrays( + [['a', 'b', 'c'], + pd.date_range("2017-01-01", periods=3, tz='Europe/Brussels')], + names=['index', None]) + + path = datadir / 'v0.7.1.column-metadata-handling.parquet' + table = _read_table(path, use_legacy_dataset=use_legacy_dataset) + result = table.to_pandas() + tm.assert_frame_equal(result, expected) + + table = _read_table( + path, columns=['a'], use_legacy_dataset=use_legacy_dataset) + result = table.to_pandas() + tm.assert_frame_equal(result, expected[['a']].reset_index(drop=True)) + + +# TODO(dataset) support pickling +def _make_dataset_for_pickling(tempdir, N=100): + path = tempdir / 'data.parquet' + fs = LocalFileSystem._get_instance() + + df = pd.DataFrame({ + 'index': np.arange(N), + 'values': np.random.randn(N) + }, columns=['index', 'values']) + table = pa.Table.from_pandas(df) + + num_groups = 3 + with pq.ParquetWriter(path, table.schema) as writer: + for i in range(num_groups): + writer.write_table(table) + + reader = pq.ParquetFile(path) + assert reader.metadata.num_row_groups == num_groups + + metadata_path = tempdir / '_metadata' + with fs.open(metadata_path, 'wb') as f: + pq.write_metadata(table.schema, f) + + dataset = pq.ParquetDataset(tempdir, filesystem=fs) + assert dataset.metadata_path == str(metadata_path) + + return dataset + + +def _assert_dataset_is_picklable(dataset, pickler): + def is_pickleable(obj): + return obj == pickler.loads(pickler.dumps(obj)) + + assert is_pickleable(dataset) + assert is_pickleable(dataset.metadata) + assert is_pickleable(dataset.metadata.schema) + assert len(dataset.metadata.schema) + for column in dataset.metadata.schema: + assert is_pickleable(column) + + for piece in dataset.pieces: + assert is_pickleable(piece) + metadata = piece.get_metadata() + assert metadata.num_row_groups + for i in range(metadata.num_row_groups): + assert is_pickleable(metadata.row_group(i)) + + +@pytest.mark.pandas +def test_builtin_pickle_dataset(tempdir, datadir): + import pickle + dataset = _make_dataset_for_pickling(tempdir) + _assert_dataset_is_picklable(dataset, pickler=pickle) + + +@pytest.mark.pandas +def test_cloudpickle_dataset(tempdir, datadir): + cp = pytest.importorskip('cloudpickle') + dataset = _make_dataset_for_pickling(tempdir) + _assert_dataset_is_picklable(dataset, pickler=cp) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_decimal_roundtrip(tempdir, use_legacy_dataset): + num_values = 10 + + columns = {} + for precision in range(1, 39): + for scale in range(0, precision + 1): + with util.random_seed(0): + random_decimal_values = [ + util.randdecimal(precision, scale) + for _ in range(num_values) + ] + column_name = ('dec_precision_{:d}_scale_{:d}' + .format(precision, scale)) + columns[column_name] = random_decimal_values + + expected = pd.DataFrame(columns) + filename = tempdir / 'decimals.parquet' + string_filename = str(filename) + table = pa.Table.from_pandas(expected) + _write_table(table, string_filename) + result_table = _read_table( + string_filename, use_legacy_dataset=use_legacy_dataset) + result = result_table.to_pandas() + tm.assert_frame_equal(result, expected) + + +@pytest.mark.pandas +@pytest.mark.xfail( + raises=pa.ArrowException, reason='Parquet does not support negative scale' +) +def test_decimal_roundtrip_negative_scale(tempdir): + expected = pd.DataFrame({'decimal_num': [decimal.Decimal('1.23E4')]}) + filename = tempdir / 'decimals.parquet' + string_filename = str(filename) + t = pa.Table.from_pandas(expected) + _write_table(t, string_filename) + result_table = _read_table(string_filename) + result = result_table.to_pandas() + tm.assert_frame_equal(result, expected) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_parquet_writer_context_obj(tempdir, use_legacy_dataset): + df = _test_dataframe(100) + df['unique_id'] = 0 + + arrow_table = pa.Table.from_pandas(df, preserve_index=False) + out = pa.BufferOutputStream() + + with pq.ParquetWriter(out, arrow_table.schema, version='2.0') as writer: + + frames = [] + for i in range(10): + df['unique_id'] = i + arrow_table = pa.Table.from_pandas(df, preserve_index=False) + writer.write_table(arrow_table) + + frames.append(df.copy()) + + buf = out.getvalue() + result = _read_table( + pa.BufferReader(buf), use_legacy_dataset=use_legacy_dataset) + + expected = pd.concat(frames, ignore_index=True) + tm.assert_frame_equal(result.to_pandas(), expected) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_parquet_writer_context_obj_with_exception( + tempdir, use_legacy_dataset +): + df = _test_dataframe(100) + df['unique_id'] = 0 + + arrow_table = pa.Table.from_pandas(df, preserve_index=False) + out = pa.BufferOutputStream() + error_text = 'Artificial Error' + + try: + with pq.ParquetWriter(out, + arrow_table.schema, + version='2.0') as writer: + + frames = [] + for i in range(10): + df['unique_id'] = i + arrow_table = pa.Table.from_pandas(df, preserve_index=False) + writer.write_table(arrow_table) + frames.append(df.copy()) + if i == 5: + raise ValueError(error_text) + except Exception as e: + assert str(e) == error_text + + buf = out.getvalue() + result = _read_table( + pa.BufferReader(buf), use_legacy_dataset=use_legacy_dataset) + + expected = pd.concat(frames, ignore_index=True) + tm.assert_frame_equal(result.to_pandas(), expected) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_zlib_compression_bug(use_legacy_dataset): + # ARROW-3514: "zlib deflate failed, output buffer too small" + table = pa.Table.from_arrays([pa.array(['abc', 'def'])], ['some_col']) + f = io.BytesIO() + pq.write_table(table, f, compression='gzip') + + f.seek(0) + roundtrip = pq.read_table(f, use_legacy_dataset=use_legacy_dataset) + tm.assert_frame_equal(roundtrip.to_pandas(), table.to_pandas()) + + +@pytest.mark.pandas +def test_merging_parquet_tables_with_different_pandas_metadata(tempdir): + # ARROW-3728: Merging Parquet Files - Pandas Meta in Schema Mismatch + schema = pa.schema([ + pa.field('int', pa.int16()), + pa.field('float', pa.float32()), + pa.field('string', pa.string()) + ]) + df1 = pd.DataFrame({ + 'int': np.arange(3, dtype=np.uint8), + 'float': np.arange(3, dtype=np.float32), + 'string': ['ABBA', 'EDDA', 'ACDC'] + }) + df2 = pd.DataFrame({ + 'int': [4, 5], + 'float': [1.1, None], + 'string': [None, None] + }) + table1 = pa.Table.from_pandas(df1, schema=schema, preserve_index=False) + table2 = pa.Table.from_pandas(df2, schema=schema, preserve_index=False) + + assert not table1.schema.equals(table2.schema, check_metadata=True) + assert table1.schema.equals(table2.schema) + + writer = pq.ParquetWriter(tempdir / 'merged.parquet', schema=schema) + writer.write_table(table1) + writer.write_table(table2) + + +def test_empty_row_groups(tempdir): + # ARROW-3020 + table = pa.Table.from_arrays([pa.array([], type='int32')], ['f0']) + + path = tempdir / 'empty_row_groups.parquet' + + num_groups = 3 + with pq.ParquetWriter(path, table.schema) as writer: + for i in range(num_groups): + writer.write_table(table) + + reader = pq.ParquetFile(path) + assert reader.metadata.num_row_groups == num_groups + + for i in range(num_groups): + assert reader.read_row_group(i).equals(table) + + +def test_parquet_file_pass_directory_instead_of_file(tempdir): + # ARROW-7208 + path = tempdir / 'directory' + os.mkdir(str(path)) + + with pytest.raises(IOError, match="Expected file path"): + pq.ParquetFile(path) + + +@pytest.mark.pandas +@pytest.mark.parametrize("filesystem", [ + None, + LocalFileSystem._get_instance(), + fs.LocalFileSystem(), +]) +def test_parquet_writer_filesystem_local(tempdir, filesystem): + df = _test_dataframe(100) + table = pa.Table.from_pandas(df, preserve_index=False) + path = str(tempdir / 'data.parquet') + + with pq.ParquetWriter( + path, table.schema, filesystem=filesystem, version='2.0' + ) as writer: + writer.write_table(table) + + result = _read_table(path).to_pandas() + tm.assert_frame_equal(result, df) + + +@pytest.fixture +def s3_example_fs(s3_connection, s3_server): + from pyarrow.fs import FileSystem + + host, port, access_key, secret_key = s3_connection + uri = ( + "s3://{}:{}@mybucket/data.parquet?scheme=http&endpoint_override={}:{}" + .format(access_key, secret_key, host, port) + ) + fs, path = FileSystem.from_uri(uri) + + fs.create_dir("mybucket") + + yield fs, uri, path + + +@pytest.mark.pandas +@pytest.mark.s3 +def test_parquet_writer_filesystem_s3(s3_example_fs): + df = _test_dataframe(100) + table = pa.Table.from_pandas(df, preserve_index=False) + + fs, uri, path = s3_example_fs + + with pq.ParquetWriter( + path, table.schema, filesystem=fs, version='2.0' + ) as writer: + writer.write_table(table) + + result = _read_table(uri).to_pandas() + tm.assert_frame_equal(result, df) + + +@pytest.mark.pandas +@pytest.mark.s3 +def test_parquet_writer_filesystem_s3_uri(s3_example_fs): + df = _test_dataframe(100) + table = pa.Table.from_pandas(df, preserve_index=False) + + fs, uri, path = s3_example_fs + + with pq.ParquetWriter(uri, table.schema, version='2.0') as writer: + writer.write_table(table) + + result = _read_table(path, filesystem=fs).to_pandas() + tm.assert_frame_equal(result, df) + + +@pytest.mark.pandas +def test_parquet_writer_filesystem_s3fs(s3_example_s3fs): + df = _test_dataframe(100) + table = pa.Table.from_pandas(df, preserve_index=False) + + fs, directory = s3_example_s3fs + path = directory + "/test.parquet" + + with pq.ParquetWriter( + path, table.schema, filesystem=fs, version='2.0' + ) as writer: + writer.write_table(table) + + result = _read_table(path, filesystem=fs).to_pandas() + tm.assert_frame_equal(result, df) + + +@pytest.mark.pandas +def test_parquet_writer_filesystem_buffer_raises(): + df = _test_dataframe(100) + table = pa.Table.from_pandas(df, preserve_index=False) + filesystem = fs.LocalFileSystem() + + # Should raise ValueError when filesystem is passed with file-like object + with pytest.raises(ValueError, match="specified path is file-like"): + pq.ParquetWriter( + pa.BufferOutputStream(), table.schema, filesystem=filesystem + ) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_parquet_writer_with_caller_provided_filesystem(use_legacy_dataset): + out = pa.BufferOutputStream() + + class CustomFS(FileSystem): + def __init__(self): + self.path = None + self.mode = None + + def open(self, path, mode='rb'): + self.path = path + self.mode = mode + return out + + fs = CustomFS() + fname = 'expected_fname.parquet' + df = _test_dataframe(100) + table = pa.Table.from_pandas(df, preserve_index=False) + + with pq.ParquetWriter(fname, table.schema, filesystem=fs, version='2.0') \ + as writer: + writer.write_table(table) + + assert fs.path == fname + assert fs.mode == 'wb' + assert out.closed + + buf = out.getvalue() + table_read = _read_table( + pa.BufferReader(buf), use_legacy_dataset=use_legacy_dataset) + df_read = table_read.to_pandas() + tm.assert_frame_equal(df_read, df) + + # Should raise ValueError when filesystem is passed with file-like object + with pytest.raises(ValueError) as err_info: + pq.ParquetWriter(pa.BufferOutputStream(), table.schema, filesystem=fs) + expected_msg = ("filesystem passed but where is file-like, so" + " there is nothing to open with filesystem.") + assert str(err_info) == expected_msg + + +def test_writing_empty_lists(): + # ARROW-2591: [Python] Segmentation fault issue in pq.write_table + arr1 = pa.array([[], []], pa.list_(pa.int32())) + table = pa.Table.from_arrays([arr1], ['list(int32)']) + _check_roundtrip(table) + + +@parametrize_legacy_dataset +def test_write_nested_zero_length_array_chunk_failure(use_legacy_dataset): + # Bug report in ARROW-3792 + cols = OrderedDict( + int32=pa.int32(), + list_string=pa.list_(pa.string()) + ) + data = [[], [OrderedDict(int32=1, list_string=('G',)), ]] + + # This produces a table with a column like + # )> + # [ + # [], + # [ + # [ + # "G" + # ] + # ] + # ] + # + # Each column is a ChunkedArray with 2 elements + my_arrays = [pa.array(batch, type=pa.struct(cols)).flatten() + for batch in data] + my_batches = [pa.RecordBatch.from_arrays(batch, schema=pa.schema(cols)) + for batch in my_arrays] + tbl = pa.Table.from_batches(my_batches, pa.schema(cols)) + _check_roundtrip(tbl, use_legacy_dataset=use_legacy_dataset) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_partitioned_dataset(tempdir, use_legacy_dataset): + # ARROW-3208: Segmentation fault when reading a Parquet partitioned dataset + # to a Parquet file + path = tempdir / "ARROW-3208" + df = pd.DataFrame({ + 'one': [-1, 10, 2.5, 100, 1000, 1, 29.2], + 'two': [-1, 10, 2, 100, 1000, 1, 11], + 'three': [0, 0, 0, 0, 0, 0, 0] + }) + table = pa.Table.from_pandas(df) + pq.write_to_dataset(table, root_path=str(path), + partition_cols=['one', 'two']) + table = pq.ParquetDataset( + path, use_legacy_dataset=use_legacy_dataset).read() + pq.write_table(table, path / "output.parquet") + + +def test_read_column_invalid_index(): + table = pa.table([pa.array([4, 5]), pa.array(["foo", "bar"])], + names=['ints', 'strs']) + bio = pa.BufferOutputStream() + pq.write_table(table, bio) + f = pq.ParquetFile(bio.getvalue()) + assert f.reader.read_column(0).to_pylist() == [4, 5] + assert f.reader.read_column(1).to_pylist() == ["foo", "bar"] + for index in (-1, 2): + with pytest.raises((ValueError, IndexError)): + f.reader.read_column(index) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_direct_read_dictionary(use_legacy_dataset): + # ARROW-3325 + repeats = 10 + nunique = 5 + + data = [ + [util.rands(10) for i in range(nunique)] * repeats, + + ] + table = pa.table(data, names=['f0']) + + bio = pa.BufferOutputStream() + pq.write_table(table, bio) + contents = bio.getvalue() + + result = pq.read_table(pa.BufferReader(contents), + read_dictionary=['f0'], + use_legacy_dataset=use_legacy_dataset) + + # Compute dictionary-encoded subfield + expected = pa.table([table[0].dictionary_encode()], names=['f0']) + assert result.equals(expected) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_dataset_read_dictionary(tempdir, use_legacy_dataset): + path = tempdir / "ARROW-3325-dataset" + t1 = pa.table([[util.rands(10) for i in range(5)] * 10], names=['f0']) + t2 = pa.table([[util.rands(10) for i in range(5)] * 10], names=['f0']) + # TODO pass use_legacy_dataset (need to fix unique names) + pq.write_to_dataset(t1, root_path=str(path)) + pq.write_to_dataset(t2, root_path=str(path)) + + result = pq.ParquetDataset( + path, read_dictionary=['f0'], + use_legacy_dataset=use_legacy_dataset).read() + + # The order of the chunks is non-deterministic + ex_chunks = [t1[0].chunk(0).dictionary_encode(), + t2[0].chunk(0).dictionary_encode()] + + assert result[0].num_chunks == 2 + c0, c1 = result[0].chunk(0), result[0].chunk(1) + if c0.equals(ex_chunks[0]): + assert c1.equals(ex_chunks[1]) + else: + assert c0.equals(ex_chunks[1]) + assert c1.equals(ex_chunks[0]) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_direct_read_dictionary_subfield(use_legacy_dataset): + repeats = 10 + nunique = 5 + + data = [ + [[util.rands(10)] for i in range(nunique)] * repeats, + ] + table = pa.table(data, names=['f0']) + + bio = pa.BufferOutputStream() + pq.write_table(table, bio) + contents = bio.getvalue() + result = pq.read_table(pa.BufferReader(contents), + read_dictionary=['f0.list.item'], + use_legacy_dataset=use_legacy_dataset) + + arr = pa.array(data[0]) + values_as_dict = arr.values.dictionary_encode() + + inner_indices = values_as_dict.indices.cast('int32') + new_values = pa.DictionaryArray.from_arrays(inner_indices, + values_as_dict.dictionary) + + offsets = pa.array(range(51), type='int32') + expected_arr = pa.ListArray.from_arrays(offsets, new_values) + expected = pa.table([expected_arr], names=['f0']) + + assert result.equals(expected) + assert result[0].num_chunks == 1 + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_write_to_dataset_metadata(tempdir, use_legacy_dataset): + path = tempdir / "ARROW-1983-dataset" + + # create and write a test dataset + df = pd.DataFrame({ + 'one': [1, 2, 3], + 'two': [-1, -2, -3], + 'three': [[1, 2], [2, 3], [3, 4]], + }) + table = pa.Table.from_pandas(df) + + metadata_list = [] + if not use_legacy_dataset: + # New dataset implementation does not yet support metadata_collector + with pytest.raises(ValueError): + pq.write_to_dataset(table, root_path=str(path), + partition_cols=['one', 'two'], + metadata_collector=metadata_list, + use_legacy_dataset=use_legacy_dataset) + return + pq.write_to_dataset(table, root_path=str(path), + partition_cols=['one', 'two'], + metadata_collector=metadata_list, + use_legacy_dataset=use_legacy_dataset) + + # open the dataset and collect metadata from pieces: + dataset = pq.ParquetDataset(path) + metadata_list2 = [p.get_metadata() for p in dataset.pieces] + + collected_paths = [] + # compare metadata list content: + assert len(metadata_list) == len(metadata_list2) + for md, md2 in zip(metadata_list, metadata_list2): + d = md.to_dict() + d2 = md2.to_dict() + # serialized_size is initialized in the reader: + assert d.pop('serialized_size') == 0 + assert d2.pop('serialized_size') > 0 + # file_path is different (not set for in-file metadata) + assert d["row_groups"][0]["columns"][0]["file_path"] != "" + assert d2["row_groups"][0]["columns"][0]["file_path"] == "" + # collect file paths to check afterwards, ignore here + collected_paths.append(d["row_groups"][0]["columns"][0]["file_path"]) + d["row_groups"][0]["columns"][0]["file_path"] = "" + assert d == d2 + + # ARROW-8244 - check the file paths in the collected metadata + n_root = len(path.parts) + file_paths = ["/".join(p.parts[n_root:]) for p in path.rglob("*.parquet")] + assert sorted(collected_paths) == sorted(file_paths) + + # writing to single file (not partitioned) + metadata_list = [] + pq.write_to_dataset(pa.table({'a': [1, 2, 3]}), root_path=str(path), + metadata_collector=metadata_list) + + # compare metadata content + file_paths = list(path.glob("*.parquet")) + assert len(file_paths) == 1 + file_path = file_paths[0] + file_metadata = pq.read_metadata(file_path) + d1 = metadata_list[0].to_dict() + d2 = file_metadata.to_dict() + # serialized_size is initialized in the reader: + assert d1.pop('serialized_size') == 0 + assert d2.pop('serialized_size') > 0 + # file_path is different (not set for in-file metadata) + assert d1["row_groups"][0]["columns"][0]["file_path"] == file_path.name + assert d2["row_groups"][0]["columns"][0]["file_path"] == "" + d1["row_groups"][0]["columns"][0]["file_path"] = "" + assert d1 == d2 + + +@parametrize_legacy_dataset +def test_parquet_file_too_small(tempdir, use_legacy_dataset): + path = str(tempdir / "test.parquet") + # TODO(dataset) with datasets API it raises OSError instead + with pytest.raises((pa.ArrowInvalid, OSError), + match='size is 0 bytes'): + with open(path, 'wb') as f: + pass + pq.read_table(path, use_legacy_dataset=use_legacy_dataset) + + with pytest.raises((pa.ArrowInvalid, OSError), + match='size is 4 bytes'): + with open(path, 'wb') as f: + f.write(b'ffff') + pq.read_table(path, use_legacy_dataset=use_legacy_dataset) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_categorical_index_survives_roundtrip(use_legacy_dataset): + # ARROW-3652, addressed by ARROW-3246 + df = pd.DataFrame([['a', 'b'], ['c', 'd']], columns=['c1', 'c2']) + df['c1'] = df['c1'].astype('category') + df = df.set_index(['c1']) + + table = pa.Table.from_pandas(df) + bos = pa.BufferOutputStream() + pq.write_table(table, bos) + ref_df = pq.read_pandas( + bos.getvalue(), use_legacy_dataset=use_legacy_dataset).to_pandas() + assert isinstance(ref_df.index, pd.CategoricalIndex) + assert ref_df.index.equals(df.index) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_categorical_order_survives_roundtrip(use_legacy_dataset): + # ARROW-6302 + df = pd.DataFrame({"a": pd.Categorical( + ["a", "b", "c", "a"], categories=["b", "c", "d"], ordered=True)}) + + table = pa.Table.from_pandas(df) + bos = pa.BufferOutputStream() + pq.write_table(table, bos) + + contents = bos.getvalue() + result = pq.read_pandas( + contents, use_legacy_dataset=use_legacy_dataset).to_pandas() + + tm.assert_frame_equal(result, df) + + +def _simple_table_write_read(table, use_legacy_dataset): + bio = pa.BufferOutputStream() + pq.write_table(table, bio) + contents = bio.getvalue() + return pq.read_table( + pa.BufferReader(contents), use_legacy_dataset=use_legacy_dataset + ) + + +@parametrize_legacy_dataset +def test_dictionary_array_automatically_read(use_legacy_dataset): + # ARROW-3246 + + # Make a large dictionary, a little over 4MB of data + dict_length = 4000 + dict_values = pa.array([('x' * 1000 + '_{}'.format(i)) + for i in range(dict_length)]) + + num_chunks = 10 + chunk_size = 100 + chunks = [] + for i in range(num_chunks): + indices = np.random.randint(0, dict_length, + size=chunk_size).astype(np.int32) + chunks.append(pa.DictionaryArray.from_arrays(pa.array(indices), + dict_values)) + + table = pa.table([pa.chunked_array(chunks)], names=['f0']) + result = _simple_table_write_read(table, use_legacy_dataset) + + assert result.equals(table) + + # The only key in the metadata was the Arrow schema key + assert result.schema.metadata is None + + +def test_field_id_metadata(): + # ARROW-7080 + table = pa.table([pa.array([1], type='int32'), + pa.array([[]], type=pa.list_(pa.int32())), + pa.array([b'boo'], type='binary')], + ['f0', 'f1', 'f2']) + + bio = pa.BufferOutputStream() + pq.write_table(table, bio) + contents = bio.getvalue() + + pf = pq.ParquetFile(pa.BufferReader(contents)) + schema = pf.schema_arrow + + # Expected Parquet schema for reference + # + # required group field_id=0 schema { + # optional int32 field_id=1 f0; + # optional group field_id=2 f1 (List) { + # repeated group field_id=3 list { + # optional int32 field_id=4 item; + # } + # } + # optional binary field_id=5 f2; + # } + + field_name = b'PARQUET:field_id' + assert schema[0].metadata[field_name] == b'1' + + list_field = schema[1] + assert list_field.metadata[field_name] == b'2' + + list_item_field = list_field.type.value_field + assert list_item_field.metadata[field_name] == b'4' + + assert schema[2].metadata[field_name] == b'5' + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_pandas_categorical_na_type_row_groups(use_legacy_dataset): + # ARROW-5085 + df = pd.DataFrame({"col": [None] * 100, "int": [1.0] * 100}) + df_category = df.astype({"col": "category", "int": "category"}) + table = pa.Table.from_pandas(df) + table_cat = pa.Table.from_pandas(df_category) + buf = pa.BufferOutputStream() + + # it works + pq.write_table(table_cat, buf, version="2.0", chunk_size=10) + result = pq.read_table( + buf.getvalue(), use_legacy_dataset=use_legacy_dataset) + + # Result is non-categorical + assert result[0].equals(table[0]) + assert result[1].equals(table[1]) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_pandas_categorical_roundtrip(use_legacy_dataset): + # ARROW-5480, this was enabled by ARROW-3246 + + # Have one of the categories unobserved and include a null (-1) + codes = np.array([2, 0, 0, 2, 0, -1, 2], dtype='int32') + categories = ['foo', 'bar', 'baz'] + df = pd.DataFrame({'x': pd.Categorical.from_codes( + codes, categories=categories)}) + + buf = pa.BufferOutputStream() + pq.write_table(pa.table(df), buf) + + result = pq.read_table( + buf.getvalue(), use_legacy_dataset=use_legacy_dataset).to_pandas() + assert result.x.dtype == 'category' + assert (result.x.cat.categories == categories).all() + tm.assert_frame_equal(result, df) + + +@pytest.mark.pandas +def test_multi_dataset_metadata(tempdir): + filenames = ["ARROW-1983-dataset.0", "ARROW-1983-dataset.1"] + metapath = str(tempdir / "_metadata") + + # create a test dataset + df = pd.DataFrame({ + 'one': [1, 2, 3], + 'two': [-1, -2, -3], + 'three': [[1, 2], [2, 3], [3, 4]], + }) + table = pa.Table.from_pandas(df) + + # write dataset twice and collect/merge metadata + _meta = None + for filename in filenames: + meta = [] + pq.write_table(table, str(tempdir / filename), + metadata_collector=meta) + meta[0].set_file_path(filename) + if _meta is None: + _meta = meta[0] + else: + _meta.append_row_groups(meta[0]) + + # Write merged metadata-only file + with open(metapath, "wb") as f: + _meta.write_metadata_file(f) + + # Read back the metadata + meta = pq.read_metadata(metapath) + md = meta.to_dict() + _md = _meta.to_dict() + for key in _md: + if key != 'serialized_size': + assert _md[key] == md[key] + assert _md['num_columns'] == 3 + assert _md['num_rows'] == 6 + assert _md['num_row_groups'] == 2 + assert _md['serialized_size'] == 0 + assert md['serialized_size'] > 0 + + +def test_write_metadata(tempdir): + path = str(tempdir / "metadata") + schema = pa.schema([("a", "int64"), ("b", "float64")]) + + # write a pyarrow schema + pq.write_metadata(schema, path) + parquet_meta = pq.read_metadata(path) + schema_as_arrow = parquet_meta.schema.to_arrow_schema() + assert schema_as_arrow.equals(schema) + + # ARROW-8980: Check that the ARROW:schema metadata key was removed + if schema_as_arrow.metadata: + assert b'ARROW:schema' not in schema_as_arrow.metadata + + # pass through writer keyword arguments + for version in ["1.0", "2.0"]: + pq.write_metadata(schema, path, version=version) + parquet_meta = pq.read_metadata(path) + assert parquet_meta.format_version == version + + # metadata_collector: list of FileMetaData objects + table = pa.table({'a': [1, 2], 'b': [.1, .2]}, schema=schema) + pq.write_table(table, tempdir / "data.parquet") + parquet_meta = pq.read_metadata(str(tempdir / "data.parquet")) + pq.write_metadata( + schema, path, metadata_collector=[parquet_meta, parquet_meta] + ) + parquet_meta_mult = pq.read_metadata(path) + assert parquet_meta_mult.num_row_groups == 2 + + # append metadata with different schema raises an error + with pytest.raises(RuntimeError, match="requires equal schemas"): + pq.write_metadata( + pa.schema([("a", "int32"), ("b", "null")]), + path, metadata_collector=[parquet_meta, parquet_meta] + ) + + +@parametrize_legacy_dataset +@pytest.mark.pandas +def test_filter_before_validate_schema(tempdir, use_legacy_dataset): + # ARROW-4076 apply filter before schema validation + # to avoid checking unneeded schemas + + # create partitioned dataset with mismatching schemas which would + # otherwise raise if first validation all schemas + dir1 = tempdir / 'A=0' + dir1.mkdir() + table1 = pa.Table.from_pandas(pd.DataFrame({'B': [1, 2, 3]})) + pq.write_table(table1, dir1 / 'data.parquet') + + dir2 = tempdir / 'A=1' + dir2.mkdir() + table2 = pa.Table.from_pandas(pd.DataFrame({'B': ['a', 'b', 'c']})) + pq.write_table(table2, dir2 / 'data.parquet') + + # read single file using filter + table = pq.read_table(tempdir, filters=[[('A', '==', 0)]], + use_legacy_dataset=use_legacy_dataset) + assert table.column('B').equals(pa.chunked_array([[1, 2, 3]])) + + +@pytest.mark.pandas +@pytest.mark.fastparquet +@pytest.mark.filterwarnings("ignore:RangeIndex:FutureWarning") +@pytest.mark.filterwarnings("ignore:tostring:DeprecationWarning:fastparquet") +def test_fastparquet_cross_compatibility(tempdir): + fp = pytest.importorskip('fastparquet') + + df = pd.DataFrame( + { + "a": list("abc"), + "b": list(range(1, 4)), + "c": np.arange(4.0, 7.0, dtype="float64"), + "d": [True, False, True], + "e": pd.date_range("20130101", periods=3), + "f": pd.Categorical(["a", "b", "a"]), + # fastparquet writes list as BYTE_ARRAY JSON, so no roundtrip + # "g": [[1, 2], None, [1, 2, 3]], + } + ) + table = pa.table(df) + + # Arrow -> fastparquet + file_arrow = str(tempdir / "cross_compat_arrow.parquet") + pq.write_table(table, file_arrow, compression=None) + + fp_file = fp.ParquetFile(file_arrow) + df_fp = fp_file.to_pandas() + tm.assert_frame_equal(df, df_fp) + + # Fastparquet -> arrow + file_fastparquet = str(tempdir / "cross_compat_fastparquet.parquet") + fp.write(file_fastparquet, df) + + table_fp = pq.read_pandas(file_fastparquet) + # for fastparquet written file, categoricals comes back as strings + # (no arrow schema in parquet metadata) + df['f'] = df['f'].astype(object) + tm.assert_frame_equal(table_fp.to_pandas(), df) + + +def test_table_large_metadata(): + # ARROW-8694 + my_schema = pa.schema([pa.field('f0', 'double')], + metadata={'large': 'x' * 10000000}) + + table = pa.table([np.arange(10)], schema=my_schema) + _check_roundtrip(table) + + +@parametrize_legacy_dataset +@pytest.mark.parametrize('array_factory', [ + lambda: pa.array([0, None] * 10), + lambda: pa.array([0, None] * 10).dictionary_encode(), + lambda: pa.array(["", None] * 10), + lambda: pa.array(["", None] * 10).dictionary_encode(), +]) +@pytest.mark.parametrize('use_dictionary', [False, True]) +@pytest.mark.parametrize('read_dictionary', [False, True]) +def test_buffer_contents( + array_factory, use_dictionary, read_dictionary, use_legacy_dataset +): + # Test that null values are deterministically initialized to zero + # after a roundtrip through Parquet. + # See ARROW-8006 and ARROW-8011. + orig_table = pa.Table.from_pydict({"col": array_factory()}) + bio = io.BytesIO() + pq.write_table(orig_table, bio, use_dictionary=True) + bio.seek(0) + read_dictionary = ['col'] if read_dictionary else None + table = pq.read_table(bio, use_threads=False, + read_dictionary=read_dictionary, + use_legacy_dataset=use_legacy_dataset) + + for col in table.columns: + [chunk] = col.chunks + buf = chunk.buffers()[1] + assert buf.to_pybytes() == buf.size * b"\0" + + +@pytest.mark.dataset +def test_dataset_unsupported_keywords(): + + with pytest.raises(ValueError, match="not yet supported with the new"): + pq.ParquetDataset("", use_legacy_dataset=False, schema=pa.schema([])) + + with pytest.raises(ValueError, match="not yet supported with the new"): + pq.ParquetDataset("", use_legacy_dataset=False, metadata=pa.schema([])) + + with pytest.raises(ValueError, match="not yet supported with the new"): + pq.ParquetDataset("", use_legacy_dataset=False, validate_schema=False) + + with pytest.raises(ValueError, match="not yet supported with the new"): + pq.ParquetDataset("", use_legacy_dataset=False, split_row_groups=True) + + with pytest.raises(ValueError, match="not yet supported with the new"): + pq.ParquetDataset("", use_legacy_dataset=False, metadata_nthreads=4) + + with pytest.raises(ValueError, match="no longer supported"): + pq.read_table("", use_legacy_dataset=False, metadata=pa.schema([])) + + +@pytest.mark.dataset +def test_dataset_partitioning(tempdir): + import pyarrow.dataset as ds + + # create small dataset with directory partitioning + root_path = tempdir / "test_partitioning" + (root_path / "2012" / "10" / "01").mkdir(parents=True) + + table = pa.table({'a': [1, 2, 3]}) + pq.write_table( + table, str(root_path / "2012" / "10" / "01" / "data.parquet")) + + # This works with new dataset API + + # read_table + part = ds.partitioning(field_names=["year", "month", "day"]) + result = pq.read_table( + str(root_path), partitioning=part, use_legacy_dataset=False) + assert result.column_names == ["a", "year", "month", "day"] + + result = pq.ParquetDataset( + str(root_path), partitioning=part, use_legacy_dataset=False).read() + assert result.column_names == ["a", "year", "month", "day"] + + # This raises an error for legacy dataset + with pytest.raises(ValueError): + pq.read_table( + str(root_path), partitioning=part, use_legacy_dataset=True) + + with pytest.raises(ValueError): + pq.ParquetDataset( + str(root_path), partitioning=part, use_legacy_dataset=True) + + +@pytest.mark.dataset +def test_parquet_dataset_new_filesystem(tempdir): + # Ensure we can pass new FileSystem object to ParquetDataset + # (use new implementation automatically without specifying + # use_legacy_dataset=False) + table = pa.table({'a': [1, 2, 3]}) + pq.write_table(table, tempdir / 'data.parquet') + # don't use simple LocalFileSystem (as that gets mapped to legacy one) + filesystem = fs.SubTreeFileSystem(str(tempdir), fs.LocalFileSystem()) + dataset = pq.ParquetDataset('.', filesystem=filesystem) + result = dataset.read() + assert result.equals(table) + + +def test_parquet_dataset_partitions_piece_path_with_fsspec(tempdir): + # ARROW-10462 ensure that on Windows we properly use posix-style paths + # as used by fsspec + fsspec = pytest.importorskip("fsspec") + filesystem = fsspec.filesystem('file') + table = pa.table({'a': [1, 2, 3]}) + pq.write_table(table, tempdir / 'data.parquet') + + # pass a posix-style path (using "/" also on Windows) + path = str(tempdir).replace("\\", "/") + dataset = pq.ParquetDataset(path, filesystem=filesystem) + # ensure the piece path is also posix-style + expected = path + "/data.parquet" + assert dataset.pieces[0].path == expected + + +def test_parquet_compression_roundtrip(tempdir): + # ARROW-10480: ensure even with nonstandard Parquet file naming + # conventions, writing and then reading a file works. In + # particular, ensure that we don't automatically double-compress + # the stream due to auto-detecting the extension in the filename + table = pa.table([pa.array(range(4))], names=["ints"]) + path = tempdir / "arrow-10480.pyarrow.gz" + pq.write_table(table, path, compression="GZIP") + result = pq.read_table(path) + assert result.equals(table) From 615efdb274edc2d27cfa2898d6418805af9fe79d Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Fri, 18 Dec 2020 13:05:20 -0500 Subject: [PATCH 23/38] move Identical() impl to .cc --- cpp/src/arrow/dataset/expression.cc | 2 ++ cpp/src/arrow/dataset/expression_internal.h | 2 -- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/src/arrow/dataset/expression.cc b/cpp/src/arrow/dataset/expression.cc index 848055ac5aa..7f08788de29 100644 --- a/cpp/src/arrow/dataset/expression.cc +++ b/cpp/src/arrow/dataset/expression.cc @@ -267,6 +267,8 @@ bool Expression::Equals(const Expression& other) const { return false; } +bool Identical(const Expression& l, const Expression& r) { return l.impl_ == r.impl_; } + size_t Expression::hash() const { if (auto lit = literal()) { if (lit->is_scalar()) { diff --git a/cpp/src/arrow/dataset/expression_internal.h b/cpp/src/arrow/dataset/expression_internal.h index 77e555aa7a6..c3fd49dc347 100644 --- a/cpp/src/arrow/dataset/expression_internal.h +++ b/cpp/src/arrow/dataset/expression_internal.h @@ -34,8 +34,6 @@ using internal::checked_cast; namespace dataset { -bool Identical(const Expression& l, const Expression& r) { return l.impl_ == r.impl_; } - const Expression::Call* CallNotNull(const Expression& expr) { auto call = expr.call(); DCHECK_NE(call, nullptr); From c220823a4e4c771f73f81a8fcdf10410cf22b100 Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Fri, 18 Dec 2020 13:52:07 -0500 Subject: [PATCH 24/38] export ExecuteScalarExpression --- cpp/src/arrow/dataset/expression.h | 1 + cpp/src/arrow/dataset/expression_test.cc | 3 +++ 2 files changed, 4 insertions(+) diff --git a/cpp/src/arrow/dataset/expression.h b/cpp/src/arrow/dataset/expression.h index 4b414d4d142..e597ffb7fcd 100644 --- a/cpp/src/arrow/dataset/expression.h +++ b/cpp/src/arrow/dataset/expression.h @@ -193,6 +193,7 @@ Result SimplifyWithGuarantee(Expression, /// Execute a scalar expression against the provided state and input Datum. This /// expression must be bound. +ARROW_DS_EXPORT Result ExecuteScalarExpression(const Expression&, const Datum& input, compute::ExecContext* = NULLPTR); diff --git a/cpp/src/arrow/dataset/expression_test.cc b/cpp/src/arrow/dataset/expression_test.cc index 175e045918d..e245b0d7093 100644 --- a/cpp/src/arrow/dataset/expression_test.cc +++ b/cpp/src/arrow/dataset/expression_test.cc @@ -1006,6 +1006,9 @@ TEST(Expression, SerializationRoundTrips) { ExpectRoundTrips(not_(field_ref("alpha"))); + ExpectRoundTrips(call("is_in", {literal(1)}, + compute::SetLookupOptions{ArrayFromJSON(int32(), "[1, 2, 3]")})); + ExpectRoundTrips( call("is_in", {call("cast", {field_ref("version")}, compute::CastOptions::Safe(float64()))}, From 38389e7c1f36a81709d9006874dd55cdd9fc29ed Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Mon, 4 Jan 2021 16:09:21 -0500 Subject: [PATCH 25/38] remove test_parquet.py --- python/pyarrow/tests/parquet/test_dataset.py | 2 +- python/pyarrow/tests/test_parquet.py | 4528 ------------------ 2 files changed, 1 insertion(+), 4529 deletions(-) delete mode 100644 python/pyarrow/tests/test_parquet.py diff --git a/python/pyarrow/tests/parquet/test_dataset.py b/python/pyarrow/tests/parquet/test_dataset.py index ac1dd279724..3307608aefa 100644 --- a/python/pyarrow/tests/parquet/test_dataset.py +++ b/python/pyarrow/tests/parquet/test_dataset.py @@ -509,7 +509,7 @@ def test_filters_invalid_column(tempdir, use_legacy_dataset): _generate_partition_directories(fs, base_path, partition_spec, df) - msg = "Field named 'non_existent_column' not found" + msg = r"No match for FieldRef.Name\(non_existent_column\)" with pytest.raises(ValueError, match=msg): pq.ParquetDataset(base_path, filesystem=fs, filters=[('non_existent_column', '<', 3), ], diff --git a/python/pyarrow/tests/test_parquet.py b/python/pyarrow/tests/test_parquet.py deleted file mode 100644 index 6fa001cc758..00000000000 --- a/python/pyarrow/tests/test_parquet.py +++ /dev/null @@ -1,4528 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from collections import OrderedDict -import datetime -import decimal -from distutils.version import LooseVersion -import io -import json -import os -import pytest - -import numpy as np - -import pyarrow as pa -from pyarrow.pandas_compat import _pandas_api -from pyarrow.tests import util -from pyarrow.util import guid -from pyarrow.filesystem import LocalFileSystem, FileSystem -from pyarrow import fs - - -try: - import pyarrow.parquet as pq -except ImportError: - pq = None - - -try: - import pandas as pd - import pandas.testing as tm - from .pandas_examples import dataframe_with_arrays, dataframe_with_lists -except ImportError: - pd = tm = None - - -# Marks all of the tests in this module -# Ignore these with pytest ... -m 'not parquet' -pytestmark = pytest.mark.parquet - - -@pytest.fixture(scope='module') -def datadir(datadir): - return datadir / 'parquet' - - -parametrize_legacy_dataset = pytest.mark.parametrize( - "use_legacy_dataset", - [True, pytest.param(False, marks=pytest.mark.dataset)]) -parametrize_legacy_dataset_not_supported = pytest.mark.parametrize( - "use_legacy_dataset", [True, pytest.param(False, marks=pytest.mark.skip)]) -parametrize_legacy_dataset_fixed = pytest.mark.parametrize( - "use_legacy_dataset", [pytest.param(True, marks=pytest.mark.xfail), - pytest.param(False, marks=pytest.mark.dataset)]) - - -def _write_table(table, path, **kwargs): - # So we see the ImportError somewhere - import pyarrow.parquet as pq - - if _pandas_api.is_data_frame(table): - table = pa.Table.from_pandas(table) - - pq.write_table(table, path, **kwargs) - return table - - -def _read_table(*args, **kwargs): - table = pq.read_table(*args, **kwargs) - table.validate(full=True) - return table - - -def _roundtrip_table(table, read_table_kwargs=None, - write_table_kwargs=None, use_legacy_dataset=True): - read_table_kwargs = read_table_kwargs or {} - write_table_kwargs = write_table_kwargs or {} - - writer = pa.BufferOutputStream() - _write_table(table, writer, **write_table_kwargs) - reader = pa.BufferReader(writer.getvalue()) - return _read_table(reader, use_legacy_dataset=use_legacy_dataset, - **read_table_kwargs) - - -def _check_roundtrip(table, expected=None, read_table_kwargs=None, - use_legacy_dataset=True, **write_table_kwargs): - if expected is None: - expected = table - - read_table_kwargs = read_table_kwargs or {} - - # intentionally check twice - result = _roundtrip_table(table, read_table_kwargs=read_table_kwargs, - write_table_kwargs=write_table_kwargs, - use_legacy_dataset=use_legacy_dataset) - assert result.equals(expected) - result = _roundtrip_table(result, read_table_kwargs=read_table_kwargs, - write_table_kwargs=write_table_kwargs, - use_legacy_dataset=use_legacy_dataset) - assert result.equals(expected) - - -def _roundtrip_pandas_dataframe(df, write_kwargs, use_legacy_dataset=True): - table = pa.Table.from_pandas(df) - result = _roundtrip_table( - table, write_table_kwargs=write_kwargs, - use_legacy_dataset=use_legacy_dataset) - return result.to_pandas() - - -def test_large_binary(): - data = [b'foo', b'bar'] * 50 - for type in [pa.large_binary(), pa.large_string()]: - arr = pa.array(data, type=type) - table = pa.Table.from_arrays([arr], names=['strs']) - for use_dictionary in [False, True]: - _check_roundtrip(table, use_dictionary=use_dictionary) - - -@pytest.mark.large_memory -def test_large_binary_huge(): - s = b'xy' * 997 - data = [s] * ((1 << 33) // len(s)) - for type in [pa.large_binary(), pa.large_string()]: - arr = pa.array(data, type=type) - table = pa.Table.from_arrays([arr], names=['strs']) - for use_dictionary in [False, True]: - _check_roundtrip(table, use_dictionary=use_dictionary) - del arr, table - - -@pytest.mark.large_memory -def test_large_binary_overflow(): - s = b'x' * (1 << 31) - arr = pa.array([s], type=pa.large_binary()) - table = pa.Table.from_arrays([arr], names=['strs']) - for use_dictionary in [False, True]: - writer = pa.BufferOutputStream() - with pytest.raises( - pa.ArrowInvalid, - match="Parquet cannot store strings with size 2GB or more"): - _write_table(table, writer, use_dictionary=use_dictionary) - - -@parametrize_legacy_dataset -@pytest.mark.parametrize('dtype', [int, float]) -def test_single_pylist_column_roundtrip(tempdir, dtype, use_legacy_dataset): - filename = tempdir / 'single_{}_column.parquet'.format(dtype.__name__) - data = [pa.array(list(map(dtype, range(5))))] - table = pa.Table.from_arrays(data, names=['a']) - _write_table(table, filename) - table_read = _read_table(filename, use_legacy_dataset=use_legacy_dataset) - for i in range(table.num_columns): - col_written = table[i] - col_read = table_read[i] - assert table.field(i).name == table_read.field(i).name - assert col_read.num_chunks == 1 - data_written = col_written.chunk(0) - data_read = col_read.chunk(0) - assert data_written.equals(data_read) - - -def alltypes_sample(size=10000, seed=0, categorical=False): - np.random.seed(seed) - arrays = { - 'uint8': np.arange(size, dtype=np.uint8), - 'uint16': np.arange(size, dtype=np.uint16), - 'uint32': np.arange(size, dtype=np.uint32), - 'uint64': np.arange(size, dtype=np.uint64), - 'int8': np.arange(size, dtype=np.int16), - 'int16': np.arange(size, dtype=np.int16), - 'int32': np.arange(size, dtype=np.int32), - 'int64': np.arange(size, dtype=np.int64), - 'float32': np.arange(size, dtype=np.float32), - 'float64': np.arange(size, dtype=np.float64), - 'bool': np.random.randn(size) > 0, - # TODO(wesm): Test other timestamp resolutions now that arrow supports - # them - 'datetime': np.arange("2016-01-01T00:00:00.001", size, - dtype='datetime64[ms]'), - 'str': pd.Series([str(x) for x in range(size)]), - 'empty_str': [''] * size, - 'str_with_nulls': [None] + [str(x) for x in range(size - 2)] + [None], - 'null': [None] * size, - 'null_list': [None] * 2 + [[None] * (x % 4) for x in range(size - 2)], - } - if categorical: - arrays['str_category'] = arrays['str'].astype('category') - return pd.DataFrame(arrays) - - -@pytest.mark.pandas -@parametrize_legacy_dataset -@pytest.mark.parametrize('chunk_size', [None, 1000]) -def test_pandas_parquet_2_0_roundtrip(tempdir, chunk_size, use_legacy_dataset): - df = alltypes_sample(size=10000, categorical=True) - - filename = tempdir / 'pandas_roundtrip.parquet' - arrow_table = pa.Table.from_pandas(df) - assert arrow_table.schema.pandas_metadata is not None - - _write_table(arrow_table, filename, version="2.0", - coerce_timestamps='ms', chunk_size=chunk_size) - table_read = pq.read_pandas( - filename, use_legacy_dataset=use_legacy_dataset) - assert table_read.schema.pandas_metadata is not None - - read_metadata = table_read.schema.metadata - assert arrow_table.schema.metadata == read_metadata - - df_read = table_read.to_pandas() - tm.assert_frame_equal(df, df_read) - - -def test_parquet_invalid_version(tempdir): - table = pa.table({'a': [1, 2, 3]}) - with pytest.raises(ValueError, match="Unsupported Parquet format version"): - _write_table(table, tempdir / 'test_version.parquet', version="2.2") - with pytest.raises(ValueError, match="Unsupported Parquet data page " + - "version"): - _write_table(table, tempdir / 'test_version.parquet', - data_page_version="2.2") - - -@parametrize_legacy_dataset -def test_set_data_page_size(use_legacy_dataset): - arr = pa.array([1, 2, 3] * 100000) - t = pa.Table.from_arrays([arr], names=['f0']) - - # 128K, 512K - page_sizes = [2 << 16, 2 << 18] - for target_page_size in page_sizes: - _check_roundtrip(t, data_page_size=target_page_size, - use_legacy_dataset=use_legacy_dataset) - - -@pytest.mark.pandas -@parametrize_legacy_dataset -def test_chunked_table_write(use_legacy_dataset): - # ARROW-232 - tables = [] - batch = pa.RecordBatch.from_pandas(alltypes_sample(size=10)) - tables.append(pa.Table.from_batches([batch] * 3)) - df, _ = dataframe_with_lists() - batch = pa.RecordBatch.from_pandas(df) - tables.append(pa.Table.from_batches([batch] * 3)) - - for data_page_version in ['1.0', '2.0']: - for use_dictionary in [True, False]: - for table in tables: - _check_roundtrip( - table, version='2.0', - use_legacy_dataset=use_legacy_dataset, - data_page_version=data_page_version, - use_dictionary=use_dictionary) - - -@pytest.mark.pandas -@parametrize_legacy_dataset -def test_memory_map(tempdir, use_legacy_dataset): - df = alltypes_sample(size=10) - - table = pa.Table.from_pandas(df) - _check_roundtrip(table, read_table_kwargs={'memory_map': True}, - version='2.0', use_legacy_dataset=use_legacy_dataset) - - filename = str(tempdir / 'tmp_file') - with open(filename, 'wb') as f: - _write_table(table, f, version='2.0') - table_read = pq.read_pandas(filename, memory_map=True, - use_legacy_dataset=use_legacy_dataset) - assert table_read.equals(table) - - -@pytest.mark.pandas -@parametrize_legacy_dataset -def test_enable_buffered_stream(tempdir, use_legacy_dataset): - df = alltypes_sample(size=10) - - table = pa.Table.from_pandas(df) - _check_roundtrip(table, read_table_kwargs={'buffer_size': 1025}, - version='2.0', use_legacy_dataset=use_legacy_dataset) - - filename = str(tempdir / 'tmp_file') - with open(filename, 'wb') as f: - _write_table(table, f, version='2.0') - table_read = pq.read_pandas(filename, buffer_size=4096, - use_legacy_dataset=use_legacy_dataset) - assert table_read.equals(table) - - -@parametrize_legacy_dataset -def test_special_chars_filename(tempdir, use_legacy_dataset): - table = pa.Table.from_arrays([pa.array([42])], ["ints"]) - filename = "foo # bar" - path = tempdir / filename - assert not path.exists() - _write_table(table, str(path)) - assert path.exists() - table_read = _read_table(str(path), use_legacy_dataset=use_legacy_dataset) - assert table_read.equals(table) - - -@pytest.mark.slow -def test_file_with_over_int16_max_row_groups(): - # PARQUET-1857: Parquet encryption support introduced a INT16_MAX upper - # limit on the number of row groups, but this limit only impacts files with - # encrypted row group metadata because of the int16 row group ordinal used - # in the Parquet Thrift metadata. Unencrypted files are not impacted, so - # this test checks that it works (even if it isn't a good idea) - t = pa.table([list(range(40000))], names=['f0']) - _check_roundtrip(t, row_group_size=1) - - -@pytest.mark.pandas -@parametrize_legacy_dataset -def test_empty_table_roundtrip(use_legacy_dataset): - df = alltypes_sample(size=10) - - # Create a non-empty table to infer the types correctly, then slice to 0 - table = pa.Table.from_pandas(df) - table = pa.Table.from_arrays( - [col.chunk(0)[:0] for col in table.itercolumns()], - names=table.schema.names) - - assert table.schema.field('null').type == pa.null() - assert table.schema.field('null_list').type == pa.list_(pa.null()) - _check_roundtrip( - table, version='2.0', use_legacy_dataset=use_legacy_dataset) - - -@pytest.mark.pandas -@parametrize_legacy_dataset -def test_empty_table_no_columns(use_legacy_dataset): - df = pd.DataFrame() - empty = pa.Table.from_pandas(df, preserve_index=False) - _check_roundtrip(empty, use_legacy_dataset=use_legacy_dataset) - - -@parametrize_legacy_dataset -def test_empty_lists_table_roundtrip(use_legacy_dataset): - # ARROW-2744: Shouldn't crash when writing an array of empty lists - arr = pa.array([[], []], type=pa.list_(pa.int32())) - table = pa.Table.from_arrays([arr], ["A"]) - _check_roundtrip(table, use_legacy_dataset=use_legacy_dataset) - - -@parametrize_legacy_dataset -def test_nested_list_nonnullable_roundtrip_bug(use_legacy_dataset): - # Reproduce failure in ARROW-5630 - typ = pa.list_(pa.field("item", pa.float32(), False)) - num_rows = 10000 - t = pa.table([ - pa.array(([[0] * ((i + 5) % 10) for i in range(0, 10)] * - (num_rows // 10)), type=typ) - ], ['a']) - _check_roundtrip( - t, data_page_size=4096, use_legacy_dataset=use_legacy_dataset) - - -@pytest.mark.pandas -@parametrize_legacy_dataset -def test_pandas_parquet_datetime_tz(use_legacy_dataset): - s = pd.Series([datetime.datetime(2017, 9, 6)]) - s = s.dt.tz_localize('utc') - - s.index = s - - # Both a column and an index to hit both use cases - df = pd.DataFrame({'tz_aware': s, - 'tz_eastern': s.dt.tz_convert('US/Eastern')}, - index=s) - - f = io.BytesIO() - - arrow_table = pa.Table.from_pandas(df) - - _write_table(arrow_table, f, coerce_timestamps='ms') - f.seek(0) - - table_read = pq.read_pandas(f, use_legacy_dataset=use_legacy_dataset) - - df_read = table_read.to_pandas() - tm.assert_frame_equal(df, df_read) - - -@pytest.mark.pandas -@parametrize_legacy_dataset -def test_datetime_timezone_tzinfo(use_legacy_dataset): - value = datetime.datetime(2018, 1, 1, 1, 23, 45, - tzinfo=datetime.timezone.utc) - df = pd.DataFrame({'foo': [value]}) - - _roundtrip_pandas_dataframe( - df, write_kwargs={}, use_legacy_dataset=use_legacy_dataset) - - -@pytest.mark.pandas -def test_pandas_parquet_custom_metadata(tempdir): - df = alltypes_sample(size=10000) - - filename = tempdir / 'pandas_roundtrip.parquet' - arrow_table = pa.Table.from_pandas(df) - assert b'pandas' in arrow_table.schema.metadata - - _write_table(arrow_table, filename, version='2.0', coerce_timestamps='ms') - - metadata = pq.read_metadata(filename).metadata - assert b'pandas' in metadata - - js = json.loads(metadata[b'pandas'].decode('utf8')) - assert js['index_columns'] == [{'kind': 'range', - 'name': None, - 'start': 0, 'stop': 10000, - 'step': 1}] - - -@pytest.mark.pandas -@parametrize_legacy_dataset -def test_pandas_parquet_column_multiindex(tempdir, use_legacy_dataset): - df = alltypes_sample(size=10) - df.columns = pd.MultiIndex.from_tuples( - list(zip(df.columns, df.columns[::-1])), - names=['level_1', 'level_2'] - ) - - filename = tempdir / 'pandas_roundtrip.parquet' - arrow_table = pa.Table.from_pandas(df) - assert arrow_table.schema.pandas_metadata is not None - - _write_table(arrow_table, filename, version='2.0', coerce_timestamps='ms') - - table_read = pq.read_pandas( - filename, use_legacy_dataset=use_legacy_dataset) - df_read = table_read.to_pandas() - tm.assert_frame_equal(df, df_read) - - -@pytest.mark.pandas -@parametrize_legacy_dataset -def test_pandas_parquet_2_0_roundtrip_read_pandas_no_index_written( - tempdir, use_legacy_dataset -): - df = alltypes_sample(size=10000) - - filename = tempdir / 'pandas_roundtrip.parquet' - arrow_table = pa.Table.from_pandas(df, preserve_index=False) - js = arrow_table.schema.pandas_metadata - assert not js['index_columns'] - # ARROW-2170 - # While index_columns should be empty, columns needs to be filled still. - assert js['columns'] - - _write_table(arrow_table, filename, version='2.0', coerce_timestamps='ms') - table_read = pq.read_pandas( - filename, use_legacy_dataset=use_legacy_dataset) - - js = table_read.schema.pandas_metadata - assert not js['index_columns'] - - read_metadata = table_read.schema.metadata - assert arrow_table.schema.metadata == read_metadata - - df_read = table_read.to_pandas() - tm.assert_frame_equal(df, df_read) - - -@pytest.mark.pandas -@parametrize_legacy_dataset -def test_pandas_parquet_1_0_roundtrip(tempdir, use_legacy_dataset): - size = 10000 - np.random.seed(0) - df = pd.DataFrame({ - 'uint8': np.arange(size, dtype=np.uint8), - 'uint16': np.arange(size, dtype=np.uint16), - 'uint32': np.arange(size, dtype=np.uint32), - 'uint64': np.arange(size, dtype=np.uint64), - 'int8': np.arange(size, dtype=np.int16), - 'int16': np.arange(size, dtype=np.int16), - 'int32': np.arange(size, dtype=np.int32), - 'int64': np.arange(size, dtype=np.int64), - 'float32': np.arange(size, dtype=np.float32), - 'float64': np.arange(size, dtype=np.float64), - 'bool': np.random.randn(size) > 0, - 'str': [str(x) for x in range(size)], - 'str_with_nulls': [None] + [str(x) for x in range(size - 2)] + [None], - 'empty_str': [''] * size - }) - filename = tempdir / 'pandas_roundtrip.parquet' - arrow_table = pa.Table.from_pandas(df) - _write_table(arrow_table, filename, version='1.0') - table_read = _read_table(filename, use_legacy_dataset=use_legacy_dataset) - df_read = table_read.to_pandas() - - # We pass uint32_t as int64_t if we write Parquet version 1.0 - df['uint32'] = df['uint32'].values.astype(np.int64) - - tm.assert_frame_equal(df, df_read) - - -@pytest.mark.pandas -@parametrize_legacy_dataset -def test_multiple_path_types(tempdir, use_legacy_dataset): - # Test compatibility with PEP 519 path-like objects - path = tempdir / 'zzz.parquet' - df = pd.DataFrame({'x': np.arange(10, dtype=np.int64)}) - _write_table(df, path) - table_read = _read_table(path, use_legacy_dataset=use_legacy_dataset) - df_read = table_read.to_pandas() - tm.assert_frame_equal(df, df_read) - - # Test compatibility with plain string paths - path = str(tempdir) + 'zzz.parquet' - df = pd.DataFrame({'x': np.arange(10, dtype=np.int64)}) - _write_table(df, path) - table_read = _read_table(path, use_legacy_dataset=use_legacy_dataset) - df_read = table_read.to_pandas() - tm.assert_frame_equal(df, df_read) - - -@pytest.mark.dataset -@parametrize_legacy_dataset -@pytest.mark.parametrize("filesystem", [ - None, fs.LocalFileSystem(), LocalFileSystem._get_instance() -]) -def test_relative_paths(tempdir, use_legacy_dataset, filesystem): - # reading and writing from relative paths - table = pa.table({"a": [1, 2, 3]}) - - # reading - pq.write_table(table, str(tempdir / "data.parquet")) - with util.change_cwd(tempdir): - result = pq.read_table("data.parquet", filesystem=filesystem, - use_legacy_dataset=use_legacy_dataset) - assert result.equals(table) - - # writing - with util.change_cwd(tempdir): - pq.write_table(table, "data2.parquet", filesystem=filesystem) - result = pq.read_table(tempdir / "data2.parquet") - assert result.equals(table) - - -@parametrize_legacy_dataset_fixed -def test_filesystem_uri(tempdir, use_legacy_dataset): - table = pa.table({"a": [1, 2, 3]}) - - directory = tempdir / "data_dir" - directory.mkdir() - path = directory / "data.parquet" - pq.write_table(table, str(path)) - - # filesystem object - result = pq.read_table( - path, filesystem=fs.LocalFileSystem(), - use_legacy_dataset=use_legacy_dataset) - assert result.equals(table) - - # filesystem URI - result = pq.read_table( - "data_dir/data.parquet", filesystem=util._filesystem_uri(tempdir), - use_legacy_dataset=use_legacy_dataset) - assert result.equals(table) - - -@parametrize_legacy_dataset -def test_read_non_existing_file(use_legacy_dataset): - # ensure we have a proper error message - with pytest.raises(FileNotFoundError): - pq.read_table('i-am-not-existing.parquet') - - -# TODO(dataset) duplicate column selection actually gives duplicate columns now -@pytest.mark.pandas -@parametrize_legacy_dataset_not_supported -def test_pandas_column_selection(tempdir, use_legacy_dataset): - size = 10000 - np.random.seed(0) - df = pd.DataFrame({ - 'uint8': np.arange(size, dtype=np.uint8), - 'uint16': np.arange(size, dtype=np.uint16) - }) - filename = tempdir / 'pandas_roundtrip.parquet' - arrow_table = pa.Table.from_pandas(df) - _write_table(arrow_table, filename) - table_read = _read_table( - filename, columns=['uint8'], use_legacy_dataset=use_legacy_dataset) - df_read = table_read.to_pandas() - - tm.assert_frame_equal(df[['uint8']], df_read) - - # ARROW-4267: Selection of duplicate columns still leads to these columns - # being read uniquely. - table_read = _read_table( - filename, columns=['uint8', 'uint8'], - use_legacy_dataset=use_legacy_dataset) - df_read = table_read.to_pandas() - - tm.assert_frame_equal(df[['uint8']], df_read) - - -def _random_integers(size, dtype): - # We do not generate integers outside the int64 range - platform_int_info = np.iinfo('int_') - iinfo = np.iinfo(dtype) - return np.random.randint(max(iinfo.min, platform_int_info.min), - min(iinfo.max, platform_int_info.max), - size=size).astype(dtype) - - -def _test_dataframe(size=10000, seed=0): - np.random.seed(seed) - df = pd.DataFrame({ - 'uint8': _random_integers(size, np.uint8), - 'uint16': _random_integers(size, np.uint16), - 'uint32': _random_integers(size, np.uint32), - 'uint64': _random_integers(size, np.uint64), - 'int8': _random_integers(size, np.int8), - 'int16': _random_integers(size, np.int16), - 'int32': _random_integers(size, np.int32), - 'int64': _random_integers(size, np.int64), - 'float32': np.random.randn(size).astype(np.float32), - 'float64': np.arange(size, dtype=np.float64), - 'bool': np.random.randn(size) > 0, - 'strings': [util.rands(10) for i in range(size)], - 'all_none': [None] * size, - 'all_none_category': [None] * size - }) - # TODO(PARQUET-1015) - # df['all_none_category'] = df['all_none_category'].astype('category') - return df - - -@pytest.mark.pandas -@parametrize_legacy_dataset -def test_pandas_parquet_native_file_roundtrip(tempdir, use_legacy_dataset): - df = _test_dataframe(10000) - arrow_table = pa.Table.from_pandas(df) - imos = pa.BufferOutputStream() - _write_table(arrow_table, imos, version="2.0") - buf = imos.getvalue() - reader = pa.BufferReader(buf) - df_read = _read_table( - reader, use_legacy_dataset=use_legacy_dataset).to_pandas() - tm.assert_frame_equal(df, df_read) - - -@parametrize_legacy_dataset -def test_parquet_read_from_buffer(tempdir, use_legacy_dataset): - # reading from a buffer from python's open() - table = pa.table({"a": [1, 2, 3]}) - pq.write_table(table, str(tempdir / "data.parquet")) - - with open(str(tempdir / "data.parquet"), "rb") as f: - result = pq.read_table(f, use_legacy_dataset=use_legacy_dataset) - assert result.equals(table) - - with open(str(tempdir / "data.parquet"), "rb") as f: - result = pq.read_table(pa.PythonFile(f), - use_legacy_dataset=use_legacy_dataset) - assert result.equals(table) - - -@pytest.mark.pandas -@parametrize_legacy_dataset -def test_parquet_incremental_file_build(tempdir, use_legacy_dataset): - df = _test_dataframe(100) - df['unique_id'] = 0 - - arrow_table = pa.Table.from_pandas(df, preserve_index=False) - out = pa.BufferOutputStream() - - writer = pq.ParquetWriter(out, arrow_table.schema, version='2.0') - - frames = [] - for i in range(10): - df['unique_id'] = i - arrow_table = pa.Table.from_pandas(df, preserve_index=False) - writer.write_table(arrow_table) - - frames.append(df.copy()) - - writer.close() - - buf = out.getvalue() - result = _read_table( - pa.BufferReader(buf), use_legacy_dataset=use_legacy_dataset) - - expected = pd.concat(frames, ignore_index=True) - tm.assert_frame_equal(result.to_pandas(), expected) - - -@pytest.mark.pandas -@parametrize_legacy_dataset -def test_read_pandas_column_subset(tempdir, use_legacy_dataset): - df = _test_dataframe(10000) - arrow_table = pa.Table.from_pandas(df) - imos = pa.BufferOutputStream() - _write_table(arrow_table, imos, version="2.0") - buf = imos.getvalue() - reader = pa.BufferReader(buf) - df_read = pq.read_pandas( - reader, columns=['strings', 'uint8'], - use_legacy_dataset=use_legacy_dataset - ).to_pandas() - tm.assert_frame_equal(df[['strings', 'uint8']], df_read) - - -@pytest.mark.pandas -@parametrize_legacy_dataset -def test_pandas_parquet_empty_roundtrip(tempdir, use_legacy_dataset): - df = _test_dataframe(0) - arrow_table = pa.Table.from_pandas(df) - imos = pa.BufferOutputStream() - _write_table(arrow_table, imos, version="2.0") - buf = imos.getvalue() - reader = pa.BufferReader(buf) - df_read = _read_table( - reader, use_legacy_dataset=use_legacy_dataset).to_pandas() - tm.assert_frame_equal(df, df_read) - - -@pytest.mark.pandas -def test_pandas_can_write_nested_data(tempdir): - data = { - "agg_col": [ - {"page_type": 1}, - {"record_type": 1}, - {"non_consecutive_home": 0}, - ], - "uid_first": "1001" - } - df = pd.DataFrame(data=data) - arrow_table = pa.Table.from_pandas(df) - imos = pa.BufferOutputStream() - # This succeeds under V2 - _write_table(arrow_table, imos) - - -@pytest.mark.pandas -@parametrize_legacy_dataset -def test_pandas_parquet_pyfile_roundtrip(tempdir, use_legacy_dataset): - filename = tempdir / 'pandas_pyfile_roundtrip.parquet' - size = 5 - df = pd.DataFrame({ - 'int64': np.arange(size, dtype=np.int64), - 'float32': np.arange(size, dtype=np.float32), - 'float64': np.arange(size, dtype=np.float64), - 'bool': np.random.randn(size) > 0, - 'strings': ['foo', 'bar', None, 'baz', 'qux'] - }) - - arrow_table = pa.Table.from_pandas(df) - - with filename.open('wb') as f: - _write_table(arrow_table, f, version="1.0") - - data = io.BytesIO(filename.read_bytes()) - - table_read = _read_table(data, use_legacy_dataset=use_legacy_dataset) - df_read = table_read.to_pandas() - tm.assert_frame_equal(df, df_read) - - -@pytest.mark.pandas -@parametrize_legacy_dataset -def test_pandas_parquet_configuration_options(tempdir, use_legacy_dataset): - size = 10000 - np.random.seed(0) - df = pd.DataFrame({ - 'uint8': np.arange(size, dtype=np.uint8), - 'uint16': np.arange(size, dtype=np.uint16), - 'uint32': np.arange(size, dtype=np.uint32), - 'uint64': np.arange(size, dtype=np.uint64), - 'int8': np.arange(size, dtype=np.int16), - 'int16': np.arange(size, dtype=np.int16), - 'int32': np.arange(size, dtype=np.int32), - 'int64': np.arange(size, dtype=np.int64), - 'float32': np.arange(size, dtype=np.float32), - 'float64': np.arange(size, dtype=np.float64), - 'bool': np.random.randn(size) > 0 - }) - filename = tempdir / 'pandas_roundtrip.parquet' - arrow_table = pa.Table.from_pandas(df) - - for use_dictionary in [True, False]: - _write_table(arrow_table, filename, version='2.0', - use_dictionary=use_dictionary) - table_read = _read_table( - filename, use_legacy_dataset=use_legacy_dataset) - df_read = table_read.to_pandas() - tm.assert_frame_equal(df, df_read) - - for write_statistics in [True, False]: - _write_table(arrow_table, filename, version='2.0', - write_statistics=write_statistics) - table_read = _read_table(filename, - use_legacy_dataset=use_legacy_dataset) - df_read = table_read.to_pandas() - tm.assert_frame_equal(df, df_read) - - for compression in ['NONE', 'SNAPPY', 'GZIP', 'LZ4', 'ZSTD']: - if (compression != 'NONE' and - not pa.lib.Codec.is_available(compression)): - continue - _write_table(arrow_table, filename, version='2.0', - compression=compression) - table_read = _read_table( - filename, use_legacy_dataset=use_legacy_dataset) - df_read = table_read.to_pandas() - tm.assert_frame_equal(df, df_read) - - -def make_sample_file(table_or_df): - if isinstance(table_or_df, pa.Table): - a_table = table_or_df - else: - a_table = pa.Table.from_pandas(table_or_df) - - buf = io.BytesIO() - _write_table(a_table, buf, compression='SNAPPY', version='2.0', - coerce_timestamps='ms') - - buf.seek(0) - return pq.ParquetFile(buf) - - -@parametrize_legacy_dataset -def test_byte_stream_split(use_legacy_dataset): - # This is only a smoke test. - arr_float = pa.array(list(map(float, range(100)))) - arr_int = pa.array(list(map(int, range(100)))) - data_float = [arr_float, arr_float] - table = pa.Table.from_arrays(data_float, names=['a', 'b']) - - # Check with byte_stream_split for both columns. - _check_roundtrip(table, expected=table, compression="gzip", - use_dictionary=False, use_byte_stream_split=True) - - # Check with byte_stream_split for column 'b' and dictionary - # for column 'a'. - _check_roundtrip(table, expected=table, compression="gzip", - use_dictionary=['a'], - use_byte_stream_split=['b']) - - # Check with a collision for both columns. - _check_roundtrip(table, expected=table, compression="gzip", - use_dictionary=['a', 'b'], - use_byte_stream_split=['a', 'b']) - - # Check with mixed column types. - mixed_table = pa.Table.from_arrays([arr_float, arr_int], - names=['a', 'b']) - _check_roundtrip(mixed_table, expected=mixed_table, - use_dictionary=['b'], - use_byte_stream_split=['a']) - - # Try to use the wrong data type with the byte_stream_split encoding. - # This should throw an exception. - table = pa.Table.from_arrays([arr_int], names=['tmp']) - with pytest.raises(IOError): - _check_roundtrip(table, expected=table, use_byte_stream_split=True, - use_dictionary=False, - use_legacy_dataset=use_legacy_dataset) - - -@parametrize_legacy_dataset -def test_compression_level(use_legacy_dataset): - arr = pa.array(list(map(int, range(1000)))) - data = [arr, arr] - table = pa.Table.from_arrays(data, names=['a', 'b']) - - # Check one compression level. - _check_roundtrip(table, expected=table, compression="gzip", - compression_level=1, - use_legacy_dataset=use_legacy_dataset) - - # Check another one to make sure that compression_level=1 does not - # coincide with the default one in Arrow. - _check_roundtrip(table, expected=table, compression="gzip", - compression_level=5, - use_legacy_dataset=use_legacy_dataset) - - # Check that the user can provide a compression per column - _check_roundtrip(table, expected=table, - compression={'a': "gzip", 'b': "snappy"}, - use_legacy_dataset=use_legacy_dataset) - - # Check that the user can provide a compression level per column - _check_roundtrip(table, expected=table, compression="gzip", - compression_level={'a': 2, 'b': 3}, - use_legacy_dataset=use_legacy_dataset) - - # Check that specifying a compression level for a codec which does allow - # specifying one, results into an error. - # Uncompressed, snappy, lz4 and lzo do not support specifying a compression - # level. - # GZIP (zlib) allows for specifying a compression level but as of up - # to version 1.2.11 the valid range is [-1, 9]. - invalid_combinations = [("snappy", 4), ("lz4", 5), ("gzip", -1337), - ("None", 444), ("lzo", 14)] - buf = io.BytesIO() - for (codec, level) in invalid_combinations: - with pytest.raises((ValueError, OSError)): - _write_table(table, buf, compression=codec, - compression_level=level) - - -@pytest.mark.pandas -def test_parquet_metadata_api(): - df = alltypes_sample(size=10000) - df = df.reindex(columns=sorted(df.columns)) - df.index = np.random.randint(0, 1000000, size=len(df)) - - fileh = make_sample_file(df) - ncols = len(df.columns) - - # Series of sniff tests - meta = fileh.metadata - repr(meta) - assert meta.num_rows == len(df) - assert meta.num_columns == ncols + 1 # +1 for index - assert meta.num_row_groups == 1 - assert meta.format_version == '2.0' - assert 'parquet-cpp' in meta.created_by - assert isinstance(meta.serialized_size, int) - assert isinstance(meta.metadata, dict) - - # Schema - schema = fileh.schema - assert meta.schema is schema - assert len(schema) == ncols + 1 # +1 for index - repr(schema) - - col = schema[0] - repr(col) - assert col.name == df.columns[0] - assert col.max_definition_level == 1 - assert col.max_repetition_level == 0 - assert col.max_repetition_level == 0 - - assert col.physical_type == 'BOOLEAN' - assert col.converted_type == 'NONE' - - with pytest.raises(IndexError): - schema[ncols + 1] # +1 for index - - with pytest.raises(IndexError): - schema[-1] - - # Row group - for rg in range(meta.num_row_groups): - rg_meta = meta.row_group(rg) - assert isinstance(rg_meta, pq.RowGroupMetaData) - repr(rg_meta) - - for col in range(rg_meta.num_columns): - col_meta = rg_meta.column(col) - assert isinstance(col_meta, pq.ColumnChunkMetaData) - repr(col_meta) - - with pytest.raises(IndexError): - meta.row_group(-1) - - with pytest.raises(IndexError): - meta.row_group(meta.num_row_groups + 1) - - rg_meta = meta.row_group(0) - assert rg_meta.num_rows == len(df) - assert rg_meta.num_columns == ncols + 1 # +1 for index - assert rg_meta.total_byte_size > 0 - - with pytest.raises(IndexError): - col_meta = rg_meta.column(-1) - - with pytest.raises(IndexError): - col_meta = rg_meta.column(ncols + 2) - - col_meta = rg_meta.column(0) - assert col_meta.file_offset > 0 - assert col_meta.file_path == '' # created from BytesIO - assert col_meta.physical_type == 'BOOLEAN' - assert col_meta.num_values == 10000 - assert col_meta.path_in_schema == 'bool' - assert col_meta.is_stats_set is True - assert isinstance(col_meta.statistics, pq.Statistics) - assert col_meta.compression == 'SNAPPY' - assert col_meta.encodings == ('PLAIN', 'RLE') - assert col_meta.has_dictionary_page is False - assert col_meta.dictionary_page_offset is None - assert col_meta.data_page_offset > 0 - assert col_meta.total_compressed_size > 0 - assert col_meta.total_uncompressed_size > 0 - with pytest.raises(NotImplementedError): - col_meta.has_index_page - with pytest.raises(NotImplementedError): - col_meta.index_page_offset - - -def test_parquet_metadata_lifetime(tempdir): - # ARROW-6642 - ensure that chained access keeps parent objects alive - table = pa.table({'a': [1, 2, 3]}) - pq.write_table(table, tempdir / 'test_metadata_segfault.parquet') - dataset = pq.ParquetDataset(tempdir / 'test_metadata_segfault.parquet') - dataset.pieces[0].get_metadata().row_group(0).column(0).statistics - - -@pytest.mark.pandas -@pytest.mark.parametrize( - ( - 'data', - 'type', - 'physical_type', - 'min_value', - 'max_value', - 'null_count', - 'num_values', - 'distinct_count' - ), - [ - ([1, 2, 2, None, 4], pa.uint8(), 'INT32', 1, 4, 1, 4, 0), - ([1, 2, 2, None, 4], pa.uint16(), 'INT32', 1, 4, 1, 4, 0), - ([1, 2, 2, None, 4], pa.uint32(), 'INT32', 1, 4, 1, 4, 0), - ([1, 2, 2, None, 4], pa.uint64(), 'INT64', 1, 4, 1, 4, 0), - ([-1, 2, 2, None, 4], pa.int8(), 'INT32', -1, 4, 1, 4, 0), - ([-1, 2, 2, None, 4], pa.int16(), 'INT32', -1, 4, 1, 4, 0), - ([-1, 2, 2, None, 4], pa.int32(), 'INT32', -1, 4, 1, 4, 0), - ([-1, 2, 2, None, 4], pa.int64(), 'INT64', -1, 4, 1, 4, 0), - ( - [-1.1, 2.2, 2.3, None, 4.4], pa.float32(), - 'FLOAT', -1.1, 4.4, 1, 4, 0 - ), - ( - [-1.1, 2.2, 2.3, None, 4.4], pa.float64(), - 'DOUBLE', -1.1, 4.4, 1, 4, 0 - ), - ( - ['', 'b', chr(1000), None, 'aaa'], pa.binary(), - 'BYTE_ARRAY', b'', chr(1000).encode('utf-8'), 1, 4, 0 - ), - ( - [True, False, False, True, True], pa.bool_(), - 'BOOLEAN', False, True, 0, 5, 0 - ), - ( - [b'\x00', b'b', b'12', None, b'aaa'], pa.binary(), - 'BYTE_ARRAY', b'\x00', b'b', 1, 4, 0 - ), - ] -) -def test_parquet_column_statistics_api(data, type, physical_type, min_value, - max_value, null_count, num_values, - distinct_count): - df = pd.DataFrame({'data': data}) - schema = pa.schema([pa.field('data', type)]) - table = pa.Table.from_pandas(df, schema=schema, safe=False) - fileh = make_sample_file(table) - - meta = fileh.metadata - - rg_meta = meta.row_group(0) - col_meta = rg_meta.column(0) - - stat = col_meta.statistics - assert stat.has_min_max - assert _close(type, stat.min, min_value) - assert _close(type, stat.max, max_value) - assert stat.null_count == null_count - assert stat.num_values == num_values - # TODO(kszucs) until parquet-cpp API doesn't expose HasDistinctCount - # method, missing distinct_count is represented as zero instead of None - assert stat.distinct_count == distinct_count - assert stat.physical_type == physical_type - - -# ARROW-6339 -@pytest.mark.pandas -def test_parquet_raise_on_unset_statistics(): - df = pd.DataFrame({"t": pd.Series([pd.NaT], dtype="datetime64[ns]")}) - meta = make_sample_file(pa.Table.from_pandas(df)).metadata - - assert not meta.row_group(0).column(0).statistics.has_min_max - assert meta.row_group(0).column(0).statistics.max is None - - -def _close(type, left, right): - if type == pa.float32(): - return abs(left - right) < 1E-7 - elif type == pa.float64(): - return abs(left - right) < 1E-13 - else: - return left == right - - -def test_statistics_convert_logical_types(tempdir): - # ARROW-5166, ARROW-4139 - - # (min, max, type) - cases = [(10, 11164359321221007157, pa.uint64()), - (10, 4294967295, pa.uint32()), - ("ähnlich", "öffentlich", pa.utf8()), - (datetime.time(10, 30, 0, 1000), datetime.time(15, 30, 0, 1000), - pa.time32('ms')), - (datetime.time(10, 30, 0, 1000), datetime.time(15, 30, 0, 1000), - pa.time64('us')), - (datetime.datetime(2019, 6, 24, 0, 0, 0, 1000), - datetime.datetime(2019, 6, 25, 0, 0, 0, 1000), - pa.timestamp('ms')), - (datetime.datetime(2019, 6, 24, 0, 0, 0, 1000), - datetime.datetime(2019, 6, 25, 0, 0, 0, 1000), - pa.timestamp('us'))] - - for i, (min_val, max_val, typ) in enumerate(cases): - t = pa.Table.from_arrays([pa.array([min_val, max_val], type=typ)], - ['col']) - path = str(tempdir / ('example{}.parquet'.format(i))) - pq.write_table(t, path, version='2.0') - pf = pq.ParquetFile(path) - stats = pf.metadata.row_group(0).column(0).statistics - assert stats.min == min_val - assert stats.max == max_val - - -def test_parquet_metadata_empty_to_dict(tempdir): - # https://issues.apache.org/jira/browse/ARROW-10146 - table = pa.table({"a": pa.array([], type="int64")}) - pq.write_table(table, tempdir / "data.parquet") - metadata = pq.read_metadata(tempdir / "data.parquet") - # ensure this doesn't error / statistics set to None - metadata_dict = metadata.to_dict() - assert len(metadata_dict["row_groups"]) == 1 - assert len(metadata_dict["row_groups"][0]["columns"]) == 1 - assert metadata_dict["row_groups"][0]["columns"][0]["statistics"] is None - - -def test_parquet_write_disable_statistics(tempdir): - table = pa.Table.from_pydict( - OrderedDict([ - ('a', pa.array([1, 2, 3])), - ('b', pa.array(['a', 'b', 'c'])) - ]) - ) - _write_table(table, tempdir / 'data.parquet') - meta = pq.read_metadata(tempdir / 'data.parquet') - for col in [0, 1]: - cc = meta.row_group(0).column(col) - assert cc.is_stats_set is True - assert cc.statistics is not None - - _write_table(table, tempdir / 'data2.parquet', write_statistics=False) - meta = pq.read_metadata(tempdir / 'data2.parquet') - for col in [0, 1]: - cc = meta.row_group(0).column(col) - assert cc.is_stats_set is False - assert cc.statistics is None - - _write_table(table, tempdir / 'data3.parquet', write_statistics=['a']) - meta = pq.read_metadata(tempdir / 'data3.parquet') - cc_a = meta.row_group(0).column(0) - cc_b = meta.row_group(0).column(1) - assert cc_a.is_stats_set is True - assert cc_b.is_stats_set is False - assert cc_a.statistics is not None - assert cc_b.statistics is None - - -@pytest.mark.pandas -def test_compare_schemas(): - df = alltypes_sample(size=10000) - - fileh = make_sample_file(df) - fileh2 = make_sample_file(df) - fileh3 = make_sample_file(df[df.columns[::2]]) - - # ParquetSchema - assert isinstance(fileh.schema, pq.ParquetSchema) - assert fileh.schema.equals(fileh.schema) - assert fileh.schema == fileh.schema - assert fileh.schema.equals(fileh2.schema) - assert fileh.schema == fileh2.schema - assert fileh.schema != 'arbitrary object' - assert not fileh.schema.equals(fileh3.schema) - assert fileh.schema != fileh3.schema - - # ColumnSchema - assert isinstance(fileh.schema[0], pq.ColumnSchema) - assert fileh.schema[0].equals(fileh.schema[0]) - assert fileh.schema[0] == fileh.schema[0] - assert not fileh.schema[0].equals(fileh.schema[1]) - assert fileh.schema[0] != fileh.schema[1] - assert fileh.schema[0] != 'arbitrary object' - - -def test_validate_schema_write_table(tempdir): - # ARROW-2926 - simple_fields = [ - pa.field('POS', pa.uint32()), - pa.field('desc', pa.string()) - ] - - simple_schema = pa.schema(simple_fields) - - # simple_table schema does not match simple_schema - simple_from_array = [pa.array([1]), pa.array(['bla'])] - simple_table = pa.Table.from_arrays(simple_from_array, ['POS', 'desc']) - - path = tempdir / 'simple_validate_schema.parquet' - - with pq.ParquetWriter(path, simple_schema, - version='2.0', - compression='snappy', flavor='spark') as w: - with pytest.raises(ValueError): - w.write_table(simple_table) - - -@pytest.mark.pandas -def test_column_of_arrays(tempdir): - df, schema = dataframe_with_arrays() - - filename = tempdir / 'pandas_roundtrip.parquet' - arrow_table = pa.Table.from_pandas(df, schema=schema) - _write_table(arrow_table, filename, version="2.0", coerce_timestamps='ms') - table_read = _read_table(filename) - df_read = table_read.to_pandas() - tm.assert_frame_equal(df, df_read) - - -@pytest.mark.pandas -def test_coerce_timestamps(tempdir): - from collections import OrderedDict - # ARROW-622 - arrays = OrderedDict() - fields = [pa.field('datetime64', - pa.list_(pa.timestamp('ms')))] - arrays['datetime64'] = [ - np.array(['2007-07-13T01:23:34.123456789', - None, - '2010-08-13T05:46:57.437699912'], - dtype='datetime64[ms]'), - None, - None, - np.array(['2007-07-13T02', - None, - '2010-08-13T05:46:57.437699912'], - dtype='datetime64[ms]'), - ] - - df = pd.DataFrame(arrays) - schema = pa.schema(fields) - - filename = tempdir / 'pandas_roundtrip.parquet' - arrow_table = pa.Table.from_pandas(df, schema=schema) - - _write_table(arrow_table, filename, version="2.0", coerce_timestamps='us') - table_read = _read_table(filename) - df_read = table_read.to_pandas() - - df_expected = df.copy() - for i, x in enumerate(df_expected['datetime64']): - if isinstance(x, np.ndarray): - df_expected['datetime64'][i] = x.astype('M8[us]') - - tm.assert_frame_equal(df_expected, df_read) - - with pytest.raises(ValueError): - _write_table(arrow_table, filename, version='2.0', - coerce_timestamps='unknown') - - -@pytest.mark.pandas -def test_coerce_timestamps_truncated(tempdir): - """ - ARROW-2555: Test that we can truncate timestamps when coercing if - explicitly allowed. - """ - dt_us = datetime.datetime(year=2017, month=1, day=1, hour=1, minute=1, - second=1, microsecond=1) - dt_ms = datetime.datetime(year=2017, month=1, day=1, hour=1, minute=1, - second=1) - - fields_us = [pa.field('datetime64', pa.timestamp('us'))] - arrays_us = {'datetime64': [dt_us, dt_ms]} - - df_us = pd.DataFrame(arrays_us) - schema_us = pa.schema(fields_us) - - filename = tempdir / 'pandas_truncated.parquet' - table_us = pa.Table.from_pandas(df_us, schema=schema_us) - - _write_table(table_us, filename, version="2.0", coerce_timestamps='ms', - allow_truncated_timestamps=True) - table_ms = _read_table(filename) - df_ms = table_ms.to_pandas() - - arrays_expected = {'datetime64': [dt_ms, dt_ms]} - df_expected = pd.DataFrame(arrays_expected) - tm.assert_frame_equal(df_expected, df_ms) - - -@pytest.mark.pandas -def test_column_of_lists(tempdir): - df, schema = dataframe_with_lists(parquet_compatible=True) - - filename = tempdir / 'pandas_roundtrip.parquet' - arrow_table = pa.Table.from_pandas(df, schema=schema) - _write_table(arrow_table, filename, version='2.0') - table_read = _read_table(filename) - df_read = table_read.to_pandas() - - tm.assert_frame_equal(df, df_read) - - -@pytest.mark.pandas -def test_date_time_types(tempdir): - t1 = pa.date32() - data1 = np.array([17259, 17260, 17261], dtype='int32') - a1 = pa.array(data1, type=t1) - - t2 = pa.date64() - data2 = data1.astype('int64') * 86400000 - a2 = pa.array(data2, type=t2) - - t3 = pa.timestamp('us') - start = pd.Timestamp('2001-01-01').value / 1000 - data3 = np.array([start, start + 1, start + 2], dtype='int64') - a3 = pa.array(data3, type=t3) - - t4 = pa.time32('ms') - data4 = np.arange(3, dtype='i4') - a4 = pa.array(data4, type=t4) - - t5 = pa.time64('us') - a5 = pa.array(data4.astype('int64'), type=t5) - - t6 = pa.time32('s') - a6 = pa.array(data4, type=t6) - - ex_t6 = pa.time32('ms') - ex_a6 = pa.array(data4 * 1000, type=ex_t6) - - t7 = pa.timestamp('ns') - start = pd.Timestamp('2001-01-01').value - data7 = np.array([start, start + 1000, start + 2000], - dtype='int64') - a7 = pa.array(data7, type=t7) - - table = pa.Table.from_arrays([a1, a2, a3, a4, a5, a6, a7], - ['date32', 'date64', 'timestamp[us]', - 'time32[s]', 'time64[us]', - 'time32_from64[s]', - 'timestamp[ns]']) - - # date64 as date32 - # time32[s] to time32[ms] - expected = pa.Table.from_arrays([a1, a1, a3, a4, a5, ex_a6, a7], - ['date32', 'date64', 'timestamp[us]', - 'time32[s]', 'time64[us]', - 'time32_from64[s]', - 'timestamp[ns]']) - - _check_roundtrip(table, expected=expected, version='2.0') - - t0 = pa.timestamp('ms') - data0 = np.arange(4, dtype='int64') - a0 = pa.array(data0, type=t0) - - t1 = pa.timestamp('us') - data1 = np.arange(4, dtype='int64') - a1 = pa.array(data1, type=t1) - - t2 = pa.timestamp('ns') - data2 = np.arange(4, dtype='int64') - a2 = pa.array(data2, type=t2) - - table = pa.Table.from_arrays([a0, a1, a2], - ['ts[ms]', 'ts[us]', 'ts[ns]']) - expected = pa.Table.from_arrays([a0, a1, a2], - ['ts[ms]', 'ts[us]', 'ts[ns]']) - - # int64 for all timestamps supported by default - filename = tempdir / 'int64_timestamps.parquet' - _write_table(table, filename, version='2.0') - parquet_schema = pq.ParquetFile(filename).schema - for i in range(3): - assert parquet_schema.column(i).physical_type == 'INT64' - read_table = _read_table(filename) - assert read_table.equals(expected) - - t0_ns = pa.timestamp('ns') - data0_ns = np.array(data0 * 1000000, dtype='int64') - a0_ns = pa.array(data0_ns, type=t0_ns) - - t1_ns = pa.timestamp('ns') - data1_ns = np.array(data1 * 1000, dtype='int64') - a1_ns = pa.array(data1_ns, type=t1_ns) - - expected = pa.Table.from_arrays([a0_ns, a1_ns, a2], - ['ts[ms]', 'ts[us]', 'ts[ns]']) - - # int96 nanosecond timestamps produced upon request - filename = tempdir / 'explicit_int96_timestamps.parquet' - _write_table(table, filename, version='2.0', - use_deprecated_int96_timestamps=True) - parquet_schema = pq.ParquetFile(filename).schema - for i in range(3): - assert parquet_schema.column(i).physical_type == 'INT96' - read_table = _read_table(filename) - assert read_table.equals(expected) - - # int96 nanosecond timestamps implied by flavor 'spark' - filename = tempdir / 'spark_int96_timestamps.parquet' - _write_table(table, filename, version='2.0', - flavor='spark') - parquet_schema = pq.ParquetFile(filename).schema - for i in range(3): - assert parquet_schema.column(i).physical_type == 'INT96' - read_table = _read_table(filename) - assert read_table.equals(expected) - - -def test_timestamp_restore_timezone(): - # ARROW-5888, restore timezone from serialized metadata - ty = pa.timestamp('ms', tz='America/New_York') - arr = pa.array([1, 2, 3], type=ty) - t = pa.table([arr], names=['f0']) - _check_roundtrip(t) - - -@pytest.mark.pandas -def test_list_of_datetime_time_roundtrip(): - # ARROW-4135 - times = pd.to_datetime(['09:00', '09:30', '10:00', '10:30', '11:00', - '11:30', '12:00']) - df = pd.DataFrame({'time': [times.time]}) - _roundtrip_pandas_dataframe(df, write_kwargs={}) - - -@pytest.mark.pandas -def test_parquet_version_timestamp_differences(): - i_s = pd.Timestamp('2010-01-01').value / 1000000000 # := 1262304000 - - d_s = np.arange(i_s, i_s + 10, 1, dtype='int64') - d_ms = d_s * 1000 - d_us = d_ms * 1000 - d_ns = d_us * 1000 - - a_s = pa.array(d_s, type=pa.timestamp('s')) - a_ms = pa.array(d_ms, type=pa.timestamp('ms')) - a_us = pa.array(d_us, type=pa.timestamp('us')) - a_ns = pa.array(d_ns, type=pa.timestamp('ns')) - - names = ['ts:s', 'ts:ms', 'ts:us', 'ts:ns'] - table = pa.Table.from_arrays([a_s, a_ms, a_us, a_ns], names) - - # Using Parquet version 1.0, seconds should be coerced to milliseconds - # and nanoseconds should be coerced to microseconds by default - expected = pa.Table.from_arrays([a_ms, a_ms, a_us, a_us], names) - _check_roundtrip(table, expected) - - # Using Parquet version 2.0, seconds should be coerced to milliseconds - # and nanoseconds should be retained by default - expected = pa.Table.from_arrays([a_ms, a_ms, a_us, a_ns], names) - _check_roundtrip(table, expected, version='2.0') - - # Using Parquet version 1.0, coercing to milliseconds or microseconds - # is allowed - expected = pa.Table.from_arrays([a_ms, a_ms, a_ms, a_ms], names) - _check_roundtrip(table, expected, coerce_timestamps='ms') - - # Using Parquet version 2.0, coercing to milliseconds or microseconds - # is allowed - expected = pa.Table.from_arrays([a_us, a_us, a_us, a_us], names) - _check_roundtrip(table, expected, version='2.0', coerce_timestamps='us') - - # TODO: after pyarrow allows coerce_timestamps='ns', tests like the - # following should pass ... - - # Using Parquet version 1.0, coercing to nanoseconds is not allowed - # expected = None - # with pytest.raises(NotImplementedError): - # _roundtrip_table(table, coerce_timestamps='ns') - - # Using Parquet version 2.0, coercing to nanoseconds is allowed - # expected = pa.Table.from_arrays([a_ns, a_ns, a_ns, a_ns], names) - # _check_roundtrip(table, expected, version='2.0', coerce_timestamps='ns') - - # For either Parquet version, coercing to nanoseconds is allowed - # if Int96 storage is used - expected = pa.Table.from_arrays([a_ns, a_ns, a_ns, a_ns], names) - _check_roundtrip(table, expected, - use_deprecated_int96_timestamps=True) - _check_roundtrip(table, expected, version='2.0', - use_deprecated_int96_timestamps=True) - - -def test_large_list_records(): - # This was fixed in PARQUET-1100 - - list_lengths = np.random.randint(0, 500, size=50) - list_lengths[::10] = 0 - - list_values = [list(map(int, np.random.randint(0, 100, size=x))) - if i % 8 else None - for i, x in enumerate(list_lengths)] - - a1 = pa.array(list_values) - - table = pa.Table.from_arrays([a1], ['int_lists']) - _check_roundtrip(table) - - -def test_sanitized_spark_field_names(): - a0 = pa.array([0, 1, 2, 3, 4]) - name = 'prohib; ,\t{}' - table = pa.Table.from_arrays([a0], [name]) - - result = _roundtrip_table(table, write_table_kwargs={'flavor': 'spark'}) - - expected_name = 'prohib______' - assert result.schema[0].name == expected_name - - -@pytest.mark.pandas -def test_spark_flavor_preserves_pandas_metadata(): - df = _test_dataframe(size=100) - df.index = np.arange(0, 10 * len(df), 10) - df.index.name = 'foo' - - result = _roundtrip_pandas_dataframe(df, {'version': '2.0', - 'flavor': 'spark'}) - tm.assert_frame_equal(result, df) - - -def test_fixed_size_binary(): - t0 = pa.binary(10) - data = [b'fooooooooo', None, b'barooooooo', b'quxooooooo'] - a0 = pa.array(data, type=t0) - - table = pa.Table.from_arrays([a0], - ['binary[10]']) - _check_roundtrip(table) - - -@pytest.mark.pandas -@parametrize_legacy_dataset -def test_multithreaded_read(use_legacy_dataset): - df = alltypes_sample(size=10000) - - table = pa.Table.from_pandas(df) - - buf = io.BytesIO() - _write_table(table, buf, compression='SNAPPY', version='2.0') - - buf.seek(0) - table1 = _read_table( - buf, use_threads=True, use_legacy_dataset=use_legacy_dataset) - - buf.seek(0) - table2 = _read_table( - buf, use_threads=False, use_legacy_dataset=use_legacy_dataset) - - assert table1.equals(table2) - - -@pytest.mark.pandas -@parametrize_legacy_dataset -def test_min_chunksize(use_legacy_dataset): - data = pd.DataFrame([np.arange(4)], columns=['A', 'B', 'C', 'D']) - table = pa.Table.from_pandas(data.reset_index()) - - buf = io.BytesIO() - _write_table(table, buf, chunk_size=-1) - - buf.seek(0) - result = _read_table(buf, use_legacy_dataset=use_legacy_dataset) - - assert result.equals(table) - - with pytest.raises(ValueError): - _write_table(table, buf, chunk_size=0) - - -@pytest.mark.pandas -def test_pass_separate_metadata(): - # ARROW-471 - df = alltypes_sample(size=10000) - - a_table = pa.Table.from_pandas(df) - - buf = io.BytesIO() - _write_table(a_table, buf, compression='snappy', version='2.0') - - buf.seek(0) - metadata = pq.read_metadata(buf) - - buf.seek(0) - - fileh = pq.ParquetFile(buf, metadata=metadata) - - tm.assert_frame_equal(df, fileh.read().to_pandas()) - - -@pytest.mark.pandas -def test_read_single_row_group(): - # ARROW-471 - N, K = 10000, 4 - df = alltypes_sample(size=N) - - a_table = pa.Table.from_pandas(df) - - buf = io.BytesIO() - _write_table(a_table, buf, row_group_size=N / K, - compression='snappy', version='2.0') - - buf.seek(0) - - pf = pq.ParquetFile(buf) - - assert pf.num_row_groups == K - - row_groups = [pf.read_row_group(i) for i in range(K)] - result = pa.concat_tables(row_groups) - tm.assert_frame_equal(df, result.to_pandas()) - - -@pytest.mark.pandas -def test_read_single_row_group_with_column_subset(): - N, K = 10000, 4 - df = alltypes_sample(size=N) - a_table = pa.Table.from_pandas(df) - - buf = io.BytesIO() - _write_table(a_table, buf, row_group_size=N / K, - compression='snappy', version='2.0') - - buf.seek(0) - pf = pq.ParquetFile(buf) - - cols = list(df.columns[:2]) - row_groups = [pf.read_row_group(i, columns=cols) for i in range(K)] - result = pa.concat_tables(row_groups) - tm.assert_frame_equal(df[cols], result.to_pandas()) - - # ARROW-4267: Selection of duplicate columns still leads to these columns - # being read uniquely. - row_groups = [pf.read_row_group(i, columns=cols + cols) for i in range(K)] - result = pa.concat_tables(row_groups) - tm.assert_frame_equal(df[cols], result.to_pandas()) - - -@pytest.mark.pandas -def test_read_multiple_row_groups(): - N, K = 10000, 4 - df = alltypes_sample(size=N) - - a_table = pa.Table.from_pandas(df) - - buf = io.BytesIO() - _write_table(a_table, buf, row_group_size=N / K, - compression='snappy', version='2.0') - - buf.seek(0) - - pf = pq.ParquetFile(buf) - - assert pf.num_row_groups == K - - result = pf.read_row_groups(range(K)) - tm.assert_frame_equal(df, result.to_pandas()) - - -@pytest.mark.pandas -def test_read_multiple_row_groups_with_column_subset(): - N, K = 10000, 4 - df = alltypes_sample(size=N) - a_table = pa.Table.from_pandas(df) - - buf = io.BytesIO() - _write_table(a_table, buf, row_group_size=N / K, - compression='snappy', version='2.0') - - buf.seek(0) - pf = pq.ParquetFile(buf) - - cols = list(df.columns[:2]) - result = pf.read_row_groups(range(K), columns=cols) - tm.assert_frame_equal(df[cols], result.to_pandas()) - - # ARROW-4267: Selection of duplicate columns still leads to these columns - # being read uniquely. - result = pf.read_row_groups(range(K), columns=cols + cols) - tm.assert_frame_equal(df[cols], result.to_pandas()) - - -@pytest.mark.pandas -def test_scan_contents(): - N, K = 10000, 4 - df = alltypes_sample(size=N) - a_table = pa.Table.from_pandas(df) - - buf = io.BytesIO() - _write_table(a_table, buf, row_group_size=N / K, - compression='snappy', version='2.0') - - buf.seek(0) - pf = pq.ParquetFile(buf) - - assert pf.scan_contents() == 10000 - assert pf.scan_contents(df.columns[:4]) == 10000 - - -@pytest.mark.pandas -def test_parquet_piece_read(tempdir): - df = _test_dataframe(1000) - table = pa.Table.from_pandas(df) - - path = tempdir / 'parquet_piece_read.parquet' - _write_table(table, path, version='2.0') - - piece1 = pq.ParquetDatasetPiece(path) - - result = piece1.read() - assert result.equals(table) - - -@pytest.mark.pandas -def test_parquet_piece_open_and_get_metadata(tempdir): - df = _test_dataframe(100) - table = pa.Table.from_pandas(df) - - path = tempdir / 'parquet_piece_read.parquet' - _write_table(table, path, version='2.0') - - piece = pq.ParquetDatasetPiece(path) - table1 = piece.read() - assert isinstance(table1, pa.Table) - meta1 = piece.get_metadata() - assert isinstance(meta1, pq.FileMetaData) - - assert table.equals(table1) - - -def test_parquet_piece_basics(): - path = '/baz.parq' - - piece1 = pq.ParquetDatasetPiece(path) - piece2 = pq.ParquetDatasetPiece(path, row_group=1) - piece3 = pq.ParquetDatasetPiece( - path, row_group=1, partition_keys=[('foo', 0), ('bar', 1)]) - - assert str(piece1) == path - assert str(piece2) == '/baz.parq | row_group=1' - assert str(piece3) == 'partition[foo=0, bar=1] /baz.parq | row_group=1' - - assert piece1 == piece1 - assert piece2 == piece2 - assert piece3 == piece3 - assert piece1 != piece3 - - -def test_partition_set_dictionary_type(): - set1 = pq.PartitionSet('key1', ['foo', 'bar', 'baz']) - set2 = pq.PartitionSet('key2', [2007, 2008, 2009]) - - assert isinstance(set1.dictionary, pa.StringArray) - assert isinstance(set2.dictionary, pa.IntegerArray) - - set3 = pq.PartitionSet('key2', [datetime.datetime(2007, 1, 1)]) - with pytest.raises(TypeError): - set3.dictionary - - -@pytest.mark.pandas -@parametrize_legacy_dataset -def test_read_partitioned_directory(tempdir, use_legacy_dataset): - fs = LocalFileSystem._get_instance() - _partition_test_for_filesystem(fs, tempdir, use_legacy_dataset) - - -@pytest.mark.pandas -def test_create_parquet_dataset_multi_threaded(tempdir): - fs = LocalFileSystem._get_instance() - base_path = tempdir - - _partition_test_for_filesystem(fs, base_path) - - manifest = pq.ParquetManifest(base_path, filesystem=fs, - metadata_nthreads=1) - dataset = pq.ParquetDataset(base_path, filesystem=fs, metadata_nthreads=16) - assert len(dataset.pieces) > 0 - partitions = dataset.partitions - assert len(partitions.partition_names) > 0 - assert partitions.partition_names == manifest.partitions.partition_names - assert len(partitions.levels) == len(manifest.partitions.levels) - - -@pytest.mark.pandas -@parametrize_legacy_dataset -def test_read_partitioned_columns_selection(tempdir, use_legacy_dataset): - # ARROW-3861 - do not include partition columns in resulting table when - # `columns` keyword was passed without those columns - fs = LocalFileSystem._get_instance() - base_path = tempdir - _partition_test_for_filesystem(fs, base_path) - - dataset = pq.ParquetDataset( - base_path, use_legacy_dataset=use_legacy_dataset) - result = dataset.read(columns=["values"]) - if use_legacy_dataset: - # ParquetDataset implementation always includes the partition columns - # automatically, and we can't easily "fix" this since dask relies on - # this behaviour (ARROW-8644) - assert result.column_names == ["values", "foo", "bar"] - else: - assert result.column_names == ["values"] - - -@pytest.mark.pandas -@parametrize_legacy_dataset -def test_filters_equivalency(tempdir, use_legacy_dataset): - fs = LocalFileSystem._get_instance() - base_path = tempdir - - integer_keys = [0, 1] - string_keys = ['a', 'b', 'c'] - boolean_keys = [True, False] - partition_spec = [ - ['integer', integer_keys], - ['string', string_keys], - ['boolean', boolean_keys] - ] - - df = pd.DataFrame({ - 'integer': np.array(integer_keys, dtype='i4').repeat(15), - 'string': np.tile(np.tile(np.array(string_keys, dtype=object), 5), 2), - 'boolean': np.tile(np.tile(np.array(boolean_keys, dtype='bool'), 5), - 3), - }, columns=['integer', 'string', 'boolean']) - - _generate_partition_directories(fs, base_path, partition_spec, df) - - # Old filters syntax: - # integer == 1 AND string != b AND boolean == True - dataset = pq.ParquetDataset( - base_path, filesystem=fs, - filters=[('integer', '=', 1), ('string', '!=', 'b'), - ('boolean', '==', True)], - use_legacy_dataset=use_legacy_dataset, - ) - table = dataset.read() - result_df = (table.to_pandas().reset_index(drop=True)) - - assert 0 not in result_df['integer'].values - assert 'b' not in result_df['string'].values - assert False not in result_df['boolean'].values - - # filters in disjunctive normal form: - # (integer == 1 AND string != b AND boolean == True) OR - # (integer == 2 AND boolean == False) - # TODO(ARROW-3388): boolean columns are reconstructed as string - filters = [ - [ - ('integer', '=', 1), - ('string', '!=', 'b'), - ('boolean', '==', 'True') - ], - [('integer', '=', 0), ('boolean', '==', 'False')] - ] - dataset = pq.ParquetDataset( - base_path, filesystem=fs, filters=filters, - use_legacy_dataset=use_legacy_dataset) - table = dataset.read() - result_df = table.to_pandas().reset_index(drop=True) - - # Check that all rows in the DF fulfill the filter - # Pandas 0.23.x has problems with indexing constant memoryviews in - # categoricals. Thus we need to make an explicit copy here with np.array. - df_filter_1 = (np.array(result_df['integer']) == 1) \ - & (np.array(result_df['string']) != 'b') \ - & (np.array(result_df['boolean']) == 'True') - df_filter_2 = (np.array(result_df['integer']) == 0) \ - & (np.array(result_df['boolean']) == 'False') - assert df_filter_1.sum() > 0 - assert df_filter_2.sum() > 0 - assert result_df.shape[0] == (df_filter_1.sum() + df_filter_2.sum()) - - if use_legacy_dataset: - # Check for \0 in predicate values. Until they are correctly - # implemented in ARROW-3391, they would otherwise lead to weird - # results with the current code. - with pytest.raises(NotImplementedError): - filters = [[('string', '==', b'1\0a')]] - pq.ParquetDataset(base_path, filesystem=fs, filters=filters) - with pytest.raises(NotImplementedError): - filters = [[('string', '==', '1\0a')]] - pq.ParquetDataset(base_path, filesystem=fs, filters=filters) - else: - for filters in [[[('string', '==', b'1\0a')]], - [[('string', '==', '1\0a')]]]: - dataset = pq.ParquetDataset( - base_path, filesystem=fs, filters=filters, - use_legacy_dataset=False) - assert dataset.read().num_rows == 0 - - -@pytest.mark.pandas -@parametrize_legacy_dataset -def test_filters_cutoff_exclusive_integer(tempdir, use_legacy_dataset): - fs = LocalFileSystem._get_instance() - base_path = tempdir - - integer_keys = [0, 1, 2, 3, 4] - partition_spec = [ - ['integers', integer_keys], - ] - N = 5 - - df = pd.DataFrame({ - 'index': np.arange(N), - 'integers': np.array(integer_keys, dtype='i4'), - }, columns=['index', 'integers']) - - _generate_partition_directories(fs, base_path, partition_spec, df) - - dataset = pq.ParquetDataset( - base_path, filesystem=fs, - filters=[ - ('integers', '<', 4), - ('integers', '>', 1), - ], - use_legacy_dataset=use_legacy_dataset - ) - table = dataset.read() - result_df = (table.to_pandas() - .sort_values(by='index') - .reset_index(drop=True)) - - result_list = [x for x in map(int, result_df['integers'].values)] - assert result_list == [2, 3] - - -@pytest.mark.pandas -@parametrize_legacy_dataset -@pytest.mark.xfail( - # different error with use_legacy_datasets because result_df is no longer - # categorical - raises=(TypeError, AssertionError), - reason='Loss of type information in creation of categoricals.' -) -def test_filters_cutoff_exclusive_datetime(tempdir, use_legacy_dataset): - fs = LocalFileSystem._get_instance() - base_path = tempdir - - date_keys = [ - datetime.date(2018, 4, 9), - datetime.date(2018, 4, 10), - datetime.date(2018, 4, 11), - datetime.date(2018, 4, 12), - datetime.date(2018, 4, 13) - ] - partition_spec = [ - ['dates', date_keys] - ] - N = 5 - - df = pd.DataFrame({ - 'index': np.arange(N), - 'dates': np.array(date_keys, dtype='datetime64'), - }, columns=['index', 'dates']) - - _generate_partition_directories(fs, base_path, partition_spec, df) - - dataset = pq.ParquetDataset( - base_path, filesystem=fs, - filters=[ - ('dates', '<', "2018-04-12"), - ('dates', '>', "2018-04-10") - ], - use_legacy_dataset=use_legacy_dataset - ) - table = dataset.read() - result_df = (table.to_pandas() - .sort_values(by='index') - .reset_index(drop=True)) - - expected = pd.Categorical( - np.array([datetime.date(2018, 4, 11)], dtype='datetime64'), - categories=np.array(date_keys, dtype='datetime64')) - - assert result_df['dates'].values == expected - - -@pytest.mark.pandas -@parametrize_legacy_dataset -def test_filters_inclusive_integer(tempdir, use_legacy_dataset): - fs = LocalFileSystem._get_instance() - base_path = tempdir - - integer_keys = [0, 1, 2, 3, 4] - partition_spec = [ - ['integers', integer_keys], - ] - N = 5 - - df = pd.DataFrame({ - 'index': np.arange(N), - 'integers': np.array(integer_keys, dtype='i4'), - }, columns=['index', 'integers']) - - _generate_partition_directories(fs, base_path, partition_spec, df) - - dataset = pq.ParquetDataset( - base_path, filesystem=fs, - filters=[ - ('integers', '<=', 3), - ('integers', '>=', 2), - ], - use_legacy_dataset=use_legacy_dataset - ) - table = dataset.read() - result_df = (table.to_pandas() - .sort_values(by='index') - .reset_index(drop=True)) - - result_list = [int(x) for x in map(int, result_df['integers'].values)] - assert result_list == [2, 3] - - -@pytest.mark.pandas -@parametrize_legacy_dataset -def test_filters_inclusive_set(tempdir, use_legacy_dataset): - fs = LocalFileSystem._get_instance() - base_path = tempdir - - integer_keys = [0, 1] - string_keys = ['a', 'b', 'c'] - boolean_keys = [True, False] - partition_spec = [ - ['integer', integer_keys], - ['string', string_keys], - ['boolean', boolean_keys] - ] - - df = pd.DataFrame({ - 'integer': np.array(integer_keys, dtype='i4').repeat(15), - 'string': np.tile(np.tile(np.array(string_keys, dtype=object), 5), 2), - 'boolean': np.tile(np.tile(np.array(boolean_keys, dtype='bool'), 5), - 3), - }, columns=['integer', 'string', 'boolean']) - - _generate_partition_directories(fs, base_path, partition_spec, df) - - dataset = pq.ParquetDataset( - base_path, filesystem=fs, - filters=[('integer', 'in', {1}), ('string', 'in', {'a', 'b'}), - ('boolean', 'in', {True})], - use_legacy_dataset=use_legacy_dataset - ) - table = dataset.read() - result_df = (table.to_pandas().reset_index(drop=True)) - - assert 0 not in result_df['integer'].values - assert 'c' not in result_df['string'].values - assert False not in result_df['boolean'].values - - -@pytest.mark.pandas -@parametrize_legacy_dataset -def test_filters_invalid_pred_op(tempdir, use_legacy_dataset): - fs = LocalFileSystem._get_instance() - base_path = tempdir - - integer_keys = [0, 1, 2, 3, 4] - partition_spec = [ - ['integers', integer_keys], - ] - N = 5 - - df = pd.DataFrame({ - 'index': np.arange(N), - 'integers': np.array(integer_keys, dtype='i4'), - }, columns=['index', 'integers']) - - _generate_partition_directories(fs, base_path, partition_spec, df) - - with pytest.raises(ValueError): - pq.ParquetDataset(base_path, - filesystem=fs, - filters=[('integers', '=<', 3), ], - use_legacy_dataset=use_legacy_dataset) - - if use_legacy_dataset: - with pytest.raises(ValueError): - pq.ParquetDataset(base_path, - filesystem=fs, - filters=[('integers', 'in', set()), ], - use_legacy_dataset=use_legacy_dataset) - else: - # Dataset API returns empty table instead - dataset = pq.ParquetDataset(base_path, - filesystem=fs, - filters=[('integers', 'in', set()), ], - use_legacy_dataset=use_legacy_dataset) - assert dataset.read().num_rows == 0 - - with pytest.raises(ValueError): - pq.ParquetDataset(base_path, - filesystem=fs, - filters=[('integers', '!=', {3})], - use_legacy_dataset=use_legacy_dataset) - - -@pytest.mark.pandas -@parametrize_legacy_dataset_fixed -def test_filters_invalid_column(tempdir, use_legacy_dataset): - # ARROW-5572 - raise error on invalid name in filter specification - # works with new dataset / xfail with legacy implementation - fs = LocalFileSystem._get_instance() - base_path = tempdir - - integer_keys = [0, 1, 2, 3, 4] - partition_spec = [['integers', integer_keys]] - N = 5 - - df = pd.DataFrame({ - 'index': np.arange(N), - 'integers': np.array(integer_keys, dtype='i4'), - }, columns=['index', 'integers']) - - _generate_partition_directories(fs, base_path, partition_spec, df) - - msg = r"No match for FieldRef.Name\(non_existent_column\)" - with pytest.raises(ValueError, match=msg): - pq.ParquetDataset(base_path, filesystem=fs, - filters=[('non_existent_column', '<', 3), ], - use_legacy_dataset=use_legacy_dataset).read() - - -@pytest.mark.pandas -@parametrize_legacy_dataset -def test_filters_read_table(tempdir, use_legacy_dataset): - # test that filters keyword is passed through in read_table - fs = LocalFileSystem._get_instance() - base_path = tempdir - - integer_keys = [0, 1, 2, 3, 4] - partition_spec = [ - ['integers', integer_keys], - ] - N = 5 - - df = pd.DataFrame({ - 'index': np.arange(N), - 'integers': np.array(integer_keys, dtype='i4'), - }, columns=['index', 'integers']) - - _generate_partition_directories(fs, base_path, partition_spec, df) - - table = pq.read_table( - base_path, filesystem=fs, filters=[('integers', '<', 3)], - use_legacy_dataset=use_legacy_dataset) - assert table.num_rows == 3 - - table = pq.read_table( - base_path, filesystem=fs, filters=[[('integers', '<', 3)]], - use_legacy_dataset=use_legacy_dataset) - assert table.num_rows == 3 - - table = pq.read_pandas( - base_path, filters=[('integers', '<', 3)], - use_legacy_dataset=use_legacy_dataset) - assert table.num_rows == 3 - - -@pytest.mark.pandas -@parametrize_legacy_dataset_fixed -def test_partition_keys_with_underscores(tempdir, use_legacy_dataset): - # ARROW-5666 - partition field values with underscores preserve underscores - # xfail with legacy dataset -> they get interpreted as integers - fs = LocalFileSystem._get_instance() - base_path = tempdir - - string_keys = ["2019_2", "2019_3"] - partition_spec = [ - ['year_week', string_keys], - ] - N = 2 - - df = pd.DataFrame({ - 'index': np.arange(N), - 'year_week': np.array(string_keys, dtype='object'), - }, columns=['index', 'year_week']) - - _generate_partition_directories(fs, base_path, partition_spec, df) - - dataset = pq.ParquetDataset( - base_path, use_legacy_dataset=use_legacy_dataset) - result = dataset.read() - assert result.column("year_week").to_pylist() == string_keys - - -@pytest.fixture -def s3_bucket(request, s3_connection, s3_server): - boto3 = pytest.importorskip('boto3') - botocore = pytest.importorskip('botocore') - - host, port, access_key, secret_key = s3_connection - s3 = boto3.resource( - 's3', - endpoint_url='http://{}:{}'.format(host, port), - aws_access_key_id=access_key, - aws_secret_access_key=secret_key, - config=botocore.client.Config(signature_version='s3v4'), - region_name='us-east-1' - ) - bucket = s3.Bucket('test-s3fs') - try: - bucket.create() - except Exception: - # we get BucketAlreadyOwnedByYou error with fsspec handler - pass - return 'test-s3fs' - - -@pytest.fixture -def s3_example_s3fs(s3_connection, s3_server, s3_bucket): - s3fs = pytest.importorskip('s3fs') - - host, port, access_key, secret_key = s3_connection - fs = s3fs.S3FileSystem( - key=access_key, - secret=secret_key, - client_kwargs={ - 'endpoint_url': 'http://{}:{}'.format(host, port) - } - ) - - test_path = '{}/{}'.format(s3_bucket, guid()) - - fs.mkdir(test_path) - yield fs, test_path - try: - fs.rm(test_path, recursive=True) - except FileNotFoundError: - pass - - -@parametrize_legacy_dataset -def test_read_s3fs(s3_example_s3fs, use_legacy_dataset): - fs, path = s3_example_s3fs - path = path + "/test.parquet" - table = pa.table({"a": [1, 2, 3]}) - _write_table(table, path, filesystem=fs) - - result = _read_table( - path, filesystem=fs, use_legacy_dataset=use_legacy_dataset - ) - assert result.equals(table) - - -@parametrize_legacy_dataset -def test_read_directory_s3fs(s3_example_s3fs, use_legacy_dataset): - fs, directory = s3_example_s3fs - path = directory + "/test.parquet" - table = pa.table({"a": [1, 2, 3]}) - _write_table(table, path, filesystem=fs) - - result = _read_table( - directory, filesystem=fs, use_legacy_dataset=use_legacy_dataset - ) - assert result.equals(table) - - -@pytest.mark.pandas -@pytest.mark.s3 -@parametrize_legacy_dataset -def test_read_partitioned_directory_s3fs_wrapper( - s3_example_s3fs, use_legacy_dataset -): - from pyarrow.filesystem import S3FSWrapper - import s3fs - - if s3fs.__version__ >= LooseVersion("0.5"): - pytest.skip("S3FSWrapper no longer working for s3fs 0.5+") - - fs, path = s3_example_s3fs - with pytest.warns(DeprecationWarning): - wrapper = S3FSWrapper(fs) - _partition_test_for_filesystem(wrapper, path) - - # Check that we can auto-wrap - dataset = pq.ParquetDataset( - path, filesystem=fs, use_legacy_dataset=use_legacy_dataset - ) - dataset.read() - - -@pytest.mark.pandas -@parametrize_legacy_dataset -def test_read_partitioned_directory_s3fs(s3_example_s3fs, use_legacy_dataset): - fs, path = s3_example_s3fs - _partition_test_for_filesystem( - fs, path, use_legacy_dataset=use_legacy_dataset - ) - - -def _partition_test_for_filesystem(fs, base_path, use_legacy_dataset=True): - foo_keys = [0, 1] - bar_keys = ['a', 'b', 'c'] - partition_spec = [ - ['foo', foo_keys], - ['bar', bar_keys] - ] - N = 30 - - df = pd.DataFrame({ - 'index': np.arange(N), - 'foo': np.array(foo_keys, dtype='i4').repeat(15), - 'bar': np.tile(np.tile(np.array(bar_keys, dtype=object), 5), 2), - 'values': np.random.randn(N) - }, columns=['index', 'foo', 'bar', 'values']) - - _generate_partition_directories(fs, base_path, partition_spec, df) - - dataset = pq.ParquetDataset( - base_path, filesystem=fs, use_legacy_dataset=use_legacy_dataset) - table = dataset.read() - result_df = (table.to_pandas() - .sort_values(by='index') - .reset_index(drop=True)) - - expected_df = (df.sort_values(by='index') - .reset_index(drop=True) - .reindex(columns=result_df.columns)) - - expected_df['foo'] = pd.Categorical(df['foo'], categories=foo_keys) - expected_df['bar'] = pd.Categorical(df['bar'], categories=bar_keys) - - assert (result_df.columns == ['index', 'values', 'foo', 'bar']).all() - - tm.assert_frame_equal(result_df, expected_df) - - -def _generate_partition_directories(fs, base_dir, partition_spec, df): - # partition_spec : list of lists, e.g. [['foo', [0, 1, 2], - # ['bar', ['a', 'b', 'c']] - # part_table : a pyarrow.Table to write to each partition - DEPTH = len(partition_spec) - - pathsep = getattr(fs, "pathsep", getattr(fs, "sep", "/")) - - def _visit_level(base_dir, level, part_keys): - name, values = partition_spec[level] - for value in values: - this_part_keys = part_keys + [(name, value)] - - level_dir = pathsep.join([ - str(base_dir), - '{}={}'.format(name, value) - ]) - fs.mkdir(level_dir) - - if level == DEPTH - 1: - # Generate example data - file_path = pathsep.join([level_dir, guid()]) - filtered_df = _filter_partition(df, this_part_keys) - part_table = pa.Table.from_pandas(filtered_df) - with fs.open(file_path, 'wb') as f: - _write_table(part_table, f) - assert fs.exists(file_path) - - file_success = pathsep.join([level_dir, '_SUCCESS']) - with fs.open(file_success, 'wb') as f: - pass - else: - _visit_level(level_dir, level + 1, this_part_keys) - file_success = pathsep.join([level_dir, '_SUCCESS']) - with fs.open(file_success, 'wb') as f: - pass - - _visit_level(base_dir, 0, []) - - -def _test_read_common_metadata_files(fs, base_path): - N = 100 - df = pd.DataFrame({ - 'index': np.arange(N), - 'values': np.random.randn(N) - }, columns=['index', 'values']) - - base_path = str(base_path) - data_path = os.path.join(base_path, 'data.parquet') - - table = pa.Table.from_pandas(df) - - with fs.open(data_path, 'wb') as f: - _write_table(table, f) - - metadata_path = os.path.join(base_path, '_common_metadata') - with fs.open(metadata_path, 'wb') as f: - pq.write_metadata(table.schema, f) - - dataset = pq.ParquetDataset(base_path, filesystem=fs) - assert dataset.common_metadata_path == str(metadata_path) - - with fs.open(data_path) as f: - common_schema = pq.read_metadata(f).schema - assert dataset.schema.equals(common_schema) - - # handle list of one directory - dataset2 = pq.ParquetDataset([base_path], filesystem=fs) - assert dataset2.schema.equals(dataset.schema) - - -@pytest.mark.pandas -def test_read_common_metadata_files(tempdir): - fs = LocalFileSystem._get_instance() - _test_read_common_metadata_files(fs, tempdir) - - -@pytest.mark.pandas -def test_read_metadata_files(tempdir): - fs = LocalFileSystem._get_instance() - - N = 100 - df = pd.DataFrame({ - 'index': np.arange(N), - 'values': np.random.randn(N) - }, columns=['index', 'values']) - - data_path = tempdir / 'data.parquet' - - table = pa.Table.from_pandas(df) - - with fs.open(data_path, 'wb') as f: - _write_table(table, f) - - metadata_path = tempdir / '_metadata' - with fs.open(metadata_path, 'wb') as f: - pq.write_metadata(table.schema, f) - - dataset = pq.ParquetDataset(tempdir, filesystem=fs) - assert dataset.metadata_path == str(metadata_path) - - with fs.open(data_path) as f: - metadata_schema = pq.read_metadata(f).schema - assert dataset.schema.equals(metadata_schema) - - -@pytest.mark.pandas -def test_read_schema(tempdir): - N = 100 - df = pd.DataFrame({ - 'index': np.arange(N), - 'values': np.random.randn(N) - }, columns=['index', 'values']) - - data_path = tempdir / 'test.parquet' - - table = pa.Table.from_pandas(df) - _write_table(table, data_path) - - read1 = pq.read_schema(data_path) - read2 = pq.read_schema(data_path, memory_map=True) - assert table.schema.equals(read1) - assert table.schema.equals(read2) - - assert table.schema.metadata[b'pandas'] == read1.metadata[b'pandas'] - - -def _filter_partition(df, part_keys): - predicate = np.ones(len(df), dtype=bool) - - to_drop = [] - for name, value in part_keys: - to_drop.append(name) - - # to avoid pandas warning - if isinstance(value, (datetime.date, datetime.datetime)): - value = pd.Timestamp(value) - - predicate &= df[name] == value - - return df[predicate].drop(to_drop, axis=1) - - -@pytest.mark.pandas -@parametrize_legacy_dataset -def test_read_multiple_files(tempdir, use_legacy_dataset): - nfiles = 10 - size = 5 - - dirpath = tempdir / guid() - dirpath.mkdir() - - test_data = [] - paths = [] - for i in range(nfiles): - df = _test_dataframe(size, seed=i) - - # Hack so that we don't have a dtype cast in v1 files - df['uint32'] = df['uint32'].astype(np.int64) - - path = dirpath / '{}.parquet'.format(i) - - table = pa.Table.from_pandas(df) - _write_table(table, path) - - test_data.append(table) - paths.append(path) - - # Write a _SUCCESS.crc file - (dirpath / '_SUCCESS.crc').touch() - - def read_multiple_files(paths, columns=None, use_threads=True, **kwargs): - dataset = pq.ParquetDataset( - paths, use_legacy_dataset=use_legacy_dataset, **kwargs) - return dataset.read(columns=columns, use_threads=use_threads) - - result = read_multiple_files(paths) - expected = pa.concat_tables(test_data) - - assert result.equals(expected) - - # Read with provided metadata - # TODO(dataset) specifying metadata not yet supported - metadata = pq.read_metadata(paths[0]) - if use_legacy_dataset: - result2 = read_multiple_files(paths, metadata=metadata) - assert result2.equals(expected) - - result3 = pq.ParquetDataset(dirpath, schema=metadata.schema).read() - assert result3.equals(expected) - else: - with pytest.raises(ValueError, match="no longer supported"): - pq.read_table(paths, metadata=metadata, use_legacy_dataset=False) - - # Read column subset - to_read = [0, 2, 6, result.num_columns - 1] - - col_names = [result.field(i).name for i in to_read] - out = pq.read_table( - dirpath, columns=col_names, use_legacy_dataset=use_legacy_dataset - ) - expected = pa.Table.from_arrays([result.column(i) for i in to_read], - names=col_names, - metadata=result.schema.metadata) - assert out.equals(expected) - - # Read with multiple threads - pq.read_table( - dirpath, use_threads=True, use_legacy_dataset=use_legacy_dataset - ) - - # Test failure modes with non-uniform metadata - bad_apple = _test_dataframe(size, seed=i).iloc[:, :4] - bad_apple_path = tempdir / '{}.parquet'.format(guid()) - - t = pa.Table.from_pandas(bad_apple) - _write_table(t, bad_apple_path) - - if not use_legacy_dataset: - # TODO(dataset) Dataset API skips bad files - return - - bad_meta = pq.read_metadata(bad_apple_path) - - with pytest.raises(ValueError): - read_multiple_files(paths + [bad_apple_path]) - - with pytest.raises(ValueError): - read_multiple_files(paths, metadata=bad_meta) - - mixed_paths = [bad_apple_path, paths[0]] - - with pytest.raises(ValueError): - read_multiple_files(mixed_paths, schema=bad_meta.schema) - - with pytest.raises(ValueError): - read_multiple_files(mixed_paths) - - -@pytest.mark.pandas -@parametrize_legacy_dataset -def test_dataset_read_pandas(tempdir, use_legacy_dataset): - nfiles = 5 - size = 5 - - dirpath = tempdir / guid() - dirpath.mkdir() - - test_data = [] - frames = [] - paths = [] - for i in range(nfiles): - df = _test_dataframe(size, seed=i) - df.index = np.arange(i * size, (i + 1) * size) - df.index.name = 'index' - - path = dirpath / '{}.parquet'.format(i) - - table = pa.Table.from_pandas(df) - _write_table(table, path) - test_data.append(table) - frames.append(df) - paths.append(path) - - dataset = pq.ParquetDataset(dirpath, use_legacy_dataset=use_legacy_dataset) - columns = ['uint8', 'strings'] - result = dataset.read_pandas(columns=columns).to_pandas() - expected = pd.concat([x[columns] for x in frames]) - - tm.assert_frame_equal(result, expected) - - -@pytest.mark.pandas -@parametrize_legacy_dataset -def test_dataset_memory_map(tempdir, use_legacy_dataset): - # ARROW-2627: Check that we can use ParquetDataset with memory-mapping - dirpath = tempdir / guid() - dirpath.mkdir() - - df = _test_dataframe(10, seed=0) - path = dirpath / '{}.parquet'.format(0) - table = pa.Table.from_pandas(df) - _write_table(table, path, version='2.0') - - dataset = pq.ParquetDataset( - dirpath, memory_map=True, use_legacy_dataset=use_legacy_dataset) - assert dataset.read().equals(table) - if use_legacy_dataset: - assert dataset.pieces[0].read().equals(table) - - -@pytest.mark.pandas -@parametrize_legacy_dataset -def test_dataset_enable_buffered_stream(tempdir, use_legacy_dataset): - dirpath = tempdir / guid() - dirpath.mkdir() - - df = _test_dataframe(10, seed=0) - path = dirpath / '{}.parquet'.format(0) - table = pa.Table.from_pandas(df) - _write_table(table, path, version='2.0') - - with pytest.raises(ValueError): - pq.ParquetDataset( - dirpath, buffer_size=-64, - use_legacy_dataset=use_legacy_dataset) - - for buffer_size in [128, 1024]: - dataset = pq.ParquetDataset( - dirpath, buffer_size=buffer_size, - use_legacy_dataset=use_legacy_dataset) - assert dataset.read().equals(table) - - -@pytest.mark.pandas -@pytest.mark.parametrize('preserve_index', [True, False, None]) -def test_dataset_read_pandas_common_metadata(tempdir, preserve_index): - # ARROW-1103 - nfiles = 5 - size = 5 - - dirpath = tempdir / guid() - dirpath.mkdir() - - test_data = [] - frames = [] - paths = [] - for i in range(nfiles): - df = _test_dataframe(size, seed=i) - df.index = pd.Index(np.arange(i * size, (i + 1) * size), name='index') - - path = dirpath / '{}.parquet'.format(i) - - table = pa.Table.from_pandas(df, preserve_index=preserve_index) - - # Obliterate metadata - table = table.replace_schema_metadata(None) - assert table.schema.metadata is None - - _write_table(table, path) - test_data.append(table) - frames.append(df) - paths.append(path) - - # Write _metadata common file - table_for_metadata = pa.Table.from_pandas( - df, preserve_index=preserve_index - ) - pq.write_metadata(table_for_metadata.schema, dirpath / '_metadata') - - dataset = pq.ParquetDataset(dirpath) - columns = ['uint8', 'strings'] - result = dataset.read_pandas(columns=columns).to_pandas() - expected = pd.concat([x[columns] for x in frames]) - expected.index.name = ( - df.index.name if preserve_index is not False else None) - tm.assert_frame_equal(result, expected) - - -def _make_example_multifile_dataset(base_path, nfiles=10, file_nrows=5): - test_data = [] - paths = [] - for i in range(nfiles): - df = _test_dataframe(file_nrows, seed=i) - path = base_path / '{}.parquet'.format(i) - - test_data.append(_write_table(df, path)) - paths.append(path) - return paths - - -def _assert_dataset_paths(dataset, paths, use_legacy_dataset): - if use_legacy_dataset: - assert set(map(str, paths)) == {x.path for x in dataset.pieces} - else: - paths = [str(path.as_posix()) for path in paths] - assert set(paths) == set(dataset._dataset.files) - - -@pytest.mark.pandas -@parametrize_legacy_dataset -@pytest.mark.parametrize('dir_prefix', ['_', '.']) -def test_ignore_private_directories(tempdir, dir_prefix, use_legacy_dataset): - dirpath = tempdir / guid() - dirpath.mkdir() - - paths = _make_example_multifile_dataset(dirpath, nfiles=10, - file_nrows=5) - - # private directory - (dirpath / '{}staging'.format(dir_prefix)).mkdir() - - dataset = pq.ParquetDataset(dirpath, use_legacy_dataset=use_legacy_dataset) - - _assert_dataset_paths(dataset, paths, use_legacy_dataset) - - -@pytest.mark.pandas -@parametrize_legacy_dataset -def test_ignore_hidden_files_dot(tempdir, use_legacy_dataset): - dirpath = tempdir / guid() - dirpath.mkdir() - - paths = _make_example_multifile_dataset(dirpath, nfiles=10, - file_nrows=5) - - with (dirpath / '.DS_Store').open('wb') as f: - f.write(b'gibberish') - - with (dirpath / '.private').open('wb') as f: - f.write(b'gibberish') - - dataset = pq.ParquetDataset(dirpath, use_legacy_dataset=use_legacy_dataset) - - _assert_dataset_paths(dataset, paths, use_legacy_dataset) - - -@pytest.mark.pandas -@parametrize_legacy_dataset -def test_ignore_hidden_files_underscore(tempdir, use_legacy_dataset): - dirpath = tempdir / guid() - dirpath.mkdir() - - paths = _make_example_multifile_dataset(dirpath, nfiles=10, - file_nrows=5) - - with (dirpath / '_committed_123').open('wb') as f: - f.write(b'abcd') - - with (dirpath / '_started_321').open('wb') as f: - f.write(b'abcd') - - dataset = pq.ParquetDataset(dirpath, use_legacy_dataset=use_legacy_dataset) - - _assert_dataset_paths(dataset, paths, use_legacy_dataset) - - -@pytest.mark.pandas -@parametrize_legacy_dataset -@pytest.mark.parametrize('dir_prefix', ['_', '.']) -def test_ignore_no_private_directories_in_base_path( - tempdir, dir_prefix, use_legacy_dataset -): - # ARROW-8427 - don't ignore explicitly listed files if parent directory - # is a private directory - dirpath = tempdir / "{0}data".format(dir_prefix) / guid() - dirpath.mkdir(parents=True) - - paths = _make_example_multifile_dataset(dirpath, nfiles=10, - file_nrows=5) - - dataset = pq.ParquetDataset(paths, use_legacy_dataset=use_legacy_dataset) - _assert_dataset_paths(dataset, paths, use_legacy_dataset) - - # ARROW-9644 - don't ignore full directory with underscore in base path - dataset = pq.ParquetDataset(dirpath, use_legacy_dataset=use_legacy_dataset) - _assert_dataset_paths(dataset, paths, use_legacy_dataset) - - -@pytest.mark.pandas -@parametrize_legacy_dataset_fixed -def test_ignore_custom_prefixes(tempdir, use_legacy_dataset): - # ARROW-9573 - allow override of default ignore_prefixes - part = ["xxx"] * 3 + ["yyy"] * 3 - table = pa.table([ - pa.array(range(len(part))), - pa.array(part).dictionary_encode(), - ], names=['index', '_part']) - - # TODO use_legacy_dataset ARROW-10247 - pq.write_to_dataset(table, str(tempdir), partition_cols=['_part']) - - private_duplicate = tempdir / '_private_duplicate' - private_duplicate.mkdir() - pq.write_to_dataset(table, str(private_duplicate), - partition_cols=['_part']) - - read = pq.read_table( - tempdir, use_legacy_dataset=use_legacy_dataset, - ignore_prefixes=['_private']) - - assert read.equals(table) - - -@parametrize_legacy_dataset_fixed -def test_empty_directory(tempdir, use_legacy_dataset): - # ARROW-5310 - reading empty directory - # fails with legacy implementation - empty_dir = tempdir / 'dataset' - empty_dir.mkdir() - - dataset = pq.ParquetDataset( - empty_dir, use_legacy_dataset=use_legacy_dataset) - result = dataset.read() - assert result.num_rows == 0 - assert result.num_columns == 0 - - -@pytest.mark.pandas -@parametrize_legacy_dataset -def test_multiindex_duplicate_values(tempdir, use_legacy_dataset): - num_rows = 3 - numbers = list(range(num_rows)) - index = pd.MultiIndex.from_arrays( - [['foo', 'foo', 'bar'], numbers], - names=['foobar', 'some_numbers'], - ) - - df = pd.DataFrame({'numbers': numbers}, index=index) - table = pa.Table.from_pandas(df) - - filename = tempdir / 'dup_multi_index_levels.parquet' - - _write_table(table, filename) - result_table = _read_table(filename, use_legacy_dataset=use_legacy_dataset) - assert table.equals(result_table) - - result_df = result_table.to_pandas() - tm.assert_frame_equal(result_df, df) - - -@pytest.mark.pandas -def test_write_error_deletes_incomplete_file(tempdir): - # ARROW-1285 - df = pd.DataFrame({'a': list('abc'), - 'b': list(range(1, 4)), - 'c': np.arange(3, 6).astype('u1'), - 'd': np.arange(4.0, 7.0, dtype='float64'), - 'e': [True, False, True], - 'f': pd.Categorical(list('abc')), - 'g': pd.date_range('20130101', periods=3), - 'h': pd.date_range('20130101', periods=3, - tz='US/Eastern'), - 'i': pd.date_range('20130101', periods=3, freq='ns')}) - - pdf = pa.Table.from_pandas(df) - - filename = tempdir / 'tmp_file' - try: - _write_table(pdf, filename) - except pa.ArrowException: - pass - - assert not filename.exists() - - -@pytest.mark.pandas -def test_noncoerced_nanoseconds_written_without_exception(tempdir): - # ARROW-1957: the Parquet version 2.0 writer preserves Arrow - # nanosecond timestamps by default - n = 9 - df = pd.DataFrame({'x': range(n)}, - index=pd.date_range('2017-01-01', freq='1n', periods=n)) - tb = pa.Table.from_pandas(df) - - filename = tempdir / 'written.parquet' - try: - pq.write_table(tb, filename, version='2.0') - except Exception: - pass - assert filename.exists() - - recovered_table = pq.read_table(filename) - assert tb.equals(recovered_table) - - # Loss of data through coercion (without explicit override) still an error - filename = tempdir / 'not_written.parquet' - with pytest.raises(ValueError): - pq.write_table(tb, filename, coerce_timestamps='ms', version='2.0') - - -@parametrize_legacy_dataset -def test_read_non_existent_file(tempdir, use_legacy_dataset): - path = 'non-existent-file.parquet' - try: - pq.read_table(path, use_legacy_dataset=use_legacy_dataset) - except Exception as e: - assert path in e.args[0] - - -@parametrize_legacy_dataset -def test_read_table_doesnt_warn(datadir, use_legacy_dataset): - with pytest.warns(None) as record: - pq.read_table(datadir / 'v0.7.1.parquet', - use_legacy_dataset=use_legacy_dataset) - - assert len(record) == 0 - - -def _test_write_to_dataset_with_partitions(base_path, - use_legacy_dataset=True, - filesystem=None, - schema=None, - index_name=None): - # ARROW-1400 - output_df = pd.DataFrame({'group1': list('aaabbbbccc'), - 'group2': list('eefeffgeee'), - 'num': list(range(10)), - 'nan': [np.nan] * 10, - 'date': np.arange('2017-01-01', '2017-01-11', - dtype='datetime64[D]')}) - cols = output_df.columns.tolist() - partition_by = ['group1', 'group2'] - output_table = pa.Table.from_pandas(output_df, schema=schema, safe=False, - preserve_index=False) - pq.write_to_dataset(output_table, base_path, partition_by, - filesystem=filesystem, - use_legacy_dataset=use_legacy_dataset) - - metadata_path = os.path.join(str(base_path), '_common_metadata') - - if filesystem is not None: - with filesystem.open(metadata_path, 'wb') as f: - pq.write_metadata(output_table.schema, f) - else: - pq.write_metadata(output_table.schema, metadata_path) - - # ARROW-2891: Ensure the output_schema is preserved when writing a - # partitioned dataset - dataset = pq.ParquetDataset(base_path, - filesystem=filesystem, - validate_schema=True, - use_legacy_dataset=use_legacy_dataset) - # ARROW-2209: Ensure the dataset schema also includes the partition columns - if use_legacy_dataset: - dataset_cols = set(dataset.schema.to_arrow_schema().names) - else: - # NB schema property is an arrow and not parquet schema - dataset_cols = set(dataset.schema.names) - - assert dataset_cols == set(output_table.schema.names) - - input_table = dataset.read() - input_df = input_table.to_pandas() - - # Read data back in and compare with original DataFrame - # Partitioned columns added to the end of the DataFrame when read - input_df_cols = input_df.columns.tolist() - assert partition_by == input_df_cols[-1 * len(partition_by):] - - input_df = input_df[cols] - # Partitioned columns become 'categorical' dtypes - for col in partition_by: - output_df[col] = output_df[col].astype('category') - tm.assert_frame_equal(output_df, input_df) - - -def _test_write_to_dataset_no_partitions(base_path, - use_legacy_dataset=True, - filesystem=None): - # ARROW-1400 - output_df = pd.DataFrame({'group1': list('aaabbbbccc'), - 'group2': list('eefeffgeee'), - 'num': list(range(10)), - 'date': np.arange('2017-01-01', '2017-01-11', - dtype='datetime64[D]')}) - cols = output_df.columns.tolist() - output_table = pa.Table.from_pandas(output_df) - - if filesystem is None: - filesystem = LocalFileSystem._get_instance() - - # Without partitions, append files to root_path - n = 5 - for i in range(n): - pq.write_to_dataset(output_table, base_path, - filesystem=filesystem) - output_files = [file for file in filesystem.ls(str(base_path)) - if file.endswith(".parquet")] - assert len(output_files) == n - - # Deduplicated incoming DataFrame should match - # original outgoing Dataframe - input_table = pq.ParquetDataset( - base_path, filesystem=filesystem, - use_legacy_dataset=use_legacy_dataset - ).read() - input_df = input_table.to_pandas() - input_df = input_df.drop_duplicates() - input_df = input_df[cols] - assert output_df.equals(input_df) - - -@pytest.mark.pandas -@parametrize_legacy_dataset -def test_write_to_dataset_with_partitions(tempdir, use_legacy_dataset): - _test_write_to_dataset_with_partitions(str(tempdir), use_legacy_dataset) - - -@pytest.mark.pandas -@parametrize_legacy_dataset -def test_write_to_dataset_with_partitions_and_schema( - tempdir, use_legacy_dataset -): - schema = pa.schema([pa.field('group1', type=pa.string()), - pa.field('group2', type=pa.string()), - pa.field('num', type=pa.int64()), - pa.field('nan', type=pa.int32()), - pa.field('date', type=pa.timestamp(unit='us'))]) - _test_write_to_dataset_with_partitions( - str(tempdir), use_legacy_dataset, schema=schema) - - -@pytest.mark.pandas -@parametrize_legacy_dataset -def test_write_to_dataset_with_partitions_and_index_name( - tempdir, use_legacy_dataset -): - _test_write_to_dataset_with_partitions( - str(tempdir), use_legacy_dataset, index_name='index_name') - - -@pytest.mark.pandas -@parametrize_legacy_dataset -def test_write_to_dataset_no_partitions(tempdir, use_legacy_dataset): - _test_write_to_dataset_no_partitions(str(tempdir), use_legacy_dataset) - - -@pytest.mark.pandas -@parametrize_legacy_dataset -def test_write_to_dataset_pathlib(tempdir, use_legacy_dataset): - _test_write_to_dataset_with_partitions( - tempdir / "test1", use_legacy_dataset) - _test_write_to_dataset_no_partitions( - tempdir / "test2", use_legacy_dataset) - - -# Those tests are failing - see ARROW-10370 -# @pytest.mark.pandas -# @pytest.mark.s3 -# @parametrize_legacy_dataset -# def test_write_to_dataset_pathlib_nonlocal( -# tempdir, s3_example_s3fs, use_legacy_dataset -# ): -# # pathlib paths are only accepted for local files -# fs, _ = s3_example_s3fs - -# with pytest.raises(TypeError, match="path-like objects are only allowed"): -# _test_write_to_dataset_with_partitions( -# tempdir / "test1", use_legacy_dataset, filesystem=fs) - -# with pytest.raises(TypeError, match="path-like objects are only allowed"): -# _test_write_to_dataset_no_partitions( -# tempdir / "test2", use_legacy_dataset, filesystem=fs) - - -@pytest.mark.pandas -@parametrize_legacy_dataset -def test_write_to_dataset_with_partitions_s3fs( - s3_example_s3fs, use_legacy_dataset -): - fs, path = s3_example_s3fs - - _test_write_to_dataset_with_partitions( - path, use_legacy_dataset, filesystem=fs) - - -@pytest.mark.pandas -@parametrize_legacy_dataset -def test_write_to_dataset_no_partitions_s3fs( - s3_example_s3fs, use_legacy_dataset -): - fs, path = s3_example_s3fs - - _test_write_to_dataset_no_partitions( - path, use_legacy_dataset, filesystem=fs) - - -@pytest.mark.pandas -@parametrize_legacy_dataset_not_supported -def test_write_to_dataset_with_partitions_and_custom_filenames( - tempdir, use_legacy_dataset -): - output_df = pd.DataFrame({'group1': list('aaabbbbccc'), - 'group2': list('eefeffgeee'), - 'num': list(range(10)), - 'nan': [np.nan] * 10, - 'date': np.arange('2017-01-01', '2017-01-11', - dtype='datetime64[D]')}) - partition_by = ['group1', 'group2'] - output_table = pa.Table.from_pandas(output_df) - path = str(tempdir) - - def partition_filename_callback(keys): - return "{}-{}.parquet".format(*keys) - - pq.write_to_dataset(output_table, path, - partition_by, partition_filename_callback, - use_legacy_dataset=use_legacy_dataset) - - dataset = pq.ParquetDataset(path) - - # ARROW-3538: Ensure partition filenames match the given pattern - # defined in the local function partition_filename_callback - expected_basenames = [ - 'a-e.parquet', 'a-f.parquet', - 'b-e.parquet', 'b-f.parquet', - 'b-g.parquet', 'c-e.parquet' - ] - output_basenames = [os.path.basename(p.path) for p in dataset.pieces] - - assert sorted(expected_basenames) == sorted(output_basenames) - - -@pytest.mark.pandas -@parametrize_legacy_dataset -def test_write_to_dataset_pandas_preserve_extensiondtypes( - tempdir, use_legacy_dataset -): - # ARROW-8251 - preserve pandas extension dtypes in roundtrip - if LooseVersion(pd.__version__) < "1.0.0": - pytest.skip("__arrow_array__ added to pandas in 1.0.0") - - df = pd.DataFrame({'part': 'a', "col": [1, 2, 3]}) - df['col'] = df['col'].astype("Int64") - table = pa.table(df) - - pq.write_to_dataset( - table, str(tempdir / "case1"), partition_cols=['part'], - use_legacy_dataset=use_legacy_dataset - ) - result = pq.read_table( - str(tempdir / "case1"), use_legacy_dataset=use_legacy_dataset - ).to_pandas() - tm.assert_frame_equal(result[["col"]], df[["col"]]) - - pq.write_to_dataset( - table, str(tempdir / "case2"), use_legacy_dataset=use_legacy_dataset - ) - result = pq.read_table( - str(tempdir / "case2"), use_legacy_dataset=use_legacy_dataset - ).to_pandas() - tm.assert_frame_equal(result[["col"]], df[["col"]]) - - pq.write_table(table, str(tempdir / "data.parquet")) - result = pq.read_table( - str(tempdir / "data.parquet"), use_legacy_dataset=use_legacy_dataset - ).to_pandas() - tm.assert_frame_equal(result[["col"]], df[["col"]]) - - -@pytest.mark.pandas -@parametrize_legacy_dataset -def test_write_to_dataset_pandas_preserve_index(tempdir, use_legacy_dataset): - # ARROW-8251 - preserve pandas index in roundtrip - - df = pd.DataFrame({'part': ['a', 'a', 'b'], "col": [1, 2, 3]}) - df.index = pd.Index(['a', 'b', 'c'], name="idx") - table = pa.table(df) - df_cat = df[["col", "part"]].copy() - df_cat["part"] = df_cat["part"].astype("category") - - pq.write_to_dataset( - table, str(tempdir / "case1"), partition_cols=['part'], - use_legacy_dataset=use_legacy_dataset - ) - result = pq.read_table( - str(tempdir / "case1"), use_legacy_dataset=use_legacy_dataset - ).to_pandas() - tm.assert_frame_equal(result, df_cat) - - pq.write_to_dataset( - table, str(tempdir / "case2"), use_legacy_dataset=use_legacy_dataset - ) - result = pq.read_table( - str(tempdir / "case2"), use_legacy_dataset=use_legacy_dataset - ).to_pandas() - tm.assert_frame_equal(result, df) - - pq.write_table(table, str(tempdir / "data.parquet")) - result = pq.read_table( - str(tempdir / "data.parquet"), use_legacy_dataset=use_legacy_dataset - ).to_pandas() - tm.assert_frame_equal(result, df) - - -@pytest.mark.large_memory -def test_large_table_int32_overflow(): - size = np.iinfo('int32').max + 1 - - arr = np.ones(size, dtype='uint8') - - parr = pa.array(arr, type=pa.uint8()) - - table = pa.Table.from_arrays([parr], names=['one']) - f = io.BytesIO() - _write_table(table, f) - - -def _simple_table_roundtrip(table, use_legacy_dataset=False, **write_kwargs): - stream = pa.BufferOutputStream() - _write_table(table, stream, **write_kwargs) - buf = stream.getvalue() - return _read_table(buf, use_legacy_dataset=use_legacy_dataset) - - -@pytest.mark.large_memory -@parametrize_legacy_dataset -def test_byte_array_exactly_2gb(use_legacy_dataset): - # Test edge case reported in ARROW-3762 - val = b'x' * (1 << 10) - - base = pa.array([val] * ((1 << 21) - 1)) - cases = [ - [b'x' * 1023], # 2^31 - 1 - [b'x' * 1024], # 2^31 - [b'x' * 1025] # 2^31 + 1 - ] - for case in cases: - values = pa.chunked_array([base, pa.array(case)]) - t = pa.table([values], names=['f0']) - result = _simple_table_roundtrip( - t, use_legacy_dataset=use_legacy_dataset, use_dictionary=False) - assert t.equals(result) - - -@pytest.mark.pandas -@pytest.mark.large_memory -@parametrize_legacy_dataset -def test_binary_array_overflow_to_chunked(use_legacy_dataset): - # ARROW-3762 - - # 2^31 + 1 bytes - values = [b'x'] + [ - b'x' * (1 << 20) - ] * 2 * (1 << 10) - df = pd.DataFrame({'byte_col': values}) - - tbl = pa.Table.from_pandas(df, preserve_index=False) - read_tbl = _simple_table_roundtrip( - tbl, use_legacy_dataset=use_legacy_dataset) - - col0_data = read_tbl[0] - assert isinstance(col0_data, pa.ChunkedArray) - - # Split up into 2GB chunks - assert col0_data.num_chunks == 2 - - assert tbl.equals(read_tbl) - - -@pytest.mark.pandas -@pytest.mark.large_memory -@parametrize_legacy_dataset -def test_list_of_binary_large_cell(use_legacy_dataset): - # ARROW-4688 - data = [] - - # TODO(wesm): handle chunked children - # 2^31 - 1 bytes in a single cell - # data.append([b'x' * (1 << 20)] * 2047 + [b'x' * ((1 << 20) - 1)]) - - # A little under 2GB in cell each containing approximately 10MB each - data.extend([[b'x' * 1000000] * 10] * 214) - - arr = pa.array(data) - table = pa.Table.from_arrays([arr], ['chunky_cells']) - read_table = _simple_table_roundtrip( - table, use_legacy_dataset=use_legacy_dataset) - assert table.equals(read_table) - - -@pytest.mark.pandas -@parametrize_legacy_dataset -def test_index_column_name_duplicate(tempdir, use_legacy_dataset): - data = { - 'close': { - pd.Timestamp('2017-06-30 01:31:00'): 154.99958999999998, - pd.Timestamp('2017-06-30 01:32:00'): 154.99958999999998, - }, - 'time': { - pd.Timestamp('2017-06-30 01:31:00'): pd.Timestamp( - '2017-06-30 01:31:00' - ), - pd.Timestamp('2017-06-30 01:32:00'): pd.Timestamp( - '2017-06-30 01:32:00' - ), - } - } - path = str(tempdir / 'data.parquet') - dfx = pd.DataFrame(data).set_index('time', drop=False) - tdfx = pa.Table.from_pandas(dfx) - _write_table(tdfx, path) - arrow_table = _read_table(path, use_legacy_dataset=use_legacy_dataset) - result_df = arrow_table.to_pandas() - tm.assert_frame_equal(result_df, dfx) - - -@pytest.mark.pandas -@parametrize_legacy_dataset -def test_parquet_nested_convenience(tempdir, use_legacy_dataset): - # ARROW-1684 - df = pd.DataFrame({ - 'a': [[1, 2, 3], None, [4, 5], []], - 'b': [[1.], None, None, [6., 7.]], - }) - - path = str(tempdir / 'nested_convenience.parquet') - - table = pa.Table.from_pandas(df, preserve_index=False) - _write_table(table, path) - - read = pq.read_table( - path, columns=['a'], use_legacy_dataset=use_legacy_dataset) - tm.assert_frame_equal(read.to_pandas(), df[['a']]) - - read = pq.read_table( - path, columns=['a', 'b'], use_legacy_dataset=use_legacy_dataset) - tm.assert_frame_equal(read.to_pandas(), df) - - -@pytest.mark.pandas -@parametrize_legacy_dataset -def test_backwards_compatible_index_naming(datadir, use_legacy_dataset): - expected_string = b"""\ -carat cut color clarity depth table price x y z - 0.23 Ideal E SI2 61.5 55.0 326 3.95 3.98 2.43 - 0.21 Premium E SI1 59.8 61.0 326 3.89 3.84 2.31 - 0.23 Good E VS1 56.9 65.0 327 4.05 4.07 2.31 - 0.29 Premium I VS2 62.4 58.0 334 4.20 4.23 2.63 - 0.31 Good J SI2 63.3 58.0 335 4.34 4.35 2.75 - 0.24 Very Good J VVS2 62.8 57.0 336 3.94 3.96 2.48 - 0.24 Very Good I VVS1 62.3 57.0 336 3.95 3.98 2.47 - 0.26 Very Good H SI1 61.9 55.0 337 4.07 4.11 2.53 - 0.22 Fair E VS2 65.1 61.0 337 3.87 3.78 2.49 - 0.23 Very Good H VS1 59.4 61.0 338 4.00 4.05 2.39""" - expected = pd.read_csv(io.BytesIO(expected_string), sep=r'\s{2,}', - index_col=None, header=0, engine='python') - table = _read_table( - datadir / 'v0.7.1.parquet', use_legacy_dataset=use_legacy_dataset) - result = table.to_pandas() - tm.assert_frame_equal(result, expected) - - -@pytest.mark.pandas -@parametrize_legacy_dataset -def test_backwards_compatible_index_multi_level_named( - datadir, use_legacy_dataset -): - expected_string = b"""\ -carat cut color clarity depth table price x y z - 0.23 Ideal E SI2 61.5 55.0 326 3.95 3.98 2.43 - 0.21 Premium E SI1 59.8 61.0 326 3.89 3.84 2.31 - 0.23 Good E VS1 56.9 65.0 327 4.05 4.07 2.31 - 0.29 Premium I VS2 62.4 58.0 334 4.20 4.23 2.63 - 0.31 Good J SI2 63.3 58.0 335 4.34 4.35 2.75 - 0.24 Very Good J VVS2 62.8 57.0 336 3.94 3.96 2.48 - 0.24 Very Good I VVS1 62.3 57.0 336 3.95 3.98 2.47 - 0.26 Very Good H SI1 61.9 55.0 337 4.07 4.11 2.53 - 0.22 Fair E VS2 65.1 61.0 337 3.87 3.78 2.49 - 0.23 Very Good H VS1 59.4 61.0 338 4.00 4.05 2.39""" - expected = pd.read_csv( - io.BytesIO(expected_string), sep=r'\s{2,}', - index_col=['cut', 'color', 'clarity'], - header=0, engine='python' - ).sort_index() - - table = _read_table(datadir / 'v0.7.1.all-named-index.parquet', - use_legacy_dataset=use_legacy_dataset) - result = table.to_pandas() - tm.assert_frame_equal(result, expected) - - -@pytest.mark.pandas -@parametrize_legacy_dataset -def test_backwards_compatible_index_multi_level_some_named( - datadir, use_legacy_dataset -): - expected_string = b"""\ -carat cut color clarity depth table price x y z - 0.23 Ideal E SI2 61.5 55.0 326 3.95 3.98 2.43 - 0.21 Premium E SI1 59.8 61.0 326 3.89 3.84 2.31 - 0.23 Good E VS1 56.9 65.0 327 4.05 4.07 2.31 - 0.29 Premium I VS2 62.4 58.0 334 4.20 4.23 2.63 - 0.31 Good J SI2 63.3 58.0 335 4.34 4.35 2.75 - 0.24 Very Good J VVS2 62.8 57.0 336 3.94 3.96 2.48 - 0.24 Very Good I VVS1 62.3 57.0 336 3.95 3.98 2.47 - 0.26 Very Good H SI1 61.9 55.0 337 4.07 4.11 2.53 - 0.22 Fair E VS2 65.1 61.0 337 3.87 3.78 2.49 - 0.23 Very Good H VS1 59.4 61.0 338 4.00 4.05 2.39""" - expected = pd.read_csv( - io.BytesIO(expected_string), - sep=r'\s{2,}', index_col=['cut', 'color', 'clarity'], - header=0, engine='python' - ).sort_index() - expected.index = expected.index.set_names(['cut', None, 'clarity']) - - table = _read_table(datadir / 'v0.7.1.some-named-index.parquet', - use_legacy_dataset=use_legacy_dataset) - result = table.to_pandas() - tm.assert_frame_equal(result, expected) - - -@pytest.mark.pandas -@parametrize_legacy_dataset -def test_backwards_compatible_column_metadata_handling( - datadir, use_legacy_dataset -): - expected = pd.DataFrame( - {'a': [1, 2, 3], 'b': [.1, .2, .3], - 'c': pd.date_range("2017-01-01", periods=3, tz='Europe/Brussels')}) - expected.index = pd.MultiIndex.from_arrays( - [['a', 'b', 'c'], - pd.date_range("2017-01-01", periods=3, tz='Europe/Brussels')], - names=['index', None]) - - path = datadir / 'v0.7.1.column-metadata-handling.parquet' - table = _read_table(path, use_legacy_dataset=use_legacy_dataset) - result = table.to_pandas() - tm.assert_frame_equal(result, expected) - - table = _read_table( - path, columns=['a'], use_legacy_dataset=use_legacy_dataset) - result = table.to_pandas() - tm.assert_frame_equal(result, expected[['a']].reset_index(drop=True)) - - -# TODO(dataset) support pickling -def _make_dataset_for_pickling(tempdir, N=100): - path = tempdir / 'data.parquet' - fs = LocalFileSystem._get_instance() - - df = pd.DataFrame({ - 'index': np.arange(N), - 'values': np.random.randn(N) - }, columns=['index', 'values']) - table = pa.Table.from_pandas(df) - - num_groups = 3 - with pq.ParquetWriter(path, table.schema) as writer: - for i in range(num_groups): - writer.write_table(table) - - reader = pq.ParquetFile(path) - assert reader.metadata.num_row_groups == num_groups - - metadata_path = tempdir / '_metadata' - with fs.open(metadata_path, 'wb') as f: - pq.write_metadata(table.schema, f) - - dataset = pq.ParquetDataset(tempdir, filesystem=fs) - assert dataset.metadata_path == str(metadata_path) - - return dataset - - -def _assert_dataset_is_picklable(dataset, pickler): - def is_pickleable(obj): - return obj == pickler.loads(pickler.dumps(obj)) - - assert is_pickleable(dataset) - assert is_pickleable(dataset.metadata) - assert is_pickleable(dataset.metadata.schema) - assert len(dataset.metadata.schema) - for column in dataset.metadata.schema: - assert is_pickleable(column) - - for piece in dataset.pieces: - assert is_pickleable(piece) - metadata = piece.get_metadata() - assert metadata.num_row_groups - for i in range(metadata.num_row_groups): - assert is_pickleable(metadata.row_group(i)) - - -@pytest.mark.pandas -def test_builtin_pickle_dataset(tempdir, datadir): - import pickle - dataset = _make_dataset_for_pickling(tempdir) - _assert_dataset_is_picklable(dataset, pickler=pickle) - - -@pytest.mark.pandas -def test_cloudpickle_dataset(tempdir, datadir): - cp = pytest.importorskip('cloudpickle') - dataset = _make_dataset_for_pickling(tempdir) - _assert_dataset_is_picklable(dataset, pickler=cp) - - -@pytest.mark.pandas -@parametrize_legacy_dataset -def test_decimal_roundtrip(tempdir, use_legacy_dataset): - num_values = 10 - - columns = {} - for precision in range(1, 39): - for scale in range(0, precision + 1): - with util.random_seed(0): - random_decimal_values = [ - util.randdecimal(precision, scale) - for _ in range(num_values) - ] - column_name = ('dec_precision_{:d}_scale_{:d}' - .format(precision, scale)) - columns[column_name] = random_decimal_values - - expected = pd.DataFrame(columns) - filename = tempdir / 'decimals.parquet' - string_filename = str(filename) - table = pa.Table.from_pandas(expected) - _write_table(table, string_filename) - result_table = _read_table( - string_filename, use_legacy_dataset=use_legacy_dataset) - result = result_table.to_pandas() - tm.assert_frame_equal(result, expected) - - -@pytest.mark.pandas -@pytest.mark.xfail( - raises=pa.ArrowException, reason='Parquet does not support negative scale' -) -def test_decimal_roundtrip_negative_scale(tempdir): - expected = pd.DataFrame({'decimal_num': [decimal.Decimal('1.23E4')]}) - filename = tempdir / 'decimals.parquet' - string_filename = str(filename) - t = pa.Table.from_pandas(expected) - _write_table(t, string_filename) - result_table = _read_table(string_filename) - result = result_table.to_pandas() - tm.assert_frame_equal(result, expected) - - -@pytest.mark.pandas -@parametrize_legacy_dataset -def test_parquet_writer_context_obj(tempdir, use_legacy_dataset): - df = _test_dataframe(100) - df['unique_id'] = 0 - - arrow_table = pa.Table.from_pandas(df, preserve_index=False) - out = pa.BufferOutputStream() - - with pq.ParquetWriter(out, arrow_table.schema, version='2.0') as writer: - - frames = [] - for i in range(10): - df['unique_id'] = i - arrow_table = pa.Table.from_pandas(df, preserve_index=False) - writer.write_table(arrow_table) - - frames.append(df.copy()) - - buf = out.getvalue() - result = _read_table( - pa.BufferReader(buf), use_legacy_dataset=use_legacy_dataset) - - expected = pd.concat(frames, ignore_index=True) - tm.assert_frame_equal(result.to_pandas(), expected) - - -@pytest.mark.pandas -@parametrize_legacy_dataset -def test_parquet_writer_context_obj_with_exception( - tempdir, use_legacy_dataset -): - df = _test_dataframe(100) - df['unique_id'] = 0 - - arrow_table = pa.Table.from_pandas(df, preserve_index=False) - out = pa.BufferOutputStream() - error_text = 'Artificial Error' - - try: - with pq.ParquetWriter(out, - arrow_table.schema, - version='2.0') as writer: - - frames = [] - for i in range(10): - df['unique_id'] = i - arrow_table = pa.Table.from_pandas(df, preserve_index=False) - writer.write_table(arrow_table) - frames.append(df.copy()) - if i == 5: - raise ValueError(error_text) - except Exception as e: - assert str(e) == error_text - - buf = out.getvalue() - result = _read_table( - pa.BufferReader(buf), use_legacy_dataset=use_legacy_dataset) - - expected = pd.concat(frames, ignore_index=True) - tm.assert_frame_equal(result.to_pandas(), expected) - - -@pytest.mark.pandas -@parametrize_legacy_dataset -def test_zlib_compression_bug(use_legacy_dataset): - # ARROW-3514: "zlib deflate failed, output buffer too small" - table = pa.Table.from_arrays([pa.array(['abc', 'def'])], ['some_col']) - f = io.BytesIO() - pq.write_table(table, f, compression='gzip') - - f.seek(0) - roundtrip = pq.read_table(f, use_legacy_dataset=use_legacy_dataset) - tm.assert_frame_equal(roundtrip.to_pandas(), table.to_pandas()) - - -@pytest.mark.pandas -def test_merging_parquet_tables_with_different_pandas_metadata(tempdir): - # ARROW-3728: Merging Parquet Files - Pandas Meta in Schema Mismatch - schema = pa.schema([ - pa.field('int', pa.int16()), - pa.field('float', pa.float32()), - pa.field('string', pa.string()) - ]) - df1 = pd.DataFrame({ - 'int': np.arange(3, dtype=np.uint8), - 'float': np.arange(3, dtype=np.float32), - 'string': ['ABBA', 'EDDA', 'ACDC'] - }) - df2 = pd.DataFrame({ - 'int': [4, 5], - 'float': [1.1, None], - 'string': [None, None] - }) - table1 = pa.Table.from_pandas(df1, schema=schema, preserve_index=False) - table2 = pa.Table.from_pandas(df2, schema=schema, preserve_index=False) - - assert not table1.schema.equals(table2.schema, check_metadata=True) - assert table1.schema.equals(table2.schema) - - writer = pq.ParquetWriter(tempdir / 'merged.parquet', schema=schema) - writer.write_table(table1) - writer.write_table(table2) - - -def test_empty_row_groups(tempdir): - # ARROW-3020 - table = pa.Table.from_arrays([pa.array([], type='int32')], ['f0']) - - path = tempdir / 'empty_row_groups.parquet' - - num_groups = 3 - with pq.ParquetWriter(path, table.schema) as writer: - for i in range(num_groups): - writer.write_table(table) - - reader = pq.ParquetFile(path) - assert reader.metadata.num_row_groups == num_groups - - for i in range(num_groups): - assert reader.read_row_group(i).equals(table) - - -def test_parquet_file_pass_directory_instead_of_file(tempdir): - # ARROW-7208 - path = tempdir / 'directory' - os.mkdir(str(path)) - - with pytest.raises(IOError, match="Expected file path"): - pq.ParquetFile(path) - - -@pytest.mark.pandas -@pytest.mark.parametrize("filesystem", [ - None, - LocalFileSystem._get_instance(), - fs.LocalFileSystem(), -]) -def test_parquet_writer_filesystem_local(tempdir, filesystem): - df = _test_dataframe(100) - table = pa.Table.from_pandas(df, preserve_index=False) - path = str(tempdir / 'data.parquet') - - with pq.ParquetWriter( - path, table.schema, filesystem=filesystem, version='2.0' - ) as writer: - writer.write_table(table) - - result = _read_table(path).to_pandas() - tm.assert_frame_equal(result, df) - - -@pytest.fixture -def s3_example_fs(s3_connection, s3_server): - from pyarrow.fs import FileSystem - - host, port, access_key, secret_key = s3_connection - uri = ( - "s3://{}:{}@mybucket/data.parquet?scheme=http&endpoint_override={}:{}" - .format(access_key, secret_key, host, port) - ) - fs, path = FileSystem.from_uri(uri) - - fs.create_dir("mybucket") - - yield fs, uri, path - - -@pytest.mark.pandas -@pytest.mark.s3 -def test_parquet_writer_filesystem_s3(s3_example_fs): - df = _test_dataframe(100) - table = pa.Table.from_pandas(df, preserve_index=False) - - fs, uri, path = s3_example_fs - - with pq.ParquetWriter( - path, table.schema, filesystem=fs, version='2.0' - ) as writer: - writer.write_table(table) - - result = _read_table(uri).to_pandas() - tm.assert_frame_equal(result, df) - - -@pytest.mark.pandas -@pytest.mark.s3 -def test_parquet_writer_filesystem_s3_uri(s3_example_fs): - df = _test_dataframe(100) - table = pa.Table.from_pandas(df, preserve_index=False) - - fs, uri, path = s3_example_fs - - with pq.ParquetWriter(uri, table.schema, version='2.0') as writer: - writer.write_table(table) - - result = _read_table(path, filesystem=fs).to_pandas() - tm.assert_frame_equal(result, df) - - -@pytest.mark.pandas -def test_parquet_writer_filesystem_s3fs(s3_example_s3fs): - df = _test_dataframe(100) - table = pa.Table.from_pandas(df, preserve_index=False) - - fs, directory = s3_example_s3fs - path = directory + "/test.parquet" - - with pq.ParquetWriter( - path, table.schema, filesystem=fs, version='2.0' - ) as writer: - writer.write_table(table) - - result = _read_table(path, filesystem=fs).to_pandas() - tm.assert_frame_equal(result, df) - - -@pytest.mark.pandas -def test_parquet_writer_filesystem_buffer_raises(): - df = _test_dataframe(100) - table = pa.Table.from_pandas(df, preserve_index=False) - filesystem = fs.LocalFileSystem() - - # Should raise ValueError when filesystem is passed with file-like object - with pytest.raises(ValueError, match="specified path is file-like"): - pq.ParquetWriter( - pa.BufferOutputStream(), table.schema, filesystem=filesystem - ) - - -@pytest.mark.pandas -@parametrize_legacy_dataset -def test_parquet_writer_with_caller_provided_filesystem(use_legacy_dataset): - out = pa.BufferOutputStream() - - class CustomFS(FileSystem): - def __init__(self): - self.path = None - self.mode = None - - def open(self, path, mode='rb'): - self.path = path - self.mode = mode - return out - - fs = CustomFS() - fname = 'expected_fname.parquet' - df = _test_dataframe(100) - table = pa.Table.from_pandas(df, preserve_index=False) - - with pq.ParquetWriter(fname, table.schema, filesystem=fs, version='2.0') \ - as writer: - writer.write_table(table) - - assert fs.path == fname - assert fs.mode == 'wb' - assert out.closed - - buf = out.getvalue() - table_read = _read_table( - pa.BufferReader(buf), use_legacy_dataset=use_legacy_dataset) - df_read = table_read.to_pandas() - tm.assert_frame_equal(df_read, df) - - # Should raise ValueError when filesystem is passed with file-like object - with pytest.raises(ValueError) as err_info: - pq.ParquetWriter(pa.BufferOutputStream(), table.schema, filesystem=fs) - expected_msg = ("filesystem passed but where is file-like, so" - " there is nothing to open with filesystem.") - assert str(err_info) == expected_msg - - -def test_writing_empty_lists(): - # ARROW-2591: [Python] Segmentation fault issue in pq.write_table - arr1 = pa.array([[], []], pa.list_(pa.int32())) - table = pa.Table.from_arrays([arr1], ['list(int32)']) - _check_roundtrip(table) - - -@parametrize_legacy_dataset -def test_write_nested_zero_length_array_chunk_failure(use_legacy_dataset): - # Bug report in ARROW-3792 - cols = OrderedDict( - int32=pa.int32(), - list_string=pa.list_(pa.string()) - ) - data = [[], [OrderedDict(int32=1, list_string=('G',)), ]] - - # This produces a table with a column like - # )> - # [ - # [], - # [ - # [ - # "G" - # ] - # ] - # ] - # - # Each column is a ChunkedArray with 2 elements - my_arrays = [pa.array(batch, type=pa.struct(cols)).flatten() - for batch in data] - my_batches = [pa.RecordBatch.from_arrays(batch, schema=pa.schema(cols)) - for batch in my_arrays] - tbl = pa.Table.from_batches(my_batches, pa.schema(cols)) - _check_roundtrip(tbl, use_legacy_dataset=use_legacy_dataset) - - -@pytest.mark.pandas -@parametrize_legacy_dataset -def test_partitioned_dataset(tempdir, use_legacy_dataset): - # ARROW-3208: Segmentation fault when reading a Parquet partitioned dataset - # to a Parquet file - path = tempdir / "ARROW-3208" - df = pd.DataFrame({ - 'one': [-1, 10, 2.5, 100, 1000, 1, 29.2], - 'two': [-1, 10, 2, 100, 1000, 1, 11], - 'three': [0, 0, 0, 0, 0, 0, 0] - }) - table = pa.Table.from_pandas(df) - pq.write_to_dataset(table, root_path=str(path), - partition_cols=['one', 'two']) - table = pq.ParquetDataset( - path, use_legacy_dataset=use_legacy_dataset).read() - pq.write_table(table, path / "output.parquet") - - -def test_read_column_invalid_index(): - table = pa.table([pa.array([4, 5]), pa.array(["foo", "bar"])], - names=['ints', 'strs']) - bio = pa.BufferOutputStream() - pq.write_table(table, bio) - f = pq.ParquetFile(bio.getvalue()) - assert f.reader.read_column(0).to_pylist() == [4, 5] - assert f.reader.read_column(1).to_pylist() == ["foo", "bar"] - for index in (-1, 2): - with pytest.raises((ValueError, IndexError)): - f.reader.read_column(index) - - -@pytest.mark.pandas -@parametrize_legacy_dataset -def test_direct_read_dictionary(use_legacy_dataset): - # ARROW-3325 - repeats = 10 - nunique = 5 - - data = [ - [util.rands(10) for i in range(nunique)] * repeats, - - ] - table = pa.table(data, names=['f0']) - - bio = pa.BufferOutputStream() - pq.write_table(table, bio) - contents = bio.getvalue() - - result = pq.read_table(pa.BufferReader(contents), - read_dictionary=['f0'], - use_legacy_dataset=use_legacy_dataset) - - # Compute dictionary-encoded subfield - expected = pa.table([table[0].dictionary_encode()], names=['f0']) - assert result.equals(expected) - - -@pytest.mark.pandas -@parametrize_legacy_dataset -def test_dataset_read_dictionary(tempdir, use_legacy_dataset): - path = tempdir / "ARROW-3325-dataset" - t1 = pa.table([[util.rands(10) for i in range(5)] * 10], names=['f0']) - t2 = pa.table([[util.rands(10) for i in range(5)] * 10], names=['f0']) - # TODO pass use_legacy_dataset (need to fix unique names) - pq.write_to_dataset(t1, root_path=str(path)) - pq.write_to_dataset(t2, root_path=str(path)) - - result = pq.ParquetDataset( - path, read_dictionary=['f0'], - use_legacy_dataset=use_legacy_dataset).read() - - # The order of the chunks is non-deterministic - ex_chunks = [t1[0].chunk(0).dictionary_encode(), - t2[0].chunk(0).dictionary_encode()] - - assert result[0].num_chunks == 2 - c0, c1 = result[0].chunk(0), result[0].chunk(1) - if c0.equals(ex_chunks[0]): - assert c1.equals(ex_chunks[1]) - else: - assert c0.equals(ex_chunks[1]) - assert c1.equals(ex_chunks[0]) - - -@pytest.mark.pandas -@parametrize_legacy_dataset -def test_direct_read_dictionary_subfield(use_legacy_dataset): - repeats = 10 - nunique = 5 - - data = [ - [[util.rands(10)] for i in range(nunique)] * repeats, - ] - table = pa.table(data, names=['f0']) - - bio = pa.BufferOutputStream() - pq.write_table(table, bio) - contents = bio.getvalue() - result = pq.read_table(pa.BufferReader(contents), - read_dictionary=['f0.list.item'], - use_legacy_dataset=use_legacy_dataset) - - arr = pa.array(data[0]) - values_as_dict = arr.values.dictionary_encode() - - inner_indices = values_as_dict.indices.cast('int32') - new_values = pa.DictionaryArray.from_arrays(inner_indices, - values_as_dict.dictionary) - - offsets = pa.array(range(51), type='int32') - expected_arr = pa.ListArray.from_arrays(offsets, new_values) - expected = pa.table([expected_arr], names=['f0']) - - assert result.equals(expected) - assert result[0].num_chunks == 1 - - -@pytest.mark.pandas -@parametrize_legacy_dataset -def test_write_to_dataset_metadata(tempdir, use_legacy_dataset): - path = tempdir / "ARROW-1983-dataset" - - # create and write a test dataset - df = pd.DataFrame({ - 'one': [1, 2, 3], - 'two': [-1, -2, -3], - 'three': [[1, 2], [2, 3], [3, 4]], - }) - table = pa.Table.from_pandas(df) - - metadata_list = [] - if not use_legacy_dataset: - # New dataset implementation does not yet support metadata_collector - with pytest.raises(ValueError): - pq.write_to_dataset(table, root_path=str(path), - partition_cols=['one', 'two'], - metadata_collector=metadata_list, - use_legacy_dataset=use_legacy_dataset) - return - pq.write_to_dataset(table, root_path=str(path), - partition_cols=['one', 'two'], - metadata_collector=metadata_list, - use_legacy_dataset=use_legacy_dataset) - - # open the dataset and collect metadata from pieces: - dataset = pq.ParquetDataset(path) - metadata_list2 = [p.get_metadata() for p in dataset.pieces] - - collected_paths = [] - # compare metadata list content: - assert len(metadata_list) == len(metadata_list2) - for md, md2 in zip(metadata_list, metadata_list2): - d = md.to_dict() - d2 = md2.to_dict() - # serialized_size is initialized in the reader: - assert d.pop('serialized_size') == 0 - assert d2.pop('serialized_size') > 0 - # file_path is different (not set for in-file metadata) - assert d["row_groups"][0]["columns"][0]["file_path"] != "" - assert d2["row_groups"][0]["columns"][0]["file_path"] == "" - # collect file paths to check afterwards, ignore here - collected_paths.append(d["row_groups"][0]["columns"][0]["file_path"]) - d["row_groups"][0]["columns"][0]["file_path"] = "" - assert d == d2 - - # ARROW-8244 - check the file paths in the collected metadata - n_root = len(path.parts) - file_paths = ["/".join(p.parts[n_root:]) for p in path.rglob("*.parquet")] - assert sorted(collected_paths) == sorted(file_paths) - - # writing to single file (not partitioned) - metadata_list = [] - pq.write_to_dataset(pa.table({'a': [1, 2, 3]}), root_path=str(path), - metadata_collector=metadata_list) - - # compare metadata content - file_paths = list(path.glob("*.parquet")) - assert len(file_paths) == 1 - file_path = file_paths[0] - file_metadata = pq.read_metadata(file_path) - d1 = metadata_list[0].to_dict() - d2 = file_metadata.to_dict() - # serialized_size is initialized in the reader: - assert d1.pop('serialized_size') == 0 - assert d2.pop('serialized_size') > 0 - # file_path is different (not set for in-file metadata) - assert d1["row_groups"][0]["columns"][0]["file_path"] == file_path.name - assert d2["row_groups"][0]["columns"][0]["file_path"] == "" - d1["row_groups"][0]["columns"][0]["file_path"] = "" - assert d1 == d2 - - -@parametrize_legacy_dataset -def test_parquet_file_too_small(tempdir, use_legacy_dataset): - path = str(tempdir / "test.parquet") - # TODO(dataset) with datasets API it raises OSError instead - with pytest.raises((pa.ArrowInvalid, OSError), - match='size is 0 bytes'): - with open(path, 'wb') as f: - pass - pq.read_table(path, use_legacy_dataset=use_legacy_dataset) - - with pytest.raises((pa.ArrowInvalid, OSError), - match='size is 4 bytes'): - with open(path, 'wb') as f: - f.write(b'ffff') - pq.read_table(path, use_legacy_dataset=use_legacy_dataset) - - -@pytest.mark.pandas -@parametrize_legacy_dataset -def test_categorical_index_survives_roundtrip(use_legacy_dataset): - # ARROW-3652, addressed by ARROW-3246 - df = pd.DataFrame([['a', 'b'], ['c', 'd']], columns=['c1', 'c2']) - df['c1'] = df['c1'].astype('category') - df = df.set_index(['c1']) - - table = pa.Table.from_pandas(df) - bos = pa.BufferOutputStream() - pq.write_table(table, bos) - ref_df = pq.read_pandas( - bos.getvalue(), use_legacy_dataset=use_legacy_dataset).to_pandas() - assert isinstance(ref_df.index, pd.CategoricalIndex) - assert ref_df.index.equals(df.index) - - -@pytest.mark.pandas -@parametrize_legacy_dataset -def test_categorical_order_survives_roundtrip(use_legacy_dataset): - # ARROW-6302 - df = pd.DataFrame({"a": pd.Categorical( - ["a", "b", "c", "a"], categories=["b", "c", "d"], ordered=True)}) - - table = pa.Table.from_pandas(df) - bos = pa.BufferOutputStream() - pq.write_table(table, bos) - - contents = bos.getvalue() - result = pq.read_pandas( - contents, use_legacy_dataset=use_legacy_dataset).to_pandas() - - tm.assert_frame_equal(result, df) - - -def _simple_table_write_read(table, use_legacy_dataset): - bio = pa.BufferOutputStream() - pq.write_table(table, bio) - contents = bio.getvalue() - return pq.read_table( - pa.BufferReader(contents), use_legacy_dataset=use_legacy_dataset - ) - - -@parametrize_legacy_dataset -def test_dictionary_array_automatically_read(use_legacy_dataset): - # ARROW-3246 - - # Make a large dictionary, a little over 4MB of data - dict_length = 4000 - dict_values = pa.array([('x' * 1000 + '_{}'.format(i)) - for i in range(dict_length)]) - - num_chunks = 10 - chunk_size = 100 - chunks = [] - for i in range(num_chunks): - indices = np.random.randint(0, dict_length, - size=chunk_size).astype(np.int32) - chunks.append(pa.DictionaryArray.from_arrays(pa.array(indices), - dict_values)) - - table = pa.table([pa.chunked_array(chunks)], names=['f0']) - result = _simple_table_write_read(table, use_legacy_dataset) - - assert result.equals(table) - - # The only key in the metadata was the Arrow schema key - assert result.schema.metadata is None - - -def test_field_id_metadata(): - # ARROW-7080 - table = pa.table([pa.array([1], type='int32'), - pa.array([[]], type=pa.list_(pa.int32())), - pa.array([b'boo'], type='binary')], - ['f0', 'f1', 'f2']) - - bio = pa.BufferOutputStream() - pq.write_table(table, bio) - contents = bio.getvalue() - - pf = pq.ParquetFile(pa.BufferReader(contents)) - schema = pf.schema_arrow - - # Expected Parquet schema for reference - # - # required group field_id=0 schema { - # optional int32 field_id=1 f0; - # optional group field_id=2 f1 (List) { - # repeated group field_id=3 list { - # optional int32 field_id=4 item; - # } - # } - # optional binary field_id=5 f2; - # } - - field_name = b'PARQUET:field_id' - assert schema[0].metadata[field_name] == b'1' - - list_field = schema[1] - assert list_field.metadata[field_name] == b'2' - - list_item_field = list_field.type.value_field - assert list_item_field.metadata[field_name] == b'4' - - assert schema[2].metadata[field_name] == b'5' - - -@pytest.mark.pandas -@parametrize_legacy_dataset -def test_pandas_categorical_na_type_row_groups(use_legacy_dataset): - # ARROW-5085 - df = pd.DataFrame({"col": [None] * 100, "int": [1.0] * 100}) - df_category = df.astype({"col": "category", "int": "category"}) - table = pa.Table.from_pandas(df) - table_cat = pa.Table.from_pandas(df_category) - buf = pa.BufferOutputStream() - - # it works - pq.write_table(table_cat, buf, version="2.0", chunk_size=10) - result = pq.read_table( - buf.getvalue(), use_legacy_dataset=use_legacy_dataset) - - # Result is non-categorical - assert result[0].equals(table[0]) - assert result[1].equals(table[1]) - - -@pytest.mark.pandas -@parametrize_legacy_dataset -def test_pandas_categorical_roundtrip(use_legacy_dataset): - # ARROW-5480, this was enabled by ARROW-3246 - - # Have one of the categories unobserved and include a null (-1) - codes = np.array([2, 0, 0, 2, 0, -1, 2], dtype='int32') - categories = ['foo', 'bar', 'baz'] - df = pd.DataFrame({'x': pd.Categorical.from_codes( - codes, categories=categories)}) - - buf = pa.BufferOutputStream() - pq.write_table(pa.table(df), buf) - - result = pq.read_table( - buf.getvalue(), use_legacy_dataset=use_legacy_dataset).to_pandas() - assert result.x.dtype == 'category' - assert (result.x.cat.categories == categories).all() - tm.assert_frame_equal(result, df) - - -@pytest.mark.pandas -def test_multi_dataset_metadata(tempdir): - filenames = ["ARROW-1983-dataset.0", "ARROW-1983-dataset.1"] - metapath = str(tempdir / "_metadata") - - # create a test dataset - df = pd.DataFrame({ - 'one': [1, 2, 3], - 'two': [-1, -2, -3], - 'three': [[1, 2], [2, 3], [3, 4]], - }) - table = pa.Table.from_pandas(df) - - # write dataset twice and collect/merge metadata - _meta = None - for filename in filenames: - meta = [] - pq.write_table(table, str(tempdir / filename), - metadata_collector=meta) - meta[0].set_file_path(filename) - if _meta is None: - _meta = meta[0] - else: - _meta.append_row_groups(meta[0]) - - # Write merged metadata-only file - with open(metapath, "wb") as f: - _meta.write_metadata_file(f) - - # Read back the metadata - meta = pq.read_metadata(metapath) - md = meta.to_dict() - _md = _meta.to_dict() - for key in _md: - if key != 'serialized_size': - assert _md[key] == md[key] - assert _md['num_columns'] == 3 - assert _md['num_rows'] == 6 - assert _md['num_row_groups'] == 2 - assert _md['serialized_size'] == 0 - assert md['serialized_size'] > 0 - - -def test_write_metadata(tempdir): - path = str(tempdir / "metadata") - schema = pa.schema([("a", "int64"), ("b", "float64")]) - - # write a pyarrow schema - pq.write_metadata(schema, path) - parquet_meta = pq.read_metadata(path) - schema_as_arrow = parquet_meta.schema.to_arrow_schema() - assert schema_as_arrow.equals(schema) - - # ARROW-8980: Check that the ARROW:schema metadata key was removed - if schema_as_arrow.metadata: - assert b'ARROW:schema' not in schema_as_arrow.metadata - - # pass through writer keyword arguments - for version in ["1.0", "2.0"]: - pq.write_metadata(schema, path, version=version) - parquet_meta = pq.read_metadata(path) - assert parquet_meta.format_version == version - - # metadata_collector: list of FileMetaData objects - table = pa.table({'a': [1, 2], 'b': [.1, .2]}, schema=schema) - pq.write_table(table, tempdir / "data.parquet") - parquet_meta = pq.read_metadata(str(tempdir / "data.parquet")) - pq.write_metadata( - schema, path, metadata_collector=[parquet_meta, parquet_meta] - ) - parquet_meta_mult = pq.read_metadata(path) - assert parquet_meta_mult.num_row_groups == 2 - - # append metadata with different schema raises an error - with pytest.raises(RuntimeError, match="requires equal schemas"): - pq.write_metadata( - pa.schema([("a", "int32"), ("b", "null")]), - path, metadata_collector=[parquet_meta, parquet_meta] - ) - - -@parametrize_legacy_dataset -@pytest.mark.pandas -def test_filter_before_validate_schema(tempdir, use_legacy_dataset): - # ARROW-4076 apply filter before schema validation - # to avoid checking unneeded schemas - - # create partitioned dataset with mismatching schemas which would - # otherwise raise if first validation all schemas - dir1 = tempdir / 'A=0' - dir1.mkdir() - table1 = pa.Table.from_pandas(pd.DataFrame({'B': [1, 2, 3]})) - pq.write_table(table1, dir1 / 'data.parquet') - - dir2 = tempdir / 'A=1' - dir2.mkdir() - table2 = pa.Table.from_pandas(pd.DataFrame({'B': ['a', 'b', 'c']})) - pq.write_table(table2, dir2 / 'data.parquet') - - # read single file using filter - table = pq.read_table(tempdir, filters=[[('A', '==', 0)]], - use_legacy_dataset=use_legacy_dataset) - assert table.column('B').equals(pa.chunked_array([[1, 2, 3]])) - - -@pytest.mark.pandas -@pytest.mark.fastparquet -@pytest.mark.filterwarnings("ignore:RangeIndex:FutureWarning") -@pytest.mark.filterwarnings("ignore:tostring:DeprecationWarning:fastparquet") -def test_fastparquet_cross_compatibility(tempdir): - fp = pytest.importorskip('fastparquet') - - df = pd.DataFrame( - { - "a": list("abc"), - "b": list(range(1, 4)), - "c": np.arange(4.0, 7.0, dtype="float64"), - "d": [True, False, True], - "e": pd.date_range("20130101", periods=3), - "f": pd.Categorical(["a", "b", "a"]), - # fastparquet writes list as BYTE_ARRAY JSON, so no roundtrip - # "g": [[1, 2], None, [1, 2, 3]], - } - ) - table = pa.table(df) - - # Arrow -> fastparquet - file_arrow = str(tempdir / "cross_compat_arrow.parquet") - pq.write_table(table, file_arrow, compression=None) - - fp_file = fp.ParquetFile(file_arrow) - df_fp = fp_file.to_pandas() - tm.assert_frame_equal(df, df_fp) - - # Fastparquet -> arrow - file_fastparquet = str(tempdir / "cross_compat_fastparquet.parquet") - fp.write(file_fastparquet, df) - - table_fp = pq.read_pandas(file_fastparquet) - # for fastparquet written file, categoricals comes back as strings - # (no arrow schema in parquet metadata) - df['f'] = df['f'].astype(object) - tm.assert_frame_equal(table_fp.to_pandas(), df) - - -def test_table_large_metadata(): - # ARROW-8694 - my_schema = pa.schema([pa.field('f0', 'double')], - metadata={'large': 'x' * 10000000}) - - table = pa.table([np.arange(10)], schema=my_schema) - _check_roundtrip(table) - - -@parametrize_legacy_dataset -@pytest.mark.parametrize('array_factory', [ - lambda: pa.array([0, None] * 10), - lambda: pa.array([0, None] * 10).dictionary_encode(), - lambda: pa.array(["", None] * 10), - lambda: pa.array(["", None] * 10).dictionary_encode(), -]) -@pytest.mark.parametrize('use_dictionary', [False, True]) -@pytest.mark.parametrize('read_dictionary', [False, True]) -def test_buffer_contents( - array_factory, use_dictionary, read_dictionary, use_legacy_dataset -): - # Test that null values are deterministically initialized to zero - # after a roundtrip through Parquet. - # See ARROW-8006 and ARROW-8011. - orig_table = pa.Table.from_pydict({"col": array_factory()}) - bio = io.BytesIO() - pq.write_table(orig_table, bio, use_dictionary=True) - bio.seek(0) - read_dictionary = ['col'] if read_dictionary else None - table = pq.read_table(bio, use_threads=False, - read_dictionary=read_dictionary, - use_legacy_dataset=use_legacy_dataset) - - for col in table.columns: - [chunk] = col.chunks - buf = chunk.buffers()[1] - assert buf.to_pybytes() == buf.size * b"\0" - - -@pytest.mark.dataset -def test_dataset_unsupported_keywords(): - - with pytest.raises(ValueError, match="not yet supported with the new"): - pq.ParquetDataset("", use_legacy_dataset=False, schema=pa.schema([])) - - with pytest.raises(ValueError, match="not yet supported with the new"): - pq.ParquetDataset("", use_legacy_dataset=False, metadata=pa.schema([])) - - with pytest.raises(ValueError, match="not yet supported with the new"): - pq.ParquetDataset("", use_legacy_dataset=False, validate_schema=False) - - with pytest.raises(ValueError, match="not yet supported with the new"): - pq.ParquetDataset("", use_legacy_dataset=False, split_row_groups=True) - - with pytest.raises(ValueError, match="not yet supported with the new"): - pq.ParquetDataset("", use_legacy_dataset=False, metadata_nthreads=4) - - with pytest.raises(ValueError, match="no longer supported"): - pq.read_table("", use_legacy_dataset=False, metadata=pa.schema([])) - - -@pytest.mark.dataset -def test_dataset_partitioning(tempdir): - import pyarrow.dataset as ds - - # create small dataset with directory partitioning - root_path = tempdir / "test_partitioning" - (root_path / "2012" / "10" / "01").mkdir(parents=True) - - table = pa.table({'a': [1, 2, 3]}) - pq.write_table( - table, str(root_path / "2012" / "10" / "01" / "data.parquet")) - - # This works with new dataset API - - # read_table - part = ds.partitioning(field_names=["year", "month", "day"]) - result = pq.read_table( - str(root_path), partitioning=part, use_legacy_dataset=False) - assert result.column_names == ["a", "year", "month", "day"] - - result = pq.ParquetDataset( - str(root_path), partitioning=part, use_legacy_dataset=False).read() - assert result.column_names == ["a", "year", "month", "day"] - - # This raises an error for legacy dataset - with pytest.raises(ValueError): - pq.read_table( - str(root_path), partitioning=part, use_legacy_dataset=True) - - with pytest.raises(ValueError): - pq.ParquetDataset( - str(root_path), partitioning=part, use_legacy_dataset=True) - - -@pytest.mark.dataset -def test_parquet_dataset_new_filesystem(tempdir): - # Ensure we can pass new FileSystem object to ParquetDataset - # (use new implementation automatically without specifying - # use_legacy_dataset=False) - table = pa.table({'a': [1, 2, 3]}) - pq.write_table(table, tempdir / 'data.parquet') - # don't use simple LocalFileSystem (as that gets mapped to legacy one) - filesystem = fs.SubTreeFileSystem(str(tempdir), fs.LocalFileSystem()) - dataset = pq.ParquetDataset('.', filesystem=filesystem) - result = dataset.read() - assert result.equals(table) - - -def test_parquet_dataset_partitions_piece_path_with_fsspec(tempdir): - # ARROW-10462 ensure that on Windows we properly use posix-style paths - # as used by fsspec - fsspec = pytest.importorskip("fsspec") - filesystem = fsspec.filesystem('file') - table = pa.table({'a': [1, 2, 3]}) - pq.write_table(table, tempdir / 'data.parquet') - - # pass a posix-style path (using "/" also on Windows) - path = str(tempdir).replace("\\", "/") - dataset = pq.ParquetDataset(path, filesystem=filesystem) - # ensure the piece path is also posix-style - expected = path + "/data.parquet" - assert dataset.pieces[0].path == expected - - -def test_parquet_compression_roundtrip(tempdir): - # ARROW-10480: ensure even with nonstandard Parquet file naming - # conventions, writing and then reading a file works. In - # particular, ensure that we don't automatically double-compress - # the stream due to auto-detecting the extension in the filename - table = pa.table([pa.array(range(4))], names=["ints"]) - path = tempdir / "arrow-10480.pyarrow.gz" - pq.write_table(table, path, compression="GZIP") - result = pq.read_table(path) - assert result.equals(table) From fbe89d1332007d2fcd6292f4d54ed9caacbdc3f9 Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Mon, 4 Jan 2021 16:54:52 -0500 Subject: [PATCH 26/38] rename struct->project, move out of cast's test --- cpp/src/arrow/array/array_struct_test.cc | 11 -- cpp/src/arrow/compute/cast.cc | 36 +++--- cpp/src/arrow/compute/cast.h | 4 +- cpp/src/arrow/compute/kernels/CMakeLists.txt | 1 + .../arrow/compute/kernels/scalar_cast_test.cc | 43 ------ .../compute/kernels/scalar_project_test.cc | 122 ++++++++++++++++++ cpp/src/arrow/dataset/expression.cc | 10 +- cpp/src/arrow/dataset/expression_internal.h | 6 +- cpp/src/arrow/type.h | 8 +- 9 files changed, 153 insertions(+), 88 deletions(-) create mode 100644 cpp/src/arrow/compute/kernels/scalar_project_test.cc diff --git a/cpp/src/arrow/array/array_struct_test.cc b/cpp/src/arrow/array/array_struct_test.cc index aef0076d0d3..98fc69355d9 100644 --- a/cpp/src/arrow/array/array_struct_test.cc +++ b/cpp/src/arrow/array/array_struct_test.cc @@ -594,16 +594,6 @@ TEST(TestFieldRef, GetChildren) { auto expected_a = ArrayFromJSON(float64(), "[6.125, 0.0, -1]"); AssertArraysEqual(*a, *expected_a); - auto ToChunked = [struct_array](int64_t midpoint) { - return ChunkedArray( - ArrayVector{ - struct_array->Slice(0, midpoint), - struct_array->Slice(midpoint), - }, - struct_array->type()); - }; - AssertChunkedEquivalent(ToChunked(1), ToChunked(2)); - // more nested: struct_array = ArrayFromJSON(struct_({field("a", struct_({field("a", float64())}))}), R"([ @@ -615,7 +605,6 @@ TEST(TestFieldRef, GetChildren) { ASSERT_OK_AND_ASSIGN(a, FieldRef("a", "a").GetOne(*struct_array)); expected_a = ArrayFromJSON(float64(), "[6.125, 0.0, -1]"); AssertArraysEqual(*a, *expected_a); - AssertChunkedEquivalent(ToChunked(1), ToChunked(2)); } } // namespace arrow diff --git a/cpp/src/arrow/compute/cast.cc b/cpp/src/arrow/compute/cast.cc index 5c332aedf73..ecdbdfd9d8f 100644 --- a/cpp/src/arrow/compute/cast.cc +++ b/cpp/src/arrow/compute/cast.cc @@ -118,17 +118,17 @@ class CastMetaFunction : public MetaFunction { } // namespace -const FunctionDoc struct_doc{"Wrap Arrays into a StructArray", - ("Names of the StructArray's fields are\n" - "specified through StructOptions."), - {"*args"}, - "StructOptions"}; - -Result StructResolve(KernelContext* ctx, - const std::vector& descrs) { - const auto& names = OptionsWrapper::Get(ctx).field_names; +const FunctionDoc project_doc{"Wrap Arrays into a StructArray", + ("Names of the StructArray's fields are\n" + "specified through ProjectOptions."), + {"*args"}, + "ProjectOptions"}; + +Result ProjectResolve(KernelContext* ctx, + const std::vector& descrs) { + const auto& names = OptionsWrapper::Get(ctx).field_names; if (names.size() != descrs.size()) { - return Status::Invalid("Struct() was passed ", names.size(), " field ", "names but ", + return Status::Invalid("project() was passed ", names.size(), " field ", "names but ", descrs.size(), " arguments"); } @@ -157,8 +157,8 @@ Result StructResolve(KernelContext* ctx, return ValueDescr{struct_(std::move(fields)), shape}; } -void StructExec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - KERNEL_ASSIGN_OR_RAISE(auto descr, ctx, StructResolve(ctx, batch.GetDescriptors())); +void ProjectExec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + KERNEL_ASSIGN_OR_RAISE(auto descr, ctx, ProjectResolve(ctx, batch.GetDescriptors())); if (descr.shape == ValueDescr::SCALAR) { ScalarVector scalars(batch.num_values()); @@ -189,15 +189,15 @@ void StructExec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { void RegisterScalarCast(FunctionRegistry* registry) { DCHECK_OK(registry->AddFunction(std::make_shared())); - auto struct_function = - std::make_shared("struct", Arity::VarArgs(), &struct_doc); - ScalarKernel kernel{KernelSignature::Make({InputType{}}, OutputType{StructResolve}, + auto project_function = + std::make_shared("project", Arity::VarArgs(), &project_doc); + ScalarKernel kernel{KernelSignature::Make({InputType{}}, OutputType{ProjectResolve}, /*is_varargs=*/true), - StructExec, OptionsWrapper::Init}; + ProjectExec, OptionsWrapper::Init}; kernel.null_handling = NullHandling::OUTPUT_NOT_NULL; kernel.mem_allocation = MemAllocation::NO_PREALLOCATE; - DCHECK_OK(struct_function->AddKernel(std::move(kernel))); - DCHECK_OK(registry->AddFunction(std::move(struct_function))); + DCHECK_OK(project_function->AddKernel(std::move(kernel))); + DCHECK_OK(registry->AddFunction(std::move(project_function))); } } // namespace internal diff --git a/cpp/src/arrow/compute/cast.h b/cpp/src/arrow/compute/cast.h index 759b7c7665b..2454491e9e2 100644 --- a/cpp/src/arrow/compute/cast.h +++ b/cpp/src/arrow/compute/cast.h @@ -160,8 +160,8 @@ Result Cast(const Datum& value, std::shared_ptr to_type, /// \addtogroup compute-concrete-options /// @{ -struct ARROW_EXPORT StructOptions : public FunctionOptions { - explicit StructOptions(std::vector n) : field_names(std::move(n)) {} +struct ARROW_EXPORT ProjectOptions : public FunctionOptions { + explicit ProjectOptions(std::vector n) : field_names(std::move(n)) {} /// Names for wrapped columns std::vector field_names; diff --git a/cpp/src/arrow/compute/kernels/CMakeLists.txt b/cpp/src/arrow/compute/kernels/CMakeLists.txt index 577b250da87..6043ae62f4d 100644 --- a/cpp/src/arrow/compute/kernels/CMakeLists.txt +++ b/cpp/src/arrow/compute/kernels/CMakeLists.txt @@ -25,6 +25,7 @@ add_arrow_compute_test(scalar_test scalar_cast_test.cc scalar_compare_test.cc scalar_nested_test.cc + scalar_project_test.cc scalar_set_lookup_test.cc scalar_string_test.cc scalar_validity_test.cc diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_test.cc b/cpp/src/arrow/compute/kernels/scalar_cast_test.cc index fda8073beb9..67727f18dc7 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_test.cc @@ -1896,48 +1896,5 @@ TEST_F(TestCast, ExtensionTypeToIntDowncast) { ASSERT_OK(UnregisterExtensionType("smallint")); } -class TestStruct : public TestBase { - public: - Result Struct(std::vector args) { - StructOptions opts{field_names}; - return CallFunction("struct", args, &opts); - } - - std::vector field_names; -}; - -TEST_F(TestStruct, Scalar) { - std::shared_ptr expected(new StructScalar{{}, struct_({})}); - ASSERT_OK_AND_EQ(Datum(expected), Struct({})); - - auto i32 = MakeScalar(1); - auto f64 = MakeScalar(2.5); - auto str = MakeScalar("yo"); - - expected.reset(new StructScalar{ - {i32, f64, str}, - struct_({field("i", i32->type), field("f", f64->type), field("s", str->type)})}); - field_names = {"i", "f", "s"}; - ASSERT_OK_AND_EQ(Datum(expected), Struct({i32, f64, str})); - - // Three field names but one input value - ASSERT_RAISES(Invalid, Struct({str})); -} - -TEST_F(TestStruct, Array) { - field_names = {"i", "s"}; - auto i32 = ArrayFromJSON(int32(), "[42, 13, 7]"); - auto str = ArrayFromJSON(utf8(), R"(["aa", "aa", "aa"])"); - ASSERT_OK_AND_ASSIGN(Datum expected, StructArray::Make({i32, str}, field_names)); - - ASSERT_OK_AND_EQ(expected, Struct({i32, str})); - - // Scalars are broadcast to the length of the arrays - ASSERT_OK_AND_EQ(expected, Struct({i32, MakeScalar("aa")})); - - // Array length mismatch - ASSERT_RAISES(Invalid, Struct({i32->Slice(1), str})); -} - } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/scalar_project_test.cc b/cpp/src/arrow/compute/kernels/scalar_project_test.cc new file mode 100644 index 00000000000..1cccfb9c540 --- /dev/null +++ b/cpp/src/arrow/compute/kernels/scalar_project_test.cc @@ -0,0 +1,122 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "arrow/array.h" +#include "arrow/buffer.h" +#include "arrow/chunked_array.h" +#include "arrow/status.h" +#include "arrow/type.h" +#include "arrow/type_fwd.h" +#include "arrow/type_traits.h" +#include "arrow/util/decimal.h" + +#include "arrow/compute/api_vector.h" +#include "arrow/compute/cast.h" +#include "arrow/compute/kernel.h" +#include "arrow/compute/kernels/test_util.h" + +namespace arrow { +namespace compute { + +struct { + public: + Result operator()(std::vector args) { + ProjectOptions opts{field_names}; + return CallFunction("project", args, &opts); + } + + std::vector field_names; +} Project; + +TEST(Project, Scalar) { + std::shared_ptr expected(new StructScalar{{}, struct_({})}); + ASSERT_OK_AND_EQ(Datum(expected), Project({})); + + auto i32 = MakeScalar(1); + auto f64 = MakeScalar(2.5); + auto str = MakeScalar("yo"); + + expected.reset(new StructScalar{ + {i32, f64, str}, + struct_({field("i", i32->type), field("f", f64->type), field("s", str->type)})}); + Project.field_names = {"i", "f", "s"}; + ASSERT_OK_AND_EQ(Datum(expected), Project({i32, f64, str})); + + // Three field names but one input value + ASSERT_RAISES(Invalid, Project({str})); +} + +TEST(Project, Array) { + Project.field_names = {"i", "s"}; + auto i32 = ArrayFromJSON(int32(), "[42, 13, 7]"); + auto str = ArrayFromJSON(utf8(), R"(["aa", "aa", "aa"])"); + ASSERT_OK_AND_ASSIGN(Datum expected, + StructArray::Make({i32, str}, Project.field_names)); + + ASSERT_OK_AND_EQ(expected, Project({i32, str})); + + // Scalars are broadcast to the length of the arrays + ASSERT_OK_AND_EQ(expected, Project({i32, MakeScalar("aa")})); + + // Array length mismatch + ASSERT_RAISES(Invalid, Project({i32->Slice(1), str})); +} + +TEST(Project, ChunkedArray) { + Project.field_names = {"i", "s"}; + + auto i32_0 = ArrayFromJSON(int32(), "[42, 13, 7]"); + auto i32_1 = ArrayFromJSON(int32(), "[]"); + auto i32_2 = ArrayFromJSON(int32(), "[32, 0]"); + + auto str_0 = ArrayFromJSON(utf8(), R"(["aa", "aa", "aa"])"); + auto str_1 = ArrayFromJSON(utf8(), "[]"); + auto str_2 = ArrayFromJSON(utf8(), R"(["aa", "aa"])"); + + ASSERT_OK_AND_ASSIGN(auto i32, ChunkedArray::Make({i32_0, i32_1, i32_2})); + ASSERT_OK_AND_ASSIGN(auto str, ChunkedArray::Make({str_0, str_1, str_2})); + + ASSERT_OK_AND_ASSIGN(auto expected_0, + StructArray::Make({i32_0, str_0}, Project.field_names)); + ASSERT_OK_AND_ASSIGN(auto expected_1, + StructArray::Make({i32_1, str_1}, Project.field_names)); + ASSERT_OK_AND_ASSIGN(auto expected_2, + StructArray::Make({i32_2, str_2}, Project.field_names)); + ASSERT_OK_AND_ASSIGN(Datum expected, + ChunkedArray::Make({expected_0, expected_1, expected_2})); + + ASSERT_OK_AND_EQ(expected, Project({i32, str})); + + // Scalars are broadcast to the length of the arrays + ASSERT_OK_AND_EQ(expected, Project({i32, MakeScalar("aa")})); + + // Array length mismatch + ASSERT_RAISES(Invalid, Project({i32->Slice(1), str})); +} + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/dataset/expression.cc b/cpp/src/arrow/dataset/expression.cc index 7f08788de29..e01ccad21a2 100644 --- a/cpp/src/arrow/dataset/expression.cc +++ b/cpp/src/arrow/dataset/expression.cc @@ -139,7 +139,7 @@ std::string Expression::ToString() const { return binary(std::move(op)); } - if (auto options = GetStructOptions(*call)) { + if (auto options = GetProjectOptions(*call)) { std::string out = "{"; auto argument = call->arguments.begin(); for (const auto& field_name : options->field_names) { @@ -251,8 +251,8 @@ bool Expression::Equals(const Expression& other) const { return options->to_type->Equals(other_options->to_type); } - if (auto options = GetStructOptions(*call)) { - auto other_options = GetStructOptions(*other_call); + if (auto options = GetProjectOptions(*call)) { + auto other_options = GetProjectOptions(*other_call); return options->field_names == other_options->field_names; } @@ -380,7 +380,7 @@ inline bool KernelStateIsImmutable(const std::string& function) { // known functions with non-null but nevertheless immutable KernelState static std::unordered_set names = { - "is_in", "index_in", "cast", "struct", "strptime", + "is_in", "index_in", "cast", "project", "strptime", }; return names.find(function) != names.end(); @@ -1121,7 +1121,7 @@ Result Deserialize(const Buffer& buffer) { } Expression project(std::vector values, std::vector names) { - return call("struct", std::move(values), compute::StructOptions{std::move(names)}); + return call("project", std::move(values), compute::ProjectOptions{std::move(names)}); } Expression equal(Expression lhs, Expression rhs) { diff --git a/cpp/src/arrow/dataset/expression_internal.h b/cpp/src/arrow/dataset/expression_internal.h index c3fd49dc347..90cbbd1c561 100644 --- a/cpp/src/arrow/dataset/expression_internal.h +++ b/cpp/src/arrow/dataset/expression_internal.h @@ -251,9 +251,9 @@ inline const compute::SetLookupOptions* GetSetLookupOptions( return checked_cast(call.options.get()); } -inline const compute::StructOptions* GetStructOptions(const Expression::Call& call) { - if (call.function_name != "struct") return nullptr; - return checked_cast(call.options.get()); +inline const compute::ProjectOptions* GetProjectOptions(const Expression::Call& call) { + if (call.function_name != "project") return nullptr; + return checked_cast(call.options.get()); } inline const compute::StrptimeOptions* GetStrptimeOptions(const Expression::Call& call) { diff --git a/cpp/src/arrow/type.h b/cpp/src/arrow/type.h index f0fa04f40be..3c52c19f191 100644 --- a/cpp/src/arrow/type.h +++ b/cpp/src/arrow/type.h @@ -1386,9 +1386,7 @@ class ARROW_EXPORT DictionaryUnifier { /// FieldPaths provide a number of accessors for drilling down to potentially nested /// children. They are overloaded for convenience to support Schema (returns a field), /// DataType (returns a child field), Field (returns a child field of this field's type) -/// Array (returns a child array), RecordBatch (returns a column), ChunkedArray (returns a -/// ChunkedArray where each chunk is a child array of the corresponding original chunk) -/// and Table (returns a column). +/// Array (returns a child array), RecordBatch (returns a column). class ARROW_EXPORT FieldPath { public: FieldPath() = default; @@ -1425,7 +1423,7 @@ class ARROW_EXPORT FieldPath { /// \brief Retrieve the referenced column from a RecordBatch or Table Result> Get(const RecordBatch& batch) const; - /// \brief Retrieve the referenced child from an Array, ArrayData, or ChunkedArray + /// \brief Retrieve the referenced child from an Array or ArrayData Result> Get(const Array& array) const; Result> Get(const ArrayData& data) const; @@ -1550,9 +1548,7 @@ class ARROW_EXPORT FieldRef { /// \brief Convenience function which applies FindAll to arg's type or schema. std::vector FindAll(const ArrayData& array) const; std::vector FindAll(const Array& array) const; - std::vector FindAll(const ChunkedArray& array) const; std::vector FindAll(const RecordBatch& batch) const; - std::vector FindAll(const Table& table) const; /// \brief Convenience function: raise an error if matches is empty. template From f03953b340a8966d48af6ae3c8c0d76500b518be Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Mon, 4 Jan 2021 17:50:58 -0500 Subject: [PATCH 27/38] add some docstrings --- cpp/src/arrow/dataset/expression.h | 7 +++++ cpp/src/arrow/dataset/expression_internal.h | 30 ++++++++++++--------- cpp/src/arrow/scalar.cc | 14 ++++++++++ cpp/src/arrow/scalar.h | 3 +++ 4 files changed, 41 insertions(+), 13 deletions(-) diff --git a/cpp/src/arrow/dataset/expression.h b/cpp/src/arrow/dataset/expression.h index e597ffb7fcd..31c582cb6fb 100644 --- a/cpp/src/arrow/dataset/expression.h +++ b/cpp/src/arrow/dataset/expression.h @@ -95,10 +95,14 @@ class ARROW_DS_EXPORT Expression { // XXX someday // Result GetPipelines(); + /// Access a Call or return nullptr if this expression is not a call const Call* call() const; + /// Access a Datum or return nullptr if this expression is not a literal const Datum* literal() const; + /// Access a FieldRef or return nullptr if this expression is not a field_ref const FieldRef* field_ref() const; + /// The type and shape to which this expression will evaluate ValueDescr descr() const; // XXX someday // NullGeneralization::type nullable() const; @@ -150,9 +154,11 @@ Expression call(std::string function, std::vector arguments, std::make_shared(std::move(options))); } +/// Assemble a list of all fields referenced by an Expression at any depth. ARROW_DS_EXPORT std::vector FieldsInExpression(const Expression&); +/// Assemble a mapping from field references to known values. ARROW_DS_EXPORT Result> ExtractKnownFieldValues( const Expression& guaranteed_true_predicate); @@ -175,6 +181,7 @@ Result Canonicalize(Expression, compute::ExecContext* = NULLPTR); ARROW_DS_EXPORT Result FoldConstants(Expression); +/// Simplify Expressions by replacing with known values of the fields which it references. ARROW_DS_EXPORT Result ReplaceFieldsWithKnownValues( const std::unordered_map& known_values, Expression); diff --git a/cpp/src/arrow/dataset/expression_internal.h b/cpp/src/arrow/dataset/expression_internal.h index 90cbbd1c561..f2cf2c34709 100644 --- a/cpp/src/arrow/dataset/expression_internal.h +++ b/cpp/src/arrow/dataset/expression_internal.h @@ -299,19 +299,11 @@ inline Result> FunctionOptionsToStructScalar( return nullptr; } - auto Finish = [](ScalarVector values, std::vector names) { - FieldVector fields(names.size()); - for (size_t i = 0; i < fields.size(); ++i) { - fields[i] = field(std::move(names[i]), values[i]->type); - } - return std::make_shared(std::move(values), struct_(std::move(fields))); - }; - if (auto options = GetSetLookupOptions(call)) { if (!options->value_set.is_array()) { return Status::NotImplemented("chunked value_set"); } - return Finish( + return StructScalar::Make( { std::make_shared(options->value_set.make_array()), MakeScalar(options->skip_nulls), @@ -319,9 +311,8 @@ inline Result> FunctionOptionsToStructScalar( {"value_set", "skip_nulls"}); } - if (call.function_name == "cast") { - auto options = checked_cast(call.options.get()); - return Finish( + if (auto options = GetCastOptions(call)) { + return StructScalar::Make( { MakeNullScalar(options->to_type), MakeScalar(options->allow_int_overflow), @@ -385,9 +376,22 @@ inline Status FunctionOptionsFromStructScalar(const StructScalar* repr, return Status::NotImplemented("conversion of options for ", call->function_name); } +/// A helper for unboxing an Expression composed of associative function calls. +/// Such expressions can frequently be rearranged to a semantically equivalent +/// expression for more optimal execution or more straightforward manipulation. +/// For example, (a + ((b + 3) + 4)) is equivalent to (((4 + 3) + a) + b) and the latter +/// can be trivially constant-folded to ((7 + a) + b). struct FlattenedAssociativeChain { + /// True if a chain was already a left fold. bool was_left_folded = true; - std::vector exprs, fringe; + + /// All "branch" expressions in a flattened chain. For example given (a + ((b + 3) + 4)) + /// exprs would be [(a + ((b + 3) + 4)), ((b + 3) + 4), (b + 3)] + std::vector exprs; + + /// All "leaf" expressions in a flattened chain. For example given (a + ((b + 3) + 4)) + /// the fringe would be [a, b, 3, 4] + std::vector fringe; explicit FlattenedAssociativeChain(Expression expr) : exprs{std::move(expr)} { auto call = CallNotNull(exprs.back()); diff --git a/cpp/src/arrow/scalar.cc b/cpp/src/arrow/scalar.cc index 4c5b2e1bcee..eca711d7c4f 100644 --- a/cpp/src/arrow/scalar.cc +++ b/cpp/src/arrow/scalar.cc @@ -184,6 +184,20 @@ FixedSizeListScalar::FixedSizeListScalar(std::shared_ptr value) : BaseListScalar( value, fixed_size_list(value->type(), static_cast(value->length()))) {} +Result> StructScalar::Make( + ScalarVector values, std::vector field_names) { + if (values.size() != field_names.size()) { + return Status::Invalid("Mismatching number of field names and child scalars"); + } + + FieldVector fields(field_names.size()); + for (size_t i = 0; i < fields.size(); ++i) { + fields[i] = arrow::field(std::move(field_names[i]), values[i]->type); + } + + return std::make_shared(std::move(values), struct_(std::move(fields))); +} + Result> StructScalar::field(FieldRef ref) const { ARROW_ASSIGN_OR_RAISE(auto path, ref.FindOne(*type)); if (path.indices().size() != 1) { diff --git a/cpp/src/arrow/scalar.h b/cpp/src/arrow/scalar.h index 1fa866c8623..2888874d292 100644 --- a/cpp/src/arrow/scalar.h +++ b/cpp/src/arrow/scalar.h @@ -411,6 +411,9 @@ struct ARROW_EXPORT StructScalar : public Scalar { StructScalar(ValueType value, std::shared_ptr type) : Scalar(std::move(type), true), value(std::move(value)) {} + static Result> Make(ValueType value, + std::vector field_names); + explicit StructScalar(std::shared_ptr type) : Scalar(std::move(type)) {} }; From 63f1464c1437495e75a0da50ebce7d3506fc1009 Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Tue, 5 Jan 2021 11:59:09 -0500 Subject: [PATCH 28/38] extract BindNonRecursive, filter2->filter, comments --- .../compute/kernels/scalar_cast_string.cc | 3 +- cpp/src/arrow/dataset/expression.cc | 124 ++++++++++++------ cpp/src/arrow/dataset/expression.h | 8 +- cpp/src/arrow/dataset/expression_internal.h | 2 - cpp/src/arrow/dataset/expression_test.cc | 45 +++++++ cpp/src/arrow/dataset/file_csv.cc | 2 +- cpp/src/arrow/dataset/file_ipc_test.cc | 4 +- cpp/src/arrow/dataset/file_parquet.cc | 4 +- cpp/src/arrow/dataset/file_parquet_test.cc | 2 +- cpp/src/arrow/dataset/scanner.cc | 8 +- cpp/src/arrow/dataset/scanner.h | 2 +- cpp/src/arrow/dataset/scanner_internal.h | 2 +- cpp/src/arrow/dataset/scanner_test.cc | 6 +- cpp/src/arrow/dataset/test_util.h | 4 +- 14 files changed, 154 insertions(+), 62 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_string.cc b/cpp/src/arrow/compute/kernels/scalar_cast_string.cc index ca9c33a8f1b..7d502f046fc 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_string.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_string.cc @@ -203,7 +203,8 @@ void AddNumberToStringCasts(CastFunction* func) { for (const std::shared_ptr& in_ty : NumericTypes()) { DCHECK_OK( func->AddKernel(in_ty->id(), {in_ty}, out_ty, - GenerateNumeric(*in_ty), + TrivialScalarUnaryAsArraysExec( + GenerateNumeric(*in_ty)), NullHandling::COMPUTED_NO_PREALLOCATE)); } } diff --git a/cpp/src/arrow/dataset/expression.cc b/cpp/src/arrow/dataset/expression.cc index e01ccad21a2..43711ee09dc 100644 --- a/cpp/src/arrow/dataset/expression.cc +++ b/cpp/src/arrow/dataset/expression.cc @@ -375,7 +375,9 @@ bool Expression::IsSatisfiable() const { return true; } -inline bool KernelStateIsImmutable(const std::string& function) { +namespace { + +bool KernelStateIsImmutable(const std::string& function) { // XXX maybe just add Kernel::state_is_immutable or so? // known functions with non-null but nevertheless immutable KernelState @@ -399,6 +401,40 @@ Result> InitKernelState( return std::move(kernel_state); } +Status InsertImplicitCasts(Expression::Call* call); + +// Produce a bound Expression from unbound Call and bound arguments. +Result BindNonRecursive(const Expression::Call& call, + std::vector arguments, + bool insert_implicit_casts, + compute::ExecContext* exec_context) { + DCHECK(std::all_of(arguments.begin(), arguments.end(), + [](const Expression& argument) { return argument.IsBound(); })); + + auto bound_call = call; + bound_call.arguments = std::move(arguments); + + if (insert_implicit_casts) { + RETURN_NOT_OK(InsertImplicitCasts(&bound_call)); + } + + ARROW_ASSIGN_OR_RAISE(bound_call.function, GetFunction(bound_call, exec_context)); + + auto descrs = GetDescriptors(bound_call.arguments); + ARROW_ASSIGN_OR_RAISE(bound_call.kernel, bound_call.function->DispatchExact(descrs)); + + compute::KernelContext kernel_context(exec_context); + ARROW_ASSIGN_OR_RAISE(bound_call.kernel_state, + InitKernelState(bound_call, exec_context)); + kernel_context.SetState(bound_call.kernel_state.get()); + + ARROW_ASSIGN_OR_RAISE( + bound_call.descr, + bound_call.kernel->signature->out_type().Resolve(&kernel_context, descrs)); + + return Expression(std::move(bound_call)); +} + Status MaybeInsertCast(std::shared_ptr to_type, Expression* expr) { if (expr->descr().type->Equals(to_type)) { return Status::OK(); @@ -410,18 +446,15 @@ Status MaybeInsertCast(std::shared_ptr to_type, Expression* expr) { return Status::OK(); } - // FIXME the resulting cast Call must be bound but this is a hack - auto with_cast = call("cast", {literal(MakeNullScalar(expr->descr().type))}, - compute::CastOptions::Safe(to_type)); + Expression::Call with_cast; + with_cast.function_name = "cast"; + with_cast.options = std::make_shared( + compute::CastOptions::Safe(std::move(to_type))); - static ValueDescr ignored_descr; - ARROW_ASSIGN_OR_RAISE(with_cast, with_cast.Bind(ignored_descr)); - - auto call_with_cast = *CallNotNull(with_cast); - call_with_cast.arguments[0] = std::move(*expr); - call_with_cast.descr = ValueDescr{std::move(to_type), expr->descr().shape}; - - *expr = Expression(std::move(call_with_cast)); + compute::ExecContext exec_context; + ARROW_ASSIGN_OR_RAISE(*expr, + BindNonRecursive(with_cast, {std::move(*expr)}, + /*insert_implicit_casts=*/false, &exec_context)); return Status::OK(); } @@ -475,6 +508,8 @@ Status InsertImplicitCasts(Expression::Call* call) { return Status::OK(); } +} // namespace + Result Expression::Bind(ValueDescr in, compute::ExecContext* exec_context) const { if (exec_context == nullptr) { @@ -490,28 +525,15 @@ Result Expression::Bind(ValueDescr in, return Expression{Parameter{*ref, std::move(descr)}}; } - auto bound_call = *CallNotNull(*this); - - ARROW_ASSIGN_OR_RAISE(bound_call.function, GetFunction(bound_call, exec_context)); + auto call = CallNotNull(*this); - for (auto&& argument : bound_call.arguments) { - ARROW_ASSIGN_OR_RAISE(argument, argument.Bind(in, exec_context)); + std::vector bound_arguments(call->arguments.size()); + for (size_t i = 0; i < bound_arguments.size(); ++i) { + ARROW_ASSIGN_OR_RAISE(bound_arguments[i], call->arguments[i].Bind(in, exec_context)); } - RETURN_NOT_OK(InsertImplicitCasts(&bound_call)); - - auto descrs = GetDescriptors(bound_call.arguments); - ARROW_ASSIGN_OR_RAISE(bound_call.kernel, bound_call.function->DispatchExact(descrs)); - - compute::KernelContext kernel_context(exec_context); - ARROW_ASSIGN_OR_RAISE(bound_call.kernel_state, - InitKernelState(bound_call, exec_context)); - kernel_context.SetState(bound_call.kernel_state.get()); - ARROW_ASSIGN_OR_RAISE( - bound_call.descr, - bound_call.kernel->signature->out_type().Resolve(&kernel_context, descrs)); - - return Expression(std::move(bound_call)); + return BindNonRecursive(*call, std::move(bound_arguments), + /*insert_implicit_casts=*/true, exec_context); } Result Expression::Bind(const Schema& in_schema, @@ -574,6 +596,8 @@ Result ExecuteScalarExpression(const Expression& expr, const Datum& input return executor->WrapResults(arguments, listener->values()); } +namespace { + std::array, 2> ArgumentsAndFlippedArguments(const Expression::Call& call) { DCHECK_EQ(call.arguments.size(), 2); @@ -626,6 +650,8 @@ bool DefinitelyNotNull(const Expression& expr) { return false; } +} // namespace + std::vector FieldsInExpression(const Expression& expr) { if (auto lit = expr.literal()) return {}; @@ -701,7 +727,9 @@ Result FoldConstants(Expression expr) { }); } -inline std::vector GuaranteeConjunctionMembers( +namespace { + +std::vector GuaranteeConjunctionMembers( const Expression& guaranteed_true_predicate) { auto guarantee = guaranteed_true_predicate.call(); if (!guarantee || guarantee->function_name != "and_kleene") { @@ -754,6 +782,8 @@ Status ExtractKnownFieldValuesImpl( return Status::OK(); } +} // namespace + Result> ExtractKnownFieldValues( const Expression& guaranteed_true_predicate) { auto conjunction_members = GuaranteeConjunctionMembers(guaranteed_true_predicate); @@ -786,7 +816,9 @@ Result ReplaceFieldsWithKnownValues( [](Expression expr, ...) { return expr; }); } -inline bool IsBinaryAssociativeCommutative(const Expression::Call& call) { +namespace { + +bool IsBinaryAssociativeCommutative(const Expression::Call& call) { static std::unordered_set binary_associative_commutative{ "and", "or", "and_kleene", "or_kleene", "xor", "multiply", "add", "multiply_checked", "add_checked"}; @@ -795,6 +827,8 @@ inline bool IsBinaryAssociativeCommutative(const Expression::Call& call) { return it != binary_associative_commutative.end(); } +} // namespace + Result Canonicalize(Expression expr, compute::ExecContext* exec_context) { if (exec_context == nullptr) { compute::ExecContext exec_context; @@ -885,6 +919,8 @@ Result Canonicalize(Expression expr, compute::ExecContext* exec_cont [](Expression expr, ...) { return expr; }); } +namespace { + Result DirectComparisonSimplification(Expression expr, const Expression::Call& guarantee) { return Modify( @@ -896,17 +932,19 @@ Result DirectComparisonSimplification(Expression expr, // Ensure both calls are comparisons with equal LHS and scalar RHS auto cmp = Comparison::Get(expr); auto cmp_guarantee = Comparison::Get(guarantee.function_name); - if (!cmp || !cmp_guarantee) return expr; + if (!cmp) return expr; + if (!cmp_guarantee) return expr; if (call->arguments[0] != guarantee.arguments[0]) return expr; auto rhs = call->arguments[1].literal(); auto guarantee_rhs = guarantee.arguments[1].literal(); - if (!rhs || !guarantee_rhs) return expr; - if (!rhs->is_scalar() || !guarantee_rhs->is_scalar()) { - return expr; - } + if (!rhs) return expr; + if (!rhs->is_scalar()) return expr; + + if (!guarantee_rhs) return expr; + if (!guarantee_rhs->is_scalar()) return expr; ARROW_ASSIGN_OR_RAISE(auto cmp_rhs_guarantee_rhs, Comparison::Execute(*rhs, *guarantee_rhs)); @@ -915,13 +953,15 @@ Result DirectComparisonSimplification(Expression expr, if (cmp_rhs_guarantee_rhs == Comparison::EQUAL) { // RHS of filter is equal to RHS of guarantee - if ((*cmp_guarantee & *cmp) == *cmp_guarantee) { + if ((*cmp & *cmp_guarantee) == *cmp_guarantee) { // guarantee is a subset of filter, so all data will be included + // x > 1, x >= 1, x != 1 guaranteed by x > 1 return literal(true); } - if ((*cmp_guarantee & *cmp) == 0) { + if ((*cmp & *cmp_guarantee) == 0) { // guarantee disjoint with filter, so all data will be excluded + // x > 1, x >= 1, x != 1 unsatisfiable if x == 1 return literal(false); } @@ -929,7 +969,7 @@ Result DirectComparisonSimplification(Expression expr, } if (*cmp_guarantee & cmp_rhs_guarantee_rhs) { - // unusable guarantee + // x > 1, x >= 1, x != 1 cannot use guarantee x >= 3 return expr; } @@ -943,6 +983,8 @@ Result DirectComparisonSimplification(Expression expr, }); } +} // namespace + Result SimplifyWithGuarantee(Expression expr, const Expression& guaranteed_true_predicate) { auto conjunction_members = GuaranteeConjunctionMembers(guaranteed_true_predicate); diff --git a/cpp/src/arrow/dataset/expression.h b/cpp/src/arrow/dataset/expression.h index 31c582cb6fb..dff4c571155 100644 --- a/cpp/src/arrow/dataset/expression.h +++ b/cpp/src/arrow/dataset/expression.h @@ -167,7 +167,13 @@ Result> ExtractKnownFieldVal /// /// @{ /// -/// These operate on bound expressions. +/// These transform bound expressions. Some transforms utilize a guarantee, which is +/// provided as an Expression which is guaranteed to evaluate to true. The +/// guaranteed_true_predicate need not be bound, but canonicalization is currently +/// deferred to producers of guarantees. For example in order to be recognized as a +/// guarantee on a field value, an Expression must be a call to "equal" with field_ref LHS +/// and literal RHS. Flipping the arguments, "is_in" with a one-long value_set, ... or +/// other semantically identical Expressions will not be recognized. /// Weak canonicalization which establishes guarantees for subsequent passes. Even /// equivalent Expressions may result in different canonicalized expressions. diff --git a/cpp/src/arrow/dataset/expression_internal.h b/cpp/src/arrow/dataset/expression_internal.h index f2cf2c34709..b217c8fd379 100644 --- a/cpp/src/arrow/dataset/expression_internal.h +++ b/cpp/src/arrow/dataset/expression_internal.h @@ -186,7 +186,6 @@ struct Comparison { static std::string GetName(type op) { switch (op) { case NA: - DCHECK(false) << "unreachable"; break; case EQUAL: return "equal"; @@ -201,7 +200,6 @@ struct Comparison { case GREATER_EQUAL: return "greater_equal"; } - DCHECK(false); return "na"; } diff --git a/cpp/src/arrow/dataset/expression_test.cc b/cpp/src/arrow/dataset/expression_test.cc index e245b0d7093..a8536279ac9 100644 --- a/cpp/src/arrow/dataset/expression_test.cc +++ b/cpp/src/arrow/dataset/expression_test.cc @@ -47,6 +47,51 @@ Expression cast(Expression argument, std::shared_ptr to_type) { compute::CastOptions::Safe(std::move(to_type))); } +template +void ExpectResultsEqual(Actual&& actual, Expected&& expected) { + using MaybeActual = typename EnsureResult::type>::type; + using MaybeExpected = typename EnsureResult::type>::type; + + MaybeActual maybe_actual(std::forward(actual)); + MaybeExpected maybe_expected(std::forward(expected)); + + if (maybe_expected.ok()) { + ASSERT_OK_AND_ASSIGN(auto actual, maybe_actual); + EXPECT_EQ(actual, *maybe_expected); + } else { + EXPECT_EQ(maybe_actual.status().code(), expected.status().code()); + EXPECT_NE(maybe_actual.status().message().find(expected.status().message()), + std::string::npos) + << " actual: " << maybe_actual.status() << "\n" + << " expected: " << maybe_expected.status(); + } +} + +TEST(ExpressionUtils, Comparison) { + auto Expect = [](Result expected, Datum l, Datum r) { + ExpectResultsEqual(Comparison::Execute(l, r).Map(Comparison::GetName), expected); + }; + + Datum zero(0), one(1), two(2), null(std::make_shared()), str("hello"); + + Status parse_failure = Status::Invalid("Failed to parse"); + + Expect("equal", one, one); + Expect("less", one, two); + Expect("greater", one, zero); + + // cast RHS to LHS type; "hello" > "1" + Expect("greater", str, one); + // cast RHS to LHS type; "hello" is not convertible to int + Expect(parse_failure, one, str); + + Expect("na", one, null); + Expect("na", str, null); + Expect("na", null, one); + // cast RHS to LHS type; "hello" is not convertible to int + Expect(parse_failure, null, str); +} + TEST(Expression, ToString) { EXPECT_EQ(field_ref("alpha").ToString(), "alpha"); diff --git a/cpp/src/arrow/dataset/file_csv.cc b/cpp/src/arrow/dataset/file_csv.cc index bc1a69066f7..87c26f1cc79 100644 --- a/cpp/src/arrow/dataset/file_csv.cc +++ b/cpp/src/arrow/dataset/file_csv.cc @@ -92,7 +92,7 @@ static inline Result GetConvertOptions( // FIXME(bkietz) also acquire types of fields materialized but not projected. // This requires that scan_options include the full dataset schema (not just // the projected schema). - for (const FieldRef& ref : FieldsInExpression(scan_options->filter2)) { + for (const FieldRef& ref : FieldsInExpression(scan_options->filter)) { DCHECK(ref.name()); ARROW_ASSIGN_OR_RAISE(auto match, ref.FindOneOrNone(*scan_options->schema())); if (!match) { diff --git a/cpp/src/arrow/dataset/file_ipc_test.cc b/cpp/src/arrow/dataset/file_ipc_test.cc index f49337b362e..42421134790 100644 --- a/cpp/src/arrow/dataset/file_ipc_test.cc +++ b/cpp/src/arrow/dataset/file_ipc_test.cc @@ -243,7 +243,7 @@ TEST_F(TestIpcFileFormat, ScanRecordBatchReaderProjected) { opts_ = ScanOptions::Make(schema_); opts_->projector = RecordBatchProjector(SchemaFromColumnNames(schema_, {"f64"})); - opts_->filter2 = equal(field_ref("i32"), literal(0)); + opts_->filter = equal(field_ref("i32"), literal(0)); // NB: projector is applied by the scanner; FileFragment does not evaluate it so // we will not drop "i32" even though it is not in the projector's schema @@ -279,7 +279,7 @@ TEST_F(TestIpcFileFormat, ScanRecordBatchReaderProjectedMissingCols) { schema_ = reader->schema(); opts_ = ScanOptions::Make(schema_); opts_->projector = RecordBatchProjector(SchemaFromColumnNames(schema_, {"f64"})); - opts_->filter2 = equal(field_ref("i32"), literal(0)); + opts_->filter = equal(field_ref("i32"), literal(0)); auto readers = {reader.get(), reader_without_i32.get(), reader_without_f64.get()}; for (auto reader : readers) { diff --git a/cpp/src/arrow/dataset/file_parquet.cc b/cpp/src/arrow/dataset/file_parquet.cc index 94d2115358f..eee175228f2 100644 --- a/cpp/src/arrow/dataset/file_parquet.cc +++ b/cpp/src/arrow/dataset/file_parquet.cc @@ -298,7 +298,7 @@ Result ParquetFileFormat::ScanFile(std::shared_ptrmetadata() != nullptr) { ARROW_ASSIGN_OR_RAISE(row_groups, - parquet_fragment->FilterRowGroups(options->filter2)); + parquet_fragment->FilterRowGroups(options->filter)); pre_filtered = true; if (row_groups.empty()) MakeEmpty(); @@ -314,7 +314,7 @@ Result ParquetFileFormat::ScanFile(std::shared_ptrFilterRowGroups(options->filter2)); + parquet_fragment->FilterRowGroups(options->filter)); if (row_groups.empty()) MakeEmpty(); } diff --git a/cpp/src/arrow/dataset/file_parquet_test.cc b/cpp/src/arrow/dataset/file_parquet_test.cc index 67d1fb17120..fa8e5c6441a 100644 --- a/cpp/src/arrow/dataset/file_parquet_test.cc +++ b/cpp/src/arrow/dataset/file_parquet_test.cc @@ -160,7 +160,7 @@ class TestParquetFileFormat : public ArrowParquetWriterMixin { } void SetFilter(Expression filter) { - ASSERT_OK_AND_ASSIGN(opts_->filter2, filter.Bind(*schema_)); + ASSERT_OK_AND_ASSIGN(opts_->filter, filter.Bind(*schema_)); } std::shared_ptr SingleBatch(Fragment* fragment) { diff --git a/cpp/src/arrow/dataset/scanner.cc b/cpp/src/arrow/dataset/scanner.cc index 0c501c9f5b3..0537d370125 100644 --- a/cpp/src/arrow/dataset/scanner.cc +++ b/cpp/src/arrow/dataset/scanner.cc @@ -39,7 +39,7 @@ ScanOptions::ScanOptions(std::shared_ptr schema) std::shared_ptr ScanOptions::ReplaceSchema( std::shared_ptr schema) const { auto copy = ScanOptions::Make(std::move(schema)); - copy->filter2 = filter2; + copy->filter = filter; copy->batch_size = batch_size; return copy; } @@ -51,7 +51,7 @@ std::vector ScanOptions::MaterializedFields() const { fields.push_back(f->name()); } - for (const FieldRef& ref : FieldsInExpression(filter2)) { + for (const FieldRef& ref : FieldsInExpression(filter)) { DCHECK(ref.name()); fields.push_back(*ref.name()); } @@ -71,7 +71,7 @@ Result Scanner::GetFragments() { // Transform Datasets in a flat Iterator. This // iterator is lazily constructed, i.e. Dataset::GetFragments is // not invoked until a Fragment is requested. - return GetFragmentsFromDatasets({dataset_}, scan_options_->filter2); + return GetFragmentsFromDatasets({dataset_}, scan_options_->filter); } Result Scanner::Scan() { @@ -125,7 +125,7 @@ Status ScannerBuilder::Filter(const Expression& filter) { for (const auto& ref : FieldsInExpression(filter)) { RETURN_NOT_OK(ref.FindOne(*schema())); } - ARROW_ASSIGN_OR_RAISE(scan_options_->filter2, filter.Bind(*schema())); + ARROW_ASSIGN_OR_RAISE(scan_options_->filter, filter.Bind(*schema())); return Status::OK(); } diff --git a/cpp/src/arrow/dataset/scanner.h b/cpp/src/arrow/dataset/scanner.h index 5902f759ec3..e65ac7fa524 100644 --- a/cpp/src/arrow/dataset/scanner.h +++ b/cpp/src/arrow/dataset/scanner.h @@ -63,7 +63,7 @@ class ARROW_DS_EXPORT ScanOptions { std::shared_ptr ReplaceSchema(std::shared_ptr schema) const; // Filter - Expression filter2 = literal(true); + Expression filter = literal(true); // Schema to which record batches will be reconciled const std::shared_ptr& schema() const { return projector.schema(); } diff --git a/cpp/src/arrow/dataset/scanner_internal.h b/cpp/src/arrow/dataset/scanner_internal.h index cd8fffd2a71..729bd1f3a48 100644 --- a/cpp/src/arrow/dataset/scanner_internal.h +++ b/cpp/src/arrow/dataset/scanner_internal.h @@ -74,7 +74,7 @@ class FilterAndProjectScanTask : public ScanTask { : ScanTask(task->options(), task->context()), task_(std::move(task)), partition_(std::move(partition)), - filter_(options()->filter2), + filter_(options()->filter), projector_(options()->projector) {} Result Execute() override { diff --git a/cpp/src/arrow/dataset/scanner_test.cc b/cpp/src/arrow/dataset/scanner_test.cc index f8f959e3a28..c0e37bef6cc 100644 --- a/cpp/src/arrow/dataset/scanner_test.cc +++ b/cpp/src/arrow/dataset/scanner_test.cc @@ -207,7 +207,7 @@ TEST(ScanOptions, TestMaterializedFields) { auto opts = ScanOptions::Make(schema({})); EXPECT_THAT(opts->MaterializedFields(), IsEmpty()); - opts->filter2 = equal(field_ref("i32"), literal(10)); + opts->filter = equal(field_ref("i32"), literal(10)); EXPECT_THAT(opts->MaterializedFields(), ElementsAre("i32")); opts = ScanOptions::Make(schema({i32, i64})); @@ -216,10 +216,10 @@ TEST(ScanOptions, TestMaterializedFields) { opts = opts->ReplaceSchema(schema({i32})); EXPECT_THAT(opts->MaterializedFields(), ElementsAre("i32")); - opts->filter2 = equal(field_ref("i32"), literal(10)); + opts->filter = equal(field_ref("i32"), literal(10)); EXPECT_THAT(opts->MaterializedFields(), ElementsAre("i32", "i32")); - opts->filter2 = equal(field_ref("i64"), literal(10)); + opts->filter = equal(field_ref("i64"), literal(10)); EXPECT_THAT(opts->MaterializedFields(), ElementsAre("i32", "i64")); } diff --git a/cpp/src/arrow/dataset/test_util.h b/cpp/src/arrow/dataset/test_util.h index 1c7c471d3ca..f0d44cfe3d6 100644 --- a/cpp/src/arrow/dataset/test_util.h +++ b/cpp/src/arrow/dataset/test_util.h @@ -152,7 +152,7 @@ class DatasetFixtureMixin : public ::testing::Test { /// record batches yielded by the data fragments of a dataset. void AssertDatasetFragmentsEqual(RecordBatchReader* expected, Dataset* dataset, bool ensure_drained = true) { - ASSERT_OK_AND_ASSIGN(auto predicate, options_->filter2.Bind(*dataset->schema())); + ASSERT_OK_AND_ASSIGN(auto predicate, options_->filter.Bind(*dataset->schema())); ASSERT_OK_AND_ASSIGN(auto it, dataset->GetFragments(predicate)); ARROW_EXPECT_OK(it.Visit([&](std::shared_ptr fragment) -> Status { @@ -202,7 +202,7 @@ class DatasetFixtureMixin : public ::testing::Test { } void SetFilter(Expression filter) { - ASSERT_OK_AND_ASSIGN(options_->filter2, filter.Bind(*schema_)); + ASSERT_OK_AND_ASSIGN(options_->filter, filter.Bind(*schema_)); } std::shared_ptr schema_; From 607ed963cd51f8a47c5ea024ceeaaec95baa7362 Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Tue, 5 Jan 2021 12:24:35 -0500 Subject: [PATCH 29/38] move more things to namespace{}, docstring for Modify --- cpp/src/arrow/dataset/expression.cc | 140 +++++++++++++++- cpp/src/arrow/dataset/expression_internal.h | 170 +++----------------- 2 files changed, 157 insertions(+), 153 deletions(-) diff --git a/cpp/src/arrow/dataset/expression.cc b/cpp/src/arrow/dataset/expression.cc index 43711ee09dc..d7a54b26490 100644 --- a/cpp/src/arrow/dataset/expression.cc +++ b/cpp/src/arrow/dataset/expression.cc @@ -300,17 +300,15 @@ size_t Expression::hash() const { bool Expression::IsBound() const { if (descr().type == nullptr) return false; - if (auto lit = literal()) return true; + if (auto call = this->call()) { + if (call->kernel == nullptr) return false; - if (auto ref = field_ref()) return true; - - auto call = CallNotNull(*this); - - for (const Expression& arg : call->arguments) { - if (!arg.IsBound()) return false; + for (const Expression& arg : call->arguments) { + if (!arg.IsBound()) return false; + } } - return call->kernel != nullptr; + return true; } bool Expression::IsScalarExpression() const { @@ -508,6 +506,45 @@ Status InsertImplicitCasts(Expression::Call* call) { return Status::OK(); } +struct FieldPathGetDatumImpl { + template ()))> + Result operator()(const std::shared_ptr& ptr) { + return path_.Get(*ptr).template As(); + } + + template + Result operator()(const T&) { + return Status::NotImplemented("FieldPath::Get() into Datum ", datum_.ToString()); + } + + const Datum& datum_; + const FieldPath& path_; +}; + +inline Result GetDatumField(const FieldRef& ref, const Datum& input) { + Datum field; + + FieldPath path; + if (auto type = input.type()) { + ARROW_ASSIGN_OR_RAISE(path, ref.FindOneOrNone(*type)); + } else if (auto schema = input.schema()) { + ARROW_ASSIGN_OR_RAISE(path, ref.FindOneOrNone(*schema)); + } else { + return Status::NotImplemented("retrieving fields from datum ", input.ToString()); + } + + if (path) { + ARROW_ASSIGN_OR_RAISE(field, + util::visit(FieldPathGetDatumImpl{input, path}, input.value)); + } + + if (field == Datum{}) { + field = Datum(std::make_shared()); + } + + return field; +} + } // namespace Result Expression::Bind(ValueDescr in, @@ -1017,6 +1054,93 @@ Result SimplifyWithGuarantee(Expression expr, return expr; } +namespace { + +inline Result> FunctionOptionsToStructScalar( + const Expression::Call& call) { + if (call.options == nullptr) { + return nullptr; + } + + if (auto options = GetSetLookupOptions(call)) { + if (!options->value_set.is_array()) { + return Status::NotImplemented("chunked value_set"); + } + return StructScalar::Make( + { + std::make_shared(options->value_set.make_array()), + MakeScalar(options->skip_nulls), + }, + {"value_set", "skip_nulls"}); + } + + if (auto options = GetCastOptions(call)) { + return StructScalar::Make( + { + MakeNullScalar(options->to_type), + MakeScalar(options->allow_int_overflow), + MakeScalar(options->allow_time_truncate), + MakeScalar(options->allow_time_overflow), + MakeScalar(options->allow_decimal_truncate), + MakeScalar(options->allow_float_truncate), + MakeScalar(options->allow_invalid_utf8), + }, + { + "to_type_holder", + "allow_int_overflow", + "allow_time_truncate", + "allow_time_overflow", + "allow_decimal_truncate", + "allow_float_truncate", + "allow_invalid_utf8", + }); + } + + return Status::NotImplemented("conversion of options for ", call.function_name); +} + +inline Status FunctionOptionsFromStructScalar(const StructScalar* repr, + Expression::Call* call) { + if (repr == nullptr) { + call->options = nullptr; + return Status::OK(); + } + + if (IsSetLookup(call->function_name)) { + ARROW_ASSIGN_OR_RAISE(auto value_set, repr->field("value_set")); + ARROW_ASSIGN_OR_RAISE(auto skip_nulls, repr->field("skip_nulls")); + call->options = std::make_shared( + checked_cast(*value_set).value, + checked_cast(*skip_nulls).value); + return Status::OK(); + } + + if (call->function_name == "cast") { + auto options = std::make_shared(); + ARROW_ASSIGN_OR_RAISE(auto to_type_holder, repr->field("to_type_holder")); + options->to_type = to_type_holder->type; + + int i = 1; + for (bool* opt : { + &options->allow_int_overflow, + &options->allow_time_truncate, + &options->allow_time_overflow, + &options->allow_decimal_truncate, + &options->allow_float_truncate, + &options->allow_invalid_utf8, + }) { + *opt = checked_cast(*repr->value[i++]).value; + } + + call->options = std::move(options); + return Status::OK(); + } + + return Status::NotImplemented("conversion of options for ", call->function_name); +} + +} // namespace + // Serialization is accomplished by converting expressions to KeyValueMetadata and storing // this in the schema of a RecordBatch. Embedded arrays and scalars are stored in its // columns. Finally, the RecordBatch is written to an IPC file. diff --git a/cpp/src/arrow/dataset/expression_internal.h b/cpp/src/arrow/dataset/expression_internal.h index b217c8fd379..0aa57f70fc1 100644 --- a/cpp/src/arrow/dataset/expression_internal.h +++ b/cpp/src/arrow/dataset/expression_internal.h @@ -40,20 +40,6 @@ const Expression::Call* CallNotNull(const Expression& expr) { return call; } -inline void GetAllFieldRefs(const Expression& expr, - std::unordered_set* refs) { - if (auto lit = expr.literal()) return; - - if (auto ref = expr.field_ref()) { - refs->emplace(*ref); - return; - } - - for (const Expression& arg : CallNotNull(expr)->arguments) { - GetAllFieldRefs(arg, refs); - } -} - inline std::vector GetDescriptors(const std::vector& exprs) { std::vector descrs(exprs.size()); for (size_t i = 0; i < exprs.size(); ++i) { @@ -71,45 +57,6 @@ inline std::vector GetDescriptors(const std::vector& values) return descrs; } -struct FieldPathGetDatumImpl { - template ()))> - Result operator()(const std::shared_ptr& ptr) { - return path_.Get(*ptr).template As(); - } - - template - Result operator()(const T&) { - return Status::NotImplemented("FieldPath::Get() into Datum ", datum_.ToString()); - } - - const Datum& datum_; - const FieldPath& path_; -}; - -inline Result GetDatumField(const FieldRef& ref, const Datum& input) { - Datum field; - - FieldPath path; - if (auto type = input.type()) { - ARROW_ASSIGN_OR_RAISE(path, ref.FindOneOrNone(*type)); - } else if (auto schema = input.schema()) { - ARROW_ASSIGN_OR_RAISE(path, ref.FindOneOrNone(*schema)); - } else { - return Status::NotImplemented("retrieving fields from datum ", input.ToString()); - } - - if (path) { - ARROW_ASSIGN_OR_RAISE(field, - util::visit(FieldPathGetDatumImpl{input, path}, input.value)); - } - - if (field == Datum{}) { - field = Datum(std::make_shared()); - } - - return field; -} - struct Comparison { enum type { NA = 0, @@ -291,89 +238,6 @@ inline Status EnsureNotDictionary(Expression::Call* call) { return Status::OK(); } -inline Result> FunctionOptionsToStructScalar( - const Expression::Call& call) { - if (call.options == nullptr) { - return nullptr; - } - - if (auto options = GetSetLookupOptions(call)) { - if (!options->value_set.is_array()) { - return Status::NotImplemented("chunked value_set"); - } - return StructScalar::Make( - { - std::make_shared(options->value_set.make_array()), - MakeScalar(options->skip_nulls), - }, - {"value_set", "skip_nulls"}); - } - - if (auto options = GetCastOptions(call)) { - return StructScalar::Make( - { - MakeNullScalar(options->to_type), - MakeScalar(options->allow_int_overflow), - MakeScalar(options->allow_time_truncate), - MakeScalar(options->allow_time_overflow), - MakeScalar(options->allow_decimal_truncate), - MakeScalar(options->allow_float_truncate), - MakeScalar(options->allow_invalid_utf8), - }, - { - "to_type_holder", - "allow_int_overflow", - "allow_time_truncate", - "allow_time_overflow", - "allow_decimal_truncate", - "allow_float_truncate", - "allow_invalid_utf8", - }); - } - - return Status::NotImplemented("conversion of options for ", call.function_name); -} - -inline Status FunctionOptionsFromStructScalar(const StructScalar* repr, - Expression::Call* call) { - if (repr == nullptr) { - call->options = nullptr; - return Status::OK(); - } - - if (IsSetLookup(call->function_name)) { - ARROW_ASSIGN_OR_RAISE(auto value_set, repr->field("value_set")); - ARROW_ASSIGN_OR_RAISE(auto skip_nulls, repr->field("skip_nulls")); - call->options = std::make_shared( - checked_cast(*value_set).value, - checked_cast(*skip_nulls).value); - return Status::OK(); - } - - if (call->function_name == "cast") { - auto options = std::make_shared(); - ARROW_ASSIGN_OR_RAISE(auto to_type_holder, repr->field("to_type_holder")); - options->to_type = to_type_holder->type; - - int i = 1; - for (bool* opt : { - &options->allow_int_overflow, - &options->allow_time_truncate, - &options->allow_time_overflow, - &options->allow_decimal_truncate, - &options->allow_float_truncate, - &options->allow_invalid_utf8, - }) { - *opt = checked_cast(*repr->value[i++]).value; - } - - call->options = std::move(options); - return Status::OK(); - } - - return Status::NotImplemented("conversion of options for ", call->function_name); -} - /// A helper for unboxing an Expression composed of associative function calls. /// Such expressions can frequently be rearranged to a semantically equivalent /// expression for more optimal execution or more straightforward manipulation. @@ -433,6 +297,16 @@ inline Result> GetFunction( return compute::GetCastFunction(to_type); } +/// Modify an Expression with pre-order and post-order visitation. +/// `pre` will be invoked on each Expression. `pre` will visit Calls before their +/// arguments, `post_call` will visit Calls (and no other Expressions) after their +/// arguments. Visitors should return the Identical expression to indicate no change; this +/// will prevent unnecessary construction in the common case where a modification is not +/// possible/necessary/... +/// +/// If an argument was modified, `post_call` visits a reconstructed Call with the modified +/// arguments but also receives a pointer to the unmodified Expression as a second +/// argument. If no arguments were modified the unmodified Expression* will be nullptr. template Result Modify(Expression expr, const PreVisit& pre, const PostVisitCall& post_call) { @@ -442,23 +316,29 @@ Result Modify(Expression expr, const PreVisit& pre, if (!call) return expr; bool at_least_one_modified = false; - auto modified_call = *call; - auto modified_argument = modified_call.arguments.begin(); + std::vector modified_arguments; + + for (size_t i = 0; i < call->arguments.size(); ++i) { + ARROW_ASSIGN_OR_RAISE(auto modified_argument, + Modify(call->arguments[i], pre, post_call)); - for (const auto& argument : call->arguments) { - ARROW_ASSIGN_OR_RAISE(*modified_argument, Modify(argument, pre, post_call)); + if (Identical(modified_argument, call->arguments[i])) { + continue; + } - if (!Identical(*modified_argument, argument)) { + if (!at_least_one_modified) { + modified_arguments = call->arguments; at_least_one_modified = true; } - ++modified_argument; + + modified_arguments[i] = std::move(modified_argument); } if (at_least_one_modified) { // reconstruct the call expression with the modified arguments - auto modified_expr = Expression(std::move(modified_call)); - - return post_call(std::move(modified_expr), &expr); + auto modified_call = *call; + modified_call.arguments = std::move(modified_arguments); + return post_call(Expression(std::move(modified_call)), &expr); } return post_call(std::move(expr), nullptr); From f9524004c37c57e96fbdc731671638106fd325f5 Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Tue, 5 Jan 2021 16:48:04 -0500 Subject: [PATCH 30/38] ensure field_ref into list safely errors --- cpp/src/arrow/dataset/expression_test.cc | 9 +++++++++ cpp/src/arrow/type.cc | 14 +++++++++++++- 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/cpp/src/arrow/dataset/expression_test.cc b/cpp/src/arrow/dataset/expression_test.cc index a8536279ac9..209b3561a03 100644 --- a/cpp/src/arrow/dataset/expression_test.cc +++ b/cpp/src/arrow/dataset/expression_test.cc @@ -31,6 +31,7 @@ #include "arrow/dataset/test_util.h" #include "arrow/testing/gtest_util.h" +using testing::HasSubstr; using testing::UnorderedElementsAreArray; namespace arrow { @@ -426,6 +427,14 @@ TEST(Expression, ExecuteFieldRef) { {"a": -1} ])"), MakeNullScalar(null())); + + // XXX this *should* fail in Bind but for now it will just error in + // ExecuteScalarExpression + ASSERT_OK_AND_ASSIGN(auto list_item, field_ref("item").Bind(list(int32()))); + EXPECT_RAISES_WITH_MESSAGE_THAT( + NotImplemented, HasSubstr("non-struct array"), + ExecuteScalarExpression(list_item, + ArrayFromJSON(list(int32()), "[[1,2], [], null, [5]]"))); } Result NaiveExecuteScalarExpression(const Expression& expr, const Datum& input) { diff --git a/cpp/src/arrow/type.cc b/cpp/src/arrow/type.cc index ad889b3eb24..12d3951865f 100644 --- a/cpp/src/arrow/type.cc +++ b/cpp/src/arrow/type.cc @@ -953,6 +953,10 @@ struct FieldPathGetImpl { int depth = 0; const T* out; for (int index : path->indices()) { + if (children == nullptr) { + return Status::NotImplemented("Get child data of non-struct array"); + } + if (index < 0 || static_cast(index) >= children->size()) { *out_of_range_depth = depth; return nullptr; @@ -990,7 +994,12 @@ struct FieldPathGetImpl { const ArrayDataVector& child_data) { return FieldPathGetImpl::Get( path, &child_data, - [](const std::shared_ptr& data) { return &data->child_data; }); + [](const std::shared_ptr& data) -> const ArrayDataVector* { + if (data->type->id() != Type::STRUCT) { + return nullptr; + } + return &data->child_data; + }); } }; @@ -1021,6 +1030,9 @@ Result> FieldPath::Get(const Array& array) const { } Result> FieldPath::Get(const ArrayData& data) const { + if (data.type->id() != Type::STRUCT) { + return Status::NotImplemented("Get child data of non-struct array"); + } return FieldPathGetImpl::Get(this, data.child_data); } From 5f29f52ade043d19cfced6bcb5a3e228aa7d314b Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Tue, 5 Jan 2021 17:16:01 -0500 Subject: [PATCH 31/38] debug prints --- cpp/src/arrow/dataset/expression.cc | 16 +++++++++++++++- cpp/src/arrow/dataset/expression_test.cc | 3 +++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/cpp/src/arrow/dataset/expression.cc b/cpp/src/arrow/dataset/expression.cc index d7a54b26490..961489ef878 100644 --- a/cpp/src/arrow/dataset/expression.cc +++ b/cpp/src/arrow/dataset/expression.cc @@ -17,6 +17,7 @@ #include "arrow/dataset/expression.h" +#include #include #include @@ -1203,6 +1204,12 @@ Result> Serialize(const Expression& expr) { } ToRecordBatch; ARROW_ASSIGN_OR_RAISE(auto batch, ToRecordBatch(expr)); + std::cout << "Serialization of:\nmetadata:\n" + << batch->schema()->metadata()->ToString() + << "\nschema:" << batch->schema()->ToString() + << "\nstorage:" << batch->ToString() << "\n" + << expr.ToString(); + ARROW_ASSIGN_OR_RAISE(auto stream, io::BufferOutputStream::Create()); ARROW_ASSIGN_OR_RAISE(auto writer, ipc::MakeFileWriter(stream, batch->schema())); RETURN_NOT_OK(writer->WriteRecordBatch(*batch)); @@ -1223,6 +1230,11 @@ Result Deserialize(const Buffer& buffer) { batch->num_rows()); } + std::cout << "Deserialization of:\nmetadata:\n" + << batch->schema()->metadata()->ToString() + << "\nschema:" << batch->schema()->ToString() + << "\nstorage:" << batch->ToString(); + struct FromRecordBatch { const RecordBatch& batch_; int index_; @@ -1283,7 +1295,9 @@ Result Deserialize(const Buffer& buffer) { } }; - return FromRecordBatch{*batch, 0}.GetOne(); + ARROW_ASSIGN_OR_RAISE(auto expr, (FromRecordBatch{*batch, 0}.GetOne())); + std::cout << expr.ToString() << std::endl; + return expr; } Expression project(std::vector values, std::vector names) { diff --git a/cpp/src/arrow/dataset/expression_test.cc b/cpp/src/arrow/dataset/expression_test.cc index 209b3561a03..a304ce2834f 100644 --- a/cpp/src/arrow/dataset/expression_test.cc +++ b/cpp/src/arrow/dataset/expression_test.cc @@ -1063,6 +1063,9 @@ TEST(Expression, SerializationRoundTrips) { ExpectRoundTrips(call("is_in", {literal(1)}, compute::SetLookupOptions{ArrayFromJSON(int32(), "[1, 2, 3]")})); + ExpectRoundTrips(call("is_in", {literal(int64_t(1))}, + compute::SetLookupOptions{ArrayFromJSON(int64(), "[1, 2, 3]")})); + ExpectRoundTrips( call("is_in", {call("cast", {field_ref("version")}, compute::CastOptions::Safe(float64()))}, From 0dffb563b95e731f9aa45f3b857433872a171334 Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Tue, 5 Jan 2021 19:31:12 -0500 Subject: [PATCH 32/38] remove unused functions --- cpp/src/arrow/dataset/expression.cc | 34 ----------------------------- 1 file changed, 34 deletions(-) diff --git a/cpp/src/arrow/dataset/expression.cc b/cpp/src/arrow/dataset/expression.cc index 961489ef878..54f7ddc02b5 100644 --- a/cpp/src/arrow/dataset/expression.cc +++ b/cpp/src/arrow/dataset/expression.cc @@ -376,17 +376,6 @@ bool Expression::IsSatisfiable() const { namespace { -bool KernelStateIsImmutable(const std::string& function) { - // XXX maybe just add Kernel::state_is_immutable or so? - - // known functions with non-null but nevertheless immutable KernelState - static std::unordered_set names = { - "is_in", "index_in", "cast", "project", "strptime", - }; - - return names.find(function) != names.end(); -} - Result> InitKernelState( const Expression::Call& call, compute::ExecContext* exec_context) { if (!call.kernel->init) return nullptr; @@ -665,29 +654,6 @@ util::optional GetNullHandling( return util::nullopt; } -bool DefinitelyNotNull(const Expression& expr) { - DCHECK(expr.IsBound()); - - if (expr.literal()) { - return !expr.IsNullLiteral(); - } - - if (expr.field_ref()) return false; - - auto call = CallNotNull(expr); - if (auto null_handling = GetNullHandling(*call)) { - if (null_handling == compute::NullHandling::OUTPUT_NOT_NULL) { - return true; - } - if (null_handling == compute::NullHandling::INTERSECTION) { - return std::all_of(call->arguments.begin(), call->arguments.end(), - DefinitelyNotNull); - } - } - - return false; -} - } // namespace std::vector FieldsInExpression(const Expression& expr) { From 24c9277a5e0b50dafa9a361a8a5c4b032cf19e13 Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Wed, 6 Jan 2021 10:09:33 -0500 Subject: [PATCH 33/38] take ownership of buffer to preserve deserialized arrays' storage --- cpp/src/arrow/dataset/expression.cc | 25 ++++---------------- cpp/src/arrow/dataset/expression.h | 2 +- cpp/src/arrow/dataset/expression_test.cc | 5 +--- cpp/src/arrow/io/memory.cc | 8 +++---- cpp/src/arrow/io/memory.h | 2 +- python/pyarrow/_dataset.pyx | 2 +- python/pyarrow/includes/libarrow_dataset.pxd | 2 +- 7 files changed, 14 insertions(+), 32 deletions(-) diff --git a/cpp/src/arrow/dataset/expression.cc b/cpp/src/arrow/dataset/expression.cc index 54f7ddc02b5..11d6a0e7c8f 100644 --- a/cpp/src/arrow/dataset/expression.cc +++ b/cpp/src/arrow/dataset/expression.cc @@ -17,7 +17,6 @@ #include "arrow/dataset/expression.h" -#include #include #include @@ -1023,7 +1022,7 @@ Result SimplifyWithGuarantee(Expression expr, namespace { -inline Result> FunctionOptionsToStructScalar( +Result> FunctionOptionsToStructScalar( const Expression::Call& call) { if (call.options == nullptr) { return nullptr; @@ -1066,8 +1065,7 @@ inline Result> FunctionOptionsToStructScalar( return Status::NotImplemented("conversion of options for ", call.function_name); } -inline Status FunctionOptionsFromStructScalar(const StructScalar* repr, - Expression::Call* call) { +Status FunctionOptionsFromStructScalar(const StructScalar* repr, Expression::Call* call) { if (repr == nullptr) { call->options = nullptr; return Status::OK(); @@ -1170,12 +1168,6 @@ Result> Serialize(const Expression& expr) { } ToRecordBatch; ARROW_ASSIGN_OR_RAISE(auto batch, ToRecordBatch(expr)); - std::cout << "Serialization of:\nmetadata:\n" - << batch->schema()->metadata()->ToString() - << "\nschema:" << batch->schema()->ToString() - << "\nstorage:" << batch->ToString() << "\n" - << expr.ToString(); - ARROW_ASSIGN_OR_RAISE(auto stream, io::BufferOutputStream::Create()); ARROW_ASSIGN_OR_RAISE(auto writer, ipc::MakeFileWriter(stream, batch->schema())); RETURN_NOT_OK(writer->WriteRecordBatch(*batch)); @@ -1183,8 +1175,8 @@ Result> Serialize(const Expression& expr) { return stream->Finish(); } -Result Deserialize(const Buffer& buffer) { - io::BufferReader stream(buffer); +Result Deserialize(std::shared_ptr buffer) { + io::BufferReader stream(std::move(buffer)); ARROW_ASSIGN_OR_RAISE(auto reader, ipc::RecordBatchFileReader::Open(&stream)); ARROW_ASSIGN_OR_RAISE(auto batch, reader->ReadRecordBatch(0)); if (batch->schema()->metadata() == nullptr) { @@ -1196,11 +1188,6 @@ Result Deserialize(const Buffer& buffer) { batch->num_rows()); } - std::cout << "Deserialization of:\nmetadata:\n" - << batch->schema()->metadata()->ToString() - << "\nschema:" << batch->schema()->ToString() - << "\nstorage:" << batch->ToString(); - struct FromRecordBatch { const RecordBatch& batch_; int index_; @@ -1261,9 +1248,7 @@ Result Deserialize(const Buffer& buffer) { } }; - ARROW_ASSIGN_OR_RAISE(auto expr, (FromRecordBatch{*batch, 0}.GetOne())); - std::cout << expr.ToString() << std::endl; - return expr; + return FromRecordBatch{*batch, 0}.GetOne(); } Expression project(std::vector values, std::vector names) { diff --git a/cpp/src/arrow/dataset/expression.h b/cpp/src/arrow/dataset/expression.h index dff4c571155..984c846210f 100644 --- a/cpp/src/arrow/dataset/expression.h +++ b/cpp/src/arrow/dataset/expression.h @@ -216,7 +216,7 @@ ARROW_DS_EXPORT Result> Serialize(const Expression&); ARROW_DS_EXPORT -Result Deserialize(const Buffer&); +Result Deserialize(std::shared_ptr); // Convenience aliases for factories diff --git a/cpp/src/arrow/dataset/expression_test.cc b/cpp/src/arrow/dataset/expression_test.cc index a304ce2834f..da5c82425b3 100644 --- a/cpp/src/arrow/dataset/expression_test.cc +++ b/cpp/src/arrow/dataset/expression_test.cc @@ -1029,7 +1029,7 @@ TEST(Expression, Filter) { TEST(Expression, SerializationRoundTrips) { auto ExpectRoundTrips = [](const Expression& expr) { ASSERT_OK_AND_ASSIGN(auto serialized, Serialize(expr)); - ASSERT_OK_AND_ASSIGN(Expression roundtripped, Deserialize(*serialized)); + ASSERT_OK_AND_ASSIGN(Expression roundtripped, Deserialize(serialized)); EXPECT_EQ(expr, roundtripped); }; @@ -1063,9 +1063,6 @@ TEST(Expression, SerializationRoundTrips) { ExpectRoundTrips(call("is_in", {literal(1)}, compute::SetLookupOptions{ArrayFromJSON(int32(), "[1, 2, 3]")})); - ExpectRoundTrips(call("is_in", {literal(int64_t(1))}, - compute::SetLookupOptions{ArrayFromJSON(int64(), "[1, 2, 3]")})); - ExpectRoundTrips( call("is_in", {call("cast", {field_ref("version")}, compute::CastOptions::Safe(float64()))}, diff --git a/cpp/src/arrow/io/memory.cc b/cpp/src/arrow/io/memory.cc index 361895494e1..1ac435ab642 100644 --- a/cpp/src/arrow/io/memory.cc +++ b/cpp/src/arrow/io/memory.cc @@ -261,10 +261,10 @@ void FixedSizeBufferWriter::set_memcopy_threshold(int64_t threshold) { // ---------------------------------------------------------------------- // In-memory buffer reader -BufferReader::BufferReader(const std::shared_ptr& buffer) - : buffer_(buffer), - data_(buffer->data()), - size_(buffer->size()), +BufferReader::BufferReader(std::shared_ptr buffer) + : buffer_(std::move(buffer)), + data_(buffer_->data()), + size_(buffer_->size()), position_(0), is_open_(true) {} diff --git a/cpp/src/arrow/io/memory.h b/cpp/src/arrow/io/memory.h index 0eeabbaca78..075398a180b 100644 --- a/cpp/src/arrow/io/memory.h +++ b/cpp/src/arrow/io/memory.h @@ -145,7 +145,7 @@ class ARROW_EXPORT FixedSizeBufferWriter : public WritableFile { class ARROW_EXPORT BufferReader : public internal::RandomAccessFileConcurrencyWrapper { public: - explicit BufferReader(const std::shared_ptr& buffer); + explicit BufferReader(std::shared_ptr buffer); explicit BufferReader(const Buffer& buffer); BufferReader(const uint8_t* data, int64_t size); diff --git a/python/pyarrow/_dataset.pyx b/python/pyarrow/_dataset.pyx index 710e3c9c2c4..410ca12c66b 100644 --- a/python/pyarrow/_dataset.pyx +++ b/python/pyarrow/_dataset.pyx @@ -117,7 +117,7 @@ cdef class Expression(_Weakrefable): @staticmethod def _deserialize(Buffer buffer not None): return Expression.wrap(GetResultValue(CDeserializeExpression( - deref(pyarrow_unwrap_buffer(buffer))))) + pyarrow_unwrap_buffer(buffer)))) def __reduce__(self): buffer = pyarrow_wrap_buffer(GetResultValue( diff --git a/python/pyarrow/includes/libarrow_dataset.pxd b/python/pyarrow/includes/libarrow_dataset.pxd index 98e6c20bf23..73803c0ad36 100644 --- a/python/pyarrow/includes/libarrow_dataset.pxd +++ b/python/pyarrow/includes/libarrow_dataset.pxd @@ -53,7 +53,7 @@ cdef extern from "arrow/dataset/api.h" namespace "arrow::dataset" nogil: cdef CResult[shared_ptr[CBuffer]] CSerializeExpression \ "arrow::dataset::Serialize"(const CExpression&) cdef CResult[CExpression] CDeserializeExpression \ - "arrow::dataset::Deserialize"(const CBuffer&) + "arrow::dataset::Deserialize"(shared_ptr[CBuffer]) cdef cppclass CRecordBatchProjector "arrow::dataset::RecordBatchProjector": pass From 12920761d06d3a61bb2b9ded29482351914bcd1a Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Wed, 6 Jan 2021 12:51:24 -0500 Subject: [PATCH 34/38] cleanup, FieldPath::operator bool -> empty --- .../arrow/compute/kernels/scalar_cast_test.cc | 1 - cpp/src/arrow/compute/kernels/util_internal.h | 9 +++++---- cpp/src/arrow/dataset/expression.cc | 12 ++++++------ cpp/src/arrow/dataset/file_csv.cc | 6 +++--- cpp/src/arrow/dataset/file_parquet.cc | 16 +++++++--------- cpp/src/arrow/dataset/partition.cc | 15 +++++++-------- cpp/src/arrow/type.h | 9 ++++----- 7 files changed, 32 insertions(+), 36 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_test.cc b/cpp/src/arrow/compute/kernels/scalar_cast_test.cc index 67727f18dc7..350728793e6 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_test.cc @@ -1798,7 +1798,6 @@ TYPED_TEST(TestDictionaryCast, Basic) { auto dict_arr = *DictionaryArray::FromArrays(dict_ty, indices, dict); std::shared_ptr expected = *Take(*dict, *indices); - // TODO: Should casting dictionary scalars work? this->CheckPass(*dict_arr, *expected, expected->type(), CastOptions::Safe(), /*check_scalar=*/false); diff --git a/cpp/src/arrow/compute/kernels/util_internal.h b/cpp/src/arrow/compute/kernels/util_internal.h index 4aad3804366..aece5a97599 100644 --- a/cpp/src/arrow/compute/kernels/util_internal.h +++ b/cpp/src/arrow/compute/kernels/util_internal.h @@ -18,7 +18,6 @@ #pragma once #include -#include #include "arrow/array/util.h" #include "arrow/buffer.h" @@ -55,9 +54,11 @@ int GetBitWidth(const DataType& type); PrimitiveArg GetPrimitiveArg(const ArrayData& arr); // Augment a unary ArrayKernelExec which supports only array-like inputs with support for -// scalar inputs. Scalars will be transformed to 1-long arrays which are passed to the -// original exec. This could be far more efficient, but instead of optimizing this it'd be -// better to support scalar inputs "upstream" in original exec. +// scalar inputs. Scalars will be transformed to 1-long arrays with the scalar's value (or +// null if the scalar is null) as its only element. This 1-long array will be passed to +// the original exec, then the only element of the resulting array will be extracted as +// the output scalar. This could be far more efficient, but instead of optimizing this +// it'd be better to support scalar inputs "upstream" in original exec. ArrayKernelExec TrivialScalarUnaryAsArraysExec(ArrayKernelExec exec); } // namespace internal diff --git a/cpp/src/arrow/dataset/expression.cc b/cpp/src/arrow/dataset/expression.cc index 11d6a0e7c8f..21136b93ef4 100644 --- a/cpp/src/arrow/dataset/expression.cc +++ b/cpp/src/arrow/dataset/expression.cc @@ -513,22 +513,22 @@ struct FieldPathGetDatumImpl { inline Result GetDatumField(const FieldRef& ref, const Datum& input) { Datum field; - FieldPath path; + FieldPath match; if (auto type = input.type()) { - ARROW_ASSIGN_OR_RAISE(path, ref.FindOneOrNone(*type)); + ARROW_ASSIGN_OR_RAISE(match, ref.FindOneOrNone(*type)); } else if (auto schema = input.schema()) { - ARROW_ASSIGN_OR_RAISE(path, ref.FindOneOrNone(*schema)); + ARROW_ASSIGN_OR_RAISE(match, ref.FindOneOrNone(*schema)); } else { return Status::NotImplemented("retrieving fields from datum ", input.ToString()); } - if (path) { + if (!match.empty()) { ARROW_ASSIGN_OR_RAISE(field, - util::visit(FieldPathGetDatumImpl{input, path}, input.value)); + util::visit(FieldPathGetDatumImpl{input, match}, input.value)); } if (field == Datum{}) { - field = Datum(std::make_shared()); + return Datum(std::make_shared()); } return field; diff --git a/cpp/src/arrow/dataset/file_csv.cc b/cpp/src/arrow/dataset/file_csv.cc index 87c26f1cc79..47ed472f235 100644 --- a/cpp/src/arrow/dataset/file_csv.cc +++ b/cpp/src/arrow/dataset/file_csv.cc @@ -95,9 +95,9 @@ static inline Result GetConvertOptions( for (const FieldRef& ref : FieldsInExpression(scan_options->filter)) { DCHECK(ref.name()); ARROW_ASSIGN_OR_RAISE(auto match, ref.FindOneOrNone(*scan_options->schema())); - if (!match) { - convert_options.include_columns.push_back(*ref.name()); - } + if (match.empty()) continue; + + convert_options.include_columns.push_back(*ref.name()); } return convert_options; diff --git a/cpp/src/arrow/dataset/file_parquet.cc b/cpp/src/arrow/dataset/file_parquet.cc index eee175228f2..0d49cd72135 100644 --- a/cpp/src/arrow/dataset/file_parquet.cc +++ b/cpp/src/arrow/dataset/file_parquet.cc @@ -297,8 +297,7 @@ Result ParquetFileFormat::ScanFile(std::shared_ptrmetadata() != nullptr) { - ARROW_ASSIGN_OR_RAISE(row_groups, - parquet_fragment->FilterRowGroups(options->filter)); + ARROW_ASSIGN_OR_RAISE(row_groups, parquet_fragment->FilterRowGroups(options->filter)); pre_filtered = true; if (row_groups.empty()) MakeEmpty(); @@ -313,8 +312,7 @@ Result ParquetFileFormat::ScanFile(std::shared_ptrFilterRowGroups(options->filter)); + ARROW_ASSIGN_OR_RAISE(row_groups, parquet_fragment->FilterRowGroups(options->filter)); if (row_groups.empty()) MakeEmpty(); } @@ -511,13 +509,13 @@ Result> ParquetFileFragment::FilterRowGroups(Expression predica } for (const FieldRef& ref : FieldsInExpression(predicate)) { - ARROW_ASSIGN_OR_RAISE(auto path, ref.FindOneOrNone(*physical_schema_)); + ARROW_ASSIGN_OR_RAISE(auto match, ref.FindOneOrNone(*physical_schema_)); - if (!path) continue; - if (statistics_expressions_complete_[path[0]]) continue; - statistics_expressions_complete_[path[0]] = true; + if (match.empty()) continue; + if (statistics_expressions_complete_[match[0]]) continue; + statistics_expressions_complete_[match[0]] = true; - const SchemaField& schema_field = manifest_->schema_fields[path[0]]; + const SchemaField& schema_field = manifest_->schema_fields[match[0]]; int i = 0; for (int row_group : *row_groups_) { auto row_group_metadata = metadata_->RowGroup(row_group); diff --git a/cpp/src/arrow/dataset/partition.cc b/cpp/src/arrow/dataset/partition.cc index 2822c4a15b6..3a164d8d795 100644 --- a/cpp/src/arrow/dataset/partition.cc +++ b/cpp/src/arrow/dataset/partition.cc @@ -80,7 +80,7 @@ Status KeyValuePartitioning::SetDefaultValuesFromKeys(const Expression& expr, ARROW_ASSIGN_OR_RAISE(auto match, ref_value.first.FindOneOrNone(*projector->schema())); - if (!match) continue; + if (match.empty()) continue; RETURN_NOT_OK(projector->SetDefaultValue(match, ref_value.second.scalar())); } return Status::OK(); @@ -105,12 +105,11 @@ Result KeyValuePartitioning::Partition( for (const auto& partition_field : schema_->fields()) { ARROW_ASSIGN_OR_RAISE( auto match, FieldRef(partition_field->name()).FindOneOrNone(*rest->schema())) + if (match.empty()) continue; - if (match) { - by_fields.push_back(partition_field); - by_columns.push_back(rest->column(match[0])); - ARROW_ASSIGN_OR_RAISE(rest, rest->RemoveColumn(match[0])); - } + by_fields.push_back(partition_field); + by_columns.push_back(rest->column(match[0])); + ARROW_ASSIGN_OR_RAISE(rest, rest->RemoveColumn(match[0])); } if (by_fields.empty()) { @@ -138,7 +137,7 @@ Result KeyValuePartitioning::Partition( Result KeyValuePartitioning::ConvertKey(const Key& key) const { ARROW_ASSIGN_OR_RAISE(auto match, FieldRef(key.name).FindOneOrNone(*schema_)); - if (!match) { + if (match.empty()) { return literal(true); } @@ -201,7 +200,7 @@ Result KeyValuePartitioning::Format(const Expression& expr) const { } ARROW_ASSIGN_OR_RAISE(auto match, ref_value.first.FindOneOrNone(*schema_)); - if (!match) continue; + if (match.empty()) continue; const auto& value = ref_value.second.scalar(); diff --git a/cpp/src/arrow/type.h b/cpp/src/arrow/type.h index 3c52c19f191..a1071c4c86e 100644 --- a/cpp/src/arrow/type.h +++ b/cpp/src/arrow/type.h @@ -1404,8 +1404,7 @@ class ARROW_EXPORT FieldPath { size_t operator()(const FieldPath& path) const { return path.hash(); } }; - explicit operator bool() const { return !indices_.empty(); } - bool operator!() const { return indices_.empty(); } + bool empty() const { return indices_.empty(); } bool operator==(const FieldPath& other) const { return indices() == other.indices(); } bool operator!=(const FieldPath& other) const { return indices() != other.indices(); } @@ -1618,10 +1617,10 @@ class ARROW_EXPORT FieldRef { template Result> GetOneOrNone(const T& root) const { ARROW_ASSIGN_OR_RAISE(auto match, FindOneOrNone(root)); - if (match) { - return match.Get(root).ValueOrDie(); + if (match.empty()) { + return static_cast>(NULLPTR); } - return GetType(NULLPTR); + return match.Get(root).ValueOrDie(); } private: From a072952b4260b82a8f6ad2d9654edf4a9eb521de Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Wed, 6 Jan 2021 13:10:20 -0500 Subject: [PATCH 35/38] move project to scalar_nested.cc, add test for different chunking --- cpp/src/arrow/compute/api_scalar.h | 7 + cpp/src/arrow/compute/cast.cc | 78 ----------- cpp/src/arrow/compute/cast.h | 12 -- cpp/src/arrow/compute/kernels/CMakeLists.txt | 1 - .../arrow/compute/kernels/scalar_nested.cc | 79 ++++++++++++ .../compute/kernels/scalar_nested_test.cc | 114 ++++++++++++++++ .../compute/kernels/scalar_project_test.cc | 122 ------------------ cpp/src/arrow/dataset/filter.cc | 20 --- cpp/src/arrow/dataset/filter.h | 18 --- cpp/src/arrow/dataset/filter_test.cc | 18 --- 10 files changed, 200 insertions(+), 269 deletions(-) delete mode 100644 cpp/src/arrow/compute/kernels/scalar_project_test.cc delete mode 100644 cpp/src/arrow/dataset/filter.cc delete mode 100644 cpp/src/arrow/dataset/filter.h delete mode 100644 cpp/src/arrow/dataset/filter_test.cc diff --git a/cpp/src/arrow/compute/api_scalar.h b/cpp/src/arrow/compute/api_scalar.h index 85c1587327f..9d3d0cb745d 100644 --- a/cpp/src/arrow/compute/api_scalar.h +++ b/cpp/src/arrow/compute/api_scalar.h @@ -107,6 +107,13 @@ struct CompareOptions : public FunctionOptions { enum CompareOperator op; }; +struct ARROW_EXPORT ProjectOptions : public FunctionOptions { + explicit ProjectOptions(std::vector n) : field_names(std::move(n)) {} + + /// Names for wrapped columns + std::vector field_names; +}; + /// @} /// \brief Add two values together. Array values must be the same length. If diff --git a/cpp/src/arrow/compute/cast.cc b/cpp/src/arrow/compute/cast.cc index ecdbdfd9d8f..fd705ff973b 100644 --- a/cpp/src/arrow/compute/cast.cc +++ b/cpp/src/arrow/compute/cast.cc @@ -118,86 +118,8 @@ class CastMetaFunction : public MetaFunction { } // namespace -const FunctionDoc project_doc{"Wrap Arrays into a StructArray", - ("Names of the StructArray's fields are\n" - "specified through ProjectOptions."), - {"*args"}, - "ProjectOptions"}; - -Result ProjectResolve(KernelContext* ctx, - const std::vector& descrs) { - const auto& names = OptionsWrapper::Get(ctx).field_names; - if (names.size() != descrs.size()) { - return Status::Invalid("project() was passed ", names.size(), " field ", "names but ", - descrs.size(), " arguments"); - } - - size_t i = 0; - FieldVector fields(descrs.size()); - - ValueDescr::Shape shape = ValueDescr::SCALAR; - for (const ValueDescr& descr : descrs) { - if (descr.shape != ValueDescr::SCALAR) { - shape = ValueDescr::ARRAY; - } else { - switch (descr.type->id()) { - case Type::EXTENSION: - case Type::DENSE_UNION: - case Type::SPARSE_UNION: - return Status::NotImplemented("Broadcasting scalars of type ", *descr.type); - default: - break; - } - } - - fields[i] = field(names[i], descr.type); - ++i; - } - - return ValueDescr{struct_(std::move(fields)), shape}; -} - -void ProjectExec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - KERNEL_ASSIGN_OR_RAISE(auto descr, ctx, ProjectResolve(ctx, batch.GetDescriptors())); - - if (descr.shape == ValueDescr::SCALAR) { - ScalarVector scalars(batch.num_values()); - for (int i = 0; i < batch.num_values(); ++i) { - scalars[i] = batch[i].scalar(); - } - - *out = - Datum(std::make_shared(std::move(scalars), std::move(descr.type))); - return; - } - - ArrayVector arrays(batch.num_values()); - for (int i = 0; i < batch.num_values(); ++i) { - if (batch[i].is_array()) { - arrays[i] = batch[i].make_array(); - continue; - } - - KERNEL_ASSIGN_OR_RAISE( - arrays[i], ctx, - MakeArrayFromScalar(*batch[i].scalar(), batch.length, ctx->memory_pool())); - } - - *out = std::make_shared(descr.type, batch.length, std::move(arrays)); -} - void RegisterScalarCast(FunctionRegistry* registry) { DCHECK_OK(registry->AddFunction(std::make_shared())); - - auto project_function = - std::make_shared("project", Arity::VarArgs(), &project_doc); - ScalarKernel kernel{KernelSignature::Make({InputType{}}, OutputType{ProjectResolve}, - /*is_varargs=*/true), - ProjectExec, OptionsWrapper::Init}; - kernel.null_handling = NullHandling::OUTPUT_NOT_NULL; - kernel.mem_allocation = MemAllocation::NO_PREALLOCATE; - DCHECK_OK(project_function->AddKernel(std::move(kernel))); - DCHECK_OK(registry->AddFunction(std::move(project_function))); } } // namespace internal diff --git a/cpp/src/arrow/compute/cast.h b/cpp/src/arrow/compute/cast.h index 2454491e9e2..0b9d9caf882 100644 --- a/cpp/src/arrow/compute/cast.h +++ b/cpp/src/arrow/compute/cast.h @@ -157,17 +157,5 @@ Result Cast(const Datum& value, std::shared_ptr to_type, const CastOptions& options = CastOptions::Safe(), ExecContext* ctx = NULLPTR); -/// \addtogroup compute-concrete-options -/// @{ - -struct ARROW_EXPORT ProjectOptions : public FunctionOptions { - explicit ProjectOptions(std::vector n) : field_names(std::move(n)) {} - - /// Names for wrapped columns - std::vector field_names; -}; - -/// @} - } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/CMakeLists.txt b/cpp/src/arrow/compute/kernels/CMakeLists.txt index 6043ae62f4d..577b250da87 100644 --- a/cpp/src/arrow/compute/kernels/CMakeLists.txt +++ b/cpp/src/arrow/compute/kernels/CMakeLists.txt @@ -25,7 +25,6 @@ add_arrow_compute_test(scalar_test scalar_cast_test.cc scalar_compare_test.cc scalar_nested_test.cc - scalar_project_test.cc scalar_set_lookup_test.cc scalar_string_test.cc scalar_validity_test.cc diff --git a/cpp/src/arrow/compute/kernels/scalar_nested.cc b/cpp/src/arrow/compute/kernels/scalar_nested.cc index a0b35738c70..6e9803caf9f 100644 --- a/cpp/src/arrow/compute/kernels/scalar_nested.cc +++ b/cpp/src/arrow/compute/kernels/scalar_nested.cc @@ -18,6 +18,7 @@ // Vector kernels involving nested types #include "arrow/array/array_base.h" +#include "arrow/compute/api_scalar.h" #include "arrow/compute/kernels/common.h" #include "arrow/result.h" #include "arrow/util/bit_block_counter.h" @@ -59,6 +60,74 @@ const FunctionDoc list_value_length_doc{ "Null values emit a null in the output."), {"lists"}}; +Result ProjectResolve(KernelContext* ctx, + const std::vector& descrs) { + const auto& names = OptionsWrapper::Get(ctx).field_names; + if (names.size() != descrs.size()) { + return Status::Invalid("project() was passed ", names.size(), " field ", "names but ", + descrs.size(), " arguments"); + } + + size_t i = 0; + FieldVector fields(descrs.size()); + + ValueDescr::Shape shape = ValueDescr::SCALAR; + for (const ValueDescr& descr : descrs) { + if (descr.shape != ValueDescr::SCALAR) { + shape = ValueDescr::ARRAY; + } else { + switch (descr.type->id()) { + case Type::EXTENSION: + case Type::DENSE_UNION: + case Type::SPARSE_UNION: + return Status::NotImplemented("Broadcasting scalars of type ", *descr.type); + default: + break; + } + } + + fields[i] = field(names[i], descr.type); + ++i; + } + + return ValueDescr{struct_(std::move(fields)), shape}; +} + +void ProjectExec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + KERNEL_ASSIGN_OR_RAISE(auto descr, ctx, ProjectResolve(ctx, batch.GetDescriptors())); + + if (descr.shape == ValueDescr::SCALAR) { + ScalarVector scalars(batch.num_values()); + for (int i = 0; i < batch.num_values(); ++i) { + scalars[i] = batch[i].scalar(); + } + + *out = + Datum(std::make_shared(std::move(scalars), std::move(descr.type))); + return; + } + + ArrayVector arrays(batch.num_values()); + for (int i = 0; i < batch.num_values(); ++i) { + if (batch[i].is_array()) { + arrays[i] = batch[i].make_array(); + continue; + } + + KERNEL_ASSIGN_OR_RAISE( + arrays[i], ctx, + MakeArrayFromScalar(*batch[i].scalar(), batch.length, ctx->memory_pool())); + } + + *out = std::make_shared(descr.type, batch.length, std::move(arrays)); +} + +const FunctionDoc project_doc{"Wrap Arrays into a StructArray", + ("Names of the StructArray's fields are\n" + "specified through ProjectOptions."), + {"*args"}, + "ProjectOptions"}; + } // namespace void RegisterScalarNested(FunctionRegistry* registry) { @@ -69,6 +138,16 @@ void RegisterScalarNested(FunctionRegistry* registry) { DCHECK_OK(list_value_length->AddKernel({InputType(Type::LARGE_LIST)}, int64(), ListValueLength)); DCHECK_OK(registry->AddFunction(std::move(list_value_length))); + + auto project_function = + std::make_shared("project", Arity::VarArgs(), &project_doc); + ScalarKernel kernel{KernelSignature::Make({InputType{}}, OutputType{ProjectResolve}, + /*is_varargs=*/true), + ProjectExec, OptionsWrapper::Init}; + kernel.null_handling = NullHandling::OUTPUT_NOT_NULL; + kernel.mem_allocation = MemAllocation::NO_PREALLOCATE; + DCHECK_OK(project_function->AddKernel(std::move(kernel))); + DCHECK_OK(registry->AddFunction(std::move(project_function))); } } // namespace internal diff --git a/cpp/src/arrow/compute/kernels/scalar_nested_test.cc b/cpp/src/arrow/compute/kernels/scalar_nested_test.cc index 24776913ee0..14363f5d0d1 100644 --- a/cpp/src/arrow/compute/kernels/scalar_nested_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_nested_test.cc @@ -17,6 +17,7 @@ #include +#include "arrow/chunked_array.h" #include "arrow/compute/api.h" #include "arrow/compute/kernels/test_util.h" #include "arrow/result.h" @@ -36,5 +37,118 @@ TEST(TestScalarNested, ListValueLength) { } } +struct { + public: + Result operator()(std::vector args) { + ProjectOptions opts{field_names}; + return CallFunction("project", args, &opts); + } + + std::vector field_names; +} Project; + +TEST(Project, Scalar) { + std::shared_ptr expected(new StructScalar{{}, struct_({})}); + ASSERT_OK_AND_EQ(Datum(expected), Project({})); + + auto i32 = MakeScalar(1); + auto f64 = MakeScalar(2.5); + auto str = MakeScalar("yo"); + + expected.reset(new StructScalar{ + {i32, f64, str}, + struct_({field("i", i32->type), field("f", f64->type), field("s", str->type)})}); + Project.field_names = {"i", "f", "s"}; + ASSERT_OK_AND_EQ(Datum(expected), Project({i32, f64, str})); + + // Three field names but one input value + ASSERT_RAISES(Invalid, Project({str})); +} + +TEST(Project, Array) { + Project.field_names = {"i", "s"}; + auto i32 = ArrayFromJSON(int32(), "[42, 13, 7]"); + auto str = ArrayFromJSON(utf8(), R"(["aa", "aa", "aa"])"); + ASSERT_OK_AND_ASSIGN(Datum expected, + StructArray::Make({i32, str}, Project.field_names)); + + ASSERT_OK_AND_EQ(expected, Project({i32, str})); + + // Scalars are broadcast to the length of the arrays + ASSERT_OK_AND_EQ(expected, Project({i32, MakeScalar("aa")})); + + // Array length mismatch + ASSERT_RAISES(Invalid, Project({i32->Slice(1), str})); +} + +TEST(Project, ChunkedArray) { + Project.field_names = {"i", "s"}; + + auto i32_0 = ArrayFromJSON(int32(), "[42, 13, 7]"); + auto i32_1 = ArrayFromJSON(int32(), "[]"); + auto i32_2 = ArrayFromJSON(int32(), "[32, 0]"); + + auto str_0 = ArrayFromJSON(utf8(), R"(["aa", "aa", "aa"])"); + auto str_1 = ArrayFromJSON(utf8(), "[]"); + auto str_2 = ArrayFromJSON(utf8(), R"(["aa", "aa"])"); + + ASSERT_OK_AND_ASSIGN(auto i32, ChunkedArray::Make({i32_0, i32_1, i32_2})); + ASSERT_OK_AND_ASSIGN(auto str, ChunkedArray::Make({str_0, str_1, str_2})); + + ASSERT_OK_AND_ASSIGN(auto expected_0, + StructArray::Make({i32_0, str_0}, Project.field_names)); + ASSERT_OK_AND_ASSIGN(auto expected_1, + StructArray::Make({i32_1, str_1}, Project.field_names)); + ASSERT_OK_AND_ASSIGN(auto expected_2, + StructArray::Make({i32_2, str_2}, Project.field_names)); + ASSERT_OK_AND_ASSIGN(Datum expected, + ChunkedArray::Make({expected_0, expected_1, expected_2})); + + ASSERT_OK_AND_EQ(expected, Project({i32, str})); + + // Scalars are broadcast to the length of the arrays + ASSERT_OK_AND_EQ(expected, Project({i32, MakeScalar("aa")})); + + // Array length mismatch + ASSERT_RAISES(Invalid, Project({i32->Slice(1), str})); +} + +TEST(Project, ChunkedArrayDifferentChunking) { + Project.field_names = {"i", "s"}; + + auto i32_0 = ArrayFromJSON(int32(), "[42, 13, 7]"); + auto i32_1 = ArrayFromJSON(int32(), "[]"); + auto i32_2 = ArrayFromJSON(int32(), "[32, 0]"); + + auto str_0 = ArrayFromJSON(utf8(), R"(["aa", "aa"])"); + auto str_1 = ArrayFromJSON(utf8(), R"(["aa"])"); + auto str_2 = ArrayFromJSON(utf8(), R"([])"); + auto str_3 = ArrayFromJSON(utf8(), R"(["aa", "aa"])"); + + ASSERT_OK_AND_ASSIGN(auto i32, ChunkedArray::Make({i32_0, i32_1, i32_2})); + ASSERT_OK_AND_ASSIGN(auto str, ChunkedArray::Make({str_0, str_1, str_2, str_3})); + + std::vector expected_rechunked = + ::arrow::internal::RechunkArraysConsistently({i32->chunks(), str->chunks()}); + ASSERT_EQ(expected_rechunked[0].size(), expected_rechunked[1].size()); + + ArrayVector expected_chunks(expected_rechunked[0].size()); + for (size_t i = 0; i < expected_chunks.size(); ++i) { + ASSERT_OK_AND_ASSIGN(expected_chunks[i], StructArray::Make({expected_rechunked[0][i], + expected_rechunked[1][i]}, + Project.field_names)); + } + + ASSERT_OK_AND_ASSIGN(Datum expected, ChunkedArray::Make(expected_chunks)); + + ASSERT_OK_AND_EQ(expected, Project({i32, str})); + + // Scalars are broadcast to the length of the arrays + ASSERT_OK_AND_EQ(expected, Project({i32, MakeScalar("aa")})); + + // Array length mismatch + ASSERT_RAISES(Invalid, Project({i32->Slice(1), str})); +} + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/scalar_project_test.cc b/cpp/src/arrow/compute/kernels/scalar_project_test.cc deleted file mode 100644 index 1cccfb9c540..00000000000 --- a/cpp/src/arrow/compute/kernels/scalar_project_test.cc +++ /dev/null @@ -1,122 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#include -#include -#include -#include -#include -#include - -#include -#include - -#include "arrow/array.h" -#include "arrow/buffer.h" -#include "arrow/chunked_array.h" -#include "arrow/status.h" -#include "arrow/type.h" -#include "arrow/type_fwd.h" -#include "arrow/type_traits.h" -#include "arrow/util/decimal.h" - -#include "arrow/compute/api_vector.h" -#include "arrow/compute/cast.h" -#include "arrow/compute/kernel.h" -#include "arrow/compute/kernels/test_util.h" - -namespace arrow { -namespace compute { - -struct { - public: - Result operator()(std::vector args) { - ProjectOptions opts{field_names}; - return CallFunction("project", args, &opts); - } - - std::vector field_names; -} Project; - -TEST(Project, Scalar) { - std::shared_ptr expected(new StructScalar{{}, struct_({})}); - ASSERT_OK_AND_EQ(Datum(expected), Project({})); - - auto i32 = MakeScalar(1); - auto f64 = MakeScalar(2.5); - auto str = MakeScalar("yo"); - - expected.reset(new StructScalar{ - {i32, f64, str}, - struct_({field("i", i32->type), field("f", f64->type), field("s", str->type)})}); - Project.field_names = {"i", "f", "s"}; - ASSERT_OK_AND_EQ(Datum(expected), Project({i32, f64, str})); - - // Three field names but one input value - ASSERT_RAISES(Invalid, Project({str})); -} - -TEST(Project, Array) { - Project.field_names = {"i", "s"}; - auto i32 = ArrayFromJSON(int32(), "[42, 13, 7]"); - auto str = ArrayFromJSON(utf8(), R"(["aa", "aa", "aa"])"); - ASSERT_OK_AND_ASSIGN(Datum expected, - StructArray::Make({i32, str}, Project.field_names)); - - ASSERT_OK_AND_EQ(expected, Project({i32, str})); - - // Scalars are broadcast to the length of the arrays - ASSERT_OK_AND_EQ(expected, Project({i32, MakeScalar("aa")})); - - // Array length mismatch - ASSERT_RAISES(Invalid, Project({i32->Slice(1), str})); -} - -TEST(Project, ChunkedArray) { - Project.field_names = {"i", "s"}; - - auto i32_0 = ArrayFromJSON(int32(), "[42, 13, 7]"); - auto i32_1 = ArrayFromJSON(int32(), "[]"); - auto i32_2 = ArrayFromJSON(int32(), "[32, 0]"); - - auto str_0 = ArrayFromJSON(utf8(), R"(["aa", "aa", "aa"])"); - auto str_1 = ArrayFromJSON(utf8(), "[]"); - auto str_2 = ArrayFromJSON(utf8(), R"(["aa", "aa"])"); - - ASSERT_OK_AND_ASSIGN(auto i32, ChunkedArray::Make({i32_0, i32_1, i32_2})); - ASSERT_OK_AND_ASSIGN(auto str, ChunkedArray::Make({str_0, str_1, str_2})); - - ASSERT_OK_AND_ASSIGN(auto expected_0, - StructArray::Make({i32_0, str_0}, Project.field_names)); - ASSERT_OK_AND_ASSIGN(auto expected_1, - StructArray::Make({i32_1, str_1}, Project.field_names)); - ASSERT_OK_AND_ASSIGN(auto expected_2, - StructArray::Make({i32_2, str_2}, Project.field_names)); - ASSERT_OK_AND_ASSIGN(Datum expected, - ChunkedArray::Make({expected_0, expected_1, expected_2})); - - ASSERT_OK_AND_EQ(expected, Project({i32, str})); - - // Scalars are broadcast to the length of the arrays - ASSERT_OK_AND_EQ(expected, Project({i32, MakeScalar("aa")})); - - // Array length mismatch - ASSERT_RAISES(Invalid, Project({i32->Slice(1), str})); -} - -} // namespace compute -} // namespace arrow diff --git a/cpp/src/arrow/dataset/filter.cc b/cpp/src/arrow/dataset/filter.cc deleted file mode 100644 index 2357896dd7a..00000000000 --- a/cpp/src/arrow/dataset/filter.cc +++ /dev/null @@ -1,20 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#include "arrow/dataset/filter.h" - -// FIXME remove diff --git a/cpp/src/arrow/dataset/filter.h b/cpp/src/arrow/dataset/filter.h deleted file mode 100644 index 9852f8c0808..00000000000 --- a/cpp/src/arrow/dataset/filter.h +++ /dev/null @@ -1,18 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -// FIXME remove diff --git a/cpp/src/arrow/dataset/filter_test.cc b/cpp/src/arrow/dataset/filter_test.cc deleted file mode 100644 index ef32b8a7bb6..00000000000 --- a/cpp/src/arrow/dataset/filter_test.cc +++ /dev/null @@ -1,18 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -// FIXME delete this From 9ce0adb6551e97551b8861ffa5a6e20f6d93f9db Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Wed, 6 Jan 2021 13:15:22 -0500 Subject: [PATCH 36/38] remove AddSimpleArrayOnlyCast --- .../compute/kernels/scalar_cast_internal.h | 11 +++------- .../compute/kernels/scalar_cast_temporal.cc | 20 +++++++++---------- 2 files changed, 13 insertions(+), 18 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_internal.h b/cpp/src/arrow/compute/kernels/scalar_cast_internal.h index 15769ce5f8f..dabf0c2b061 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_internal.h +++ b/cpp/src/arrow/compute/kernels/scalar_cast_internal.h @@ -55,16 +55,11 @@ void OutputAllNull(KernelContext* ctx, const ExecBatch& batch, Datum* out); void CastFromNull(KernelContext* ctx, const ExecBatch& batch, Datum* out); -// Adds a cast function where the functor is defined and the input and output -// types have a type_singleton +// Adds a cast function where CastFunctor is specialized and the input and output +// types are parameter free (have a type_singleton). Scalar inputs are handled by +// wrapping with TrivialScalarUnaryAsArraysExec. template void AddSimpleCast(InputType in_ty, OutputType out_ty, CastFunction* func) { - DCHECK_OK(func->AddKernel(InType::type_id, {in_ty}, out_ty, - CastFunctor::Exec)); -} - -template -void AddSimpleArrayOnlyCast(InputType in_ty, OutputType out_ty, CastFunction* func) { DCHECK_OK(func->AddKernel( InType::type_id, {in_ty}, out_ty, TrivialScalarUnaryAsArraysExec(CastFunctor::Exec))); diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_temporal.cc b/cpp/src/arrow/compute/kernels/scalar_cast_temporal.cc index 99c1401d1b8..4450504241f 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_temporal.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_temporal.cc @@ -336,10 +336,10 @@ std::shared_ptr GetDate32Cast() { AddZeroCopyCast(Type::INT32, int32(), date32(), func.get()); // date64 -> date32 - AddSimpleArrayOnlyCast(date64(), date32(), func.get()); + AddSimpleCast(date64(), date32(), func.get()); // timestamp -> date32 - AddSimpleArrayOnlyCast(InputType(Type::TIMESTAMP), date32(), + AddSimpleCast(InputType(Type::TIMESTAMP), date32(), func.get()); return func; } @@ -353,10 +353,10 @@ std::shared_ptr GetDate64Cast() { AddZeroCopyCast(Type::INT64, int64(), date64(), func.get()); // date32 -> date64 - AddSimpleArrayOnlyCast(date32(), date64(), func.get()); + AddSimpleCast(date32(), date64(), func.get()); // timestamp -> date64 - AddSimpleArrayOnlyCast(InputType(Type::TIMESTAMP), date64(), + AddSimpleCast(InputType(Type::TIMESTAMP), date64(), func.get()); return func; } @@ -387,7 +387,7 @@ std::shared_ptr GetTime32Cast() { AddZeroCopyCast(Type::INT32, /*in_type=*/int32(), kOutputTargetType, func.get()); // time64 -> time32 - AddSimpleArrayOnlyCast(InputType(Type::TIME64), + AddSimpleCast(InputType(Type::TIME64), kOutputTargetType, func.get()); // time32 -> time32 @@ -404,7 +404,7 @@ std::shared_ptr GetTime64Cast() { AddZeroCopyCast(Type::INT64, /*in_type=*/int64(), kOutputTargetType, func.get()); // time32 -> time64 - AddSimpleArrayOnlyCast(InputType(Type::TIME32), + AddSimpleCast(InputType(Type::TIME32), kOutputTargetType, func.get()); // Between durations @@ -422,16 +422,16 @@ std::shared_ptr GetTimestampCast() { // From date types // TODO: ARROW-8876, these casts are not directly tested - AddSimpleArrayOnlyCast(InputType(Type::DATE32), + AddSimpleCast(InputType(Type::DATE32), kOutputTargetType, func.get()); - AddSimpleArrayOnlyCast(InputType(Type::DATE64), + AddSimpleCast(InputType(Type::DATE64), kOutputTargetType, func.get()); // string -> timestamp - AddSimpleArrayOnlyCast(utf8(), kOutputTargetType, + AddSimpleCast(utf8(), kOutputTargetType, func.get()); // large_string -> timestamp - AddSimpleArrayOnlyCast(large_utf8(), kOutputTargetType, + AddSimpleCast(large_utf8(), kOutputTargetType, func.get()); // From one timestamp to another From b14a4560d9c8028af0e6d551777b44ea8d32390b Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Wed, 6 Jan 2021 13:33:30 -0500 Subject: [PATCH 37/38] clang-format --- .../compute/kernels/scalar_cast_temporal.cc | 25 +++++++++---------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_temporal.cc b/cpp/src/arrow/compute/kernels/scalar_cast_temporal.cc index 4450504241f..e470f9f90de 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_temporal.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_temporal.cc @@ -340,7 +340,7 @@ std::shared_ptr GetDate32Cast() { // timestamp -> date32 AddSimpleCast(InputType(Type::TIMESTAMP), date32(), - func.get()); + func.get()); return func; } @@ -357,7 +357,7 @@ std::shared_ptr GetDate64Cast() { // timestamp -> date64 AddSimpleCast(InputType(Type::TIMESTAMP), date64(), - func.get()); + func.get()); return func; } @@ -387,8 +387,8 @@ std::shared_ptr GetTime32Cast() { AddZeroCopyCast(Type::INT32, /*in_type=*/int32(), kOutputTargetType, func.get()); // time64 -> time32 - AddSimpleCast(InputType(Type::TIME64), - kOutputTargetType, func.get()); + AddSimpleCast(InputType(Type::TIME64), kOutputTargetType, + func.get()); // time32 -> time32 AddCrossUnitCast(func.get()); @@ -404,8 +404,8 @@ std::shared_ptr GetTime64Cast() { AddZeroCopyCast(Type::INT64, /*in_type=*/int64(), kOutputTargetType, func.get()); // time32 -> time64 - AddSimpleCast(InputType(Type::TIME32), - kOutputTargetType, func.get()); + AddSimpleCast(InputType(Type::TIME32), kOutputTargetType, + func.get()); // Between durations AddCrossUnitCast(func.get()); @@ -422,17 +422,16 @@ std::shared_ptr GetTimestampCast() { // From date types // TODO: ARROW-8876, these casts are not directly tested - AddSimpleCast(InputType(Type::DATE32), - kOutputTargetType, func.get()); - AddSimpleCast(InputType(Type::DATE64), - kOutputTargetType, func.get()); + AddSimpleCast(InputType(Type::DATE32), kOutputTargetType, + func.get()); + AddSimpleCast(InputType(Type::DATE64), kOutputTargetType, + func.get()); // string -> timestamp - AddSimpleCast(utf8(), kOutputTargetType, - func.get()); + AddSimpleCast(utf8(), kOutputTargetType, func.get()); // large_string -> timestamp AddSimpleCast(large_utf8(), kOutputTargetType, - func.get()); + func.get()); // From one timestamp to another AddCrossUnitCast(func.get()); From e8080da8ea075c196f669b9a3b3e1d41c49f40cc Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Wed, 6 Jan 2021 16:55:17 -0500 Subject: [PATCH 38/38] incorrect projection in CsvFileFormat --- cpp/src/arrow/dataset/file_csv.cc | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/cpp/src/arrow/dataset/file_csv.cc b/cpp/src/arrow/dataset/file_csv.cc index 47ed472f235..534c4704cb9 100644 --- a/cpp/src/arrow/dataset/file_csv.cc +++ b/cpp/src/arrow/dataset/file_csv.cc @@ -90,14 +90,16 @@ static inline Result GetConvertOptions( } // FIXME(bkietz) also acquire types of fields materialized but not projected. - // This requires that scan_options include the full dataset schema (not just + // (This will require that scan_options include the full dataset schema, not just // the projected schema). for (const FieldRef& ref : FieldsInExpression(scan_options->filter)) { DCHECK(ref.name()); ARROW_ASSIGN_OR_RAISE(auto match, ref.FindOneOrNone(*scan_options->schema())); - if (match.empty()) continue; - convert_options.include_columns.push_back(*ref.name()); + if (match.empty()) { + // a field was filtered but not in the projected schema; be sure it is included + convert_options.include_columns.push_back(*ref.name()); + } } return convert_options;