diff --git a/cpp/examples/arrow/dataset-parquet-scan-example.cc b/cpp/examples/arrow/dataset-parquet-scan-example.cc index 3cdd298b8fd..197ca5aa4c6 100644 --- a/cpp/examples/arrow/dataset-parquet-scan-example.cc +++ b/cpp/examples/arrow/dataset-parquet-scan-example.cc @@ -18,9 +18,9 @@ #include #include #include +#include #include #include -#include #include #include #include @@ -37,8 +37,6 @@ namespace fs = arrow::fs; namespace ds = arrow::dataset; -using ds::string_literals::operator"" _; - #define ABORT_ON_FAILURE(expr) \ do { \ arrow::Status status_ = (expr); \ @@ -62,7 +60,8 @@ struct Configuration { // Indicates the filter by which rows will be filtered. This optimization can // make use of partition information and/or file metadata if possible. - std::shared_ptr filter = ("total_amount"_ > 1000.0f).Copy(); + ds::Expression filter = + ds::greater(ds::field_ref("total_amount"), ds::literal(1000.0f)); ds::InspectOptions inspect_options{}; ds::FinishOptions finish_options{}; @@ -147,7 +146,7 @@ std::shared_ptr GetDatasetFromPath( std::shared_ptr GetScannerFromDataset(std::shared_ptr dataset, std::vector columns, - std::shared_ptr filter, + ds::Expression filter, bool use_threads) { auto scanner_builder = dataset->NewScan().ValueOrDie(); @@ -155,9 +154,7 @@ std::shared_ptr GetScannerFromDataset(std::shared_ptr ABORT_ON_FAILURE(scanner_builder->Project(columns)); } - if (filter != nullptr) { - ABORT_ON_FAILURE(scanner_builder->Filter(filter)); - } + ABORT_ON_FAILURE(scanner_builder->Filter(filter)); ABORT_ON_FAILURE(scanner_builder->UseThreads(use_threads)); diff --git a/cpp/src/arrow/array/array_struct_test.cc b/cpp/src/arrow/array/array_struct_test.cc index f54b43465e9..aef0076d0d3 100644 --- a/cpp/src/arrow/array/array_struct_test.cc +++ b/cpp/src/arrow/array/array_struct_test.cc @@ -24,6 +24,7 @@ #include "arrow/array.h" #include "arrow/array/builder_nested.h" +#include "arrow/chunked_array.h" #include "arrow/status.h" #include "arrow/testing/gtest_common.h" #include "arrow/testing/gtest_util.h" @@ -582,4 +583,39 @@ TEST_F(TestStructBuilder, TestSlice) { ASSERT_EQ(list_field->null_count(), 1); } +TEST(TestFieldRef, GetChildren) { + auto struct_array = ArrayFromJSON(struct_({field("a", float64())}), R"([ + {"a": 6.125}, + {"a": 0.0}, + {"a": -1} + ])"); + + ASSERT_OK_AND_ASSIGN(auto a, FieldRef("a").GetOne(*struct_array)); + auto expected_a = ArrayFromJSON(float64(), "[6.125, 0.0, -1]"); + AssertArraysEqual(*a, *expected_a); + + auto ToChunked = [struct_array](int64_t midpoint) { + return ChunkedArray( + ArrayVector{ + struct_array->Slice(0, midpoint), + struct_array->Slice(midpoint), + }, + struct_array->type()); + }; + AssertChunkedEquivalent(ToChunked(1), ToChunked(2)); + + // more nested: + struct_array = + ArrayFromJSON(struct_({field("a", struct_({field("a", float64())}))}), R"([ + {"a": {"a": 6.125}}, + {"a": {"a": 0.0}}, + {"a": {"a": -1}} + ])"); + + ASSERT_OK_AND_ASSIGN(a, FieldRef("a", "a").GetOne(*struct_array)); + expected_a = ArrayFromJSON(float64(), "[6.125, 0.0, -1]"); + AssertArraysEqual(*a, *expected_a); + AssertChunkedEquivalent(ToChunked(1), ToChunked(2)); +} + } // namespace arrow diff --git a/cpp/src/arrow/array/util.cc b/cpp/src/arrow/array/util.cc index 70167954e5a..0d498931d42 100644 --- a/cpp/src/arrow/array/util.cc +++ b/cpp/src/arrow/array/util.cc @@ -297,8 +297,9 @@ class RepeatedArrayFactory { return out_; } - Status Visit(const DataType& type) { - return Status::NotImplemented("construction from scalar of type ", *scalar_.type); + Status Visit(const NullType& type) { + DCHECK(false); // already forwarded to MakeArrayOfNull + return Status::OK(); } Status Visit(const BooleanType&) { @@ -403,6 +404,18 @@ class RepeatedArrayFactory { return Status::OK(); } + Status Visit(const ExtensionType& type) { + return Status::NotImplemented("construction from scalar of type ", *scalar_.type); + } + + Status Visit(const DenseUnionType& type) { + return Status::NotImplemented("construction from scalar of type ", *scalar_.type); + } + + Status Visit(const SparseUnionType& type) { + return Status::NotImplemented("construction from scalar of type ", *scalar_.type); + } + template Status CreateOffsetsBuffer(OffsetType value_length, std::shared_ptr* out) { TypedBufferBuilder builder(pool_); diff --git a/cpp/src/arrow/chunked_array.h b/cpp/src/arrow/chunked_array.h index b1f8d6cd6f5..5c0dda91850 100644 --- a/cpp/src/arrow/chunked_array.h +++ b/cpp/src/arrow/chunked_array.h @@ -72,6 +72,9 @@ class ARROW_EXPORT ChunkedArray { /// data type. explicit ChunkedArray(ArrayVector chunks); + ChunkedArray(ChunkedArray&&) = default; + ChunkedArray& operator=(ChunkedArray&&) = default; + /// \brief Construct a chunked array from a single Array explicit ChunkedArray(std::shared_ptr chunk) : ChunkedArray(ArrayVector{std::move(chunk)}) {} diff --git a/cpp/src/arrow/compute/api_scalar.h b/cpp/src/arrow/compute/api_scalar.h index 789ac909ccf..1c6c2ea95b0 100644 --- a/cpp/src/arrow/compute/api_scalar.h +++ b/cpp/src/arrow/compute/api_scalar.h @@ -70,7 +70,7 @@ struct ARROW_EXPORT SplitPatternOptions : public SplitOptions { /// Options for IsIn and IndexIn functions struct ARROW_EXPORT SetLookupOptions : public FunctionOptions { - explicit SetLookupOptions(Datum value_set, bool skip_nulls) + explicit SetLookupOptions(Datum value_set, bool skip_nulls = false) : value_set(std::move(value_set)), skip_nulls(skip_nulls) {} /// The set of values to look up input values into. @@ -86,7 +86,7 @@ struct ARROW_EXPORT SetLookupOptions : public FunctionOptions { struct ARROW_EXPORT StrptimeOptions : public FunctionOptions { explicit StrptimeOptions(std::string format, TimeUnit::type unit) - : format(format), unit(unit) {} + : format(std::move(format)), unit(unit) {} std::string format; TimeUnit::type unit; diff --git a/cpp/src/arrow/compute/cast.cc b/cpp/src/arrow/compute/cast.cc index 29a80f73241..5c332aedf73 100644 --- a/cpp/src/arrow/compute/cast.cc +++ b/cpp/src/arrow/compute/cast.cc @@ -118,8 +118,86 @@ class CastMetaFunction : public MetaFunction { } // namespace +const FunctionDoc struct_doc{"Wrap Arrays into a StructArray", + ("Names of the StructArray's fields are\n" + "specified through StructOptions."), + {"*args"}, + "StructOptions"}; + +Result StructResolve(KernelContext* ctx, + const std::vector& descrs) { + const auto& names = OptionsWrapper::Get(ctx).field_names; + if (names.size() != descrs.size()) { + return Status::Invalid("Struct() was passed ", names.size(), " field ", "names but ", + descrs.size(), " arguments"); + } + + size_t i = 0; + FieldVector fields(descrs.size()); + + ValueDescr::Shape shape = ValueDescr::SCALAR; + for (const ValueDescr& descr : descrs) { + if (descr.shape != ValueDescr::SCALAR) { + shape = ValueDescr::ARRAY; + } else { + switch (descr.type->id()) { + case Type::EXTENSION: + case Type::DENSE_UNION: + case Type::SPARSE_UNION: + return Status::NotImplemented("Broadcasting scalars of type ", *descr.type); + default: + break; + } + } + + fields[i] = field(names[i], descr.type); + ++i; + } + + return ValueDescr{struct_(std::move(fields)), shape}; +} + +void StructExec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + KERNEL_ASSIGN_OR_RAISE(auto descr, ctx, StructResolve(ctx, batch.GetDescriptors())); + + if (descr.shape == ValueDescr::SCALAR) { + ScalarVector scalars(batch.num_values()); + for (int i = 0; i < batch.num_values(); ++i) { + scalars[i] = batch[i].scalar(); + } + + *out = + Datum(std::make_shared(std::move(scalars), std::move(descr.type))); + return; + } + + ArrayVector arrays(batch.num_values()); + for (int i = 0; i < batch.num_values(); ++i) { + if (batch[i].is_array()) { + arrays[i] = batch[i].make_array(); + continue; + } + + KERNEL_ASSIGN_OR_RAISE( + arrays[i], ctx, + MakeArrayFromScalar(*batch[i].scalar(), batch.length, ctx->memory_pool())); + } + + *out = std::make_shared(descr.type, batch.length, std::move(arrays)); +} + void RegisterScalarCast(FunctionRegistry* registry) { DCHECK_OK(registry->AddFunction(std::make_shared())); + + auto struct_function = + std::make_shared("struct", Arity::VarArgs(), &struct_doc); + ScalarKernel kernel{KernelSignature::Make({InputType{}}, OutputType{StructResolve}, + /*is_varargs=*/true), + StructExec, OptionsWrapper::Init}; + kernel.null_handling = NullHandling::OUTPUT_NOT_NULL; + kernel.mem_allocation = MemAllocation::NO_PREALLOCATE; + DCHECK_OK(struct_function->AddKernel(std::move(kernel))); + DCHECK_OK(registry->AddFunction(std::move(struct_function))); } } // namespace internal @@ -135,7 +213,7 @@ CastFunction::CastFunction(std::string name, Type::type out_type) impl_->out_type = out_type; } -CastFunction::~CastFunction() {} +CastFunction::~CastFunction() = default; Type::type CastFunction::out_type_id() const { return impl_->out_type; } diff --git a/cpp/src/arrow/compute/cast.h b/cpp/src/arrow/compute/cast.h index 43392ce99bf..759b7c7665b 100644 --- a/cpp/src/arrow/compute/cast.h +++ b/cpp/src/arrow/compute/cast.h @@ -42,15 +42,7 @@ class ExecContext; /// @{ struct ARROW_EXPORT CastOptions : public FunctionOptions { - CastOptions() - : allow_int_overflow(false), - allow_time_truncate(false), - allow_time_overflow(false), - allow_decimal_truncate(false), - allow_float_truncate(false), - allow_invalid_utf8(false) {} - - explicit CastOptions(bool safe) + explicit CastOptions(bool safe = true) : allow_int_overflow(!safe), allow_time_truncate(!safe), allow_time_overflow(!safe), @@ -58,9 +50,17 @@ struct ARROW_EXPORT CastOptions : public FunctionOptions { allow_float_truncate(!safe), allow_invalid_utf8(!safe) {} - static CastOptions Safe() { return CastOptions(true); } + static CastOptions Safe(std::shared_ptr to_type = NULLPTR) { + CastOptions safe(true); + safe.to_type = std::move(to_type); + return safe; + } - static CastOptions Unsafe() { return CastOptions(false); } + static CastOptions Unsafe(std::shared_ptr to_type = NULLPTR) { + CastOptions unsafe(false); + unsafe.to_type = std::move(to_type); + return unsafe; + } // Type being casted to. May be passed separate to eager function // compute::Cast @@ -83,7 +83,7 @@ struct ARROW_EXPORT CastOptions : public FunctionOptions { class CastFunction : public ScalarFunction { public: CastFunction(std::string name, Type::type out_type); - ~CastFunction(); + ~CastFunction() override; Type::type out_type_id() const; @@ -157,5 +157,17 @@ Result Cast(const Datum& value, std::shared_ptr to_type, const CastOptions& options = CastOptions::Safe(), ExecContext* ctx = NULLPTR); +/// \addtogroup compute-concrete-options +/// @{ + +struct ARROW_EXPORT StructOptions : public FunctionOptions { + explicit StructOptions(std::vector n) : field_names(std::move(n)) {} + + /// Names for wrapped columns + std::vector field_names; +}; + +/// @} + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/exec.cc b/cpp/src/arrow/compute/exec.cc index 790ae9b3997..fbd8229f0c8 100644 --- a/cpp/src/arrow/compute/exec.cc +++ b/cpp/src/arrow/compute/exec.cc @@ -648,7 +648,8 @@ class ScalarExecutor : public KernelExecutorImpl { // Decide if we need to preallocate memory for this kernel validity_preallocated_ = (kernel_->null_handling != NullHandling::COMPUTED_NO_PREALLOCATE && - kernel_->null_handling != NullHandling::OUTPUT_NOT_NULL); + kernel_->null_handling != NullHandling::OUTPUT_NOT_NULL && + output_descr_.type->id() != Type::NA); if (kernel_->mem_allocation == MemAllocation::PREALLOCATE) { ComputeDataPreallocate(*output_descr_.type, &data_preallocated_); } diff --git a/cpp/src/arrow/compute/exec.h b/cpp/src/arrow/compute/exec.h index 142e149bccc..f491489ed8a 100644 --- a/cpp/src/arrow/compute/exec.h +++ b/cpp/src/arrow/compute/exec.h @@ -165,7 +165,7 @@ class ARROW_EXPORT SelectionVector { /// than is desirable for this class. Microbenchmarks would help determine for /// sure. See ARROW-8928. struct ExecBatch { - ExecBatch() {} + ExecBatch() = default; ExecBatch(std::vector values, int64_t length) : values(std::move(values)), length(length) {} diff --git a/cpp/src/arrow/compute/exec_internal.h b/cpp/src/arrow/compute/exec_internal.h index 8bad135e40d..a74e5c8d8fa 100644 --- a/cpp/src/arrow/compute/exec_internal.h +++ b/cpp/src/arrow/compute/exec_internal.h @@ -84,14 +84,14 @@ class ARROW_EXPORT ExecListener { class DatumAccumulator : public ExecListener { public: - DatumAccumulator() {} + DatumAccumulator() = default; Status OnResult(Datum value) override { values_.emplace_back(value); return Status::OK(); } - std::vector values() const { return values_; } + std::vector values() { return std::move(values_); } private: std::vector values_; diff --git a/cpp/src/arrow/compute/function.cc b/cpp/src/arrow/compute/function.cc index 2d3e06e2fb2..28951b7dae1 100644 --- a/cpp/src/arrow/compute/function.cc +++ b/cpp/src/arrow/compute/function.cc @@ -150,10 +150,15 @@ Result Function::Execute(const std::vector& args, Status Function::Validate() const { if (!doc_->summary.empty()) { // Documentation given, check its contents - if (static_cast(doc_->arg_names.size()) != arity_.num_args) { - return Status::Invalid("In function '", name_, - "': ", "number of argument names != function arity"); + int arg_count = static_cast(doc_->arg_names.size()); + if (arg_count == arity_.num_args) { + return Status::OK(); } + if (arity_.is_varargs && arg_count == arity_.num_args + 1) { + return Status::OK(); + } + return Status::Invalid("In function '", name_, + "': ", "number of argument names != function arity"); } return Status::OK(); } diff --git a/cpp/src/arrow/compute/function.h b/cpp/src/arrow/compute/function.h index a71dbe40292..e8e732027c9 100644 --- a/cpp/src/arrow/compute/function.h +++ b/cpp/src/arrow/compute/function.h @@ -41,7 +41,9 @@ namespace compute { /// \brief Base class for specifying options configuring a function's behavior, /// such as error handling. -struct ARROW_EXPORT FunctionOptions {}; +struct ARROW_EXPORT FunctionOptions { + virtual ~FunctionOptions() = default; +}; /// \brief Contains the number of required arguments for the function. /// @@ -96,7 +98,7 @@ struct ARROW_EXPORT FunctionDoc { /// \brief Name of the options class, if any. std::string options_class; - FunctionDoc() {} + FunctionDoc() = default; FunctionDoc(std::string summary, std::string description, std::vector arg_names, std::string options_class = "") diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_internal.cc b/cpp/src/arrow/compute/kernels/scalar_cast_internal.cc index cd33de672a6..f8dde20e3aa 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_internal.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_internal.cc @@ -149,6 +149,12 @@ void CastNumberToNumberUnsafe(Type::type in_type, Type::type out_type, const Dat // ---------------------------------------------------------------------- void UnpackDictionary(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + if (out->is_scalar()) { + KERNEL_ASSIGN_OR_RAISE(*out, ctx, + batch[0].scalar_as().GetEncodedValue()); + return; + } + DictionaryArray dict_arr(batch[0].array()); const CastOptions& options = checked_cast(*ctx->state()).options; @@ -160,16 +166,16 @@ void UnpackDictionary(KernelContext* ctx, const ExecBatch& batch, Datum* out) { return; } - Result result = Take(Datum(dict_arr.dictionary()), Datum(dict_arr.indices()), - /*options=*/TakeOptions::Defaults(), ctx->exec_context()); - if (!result.ok()) { - ctx->SetStatus(result.status()); - return; - } - *out = *result; + KERNEL_ASSIGN_OR_RAISE(*out, ctx, + Take(Datum(dict_arr.dictionary()), Datum(dict_arr.indices()), + TakeOptions::Defaults(), ctx->exec_context())); } void OutputAllNull(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + if (out->is_scalar()) { + out->scalar()->is_valid = false; + return; + } ArrayData* output = out->mutable_array(); output->buffers = {nullptr}; output->null_count = batch.length; @@ -191,6 +197,8 @@ void CastFromExtension(KernelContext* ctx, const ExecBatch& batch, Datum* out) { } void CastFromNull(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + if (batch[0].is_scalar()) return; + ArrayData* output = out->mutable_array(); std::shared_ptr nulls; Status s = MakeArrayOfNull(output->type, batch.length).Value(&nulls); @@ -251,7 +259,7 @@ static bool CanCastFromDictionary(Type::type type_id) { void AddCommonCasts(Type::type out_type_id, OutputType out_ty, CastFunction* func) { // From null to this type - DCHECK_OK(func->AddKernel(Type::NA, {InputType::Array(null())}, out_ty, CastFromNull)); + DCHECK_OK(func->AddKernel(Type::NA, {null()}, out_ty, CastFromNull)); // From dictionary to this type if (CanCastFromDictionary(out_type_id)) { @@ -259,9 +267,9 @@ void AddCommonCasts(Type::type out_type_id, OutputType out_ty, CastFunction* fun // // XXX: Uses Take and does its own memory allocation for the moment. We can // fix this later. - DCHECK_OK(func->AddKernel( - Type::DICTIONARY, {InputType::Array(Type::DICTIONARY)}, out_ty, UnpackDictionary, - NullHandling::COMPUTED_NO_PREALLOCATE, MemAllocation::NO_PREALLOCATE)); + DCHECK_OK(func->AddKernel(Type::DICTIONARY, {InputType(Type::DICTIONARY)}, out_ty, + UnpackDictionary, NullHandling::COMPUTED_NO_PREALLOCATE, + MemAllocation::NO_PREALLOCATE)); } // From extension type to this type diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_internal.h b/cpp/src/arrow/compute/kernels/scalar_cast_internal.h index 333601cf216..15769ce5f8f 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_internal.h +++ b/cpp/src/arrow/compute/kernels/scalar_cast_internal.h @@ -21,6 +21,7 @@ #include "arrow/compute/cast.h" // IWYU pragma: export #include "arrow/compute/cast_internal.h" // IWYU pragma: export #include "arrow/compute/kernels/common.h" +#include "arrow/compute/kernels/util_internal.h" namespace arrow { @@ -62,6 +63,13 @@ void AddSimpleCast(InputType in_ty, OutputType out_ty, CastFunction* func) { CastFunctor::Exec)); } +template +void AddSimpleArrayOnlyCast(InputType in_ty, OutputType out_ty, CastFunction* func) { + DCHECK_OK(func->AddKernel( + InType::type_id, {in_ty}, out_ty, + TrivialScalarUnaryAsArraysExec(CastFunctor::Exec))); +} + void ZeroCopyCastExec(KernelContext* ctx, const ExecBatch& batch, Datum* out); void AddZeroCopyCast(Type::type in_type_id, InputType in_type, OutputType out_type, diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc b/cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc index f46cf7e7a75..6e550fb12c0 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc @@ -282,7 +282,9 @@ struct ParseString { OutValue Call(KernelContext* ctx, Arg0Value val) const { OutValue result = OutValue(0); if (ARROW_PREDICT_FALSE(!ParseValue(val.data(), val.size(), &result))) { - ctx->SetStatus(Status::Invalid("Failed to parse string: ", val)); + ctx->SetStatus(Status::Invalid("Failed to parse string: '", val, + "' as a scalar of type ", + TypeTraits::type_singleton()->ToString())); } return result; } @@ -630,8 +632,8 @@ std::vector> GetNumericCasts() { // Make a cast to null that does not do much. Not sure why we need to be able // to cast from dict -> null but there are unit tests for it auto cast_null = std::make_shared("cast_null", Type::NA); - DCHECK_OK(cast_null->AddKernel(Type::DICTIONARY, {InputType::Array(Type::DICTIONARY)}, - null(), OutputAllNull)); + DCHECK_OK(cast_null->AddKernel(Type::DICTIONARY, {InputType(Type::DICTIONARY)}, null(), + OutputAllNull)); functions.push_back(cast_null); functions.push_back(GetCastToInteger("cast_int8")); diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_string.cc b/cpp/src/arrow/compute/kernels/scalar_cast_string.cc index 9d3812e5967..ca9c33a8f1b 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_string.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_string.cc @@ -49,6 +49,7 @@ struct NumericToStringCastFunctor { using FormatterType = StringFormatter; static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + DCHECK(out->is_array()); const ArrayData& input = *batch[0].array(); ArrayData* output = out->mutable_array(); ctx->SetStatus(Convert(ctx, input, output)); @@ -160,6 +161,7 @@ struct BinaryToBinaryCastFunctor { using output_offset_type = typename O::offset_type; static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + DCHECK(out->is_array()); const CastOptions& options = checked_cast(*ctx->state()).options; const ArrayData& input = *batch[0].array(); @@ -194,7 +196,8 @@ void AddNumberToStringCasts(CastFunction* func) { auto out_ty = TypeTraits::type_singleton(); DCHECK_OK(func->AddKernel(Type::BOOL, {boolean()}, out_ty, - NumericToStringCastFunctor::Exec, + TrivialScalarUnaryAsArraysExec( + NumericToStringCastFunctor::Exec), NullHandling::COMPUTED_NO_PREALLOCATE)); for (const std::shared_ptr& in_ty : NumericTypes()) { @@ -210,9 +213,10 @@ void AddBinaryToBinaryCast(CastFunction* func) { auto in_ty = TypeTraits::type_singleton(); auto out_ty = TypeTraits::type_singleton(); - DCHECK_OK(func->AddKernel(OutType::type_id, {in_ty}, out_ty, - BinaryToBinaryCastFunctor::Exec, - NullHandling::COMPUTED_NO_PREALLOCATE)); + DCHECK_OK(func->AddKernel( + OutType::type_id, {in_ty}, out_ty, + TrivialScalarUnaryAsArraysExec(BinaryToBinaryCastFunctor::Exec), + NullHandling::COMPUTED_NO_PREALLOCATE)); } } // namespace diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_temporal.cc b/cpp/src/arrow/compute/kernels/scalar_cast_temporal.cc index 10b38d75095..99c1401d1b8 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_temporal.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_temporal.cc @@ -126,7 +126,6 @@ struct CastFunctor< enable_if_t<(is_timestamp_type::value && is_timestamp_type::value) || (is_duration_type::value && is_duration_type::value)>> { static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - // TODO: Make this work on scalar inputs DCHECK_EQ(batch[0].kind(), Datum::ARRAY); const ArrayData& input = *batch[0].array(); @@ -147,7 +146,6 @@ struct CastFunctor< template <> struct CastFunctor { static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - // TODO: Make this work on scalar inputs DCHECK_EQ(batch[0].kind(), Datum::ARRAY); const ArrayData& input = *batch[0].array(); @@ -170,7 +168,6 @@ struct CastFunctor { template <> struct CastFunctor { static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - // TODO: Make this work on scalar inputs DCHECK_EQ(batch[0].kind(), Datum::ARRAY); const CastOptions& options = checked_cast(*ctx->state()).options; @@ -224,7 +221,6 @@ struct CastFunctor::value && is_time_type:: using out_t = typename O::c_type; static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - // TODO: Make this work on scalar inputs DCHECK_EQ(batch[0].kind(), Datum::ARRAY); const ArrayData& input = *batch[0].array(); @@ -245,7 +241,6 @@ struct CastFunctor::value && is_time_type:: template <> struct CastFunctor { static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - // TODO: Make this work on scalar inputs DCHECK_EQ(batch[0].kind(), Datum::ARRAY); ShiftTime(ctx, util::MULTIPLY, kMillisecondsInDay, @@ -256,7 +251,6 @@ struct CastFunctor { template <> struct CastFunctor { static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - // TODO: Make this work on scalar inputs DCHECK_EQ(batch[0].kind(), Datum::ARRAY); ShiftTime(ctx, util::DIVIDE, kMillisecondsInDay, *batch[0].array(), @@ -264,6 +258,40 @@ struct CastFunctor { } }; +// ---------------------------------------------------------------------- +// date32, date64 to timestamp + +template <> +struct CastFunctor { + static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + DCHECK_EQ(batch[0].kind(), Datum::ARRAY); + + const auto& out_type = checked_cast(*out->type()); + // get conversion SECOND -> unit + auto conversion = util::GetTimestampConversion(TimeUnit::SECOND, out_type.unit()); + DCHECK_EQ(conversion.first, util::MULTIPLY); + + // multiply to achieve days -> unit + conversion.second *= kMillisecondsInDay / 1000; + ShiftTime(ctx, util::MULTIPLY, conversion.second, *batch[0].array(), + out->mutable_array()); + } +}; + +template <> +struct CastFunctor { + static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + DCHECK_EQ(batch[0].kind(), Datum::ARRAY); + + const auto& out_type = checked_cast(*out->type()); + + // date64 is ms since epoch + auto conversion = util::GetTimestampConversion(TimeUnit::MILLI, out_type.unit()); + ShiftTime(ctx, conversion.first, conversion.second, + *batch[0].array(), out->mutable_array()); + } +}; + // ---------------------------------------------------------------------- // String to Timestamp @@ -272,7 +300,8 @@ struct ParseTimestamp { OutValue Call(KernelContext* ctx, Arg0Value val) const { OutValue result = 0; if (ARROW_PREDICT_FALSE(!ParseValue(type, val.data(), val.size(), &result))) { - ctx->SetStatus(Status::Invalid("Failed to parse string: ", val)); + ctx->SetStatus(Status::Invalid("Failed to parse string: '", val, + "' as a scalar of type ", type.ToString())); } return result; } @@ -307,11 +336,11 @@ std::shared_ptr GetDate32Cast() { AddZeroCopyCast(Type::INT32, int32(), date32(), func.get()); // date64 -> date32 - AddSimpleCast(date64(), date32(), func.get()); + AddSimpleArrayOnlyCast(date64(), date32(), func.get()); // timestamp -> date32 - AddSimpleCast(InputType(Type::TIMESTAMP), date32(), - func.get()); + AddSimpleArrayOnlyCast(InputType(Type::TIMESTAMP), date32(), + func.get()); return func; } @@ -324,11 +353,11 @@ std::shared_ptr GetDate64Cast() { AddZeroCopyCast(Type::INT64, int64(), date64(), func.get()); // date32 -> date64 - AddSimpleCast(date32(), date64(), func.get()); + AddSimpleArrayOnlyCast(date32(), date64(), func.get()); // timestamp -> date64 - AddSimpleCast(InputType(Type::TIMESTAMP), date64(), - func.get()); + AddSimpleArrayOnlyCast(InputType(Type::TIMESTAMP), date64(), + func.get()); return func; } @@ -358,8 +387,8 @@ std::shared_ptr GetTime32Cast() { AddZeroCopyCast(Type::INT32, /*in_type=*/int32(), kOutputTargetType, func.get()); // time64 -> time32 - AddSimpleCast(InputType(Type::TIME64), kOutputTargetType, - func.get()); + AddSimpleArrayOnlyCast(InputType(Type::TIME64), + kOutputTargetType, func.get()); // time32 -> time32 AddCrossUnitCast(func.get()); @@ -375,8 +404,8 @@ std::shared_ptr GetTime64Cast() { AddZeroCopyCast(Type::INT64, /*in_type=*/int64(), kOutputTargetType, func.get()); // time32 -> time64 - AddSimpleCast(InputType(Type::TIME32), kOutputTargetType, - func.get()); + AddSimpleArrayOnlyCast(InputType(Type::TIME32), + kOutputTargetType, func.get()); // Between durations AddCrossUnitCast(func.get()); @@ -392,17 +421,18 @@ std::shared_ptr GetTimestampCast() { AddZeroCopyCast(Type::INT64, /*in_type=*/int64(), kOutputTargetType, func.get()); // From date types - // TODO: ARROW-8876, these casts are not implemented - // AddSimpleCast(InputType(Type::DATE32), - // kOutputTargetType, func.get()); - // AddSimpleCast(InputType(Type::DATE64), - // kOutputTargetType, func.get()); + // TODO: ARROW-8876, these casts are not directly tested + AddSimpleArrayOnlyCast(InputType(Type::DATE32), + kOutputTargetType, func.get()); + AddSimpleArrayOnlyCast(InputType(Type::DATE64), + kOutputTargetType, func.get()); // string -> timestamp - AddSimpleCast(utf8(), kOutputTargetType, func.get()); + AddSimpleArrayOnlyCast(utf8(), kOutputTargetType, + func.get()); // large_string -> timestamp - AddSimpleCast(large_utf8(), kOutputTargetType, - func.get()); + AddSimpleArrayOnlyCast(large_utf8(), kOutputTargetType, + func.get()); // From one timestamp to another AddCrossUnitCast(func.get()); diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_test.cc b/cpp/src/arrow/compute/kernels/scalar_cast_test.cc index ec612e497f4..fda8073beb9 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_test.cc @@ -1801,6 +1801,10 @@ TYPED_TEST(TestDictionaryCast, Basic) { // TODO: Should casting dictionary scalars work? this->CheckPass(*dict_arr, *expected, expected->type(), CastOptions::Safe(), /*check_scalar=*/false); + + auto opts = CastOptions::Safe(); + opts.to_type = expected->type(); + CheckScalarUnary("cast", dict_arr, expected, &opts); } } @@ -1892,5 +1896,48 @@ TEST_F(TestCast, ExtensionTypeToIntDowncast) { ASSERT_OK(UnregisterExtensionType("smallint")); } +class TestStruct : public TestBase { + public: + Result Struct(std::vector args) { + StructOptions opts{field_names}; + return CallFunction("struct", args, &opts); + } + + std::vector field_names; +}; + +TEST_F(TestStruct, Scalar) { + std::shared_ptr expected(new StructScalar{{}, struct_({})}); + ASSERT_OK_AND_EQ(Datum(expected), Struct({})); + + auto i32 = MakeScalar(1); + auto f64 = MakeScalar(2.5); + auto str = MakeScalar("yo"); + + expected.reset(new StructScalar{ + {i32, f64, str}, + struct_({field("i", i32->type), field("f", f64->type), field("s", str->type)})}); + field_names = {"i", "f", "s"}; + ASSERT_OK_AND_EQ(Datum(expected), Struct({i32, f64, str})); + + // Three field names but one input value + ASSERT_RAISES(Invalid, Struct({str})); +} + +TEST_F(TestStruct, Array) { + field_names = {"i", "s"}; + auto i32 = ArrayFromJSON(int32(), "[42, 13, 7]"); + auto str = ArrayFromJSON(utf8(), R"(["aa", "aa", "aa"])"); + ASSERT_OK_AND_ASSIGN(Datum expected, StructArray::Make({i32, str}, field_names)); + + ASSERT_OK_AND_EQ(expected, Struct({i32, str})); + + // Scalars are broadcast to the length of the arrays + ASSERT_OK_AND_EQ(expected, Struct({i32, MakeScalar("aa")})); + + // Array length mismatch + ASSERT_RAISES(Invalid, Struct({i32->Slice(1), str})); +} + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/scalar_set_lookup.cc b/cpp/src/arrow/compute/kernels/scalar_set_lookup.cc index d43083c7538..722f3173a34 100644 --- a/cpp/src/arrow/compute/kernels/scalar_set_lookup.cc +++ b/cpp/src/arrow/compute/kernels/scalar_set_lookup.cc @@ -19,10 +19,10 @@ #include "arrow/array/builder_primitive.h" #include "arrow/compute/api_scalar.h" #include "arrow/compute/kernels/common.h" +#include "arrow/compute/kernels/util_internal.h" #include "arrow/util/bit_util.h" #include "arrow/util/bitmap_writer.h" #include "arrow/util/hashing.h" -#include "arrow/util/optional.h" #include "arrow/visitor_inline.h" namespace arrow { @@ -255,25 +255,8 @@ struct IndexInVisitor { } }; -void ExecArrayOrScalar(KernelContext* ctx, const Datum& in, Datum* out, - std::function array_impl) { - if (in.is_array()) { - KERNEL_RETURN_IF_ERROR(ctx, array_impl(*in.array())); - return; - } - - std::shared_ptr in_array; - std::shared_ptr out_scalar; - KERNEL_RETURN_IF_ERROR(ctx, MakeArrayFromScalar(*in.scalar(), 1).Value(&in_array)); - KERNEL_RETURN_IF_ERROR(ctx, array_impl(*in_array->data())); - KERNEL_RETURN_IF_ERROR(ctx, out->make_array()->GetScalar(0).Value(&out_scalar)); - *out = std::move(out_scalar); -} - void ExecIndexIn(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - ExecArrayOrScalar(ctx, batch[0], out, [&](const ArrayData& in) { - return IndexInVisitor(ctx, in, out).Execute(); - }); + KERNEL_RETURN_IF_ERROR(ctx, IndexInVisitor(ctx, *batch[0].array(), out).Execute()); } // ---------------------------------------------------------------------- @@ -360,9 +343,7 @@ struct IsInVisitor { }; void ExecIsIn(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - ExecArrayOrScalar(ctx, batch[0], out, [&](const ArrayData& in) { - return IsInVisitor(ctx, in, out).Execute(); - }); + KERNEL_RETURN_IF_ERROR(ctx, IsInVisitor(ctx, *batch[0].array(), out).Execute()); } // Unary set lookup kernels available for the following input types @@ -451,7 +432,7 @@ void RegisterScalarSetLookup(FunctionRegistry* registry) { { ScalarKernel isin_base; isin_base.init = InitSetLookup; - isin_base.exec = ExecIsIn; + isin_base.exec = TrivialScalarUnaryAsArraysExec(ExecIsIn); auto is_in = std::make_shared("is_in", Arity::Unary(), &is_in_doc); AddBasicSetLookupKernels(isin_base, /*output_type=*/boolean(), is_in.get()); @@ -468,7 +449,7 @@ void RegisterScalarSetLookup(FunctionRegistry* registry) { { ScalarKernel index_in_base; index_in_base.init = InitSetLookup; - index_in_base.exec = ExecIndexIn; + index_in_base.exec = TrivialScalarUnaryAsArraysExec(ExecIndexIn); index_in_base.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE; index_in_base.mem_allocation = MemAllocation::NO_PREALLOCATE; auto index_in = diff --git a/cpp/src/arrow/compute/kernels/scalar_string.cc b/cpp/src/arrow/compute/kernels/scalar_string.cc index eff60b80481..2f4a0d45ed2 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string.cc @@ -1207,7 +1207,9 @@ struct ParseStrptime { int64_t Call(KernelContext* ctx, util::string_view val) const { int64_t result = 0; if (!(*parser)(val.data(), val.size(), unit, &result)) { - ctx->SetStatus(Status::Invalid("Failed to parse string ", val)); + ctx->SetStatus(Status::Invalid("Failed to parse string: '", val, + "' as a scalar of type ", + TimestampType(unit).ToString())); } return result; } diff --git a/cpp/src/arrow/compute/kernels/util_internal.cc b/cpp/src/arrow/compute/kernels/util_internal.cc index 32c6317a104..3d21f5b1494 100644 --- a/cpp/src/arrow/compute/kernels/util_internal.cc +++ b/cpp/src/arrow/compute/kernels/util_internal.cc @@ -57,6 +57,28 @@ PrimitiveArg GetPrimitiveArg(const ArrayData& arr) { return arg; } +ArrayKernelExec TrivialScalarUnaryAsArraysExec(ArrayKernelExec exec) { + return [exec](KernelContext* ctx, const ExecBatch& batch, Datum* out) { + if (out->is_array()) { + return exec(ctx, batch, out); + } + + if (!batch[0].scalar()->is_valid) { + out->scalar()->is_valid = false; + return; + } + + KERNEL_ASSIGN_OR_RAISE(Datum array_in, ctx, + MakeArrayFromScalar(*batch[0].scalar(), 1)); + + KERNEL_ASSIGN_OR_RAISE(Datum array_out, ctx, MakeArrayFromScalar(*out->scalar(), 1)); + + exec(ctx, ExecBatch{{std::move(array_in)}, 1}, &array_out); + + KERNEL_ASSIGN_OR_RAISE(*out, ctx, array_out.make_array()->GetScalar(0)); + }; +} + } // namespace internal } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/util_internal.h b/cpp/src/arrow/compute/kernels/util_internal.h index 7ab59965752..4aad3804366 100644 --- a/cpp/src/arrow/compute/kernels/util_internal.h +++ b/cpp/src/arrow/compute/kernels/util_internal.h @@ -18,8 +18,12 @@ #pragma once #include +#include +#include "arrow/array/util.h" #include "arrow/buffer.h" +#include "arrow/compute/kernels/codegen_internal.h" +#include "arrow/compute/type_fwd.h" namespace arrow { namespace compute { @@ -50,6 +54,12 @@ int GetBitWidth(const DataType& type); // rather than duplicating compiled code to do all these in each kernel. PrimitiveArg GetPrimitiveArg(const ArrayData& arr); +// Augment a unary ArrayKernelExec which supports only array-like inputs with support for +// scalar inputs. Scalars will be transformed to 1-long arrays which are passed to the +// original exec. This could be far more efficient, but instead of optimizing this it'd be +// better to support scalar inputs "upstream" in original exec. +ArrayKernelExec TrivialScalarUnaryAsArraysExec(ArrayKernelExec exec); + } // namespace internal } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/type_fwd.h b/cpp/src/arrow/compute/type_fwd.h index 4b05ceccb5a..9888e610aa7 100644 --- a/cpp/src/arrow/compute/type_fwd.h +++ b/cpp/src/arrow/compute/type_fwd.h @@ -20,9 +20,13 @@ namespace arrow { struct Datum; +struct ValueDescr; namespace compute { +class Function; +struct FunctionOptions; + struct CastOptions; class ExecContext; @@ -33,5 +37,7 @@ struct ScalarKernel; struct ScalarAggregateKernel; struct VectorKernel; +struct KernelState; + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/dataset/CMakeLists.txt b/cpp/src/arrow/dataset/CMakeLists.txt index b693c48049c..aa9cde2dbe9 100644 --- a/cpp/src/arrow/dataset/CMakeLists.txt +++ b/cpp/src/arrow/dataset/CMakeLists.txt @@ -22,9 +22,9 @@ arrow_install_all_headers("arrow/dataset") set(ARROW_DATASET_SRCS dataset.cc discovery.cc + expression.cc file_base.cc file_ipc.cc - filter.cc partition.cc projector.cc scanner.cc) @@ -106,9 +106,9 @@ endfunction() add_arrow_dataset_test(dataset_test) add_arrow_dataset_test(discovery_test) +add_arrow_dataset_test(expression_test) add_arrow_dataset_test(file_ipc_test) add_arrow_dataset_test(file_test) -add_arrow_dataset_test(filter_test) add_arrow_dataset_test(partition_test) add_arrow_dataset_test(scanner_test) diff --git a/cpp/src/arrow/dataset/api.h b/cpp/src/arrow/dataset/api.h index c4c90088e8c..da9f5ed371e 100644 --- a/cpp/src/arrow/dataset/api.h +++ b/cpp/src/arrow/dataset/api.h @@ -21,9 +21,9 @@ #include "arrow/dataset/dataset.h" #include "arrow/dataset/discovery.h" +#include "arrow/dataset/expression.h" #include "arrow/dataset/file_base.h" #include "arrow/dataset/file_csv.h" #include "arrow/dataset/file_ipc.h" #include "arrow/dataset/file_parquet.h" -#include "arrow/dataset/filter.h" #include "arrow/dataset/scanner.h" diff --git a/cpp/src/arrow/dataset/dataset.cc b/cpp/src/arrow/dataset/dataset.cc index 71755aaf566..e2386dddec7 100644 --- a/cpp/src/arrow/dataset/dataset.cc +++ b/cpp/src/arrow/dataset/dataset.cc @@ -21,7 +21,6 @@ #include #include "arrow/dataset/dataset_internal.h" -#include "arrow/dataset/filter.h" #include "arrow/dataset/scanner.h" #include "arrow/table.h" #include "arrow/util/bit_util.h" @@ -32,12 +31,10 @@ namespace arrow { namespace dataset { -Fragment::Fragment(std::shared_ptr partition_expression, +Fragment::Fragment(Expression partition_expression, std::shared_ptr physical_schema) : partition_expression_(std::move(partition_expression)), - physical_schema_(std::move(physical_schema)) { - DCHECK_NE(partition_expression_, nullptr); -} + physical_schema_(std::move(physical_schema)) {} Result> Fragment::ReadPhysicalSchema() { { @@ -61,14 +58,14 @@ Result> InMemoryFragment::ReadPhysicalSchemaImpl() { InMemoryFragment::InMemoryFragment(std::shared_ptr schema, RecordBatchVector record_batches, - std::shared_ptr partition_expression) + Expression partition_expression) : Fragment(std::move(partition_expression), std::move(schema)), record_batches_(std::move(record_batches)) { DCHECK_NE(physical_schema_, nullptr); } InMemoryFragment::InMemoryFragment(RecordBatchVector record_batches, - std::shared_ptr partition_expression) + Expression partition_expression) : InMemoryFragment(record_batches.empty() ? schema({}) : record_batches[0]->schema(), std::move(record_batches), std::move(partition_expression)) {} @@ -95,11 +92,9 @@ Result InMemoryFragment::Scan(std::shared_ptr opt return MakeMapIterator(fn, std::move(batches_it)); } -Dataset::Dataset(std::shared_ptr schema, - std::shared_ptr partition_expression) - : schema_(std::move(schema)), partition_expression_(std::move(partition_expression)) { - DCHECK_NE(partition_expression_, nullptr); -} +Dataset::Dataset(std::shared_ptr schema, Expression partition_expression) + : schema_(std::move(schema)), + partition_expression_(std::move(partition_expression)) {} Result> Dataset::NewScan( std::shared_ptr context) { @@ -110,10 +105,16 @@ Result> Dataset::NewScan() { return NewScan(std::make_shared()); } -FragmentIterator Dataset::GetFragments(std::shared_ptr predicate) { - predicate = predicate->Assume(*partition_expression_); - return predicate->IsSatisfiable() ? GetFragmentsImpl(std::move(predicate)) - : MakeEmptyIterator>(); +Result Dataset::GetFragments() { + ARROW_ASSIGN_OR_RAISE(auto predicate, literal(true).Bind(*schema_)); + return GetFragments(std::move(predicate)); +} + +Result Dataset::GetFragments(Expression predicate) { + ARROW_ASSIGN_OR_RAISE( + predicate, SimplifyWithGuarantee(std::move(predicate), partition_expression_)); + return predicate.IsSatisfiable() ? GetFragmentsImpl(std::move(predicate)) + : MakeEmptyIterator>(); } struct VectorRecordBatchGenerator : InMemoryDataset::RecordBatchGenerator { @@ -153,7 +154,7 @@ Result> InMemoryDataset::ReplaceSchema( return std::make_shared(std::move(schema), get_batches_); } -FragmentIterator InMemoryDataset::GetFragmentsImpl(std::shared_ptr) { +Result InMemoryDataset::GetFragmentsImpl(Expression) { auto schema = this->schema(); auto create_fragment = @@ -194,7 +195,7 @@ Result> UnionDataset::ReplaceSchema( new UnionDataset(std::move(schema), std::move(children))); } -FragmentIterator UnionDataset::GetFragmentsImpl(std::shared_ptr predicate) { +Result UnionDataset::GetFragmentsImpl(Expression predicate) { return GetFragmentsFromDatasets(children_, predicate); } diff --git a/cpp/src/arrow/dataset/dataset.h b/cpp/src/arrow/dataset/dataset.h index be138dd277d..c92381d78c5 100644 --- a/cpp/src/arrow/dataset/dataset.h +++ b/cpp/src/arrow/dataset/dataset.h @@ -25,6 +25,7 @@ #include #include +#include "arrow/dataset/expression.h" #include "arrow/dataset/type_fwd.h" #include "arrow/dataset/visibility.h" #include "arrow/util/macros.h" @@ -71,21 +72,19 @@ class ARROW_DS_EXPORT Fragment { /// \brief An expression which evaluates to true for all data viewed by this /// Fragment. - const std::shared_ptr& partition_expression() const { - return partition_expression_; - } + const Expression& partition_expression() const { return partition_expression_; } virtual ~Fragment() = default; protected: Fragment() = default; - explicit Fragment(std::shared_ptr partition_expression, + explicit Fragment(Expression partition_expression, std::shared_ptr physical_schema); virtual Result> ReadPhysicalSchemaImpl() = 0; util::Mutex physical_schema_mutex_; - std::shared_ptr partition_expression_ = scalar(true); + Expression partition_expression_ = literal(true); std::shared_ptr physical_schema_; }; @@ -94,9 +93,8 @@ class ARROW_DS_EXPORT Fragment { class ARROW_DS_EXPORT InMemoryFragment : public Fragment { public: InMemoryFragment(std::shared_ptr schema, RecordBatchVector record_batches, - std::shared_ptr = scalar(true)); - explicit InMemoryFragment(RecordBatchVector record_batches, - std::shared_ptr = scalar(true)); + Expression = literal(true)); + explicit InMemoryFragment(RecordBatchVector record_batches, Expression = literal(true)); Result Scan(std::shared_ptr options, std::shared_ptr context) override; @@ -123,15 +121,14 @@ class ARROW_DS_EXPORT Dataset : public std::enable_shared_from_this { Result> NewScan(); /// \brief GetFragments returns an iterator of Fragments given a predicate. - FragmentIterator GetFragments(std::shared_ptr predicate = scalar(true)); + Result GetFragments(Expression predicate); + Result GetFragments(); const std::shared_ptr& schema() const { return schema_; } /// \brief An expression which evaluates to true for all data viewed by this Dataset. /// May be null, which indicates no information is available. - const std::shared_ptr& partition_expression() const { - return partition_expression_; - } + const Expression& partition_expression() const { return partition_expression_; } /// \brief The name identifying the kind of Dataset virtual std::string type_name() const = 0; @@ -148,13 +145,12 @@ class ARROW_DS_EXPORT Dataset : public std::enable_shared_from_this { protected: explicit Dataset(std::shared_ptr schema) : schema_(std::move(schema)) {} - Dataset(std::shared_ptr schema, - std::shared_ptr partition_expression); + Dataset(std::shared_ptr schema, Expression partition_expression); - virtual FragmentIterator GetFragmentsImpl(std::shared_ptr predicate) = 0; + virtual Result GetFragmentsImpl(Expression predicate) = 0; std::shared_ptr schema_; - std::shared_ptr partition_expression_ = scalar(true); + Expression partition_expression_ = literal(true); }; /// \brief A Source which yields fragments wrapping a stream of record batches. @@ -183,7 +179,7 @@ class ARROW_DS_EXPORT InMemoryDataset : public Dataset { std::shared_ptr schema) const override; protected: - FragmentIterator GetFragmentsImpl(std::shared_ptr predicate) override; + Result GetFragmentsImpl(Expression predicate) override; std::shared_ptr get_batches_; }; @@ -207,7 +203,7 @@ class ARROW_DS_EXPORT UnionDataset : public Dataset { std::shared_ptr schema) const override; protected: - FragmentIterator GetFragmentsImpl(std::shared_ptr predicate) override; + Result GetFragmentsImpl(Expression predicate) override; explicit UnionDataset(std::shared_ptr schema, DatasetVector children) : Dataset(std::move(schema)), children_(std::move(children)) {} diff --git a/cpp/src/arrow/dataset/dataset_internal.h b/cpp/src/arrow/dataset/dataset_internal.h index 489339e7907..cb6d406fb70 100644 --- a/cpp/src/arrow/dataset/dataset_internal.h +++ b/cpp/src/arrow/dataset/dataset_internal.h @@ -35,18 +35,18 @@ namespace dataset { /// \brief GetFragmentsFromDatasets transforms a vector into a /// flattened FragmentIterator. -inline FragmentIterator GetFragmentsFromDatasets(const DatasetVector& datasets, - std::shared_ptr predicate) { +inline Result GetFragmentsFromDatasets(const DatasetVector& datasets, + Expression predicate) { // Iterator auto datasets_it = MakeVectorIterator(datasets); // Dataset -> Iterator - auto fn = [predicate](std::shared_ptr dataset) -> FragmentIterator { + auto fn = [predicate](std::shared_ptr dataset) -> Result { return dataset->GetFragments(predicate); }; // Iterator> - auto fragments_it = MakeMapIterator(fn, std::move(datasets_it)); + auto fragments_it = MakeMaybeMapIterator(fn, std::move(datasets_it)); // Iterator return MakeFlattenIterator(std::move(fragments_it)); diff --git a/cpp/src/arrow/dataset/dataset_test.cc b/cpp/src/arrow/dataset/dataset_test.cc index 0b138caafe7..82a5a63c2c2 100644 --- a/cpp/src/arrow/dataset/dataset_test.cc +++ b/cpp/src/arrow/dataset/dataset_test.cc @@ -505,7 +505,8 @@ TEST_F(TestEndToEnd, EndToEndSingleDataset) { // The following filter tests both predicate pushdown and post filtering // without partition information because `year` is a partition and `sales` is // not. - auto filter = ("year"_ == 2019 && "sales"_ > 100.0); + auto filter = and_(equal(field_ref("year"), literal(2019)), + greater(field_ref("sales"), literal(100.0))); ASSERT_OK(scanner_builder->Filter(filter)); ASSERT_OK_AND_ASSIGN(auto scanner, scanner_builder->Finish()); @@ -701,8 +702,10 @@ TEST_F(TestSchemaUnification, SelectPhysicalColumnsFilterPartitionColumn) { // when some of the columns may not be materialized ASSERT_OK_AND_ASSIGN(auto scan_builder, dataset_->NewScan()); ASSERT_OK(scan_builder->Project({"phy_2", "phy_3", "phy_4"})); - ASSERT_OK(scan_builder->Filter(("part_df"_ == 1 && "phy_2"_ == 211) || - ("part_ds"_ == 2 && "phy_4"_ != 422))); + ASSERT_OK(scan_builder->Filter(or_(and_(equal(field_ref("part_df"), literal(1)), + equal(field_ref("phy_2"), literal(211))), + and_(equal(field_ref("part_ds"), literal(2)), + not_equal(field_ref("phy_4"), literal(422)))))); using TupleType = std::tuple; std::vector rows = { @@ -733,7 +736,7 @@ TEST_F(TestSchemaUnification, SelectPartitionColumns) { TEST_F(TestSchemaUnification, SelectPartitionColumnsFilterPhysicalColumn) { // Selects re-ordered virtual columns with a filter on a physical columns ASSERT_OK_AND_ASSIGN(auto scan_builder, dataset_->NewScan()); - ASSERT_OK(scan_builder->Filter("phy_1"_ == 111)); + ASSERT_OK(scan_builder->Filter(equal(field_ref("phy_1"), literal(111)))); ASSERT_OK(scan_builder->Project({"part_df", "part_ds"})); using TupleType = std::tuple; @@ -747,7 +750,7 @@ TEST_F(TestSchemaUnification, SelectMixedColumnsAndFilter) { // Selects mix of physical/virtual with a different order and uses a filter on // a physical column not selected. ASSERT_OK_AND_ASSIGN(auto scan_builder, dataset_->NewScan()); - ASSERT_OK(scan_builder->Filter("phy_2"_ >= 212)); + ASSERT_OK(scan_builder->Filter(greater_equal(field_ref("phy_2"), literal(212)))); ASSERT_OK(scan_builder->Project({"part_df", "phy_3", "part_ds", "phy_1"})); using TupleType = std::tuple; @@ -785,8 +788,7 @@ TEST(TestDictPartitionColumn, SelectPartitionColumnFilterPhysicalColumn) { // Selects re-ordered virtual column with a filter on a physical column ASSERT_OK_AND_ASSIGN(auto scan_builder, dataset->NewScan()); - ASSERT_OK(scan_builder->Filter("phy_1"_ == 111)); - + ASSERT_OK(scan_builder->Filter(equal(field_ref("phy_1"), literal(111)))); ASSERT_OK(scan_builder->Project({"part"})); ASSERT_OK_AND_ASSIGN(auto scanner, scan_builder->Finish()); diff --git a/cpp/src/arrow/dataset/discovery.cc b/cpp/src/arrow/dataset/discovery.cc index b7af6dcbf0f..079cc6d4065 100644 --- a/cpp/src/arrow/dataset/discovery.cc +++ b/cpp/src/arrow/dataset/discovery.cc @@ -35,7 +35,7 @@ namespace arrow { namespace dataset { -DatasetFactory::DatasetFactory() : root_partition_(scalar(true)) {} +DatasetFactory::DatasetFactory() : root_partition_(literal(true)) {} Result> DatasetFactory::Inspect(InspectOptions options) { ARROW_ASSIGN_OR_RAISE(auto schemas, InspectSchemas(std::move(options))); diff --git a/cpp/src/arrow/dataset/discovery.h b/cpp/src/arrow/dataset/discovery.h index 5176f4f53a7..b7786cd305f 100644 --- a/cpp/src/arrow/dataset/discovery.h +++ b/cpp/src/arrow/dataset/discovery.h @@ -90,9 +90,9 @@ class ARROW_DS_EXPORT DatasetFactory { virtual Result> Finish(FinishOptions options) = 0; /// \brief Optional root partition for the resulting Dataset. - const std::shared_ptr& root_partition() const { return root_partition_; } - Status SetRootPartition(std::shared_ptr partition) { - root_partition_ = partition; + const Expression& root_partition() const { return root_partition_; } + Status SetRootPartition(Expression partition) { + root_partition_ = std::move(partition); return Status::OK(); } @@ -101,7 +101,7 @@ class ARROW_DS_EXPORT DatasetFactory { protected: DatasetFactory(); - std::shared_ptr root_partition_; + Expression root_partition_; }; /// \brief DatasetFactory provides a way to inspect/discover a Dataset's diff --git a/cpp/src/arrow/dataset/discovery_test.cc b/cpp/src/arrow/dataset/discovery_test.cc index bc330edb53b..a951e827fa4 100644 --- a/cpp/src/arrow/dataset/discovery_test.cc +++ b/cpp/src/arrow/dataset/discovery_test.cc @@ -23,7 +23,6 @@ #include #include -#include "arrow/dataset/filter.h" #include "arrow/dataset/partition.h" #include "arrow/dataset/test_util.h" #include "arrow/filesystem/test_util.h" @@ -139,7 +138,8 @@ class FileSystemDatasetFactoryTest : public DatasetFactoryTest { } options_ = ScanOptions::Make(schema); ASSERT_OK_AND_ASSIGN(dataset_, factory_->Finish(schema)); - AssertFragmentsAreFromPath(dataset_->GetFragments(), paths); + ASSERT_OK_AND_ASSIGN(auto fragment_it, dataset_->GetFragments()); + AssertFragmentsAreFromPath(std::move(fragment_it), paths); } protected: @@ -368,10 +368,13 @@ TEST_F(FileSystemDatasetFactoryTest, FilenameNotPartOfPartitions) { // column. In such case, the filename should not be used. MakeFactory({fs::File("one/file.parquet")}); + auto expected = equal(field_ref("first"), literal("one")); + ASSERT_OK_AND_ASSIGN(auto dataset, factory_->Finish()); - for (const auto& maybe_fragment : dataset->GetFragments()) { + ASSERT_OK_AND_ASSIGN(auto fragment_it, dataset->GetFragments()); + for (const auto& maybe_fragment : fragment_it) { ASSERT_OK_AND_ASSIGN(auto fragment, maybe_fragment); - ASSERT_TRUE(fragment->partition_expression()->Equals(("first"_ == "one"))); + EXPECT_EQ(fragment->partition_expression(), expected); } } diff --git a/cpp/src/arrow/dataset/expression.cc b/cpp/src/arrow/dataset/expression.cc new file mode 100644 index 00000000000..7f08788de29 --- /dev/null +++ b/cpp/src/arrow/dataset/expression.cc @@ -0,0 +1,1188 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/dataset/expression.h" + +#include +#include + +#include "arrow/chunked_array.h" +#include "arrow/compute/exec_internal.h" +#include "arrow/dataset/expression_internal.h" +#include "arrow/io/memory.h" +#include "arrow/ipc/reader.h" +#include "arrow/ipc/writer.h" +#include "arrow/util/atomic_shared_ptr.h" +#include "arrow/util/key_value_metadata.h" +#include "arrow/util/logging.h" +#include "arrow/util/optional.h" +#include "arrow/util/string.h" +#include "arrow/util/value_parsing.h" + +namespace arrow { + +using internal::checked_cast; +using internal::checked_pointer_cast; + +namespace dataset { + +Expression::Expression(Call call) : impl_(std::make_shared(std::move(call))) {} + +Expression::Expression(Datum literal) + : impl_(std::make_shared(std::move(literal))) {} + +Expression::Expression(Parameter parameter) + : impl_(std::make_shared(std::move(parameter))) {} + +Expression literal(Datum lit) { return Expression(std::move(lit)); } + +Expression field_ref(FieldRef ref) { + return Expression(Expression::Parameter{std::move(ref), {}}); +} + +Expression call(std::string function, std::vector arguments, + std::shared_ptr options) { + Expression::Call call; + call.function_name = std::move(function); + call.arguments = std::move(arguments); + call.options = std::move(options); + return Expression(std::move(call)); +} + +const Datum* Expression::literal() const { return util::get_if(impl_.get()); } + +const FieldRef* Expression::field_ref() const { + if (auto parameter = util::get_if(impl_.get())) { + return ¶meter->ref; + } + return nullptr; +} + +const Expression::Call* Expression::call() const { + return util::get_if(impl_.get()); +} + +ValueDescr Expression::descr() const { + if (impl_ == nullptr) return {}; + + if (auto lit = literal()) { + return lit->descr(); + } + + if (auto parameter = util::get_if(impl_.get())) { + return parameter->descr; + } + + return CallNotNull(*this)->descr; +} + +std::string Expression::ToString() const { + if (auto lit = literal()) { + if (lit->is_scalar()) { + switch (lit->type()->id()) { + case Type::STRING: + case Type::LARGE_STRING: + return '"' + + Escape(util::string_view(*lit->scalar_as().value)) + + '"'; + + case Type::BINARY: + case Type::FIXED_SIZE_BINARY: + case Type::LARGE_BINARY: + return '"' + lit->scalar_as().value->ToHexString() + '"'; + + default: + break; + } + return lit->scalar()->ToString(); + } + return lit->ToString(); + } + + if (auto ref = field_ref()) { + if (auto name = ref->name()) { + return *name; + } + if (auto path = ref->field_path()) { + return path->ToString(); + } + return ref->ToString(); + } + + auto call = CallNotNull(*this); + auto binary = [&](std::string op) { + return "(" + call->arguments[0].ToString() + " " + op + " " + + call->arguments[1].ToString() + ")"; + }; + + if (auto cmp = Comparison::Get(call->function_name)) { + return binary(Comparison::GetOp(*cmp)); + } + + constexpr util::string_view kleene = "_kleene"; + if (util::string_view{call->function_name}.ends_with(kleene)) { + auto op = call->function_name.substr(0, call->function_name.size() - kleene.size()); + return binary(std::move(op)); + } + + if (auto options = GetStructOptions(*call)) { + std::string out = "{"; + auto argument = call->arguments.begin(); + for (const auto& field_name : options->field_names) { + out += field_name + "=" + argument++->ToString() + ", "; + } + out.resize(out.size() - 1); + out.back() = '}'; + return out; + } + + std::string out = call->function_name + "("; + for (const auto& arg : call->arguments) { + out += arg.ToString() + ", "; + } + + if (call->options == nullptr) { + out.resize(out.size() - 1); + out.back() = ')'; + return out; + } + + if (auto options = GetSetLookupOptions(*call)) { + DCHECK_EQ(options->value_set.kind(), Datum::ARRAY); + out += "value_set=" + options->value_set.make_array()->ToString(); + if (options->skip_nulls) { + out += ", skip_nulls"; + } + return out + ")"; + } + + if (auto options = GetCastOptions(*call)) { + if (options->to_type == nullptr) { + return out + "to_type=)"; + } + out += "to_type=" + options->to_type->ToString(); + if (options->allow_int_overflow) out += ", allow_int_overflow"; + if (options->allow_time_truncate) out += ", allow_time_truncate"; + if (options->allow_time_overflow) out += ", allow_time_overflow"; + if (options->allow_decimal_truncate) out += ", allow_decimal_truncate"; + if (options->allow_float_truncate) out += ", allow_float_truncate"; + if (options->allow_invalid_utf8) out += ", allow_invalid_utf8"; + return out + ")"; + } + + if (auto options = GetStrptimeOptions(*call)) { + return out + "format=" + options->format + + ", unit=" + internal::ToString(options->unit) + ")"; + } + + return out + "{NON-REPRESENTABLE OPTIONS})"; +} + +void PrintTo(const Expression& expr, std::ostream* os) { + *os << expr.ToString(); + if (expr.IsBound()) { + *os << "[bound]"; + } +} + +bool Expression::Equals(const Expression& other) const { + if (Identical(*this, other)) return true; + + if (impl_->index() != other.impl_->index()) { + return false; + } + + if (auto lit = literal()) { + return lit->Equals(*other.literal()); + } + + if (auto ref = field_ref()) { + return ref->Equals(*other.field_ref()); + } + + auto call = CallNotNull(*this); + auto other_call = CallNotNull(other); + + if (call->function_name != other_call->function_name || + call->kernel != other_call->kernel) { + return false; + } + + for (size_t i = 0; i < call->arguments.size(); ++i) { + if (!call->arguments[i].Equals(other_call->arguments[i])) { + return false; + } + } + + if (call->options == other_call->options) return true; + + if (auto options = GetSetLookupOptions(*call)) { + auto other_options = GetSetLookupOptions(*other_call); + return options->value_set == other_options->value_set && + options->skip_nulls == other_options->skip_nulls; + } + + if (auto options = GetCastOptions(*call)) { + auto other_options = GetCastOptions(*other_call); + for (auto safety_opt : { + &compute::CastOptions::allow_int_overflow, + &compute::CastOptions::allow_time_truncate, + &compute::CastOptions::allow_time_overflow, + &compute::CastOptions::allow_decimal_truncate, + &compute::CastOptions::allow_float_truncate, + &compute::CastOptions::allow_invalid_utf8, + }) { + if (options->*safety_opt != other_options->*safety_opt) return false; + } + return options->to_type->Equals(other_options->to_type); + } + + if (auto options = GetStructOptions(*call)) { + auto other_options = GetStructOptions(*other_call); + return options->field_names == other_options->field_names; + } + + if (auto options = GetStrptimeOptions(*call)) { + auto other_options = GetStrptimeOptions(*other_call); + return options->format == other_options->format && + options->unit == other_options->unit; + } + + ARROW_LOG(WARNING) << "comparing unknown FunctionOptions for function " + << call->function_name; + return false; +} + +bool Identical(const Expression& l, const Expression& r) { return l.impl_ == r.impl_; } + +size_t Expression::hash() const { + if (auto lit = literal()) { + if (lit->is_scalar()) { + return Scalar::Hash::hash(*lit->scalar()); + } + return 0; + } + + if (auto ref = field_ref()) { + return ref->hash(); + } + + auto call = CallNotNull(*this); + if (call->hash != nullptr) { + return call->hash->load(); + } + + size_t out = std::hash{}(call->function_name); + for (const auto& arg : call->arguments) { + out ^= arg.hash(); + } + + std::shared_ptr> expected = nullptr; + internal::atomic_compare_exchange_strong(&const_cast(call)->hash, &expected, + std::make_shared>(out)); + return out; +} + +bool Expression::IsBound() const { + if (descr().type == nullptr) return false; + + if (auto lit = literal()) return true; + + if (auto ref = field_ref()) return true; + + auto call = CallNotNull(*this); + + for (const Expression& arg : call->arguments) { + if (!arg.IsBound()) return false; + } + + return call->kernel != nullptr; +} + +bool Expression::IsScalarExpression() const { + if (auto lit = literal()) { + return lit->is_scalar(); + } + + // FIXME handle case where a list's item field is referenced + if (auto ref = field_ref()) return true; + + auto call = CallNotNull(*this); + + for (const Expression& arg : call->arguments) { + if (!arg.IsScalarExpression()) return false; + } + + if (call->function) { + return call->function->kind() == compute::Function::SCALAR; + } + + // this expression is not bound; make a best guess based on + // the default function registry + if (auto function = compute::GetFunctionRegistry() + ->GetFunction(call->function_name) + .ValueOr(nullptr)) { + return function->kind() == compute::Function::SCALAR; + } + + // unknown function or other error; conservatively return false + return false; +} + +bool Expression::IsNullLiteral() const { + if (auto lit = literal()) { + if (lit->null_count() == lit->length()) { + return true; + } + } + + return false; +} + +bool Expression::IsSatisfiable() const { + if (descr().type && descr().type->id() == Type::NA) { + return false; + } + + if (auto lit = literal()) { + if (lit->null_count() == lit->length()) { + return false; + } + + if (lit->is_scalar() && lit->type()->id() == Type::BOOL) { + return lit->scalar_as().value; + } + } + + if (auto ref = field_ref()) { + return true; + } + + return true; +} + +inline bool KernelStateIsImmutable(const std::string& function) { + // XXX maybe just add Kernel::state_is_immutable or so? + + // known functions with non-null but nevertheless immutable KernelState + static std::unordered_set names = { + "is_in", "index_in", "cast", "struct", "strptime", + }; + + return names.find(function) != names.end(); +} + +Result> InitKernelState( + const Expression::Call& call, compute::ExecContext* exec_context) { + if (!call.kernel->init) return nullptr; + + compute::KernelContext kernel_context(exec_context); + compute::KernelInitArgs kernel_init_args{call.kernel, GetDescriptors(call.arguments), + call.options.get()}; + + auto kernel_state = call.kernel->init(&kernel_context, kernel_init_args); + RETURN_NOT_OK(kernel_context.status()); + return std::move(kernel_state); +} + +Status MaybeInsertCast(std::shared_ptr to_type, Expression* expr) { + if (expr->descr().type->Equals(to_type)) { + return Status::OK(); + } + + if (auto lit = expr->literal()) { + ARROW_ASSIGN_OR_RAISE(Datum new_lit, compute::Cast(*lit, to_type)); + *expr = literal(std::move(new_lit)); + return Status::OK(); + } + + // FIXME the resulting cast Call must be bound but this is a hack + auto with_cast = call("cast", {literal(MakeNullScalar(expr->descr().type))}, + compute::CastOptions::Safe(to_type)); + + static ValueDescr ignored_descr; + ARROW_ASSIGN_OR_RAISE(with_cast, with_cast.Bind(ignored_descr)); + + auto call_with_cast = *CallNotNull(with_cast); + call_with_cast.arguments[0] = std::move(*expr); + call_with_cast.descr = ValueDescr{std::move(to_type), expr->descr().shape}; + + *expr = Expression(std::move(call_with_cast)); + return Status::OK(); +} + +Status InsertImplicitCasts(Expression::Call* call) { + DCHECK(std::all_of(call->arguments.begin(), call->arguments.end(), + [](const Expression& argument) { return argument.IsBound(); })); + + if (IsSameTypesBinary(call->function_name)) { + for (auto&& argument : call->arguments) { + if (auto value_type = GetDictionaryValueType(argument.descr().type)) { + RETURN_NOT_OK(MaybeInsertCast(std::move(value_type), &argument)); + } + } + + if (call->arguments[0].descr().shape == ValueDescr::SCALAR) { + // argument 0 is scalar so casting is cheap + return MaybeInsertCast(call->arguments[1].descr().type, &call->arguments[0]); + } + + // cast argument 1 unconditionally + return MaybeInsertCast(call->arguments[0].descr().type, &call->arguments[1]); + } + + if (auto options = GetSetLookupOptions(*call)) { + if (auto value_type = GetDictionaryValueType(call->arguments[0].descr().type)) { + // DICTIONARY input is not supported; decode it. + RETURN_NOT_OK(MaybeInsertCast(std::move(value_type), &call->arguments[0])); + } + + if (options->value_set.type()->id() == Type::DICTIONARY) { + // DICTIONARY value_set is not supported; decode it. + auto new_options = std::make_shared(*options); + RETURN_NOT_OK(EnsureNotDictionary(&new_options->value_set)); + options = new_options.get(); + call->options = std::move(new_options); + } + + if (!options->value_set.type()->Equals(call->arguments[0].descr().type)) { + // The value_set is assumed smaller than inputs, casting it should be cheaper. + auto new_options = std::make_shared(*options); + ARROW_ASSIGN_OR_RAISE(new_options->value_set, + compute::Cast(std::move(new_options->value_set), + call->arguments[0].descr().type)); + options = new_options.get(); + call->options = std::move(new_options); + } + + return Status::OK(); + } + + return Status::OK(); +} + +Result Expression::Bind(ValueDescr in, + compute::ExecContext* exec_context) const { + if (exec_context == nullptr) { + compute::ExecContext exec_context; + return Bind(std::move(in), &exec_context); + } + + if (literal()) return *this; + + if (auto ref = field_ref()) { + ARROW_ASSIGN_OR_RAISE(auto field, ref->GetOneOrNone(*in.type)); + auto descr = field ? ValueDescr{field->type(), in.shape} : ValueDescr::Scalar(null()); + return Expression{Parameter{*ref, std::move(descr)}}; + } + + auto bound_call = *CallNotNull(*this); + + ARROW_ASSIGN_OR_RAISE(bound_call.function, GetFunction(bound_call, exec_context)); + + for (auto&& argument : bound_call.arguments) { + ARROW_ASSIGN_OR_RAISE(argument, argument.Bind(in, exec_context)); + } + RETURN_NOT_OK(InsertImplicitCasts(&bound_call)); + + auto descrs = GetDescriptors(bound_call.arguments); + ARROW_ASSIGN_OR_RAISE(bound_call.kernel, bound_call.function->DispatchExact(descrs)); + + compute::KernelContext kernel_context(exec_context); + ARROW_ASSIGN_OR_RAISE(bound_call.kernel_state, + InitKernelState(bound_call, exec_context)); + kernel_context.SetState(bound_call.kernel_state.get()); + + ARROW_ASSIGN_OR_RAISE( + bound_call.descr, + bound_call.kernel->signature->out_type().Resolve(&kernel_context, descrs)); + + return Expression(std::move(bound_call)); +} + +Result Expression::Bind(const Schema& in_schema, + compute::ExecContext* exec_context) const { + return Bind(ValueDescr::Array(struct_(in_schema.fields())), exec_context); +} + +Result ExecuteScalarExpression(const Expression& expr, const Datum& input, + compute::ExecContext* exec_context) { + if (exec_context == nullptr) { + compute::ExecContext exec_context; + return ExecuteScalarExpression(expr, input, &exec_context); + } + + if (!expr.IsBound()) { + return Status::Invalid("Cannot Execute unbound expression."); + } + + if (!expr.IsScalarExpression()) { + return Status::Invalid( + "ExecuteScalarExpression cannot Execute non-scalar expression ", expr.ToString()); + } + + if (auto lit = expr.literal()) return *lit; + + if (auto ref = expr.field_ref()) { + ARROW_ASSIGN_OR_RAISE(Datum field, GetDatumField(*ref, input)); + + if (field.descr() != expr.descr()) { + // Refernced field was present but didn't have the expected type. + // Should we just error here? For now, pay dispatch cost and just cast. + ARROW_ASSIGN_OR_RAISE( + field, compute::Cast(field, expr.descr().type, compute::CastOptions::Safe(), + exec_context)); + } + + return field; + } + + auto call = CallNotNull(expr); + + std::vector arguments(call->arguments.size()); + for (size_t i = 0; i < arguments.size(); ++i) { + ARROW_ASSIGN_OR_RAISE( + arguments[i], ExecuteScalarExpression(call->arguments[i], input, exec_context)); + } + + auto executor = compute::detail::KernelExecutor::MakeScalar(); + + compute::KernelContext kernel_context(exec_context); + kernel_context.SetState(call->kernel_state.get()); + + auto kernel = call->kernel; + auto descrs = GetDescriptors(arguments); + auto options = call->options.get(); + RETURN_NOT_OK(executor->Init(&kernel_context, {kernel, descrs, options})); + + auto listener = std::make_shared(); + RETURN_NOT_OK(executor->Execute(arguments, listener.get())); + return executor->WrapResults(arguments, listener->values()); +} + +std::array, 2> +ArgumentsAndFlippedArguments(const Expression::Call& call) { + DCHECK_EQ(call.arguments.size(), 2); + return {std::pair{call.arguments[0], + call.arguments[1]}, + std::pair{call.arguments[1], + call.arguments[0]}}; +} + +template ::value_type> +util::optional FoldLeft(It begin, It end, const BinOp& bin_op) { + if (begin == end) return util::nullopt; + + Out folded = std::move(*begin++); + while (begin != end) { + folded = bin_op(std::move(folded), std::move(*begin++)); + } + return folded; +} + +util::optional GetNullHandling( + const Expression::Call& call) { + if (call.function && call.function->kind() == compute::Function::SCALAR) { + return static_cast(call.kernel)->null_handling; + } + return util::nullopt; +} + +bool DefinitelyNotNull(const Expression& expr) { + DCHECK(expr.IsBound()); + + if (expr.literal()) { + return !expr.IsNullLiteral(); + } + + if (expr.field_ref()) return false; + + auto call = CallNotNull(expr); + if (auto null_handling = GetNullHandling(*call)) { + if (null_handling == compute::NullHandling::OUTPUT_NOT_NULL) { + return true; + } + if (null_handling == compute::NullHandling::INTERSECTION) { + return std::all_of(call->arguments.begin(), call->arguments.end(), + DefinitelyNotNull); + } + } + + return false; +} + +std::vector FieldsInExpression(const Expression& expr) { + if (auto lit = expr.literal()) return {}; + + if (auto ref = expr.field_ref()) { + return {*ref}; + } + + std::vector fields; + for (const Expression& arg : CallNotNull(expr)->arguments) { + auto argument_fields = FieldsInExpression(arg); + std::move(argument_fields.begin(), argument_fields.end(), std::back_inserter(fields)); + } + return fields; +} + +Result FoldConstants(Expression expr) { + return Modify( + std::move(expr), [](Expression expr) { return expr; }, + [](Expression expr, ...) -> Result { + auto call = CallNotNull(expr); + if (std::all_of(call->arguments.begin(), call->arguments.end(), + [](const Expression& argument) { return argument.literal(); })) { + // all arguments are literal; we can evaluate this subexpression *now* + static const Datum ignored_input; + ARROW_ASSIGN_OR_RAISE(Datum constant, + ExecuteScalarExpression(expr, ignored_input)); + + return literal(std::move(constant)); + } + + // XXX the following should probably be in a registry of passes instead + // of inline + + if (GetNullHandling(*call) == compute::NullHandling::INTERSECTION) { + // kernels which always produce intersected validity can be resolved + // to null *now* if any of their inputs is a null literal + for (const auto& argument : call->arguments) { + if (argument.IsNullLiteral()) { + return argument; + } + } + } + + if (call->function_name == "and_kleene") { + for (auto args : ArgumentsAndFlippedArguments(*call)) { + // true and x == x + if (args.first == literal(true)) return args.second; + + // false and x == false + if (args.first == literal(false)) return args.first; + + // x and x == x + if (args.first == args.second) return args.first; + } + return expr; + } + + if (call->function_name == "or_kleene") { + for (auto args : ArgumentsAndFlippedArguments(*call)) { + // false or x == x + if (args.first == literal(false)) return args.second; + + // true or x == true + if (args.first == literal(true)) return args.first; + + // x or x == x + if (args.first == args.second) return args.first; + } + return expr; + } + + return expr; + }); +} + +inline std::vector GuaranteeConjunctionMembers( + const Expression& guaranteed_true_predicate) { + auto guarantee = guaranteed_true_predicate.call(); + if (!guarantee || guarantee->function_name != "and_kleene") { + return {guaranteed_true_predicate}; + } + return FlattenedAssociativeChain(guaranteed_true_predicate).fringe; +} + +// Conjunction members which are represented in known_values are erased from +// conjunction_members +Status ExtractKnownFieldValuesImpl( + std::vector* conjunction_members, + std::unordered_map* known_values) { + auto unconsumed_end = + std::partition(conjunction_members->begin(), conjunction_members->end(), + [](const Expression& expr) { + // search for an equality conditions between a field and a literal + auto call = expr.call(); + if (!call) return true; + + if (call->function_name == "equal") { + auto ref = call->arguments[0].field_ref(); + auto lit = call->arguments[1].literal(); + return !(ref && lit); + } + + return true; + }); + + for (auto it = unconsumed_end; it != conjunction_members->end(); ++it) { + auto call = CallNotNull(*it); + + auto ref = call->arguments[0].field_ref(); + auto lit = call->arguments[1].literal(); + + auto it_success = known_values->emplace(*ref, *lit); + if (it_success.second) continue; + + // A value was already known for ref; check it + auto ref_lit = it_success.first; + if (*lit != ref_lit->second) { + return Status::Invalid("Conflicting guarantees: (", ref->ToString(), + " == ", lit->ToString(), ") vs (", ref->ToString(), + " == ", ref_lit->second.ToString()); + } + } + + conjunction_members->erase(unconsumed_end, conjunction_members->end()); + + return Status::OK(); +} + +Result> ExtractKnownFieldValues( + const Expression& guaranteed_true_predicate) { + auto conjunction_members = GuaranteeConjunctionMembers(guaranteed_true_predicate); + std::unordered_map known_values; + RETURN_NOT_OK(ExtractKnownFieldValuesImpl(&conjunction_members, &known_values)); + return known_values; +} + +Result ReplaceFieldsWithKnownValues( + const std::unordered_map& known_values, + Expression expr) { + if (!expr.IsBound()) { + return Status::Invalid( + "ReplaceFieldsWithKnownValues called on an unbound Expression"); + } + + return Modify( + std::move(expr), + [&known_values](Expression expr) -> Result { + if (auto ref = expr.field_ref()) { + auto it = known_values.find(*ref); + if (it != known_values.end()) { + ARROW_ASSIGN_OR_RAISE(Datum lit, + compute::Cast(it->second, expr.descr().type)); + return literal(std::move(lit)); + } + } + return expr; + }, + [](Expression expr, ...) { return expr; }); +} + +inline bool IsBinaryAssociativeCommutative(const Expression::Call& call) { + static std::unordered_set binary_associative_commutative{ + "and", "or", "and_kleene", "or_kleene", "xor", + "multiply", "add", "multiply_checked", "add_checked"}; + + auto it = binary_associative_commutative.find(call.function_name); + return it != binary_associative_commutative.end(); +} + +Result Canonicalize(Expression expr, compute::ExecContext* exec_context) { + if (exec_context == nullptr) { + compute::ExecContext exec_context; + return Canonicalize(std::move(expr), &exec_context); + } + + // If potentially reconstructing more deeply than a call's immediate arguments + // (for example, when reorganizing an associative chain), add expressions to this set to + // avoid unnecessary work + struct { + std::unordered_set set_; + + bool operator()(const Expression& expr) const { + return set_.find(expr) != set_.end(); + } + + void Add(std::vector exprs) { + std::move(exprs.begin(), exprs.end(), std::inserter(set_, set_.end())); + } + } AlreadyCanonicalized; + + return Modify( + std::move(expr), + [&AlreadyCanonicalized, exec_context](Expression expr) -> Result { + auto call = expr.call(); + if (!call) return expr; + + if (AlreadyCanonicalized(expr)) return expr; + + if (IsBinaryAssociativeCommutative(*call)) { + struct { + int Priority(const Expression& operand) const { + // order literals first, starting with nulls + if (operand.IsNullLiteral()) return 0; + if (operand.literal()) return 1; + return 2; + } + bool operator()(const Expression& l, const Expression& r) const { + return Priority(l) < Priority(r); + } + } CanonicalOrdering; + + FlattenedAssociativeChain chain(expr); + if (chain.was_left_folded && + std::is_sorted(chain.fringe.begin(), chain.fringe.end(), + CanonicalOrdering)) { + AlreadyCanonicalized.Add(std::move(chain.exprs)); + return expr; + } + + std::stable_sort(chain.fringe.begin(), chain.fringe.end(), CanonicalOrdering); + + // fold the chain back up + auto folded = + FoldLeft(chain.fringe.begin(), chain.fringe.end(), + [call, &AlreadyCanonicalized](Expression l, Expression r) { + auto canonicalized_call = *call; + canonicalized_call.arguments = {std::move(l), std::move(r)}; + Expression expr(std::move(canonicalized_call)); + AlreadyCanonicalized.Add({expr}); + return expr; + }); + return std::move(*folded); + } + + if (auto cmp = Comparison::Get(call->function_name)) { + if (call->arguments[0].literal() && !call->arguments[1].literal()) { + // ensure that literals are on comparisons' RHS + auto flipped_call = *call; + flipped_call.function_name = + Comparison::GetName(Comparison::GetFlipped(*cmp)); + // look up the flipped kernel + // TODO extract a helper for use here and in Bind + ARROW_ASSIGN_OR_RAISE( + auto function, + exec_context->func_registry()->GetFunction(flipped_call.function_name)); + + auto descrs = GetDescriptors(flipped_call.arguments); + ARROW_ASSIGN_OR_RAISE(flipped_call.kernel, function->DispatchExact(descrs)); + + std::swap(flipped_call.arguments[0], flipped_call.arguments[1]); + return Expression(std::move(flipped_call)); + } + } + + return expr; + }, + [](Expression expr, ...) { return expr; }); +} + +Result DirectComparisonSimplification(Expression expr, + const Expression::Call& guarantee) { + return Modify( + std::move(expr), [](Expression expr) { return expr; }, + [&guarantee](Expression expr, ...) -> Result { + auto call = expr.call(); + if (!call) return expr; + + // Ensure both calls are comparisons with equal LHS and scalar RHS + auto cmp = Comparison::Get(expr); + auto cmp_guarantee = Comparison::Get(guarantee.function_name); + if (!cmp || !cmp_guarantee) return expr; + + if (call->arguments[0] != guarantee.arguments[0]) return expr; + + auto rhs = call->arguments[1].literal(); + auto guarantee_rhs = guarantee.arguments[1].literal(); + if (!rhs || !guarantee_rhs) return expr; + + if (!rhs->is_scalar() || !guarantee_rhs->is_scalar()) { + return expr; + } + + ARROW_ASSIGN_OR_RAISE(auto cmp_rhs_guarantee_rhs, + Comparison::Execute(*rhs, *guarantee_rhs)); + DCHECK_NE(cmp_rhs_guarantee_rhs, Comparison::NA); + + if (cmp_rhs_guarantee_rhs == Comparison::EQUAL) { + // RHS of filter is equal to RHS of guarantee + + if ((*cmp_guarantee & *cmp) == *cmp_guarantee) { + // guarantee is a subset of filter, so all data will be included + return literal(true); + } + + if ((*cmp_guarantee & *cmp) == 0) { + // guarantee disjoint with filter, so all data will be excluded + return literal(false); + } + + return expr; + } + + if (*cmp_guarantee & cmp_rhs_guarantee_rhs) { + // unusable guarantee + return expr; + } + + if (*cmp & Comparison::GetFlipped(cmp_rhs_guarantee_rhs)) { + // x > 1, x >= 1, x != 1 guaranteed by x >= 3 + return literal(true); + } else { + // x < 1, x <= 1, x == 1 unsatisfiable if x >= 3 + return literal(false); + } + }); +} + +Result SimplifyWithGuarantee(Expression expr, + const Expression& guaranteed_true_predicate) { + auto conjunction_members = GuaranteeConjunctionMembers(guaranteed_true_predicate); + + std::unordered_map known_values; + RETURN_NOT_OK(ExtractKnownFieldValuesImpl(&conjunction_members, &known_values)); + + ARROW_ASSIGN_OR_RAISE(expr, + ReplaceFieldsWithKnownValues(known_values, std::move(expr))); + + auto CanonicalizeAndFoldConstants = [&expr] { + ARROW_ASSIGN_OR_RAISE(expr, Canonicalize(std::move(expr))); + ARROW_ASSIGN_OR_RAISE(expr, FoldConstants(std::move(expr))); + return Status::OK(); + }; + RETURN_NOT_OK(CanonicalizeAndFoldConstants()); + + for (const auto& guarantee : conjunction_members) { + if (Comparison::Get(guarantee) && guarantee.call()->arguments[1].literal()) { + ARROW_ASSIGN_OR_RAISE( + auto simplified, DirectComparisonSimplification(expr, *CallNotNull(guarantee))); + + if (Identical(simplified, expr)) continue; + + expr = std::move(simplified); + RETURN_NOT_OK(CanonicalizeAndFoldConstants()); + } + } + + return expr; +} + +// Serialization is accomplished by converting expressions to KeyValueMetadata and storing +// this in the schema of a RecordBatch. Embedded arrays and scalars are stored in its +// columns. Finally, the RecordBatch is written to an IPC file. +Result> Serialize(const Expression& expr) { + struct { + std::shared_ptr metadata_ = std::make_shared(); + ArrayVector columns_; + + Result AddScalar(const Scalar& scalar) { + auto ret = columns_.size(); + ARROW_ASSIGN_OR_RAISE(auto array, MakeArrayFromScalar(scalar, 1)); + columns_.push_back(std::move(array)); + return std::to_string(ret); + } + + Status Visit(const Expression& expr) { + if (auto lit = expr.literal()) { + if (!lit->is_scalar()) { + return Status::NotImplemented("Serialization of non-scalar literals"); + } + ARROW_ASSIGN_OR_RAISE(auto value, AddScalar(*lit->scalar())); + metadata_->Append("literal", std::move(value)); + return Status::OK(); + } + + if (auto ref = expr.field_ref()) { + if (!ref->name()) { + return Status::NotImplemented("Serialization of non-name field_refs"); + } + metadata_->Append("field_ref", *ref->name()); + return Status::OK(); + } + + auto call = CallNotNull(expr); + metadata_->Append("call", call->function_name); + + for (const auto& argument : call->arguments) { + RETURN_NOT_OK(Visit(argument)); + } + + if (call->options) { + ARROW_ASSIGN_OR_RAISE(auto options_scalar, FunctionOptionsToStructScalar(*call)); + ARROW_ASSIGN_OR_RAISE(auto value, AddScalar(*options_scalar)); + metadata_->Append("options", std::move(value)); + } + + metadata_->Append("end", call->function_name); + return Status::OK(); + } + + Result> operator()(const Expression& expr) { + RETURN_NOT_OK(Visit(expr)); + FieldVector fields(columns_.size()); + for (size_t i = 0; i < fields.size(); ++i) { + fields[i] = field("", columns_[i]->type()); + } + return RecordBatch::Make(schema(std::move(fields), std::move(metadata_)), 1, + std::move(columns_)); + } + } ToRecordBatch; + + ARROW_ASSIGN_OR_RAISE(auto batch, ToRecordBatch(expr)); + ARROW_ASSIGN_OR_RAISE(auto stream, io::BufferOutputStream::Create()); + ARROW_ASSIGN_OR_RAISE(auto writer, ipc::MakeFileWriter(stream, batch->schema())); + RETURN_NOT_OK(writer->WriteRecordBatch(*batch)); + RETURN_NOT_OK(writer->Close()); + return stream->Finish(); +} + +Result Deserialize(const Buffer& buffer) { + io::BufferReader stream(buffer); + ARROW_ASSIGN_OR_RAISE(auto reader, ipc::RecordBatchFileReader::Open(&stream)); + ARROW_ASSIGN_OR_RAISE(auto batch, reader->ReadRecordBatch(0)); + if (batch->schema()->metadata() == nullptr) { + return Status::Invalid("serialized Expression's batch repr had null metadata"); + } + if (batch->num_rows() != 1) { + return Status::Invalid( + "serialized Expression's batch repr was not a single row - had ", + batch->num_rows()); + } + + struct FromRecordBatch { + const RecordBatch& batch_; + int index_; + + const KeyValueMetadata& metadata() { return *batch_.schema()->metadata(); } + + Result> GetScalar(const std::string& i) { + int32_t column_index; + if (!internal::ParseValue(i.data(), i.length(), &column_index)) { + return Status::Invalid("Couldn't parse column_index"); + } + if (column_index >= batch_.num_columns()) { + return Status::Invalid("column_index out of bounds"); + } + return batch_.column(column_index)->GetScalar(0); + } + + Result GetOne() { + if (index_ >= metadata().size()) { + return Status::Invalid("unterminated serialized Expression"); + } + + const std::string& key = metadata().key(index_); + const std::string& value = metadata().value(index_); + ++index_; + + if (key == "literal") { + ARROW_ASSIGN_OR_RAISE(auto scalar, GetScalar(value)); + return literal(std::move(scalar)); + } + + if (key == "field_ref") { + return field_ref(value); + } + + if (key != "call") { + return Status::Invalid("Unrecognized serialized Expression key ", key); + } + + std::vector arguments; + while (metadata().key(index_) != "end") { + if (metadata().key(index_) == "options") { + ARROW_ASSIGN_OR_RAISE(auto options_scalar, GetScalar(metadata().value(index_))); + auto expr = call(value, std::move(arguments)); + RETURN_NOT_OK(FunctionOptionsFromStructScalar( + checked_cast(options_scalar.get()), + const_cast(expr.call()))); + index_ += 2; + return expr; + } + + ARROW_ASSIGN_OR_RAISE(auto argument, GetOne()); + arguments.push_back(std::move(argument)); + } + + ++index_; + return call(value, std::move(arguments)); + } + }; + + return FromRecordBatch{*batch, 0}.GetOne(); +} + +Expression project(std::vector values, std::vector names) { + return call("struct", std::move(values), compute::StructOptions{std::move(names)}); +} + +Expression equal(Expression lhs, Expression rhs) { + return call("equal", {std::move(lhs), std::move(rhs)}); +} + +Expression not_equal(Expression lhs, Expression rhs) { + return call("not_equal", {std::move(lhs), std::move(rhs)}); +} + +Expression less(Expression lhs, Expression rhs) { + return call("less", {std::move(lhs), std::move(rhs)}); +} + +Expression less_equal(Expression lhs, Expression rhs) { + return call("less_equal", {std::move(lhs), std::move(rhs)}); +} + +Expression greater(Expression lhs, Expression rhs) { + return call("greater", {std::move(lhs), std::move(rhs)}); +} + +Expression greater_equal(Expression lhs, Expression rhs) { + return call("greater_equal", {std::move(lhs), std::move(rhs)}); +} + +Expression and_(Expression lhs, Expression rhs) { + return call("and_kleene", {std::move(lhs), std::move(rhs)}); +} + +Expression and_(const std::vector& operands) { + auto folded = FoldLeft(operands.begin(), + operands.end(), and_); + if (folded) { + return std::move(*folded); + } + return literal(true); +} + +Expression or_(Expression lhs, Expression rhs) { + return call("or_kleene", {std::move(lhs), std::move(rhs)}); +} + +Expression or_(const std::vector& operands) { + auto folded = + FoldLeft(operands.begin(), operands.end(), or_); + if (folded) { + return std::move(*folded); + } + return literal(false); +} + +Expression not_(Expression operand) { return call("invert", {std::move(operand)}); } + +Expression operator&&(Expression lhs, Expression rhs) { + return and_(std::move(lhs), std::move(rhs)); +} + +Expression operator||(Expression lhs, Expression rhs) { + return or_(std::move(lhs), std::move(rhs)); +} + +} // namespace dataset +} // namespace arrow diff --git a/cpp/src/arrow/dataset/expression.h b/cpp/src/arrow/dataset/expression.h new file mode 100644 index 00000000000..e597ffb7fcd --- /dev/null +++ b/cpp/src/arrow/dataset/expression.h @@ -0,0 +1,232 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// This API is EXPERIMENTAL. + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "arrow/compute/type_fwd.h" +#include "arrow/dataset/type_fwd.h" +#include "arrow/dataset/visibility.h" +#include "arrow/datum.h" +#include "arrow/type_fwd.h" +#include "arrow/util/variant.h" + +namespace arrow { +namespace dataset { + +/// An unbound expression which maps a single Datum to another Datum. +/// An expression is one of +/// - A literal Datum. +/// - A reference to a single (potentially nested) field of the input Datum. +/// - A call to a compute function, with arguments specified by other Expressions. +class ARROW_DS_EXPORT Expression { + public: + struct Call { + std::string function_name; + std::vector arguments; + std::shared_ptr options; + std::shared_ptr> hash; + + // post-Bind properties: + const compute::Kernel* kernel = NULLPTR; + std::shared_ptr function; + std::shared_ptr kernel_state; + ValueDescr descr; + }; + + std::string ToString() const; + bool Equals(const Expression& other) const; + size_t hash() const; + struct Hash { + size_t operator()(const Expression& expr) const { return expr.hash(); } + }; + + /// Bind this expression to the given input type, looking up Kernels and field types. + /// Some expression simplification may be performed and implicit casts will be inserted. + /// Any state necessary for execution will be initialized and returned. + Result Bind(ValueDescr in, compute::ExecContext* = NULLPTR) const; + Result Bind(const Schema& in_schema, compute::ExecContext* = NULLPTR) const; + + // XXX someday + // Clone all KernelState in this bound expression. If any function referenced by this + // expression has mutable KernelState, it is not safe to execute or apply simplification + // passes to it (or copies of it!) from multiple threads. Cloning state produces new + // KernelStates where necessary to ensure that Expressions may be manipulated safely + // on multiple threads. + // Result CloneState() const; + // Status SetState(ExpressionState); + + /// Return true if all an expression's field references have explicit ValueDescr and all + /// of its functions' kernels are looked up. + bool IsBound() const; + + /// Return true if this expression is composed only of Scalar literals, field + /// references, and calls to ScalarFunctions. + bool IsScalarExpression() const; + + /// Return true if this expression is literal and entirely null. + bool IsNullLiteral() const; + + /// Return true if this expression could evaluate to true. + bool IsSatisfiable() const; + + // XXX someday + // Result GetPipelines(); + + const Call* call() const; + const Datum* literal() const; + const FieldRef* field_ref() const; + + ValueDescr descr() const; + // XXX someday + // NullGeneralization::type nullable() const; + + struct Parameter { + FieldRef ref; + ValueDescr descr; + }; + + Expression() = default; + explicit Expression(Call call); + explicit Expression(Datum literal); + explicit Expression(Parameter parameter); + + private: + using Impl = util::Variant; + std::shared_ptr impl_; + + ARROW_DS_EXPORT friend bool Identical(const Expression& l, const Expression& r); + + ARROW_DS_EXPORT friend void PrintTo(const Expression&, std::ostream*); +}; + +inline bool operator==(const Expression& l, const Expression& r) { return l.Equals(r); } +inline bool operator!=(const Expression& l, const Expression& r) { return !l.Equals(r); } + +// Factories + +ARROW_DS_EXPORT +Expression literal(Datum lit); + +template +Expression literal(Arg&& arg) { + return literal(Datum(std::forward(arg))); +} + +ARROW_DS_EXPORT +Expression field_ref(FieldRef ref); + +ARROW_DS_EXPORT +Expression call(std::string function, std::vector arguments, + std::shared_ptr options = NULLPTR); + +template ::value>::type> +Expression call(std::string function, std::vector arguments, + Options options) { + return call(std::move(function), std::move(arguments), + std::make_shared(std::move(options))); +} + +ARROW_DS_EXPORT +std::vector FieldsInExpression(const Expression&); + +ARROW_DS_EXPORT +Result> ExtractKnownFieldValues( + const Expression& guaranteed_true_predicate); + +/// \defgroup expression-passes Functions for modification of Expressions +/// +/// @{ +/// +/// These operate on bound expressions. + +/// Weak canonicalization which establishes guarantees for subsequent passes. Even +/// equivalent Expressions may result in different canonicalized expressions. +/// TODO this could be a strong canonicalization +ARROW_DS_EXPORT +Result Canonicalize(Expression, compute::ExecContext* = NULLPTR); + +/// Simplify Expressions based on literal arguments (for example, add(null, x) will always +/// be null so replace the call with a null literal). Includes early evaluation of all +/// calls whose arguments are entirely literal. +ARROW_DS_EXPORT +Result FoldConstants(Expression); + +ARROW_DS_EXPORT +Result ReplaceFieldsWithKnownValues( + const std::unordered_map& known_values, Expression); + +/// Simplify an expression by replacing subexpressions based on a guarantee: +/// a boolean expression which is guaranteed to evaluate to `true`. For example, this is +/// used to remove redundant function calls from a filter expression or to replace a +/// reference to a constant-value field with a literal. +ARROW_DS_EXPORT +Result SimplifyWithGuarantee(Expression, + const Expression& guaranteed_true_predicate); + +/// @} + +// Execution + +/// Execute a scalar expression against the provided state and input Datum. This +/// expression must be bound. +ARROW_DS_EXPORT +Result ExecuteScalarExpression(const Expression&, const Datum& input, + compute::ExecContext* = NULLPTR); + +// Serialization + +ARROW_DS_EXPORT +Result> Serialize(const Expression&); + +ARROW_DS_EXPORT +Result Deserialize(const Buffer&); + +// Convenience aliases for factories + +ARROW_DS_EXPORT Expression project(std::vector values, + std::vector names); + +ARROW_DS_EXPORT Expression equal(Expression lhs, Expression rhs); + +ARROW_DS_EXPORT Expression not_equal(Expression lhs, Expression rhs); + +ARROW_DS_EXPORT Expression less(Expression lhs, Expression rhs); + +ARROW_DS_EXPORT Expression less_equal(Expression lhs, Expression rhs); + +ARROW_DS_EXPORT Expression greater(Expression lhs, Expression rhs); + +ARROW_DS_EXPORT Expression greater_equal(Expression lhs, Expression rhs); + +ARROW_DS_EXPORT Expression and_(Expression lhs, Expression rhs); +ARROW_DS_EXPORT Expression and_(const std::vector&); +ARROW_DS_EXPORT Expression or_(Expression lhs, Expression rhs); +ARROW_DS_EXPORT Expression or_(const std::vector&); +ARROW_DS_EXPORT Expression not_(Expression operand); + +} // namespace dataset +} // namespace arrow diff --git a/cpp/src/arrow/dataset/expression_internal.h b/cpp/src/arrow/dataset/expression_internal.h new file mode 100644 index 00000000000..c3fd49dc347 --- /dev/null +++ b/cpp/src/arrow/dataset/expression_internal.h @@ -0,0 +1,466 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/dataset/expression.h" + +#include +#include +#include + +#include "arrow/compute/api_scalar.h" +#include "arrow/compute/cast.h" +#include "arrow/compute/registry.h" +#include "arrow/record_batch.h" +#include "arrow/table.h" +#include "arrow/util/logging.h" + +namespace arrow { + +using internal::checked_cast; + +namespace dataset { + +const Expression::Call* CallNotNull(const Expression& expr) { + auto call = expr.call(); + DCHECK_NE(call, nullptr); + return call; +} + +inline void GetAllFieldRefs(const Expression& expr, + std::unordered_set* refs) { + if (auto lit = expr.literal()) return; + + if (auto ref = expr.field_ref()) { + refs->emplace(*ref); + return; + } + + for (const Expression& arg : CallNotNull(expr)->arguments) { + GetAllFieldRefs(arg, refs); + } +} + +inline std::vector GetDescriptors(const std::vector& exprs) { + std::vector descrs(exprs.size()); + for (size_t i = 0; i < exprs.size(); ++i) { + DCHECK(exprs[i].IsBound()); + descrs[i] = exprs[i].descr(); + } + return descrs; +} + +inline std::vector GetDescriptors(const std::vector& values) { + std::vector descrs(values.size()); + for (size_t i = 0; i < values.size(); ++i) { + descrs[i] = values[i].descr(); + } + return descrs; +} + +struct FieldPathGetDatumImpl { + template ()))> + Result operator()(const std::shared_ptr& ptr) { + return path_.Get(*ptr).template As(); + } + + template + Result operator()(const T&) { + return Status::NotImplemented("FieldPath::Get() into Datum ", datum_.ToString()); + } + + const Datum& datum_; + const FieldPath& path_; +}; + +inline Result GetDatumField(const FieldRef& ref, const Datum& input) { + Datum field; + + FieldPath path; + if (auto type = input.type()) { + ARROW_ASSIGN_OR_RAISE(path, ref.FindOneOrNone(*type)); + } else if (auto schema = input.schema()) { + ARROW_ASSIGN_OR_RAISE(path, ref.FindOneOrNone(*schema)); + } else { + return Status::NotImplemented("retrieving fields from datum ", input.ToString()); + } + + if (path) { + ARROW_ASSIGN_OR_RAISE(field, + util::visit(FieldPathGetDatumImpl{input, path}, input.value)); + } + + if (field == Datum{}) { + field = Datum(std::make_shared()); + } + + return field; +} + +struct Comparison { + enum type { + NA = 0, + EQUAL = 1, + LESS = 2, + GREATER = 4, + NOT_EQUAL = LESS | GREATER, + LESS_EQUAL = LESS | EQUAL, + GREATER_EQUAL = GREATER | EQUAL, + }; + + static const type* Get(const std::string& function) { + static std::unordered_map map{ + {"equal", EQUAL}, {"not_equal", NOT_EQUAL}, + {"less", LESS}, {"less_equal", LESS_EQUAL}, + {"greater", GREATER}, {"greater_equal", GREATER_EQUAL}, + }; + + auto it = map.find(function); + return it != map.end() ? &it->second : nullptr; + } + + static const type* Get(const Expression& expr) { + if (auto call = expr.call()) { + return Comparison::Get(call->function_name); + } + return nullptr; + } + + // Execute a simple Comparison between scalars, casting the RHS if types disagree + static Result Execute(Datum l, Datum r) { + if (!l.is_scalar() || !r.is_scalar()) { + return Status::Invalid("Cannot Execute Comparison on non-scalars"); + } + + if (!l.type()->Equals(r.type())) { + ARROW_ASSIGN_OR_RAISE(r, compute::Cast(r, l.type())); + } + + std::vector arguments{std::move(l), std::move(r)}; + + ARROW_ASSIGN_OR_RAISE(auto equal, compute::CallFunction("equal", arguments)); + + if (!equal.scalar()->is_valid) return NA; + if (equal.scalar_as().value) return EQUAL; + + ARROW_ASSIGN_OR_RAISE(auto less, compute::CallFunction("less", arguments)); + + if (!less.scalar()->is_valid) return NA; + return less.scalar_as().value ? LESS : GREATER; + } + + static type GetFlipped(type op) { + switch (op) { + case NA: + return NA; + case EQUAL: + return EQUAL; + case LESS: + return GREATER; + case GREATER: + return LESS; + case NOT_EQUAL: + return NOT_EQUAL; + case LESS_EQUAL: + return GREATER_EQUAL; + case GREATER_EQUAL: + return LESS_EQUAL; + } + DCHECK(false); + return NA; + } + + static std::string GetName(type op) { + switch (op) { + case NA: + DCHECK(false) << "unreachable"; + break; + case EQUAL: + return "equal"; + case LESS: + return "less"; + case GREATER: + return "greater"; + case NOT_EQUAL: + return "not_equal"; + case LESS_EQUAL: + return "less_equal"; + case GREATER_EQUAL: + return "greater_equal"; + } + DCHECK(false); + return "na"; + } + + static std::string GetOp(type op) { + switch (op) { + case NA: + DCHECK(false) << "unreachable"; + break; + case EQUAL: + return "=="; + case LESS: + return "<"; + case GREATER: + return ">"; + case NOT_EQUAL: + return "!="; + case LESS_EQUAL: + return "<="; + case GREATER_EQUAL: + return ">="; + } + DCHECK(false); + return ""; + } +}; + +inline const compute::CastOptions* GetCastOptions(const Expression::Call& call) { + if (call.function_name != "cast") return nullptr; + return checked_cast(call.options.get()); +} + +inline bool IsSetLookup(const std::string& function) { + return function == "is_in" || function == "index_in"; +} + +inline bool IsSameTypesBinary(const std::string& function) { + if (Comparison::Get(function)) return true; + + static std::unordered_set set{"add", "subtract", "multiply", "divide"}; + + return set.find(function) != set.end(); +} + +inline const compute::SetLookupOptions* GetSetLookupOptions( + const Expression::Call& call) { + if (!IsSetLookup(call.function_name)) return nullptr; + return checked_cast(call.options.get()); +} + +inline const compute::StructOptions* GetStructOptions(const Expression::Call& call) { + if (call.function_name != "struct") return nullptr; + return checked_cast(call.options.get()); +} + +inline const compute::StrptimeOptions* GetStrptimeOptions(const Expression::Call& call) { + if (call.function_name != "strptime") return nullptr; + return checked_cast(call.options.get()); +} + +inline std::shared_ptr GetDictionaryValueType( + const std::shared_ptr& type) { + if (type && type->id() == Type::DICTIONARY) { + return checked_cast(*type).value_type(); + } + return nullptr; +} + +inline Status EnsureNotDictionary(ValueDescr* descr) { + if (auto value_type = GetDictionaryValueType(descr->type)) { + descr->type = std::move(value_type); + } + return Status::OK(); +} + +inline Status EnsureNotDictionary(Datum* datum) { + if (datum->type()->id() == Type::DICTIONARY) { + const auto& type = checked_cast(*datum->type()).value_type(); + ARROW_ASSIGN_OR_RAISE(*datum, compute::Cast(*datum, type)); + } + return Status::OK(); +} + +inline Status EnsureNotDictionary(Expression::Call* call) { + if (auto options = GetSetLookupOptions(*call)) { + auto new_options = *options; + RETURN_NOT_OK(EnsureNotDictionary(&new_options.value_set)); + call->options.reset(new compute::SetLookupOptions(std::move(new_options))); + } + return Status::OK(); +} + +inline Result> FunctionOptionsToStructScalar( + const Expression::Call& call) { + if (call.options == nullptr) { + return nullptr; + } + + auto Finish = [](ScalarVector values, std::vector names) { + FieldVector fields(names.size()); + for (size_t i = 0; i < fields.size(); ++i) { + fields[i] = field(std::move(names[i]), values[i]->type); + } + return std::make_shared(std::move(values), struct_(std::move(fields))); + }; + + if (auto options = GetSetLookupOptions(call)) { + if (!options->value_set.is_array()) { + return Status::NotImplemented("chunked value_set"); + } + return Finish( + { + std::make_shared(options->value_set.make_array()), + MakeScalar(options->skip_nulls), + }, + {"value_set", "skip_nulls"}); + } + + if (call.function_name == "cast") { + auto options = checked_cast(call.options.get()); + return Finish( + { + MakeNullScalar(options->to_type), + MakeScalar(options->allow_int_overflow), + MakeScalar(options->allow_time_truncate), + MakeScalar(options->allow_time_overflow), + MakeScalar(options->allow_decimal_truncate), + MakeScalar(options->allow_float_truncate), + MakeScalar(options->allow_invalid_utf8), + }, + { + "to_type_holder", + "allow_int_overflow", + "allow_time_truncate", + "allow_time_overflow", + "allow_decimal_truncate", + "allow_float_truncate", + "allow_invalid_utf8", + }); + } + + return Status::NotImplemented("conversion of options for ", call.function_name); +} + +inline Status FunctionOptionsFromStructScalar(const StructScalar* repr, + Expression::Call* call) { + if (repr == nullptr) { + call->options = nullptr; + return Status::OK(); + } + + if (IsSetLookup(call->function_name)) { + ARROW_ASSIGN_OR_RAISE(auto value_set, repr->field("value_set")); + ARROW_ASSIGN_OR_RAISE(auto skip_nulls, repr->field("skip_nulls")); + call->options = std::make_shared( + checked_cast(*value_set).value, + checked_cast(*skip_nulls).value); + return Status::OK(); + } + + if (call->function_name == "cast") { + auto options = std::make_shared(); + ARROW_ASSIGN_OR_RAISE(auto to_type_holder, repr->field("to_type_holder")); + options->to_type = to_type_holder->type; + + int i = 1; + for (bool* opt : { + &options->allow_int_overflow, + &options->allow_time_truncate, + &options->allow_time_overflow, + &options->allow_decimal_truncate, + &options->allow_float_truncate, + &options->allow_invalid_utf8, + }) { + *opt = checked_cast(*repr->value[i++]).value; + } + + call->options = std::move(options); + return Status::OK(); + } + + return Status::NotImplemented("conversion of options for ", call->function_name); +} + +struct FlattenedAssociativeChain { + bool was_left_folded = true; + std::vector exprs, fringe; + + explicit FlattenedAssociativeChain(Expression expr) : exprs{std::move(expr)} { + auto call = CallNotNull(exprs.back()); + fringe = call->arguments; + + auto it = fringe.begin(); + + while (it != fringe.end()) { + auto sub_call = it->call(); + if (!sub_call || sub_call->function_name != call->function_name) { + ++it; + continue; + } + + if (it != fringe.begin()) { + was_left_folded = false; + } + + exprs.push_back(std::move(*it)); + it = fringe.erase(it); + + auto index = it - fringe.begin(); + fringe.insert(it, sub_call->arguments.begin(), sub_call->arguments.end()); + it = fringe.begin() + index; + // NB: no increment so we hit sub_call's first argument next iteration + } + + DCHECK(std::all_of(exprs.begin(), exprs.end(), [](const Expression& expr) { + return CallNotNull(expr)->options == nullptr; + })); + } +}; + +inline Result> GetFunction( + const Expression::Call& call, compute::ExecContext* exec_context) { + if (call.function_name != "cast") { + return exec_context->func_registry()->GetFunction(call.function_name); + } + // XXX this special case is strange; why not make "cast" a ScalarFunction? + const auto& to_type = checked_cast(*call.options).to_type; + return compute::GetCastFunction(to_type); +} + +template +Result Modify(Expression expr, const PreVisit& pre, + const PostVisitCall& post_call) { + ARROW_ASSIGN_OR_RAISE(expr, Result(pre(std::move(expr)))); + + auto call = expr.call(); + if (!call) return expr; + + bool at_least_one_modified = false; + auto modified_call = *call; + auto modified_argument = modified_call.arguments.begin(); + + for (const auto& argument : call->arguments) { + ARROW_ASSIGN_OR_RAISE(*modified_argument, Modify(argument, pre, post_call)); + + if (!Identical(*modified_argument, argument)) { + at_least_one_modified = true; + } + ++modified_argument; + } + + if (at_least_one_modified) { + // reconstruct the call expression with the modified arguments + auto modified_expr = Expression(std::move(modified_call)); + + return post_call(std::move(modified_expr), &expr); + } + + return post_call(std::move(expr), nullptr); +} + +} // namespace dataset +} // namespace arrow diff --git a/cpp/src/arrow/dataset/expression_test.cc b/cpp/src/arrow/dataset/expression_test.cc new file mode 100644 index 00000000000..e245b0d7093 --- /dev/null +++ b/cpp/src/arrow/dataset/expression_test.cc @@ -0,0 +1,1035 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/dataset/expression.h" + +#include +#include +#include +#include +#include + +#include +#include + +#include "arrow/compute/registry.h" +#include "arrow/dataset/expression_internal.h" +#include "arrow/dataset/test_util.h" +#include "arrow/testing/gtest_util.h" + +using testing::UnorderedElementsAreArray; + +namespace arrow { + +using internal::checked_cast; +using internal::checked_pointer_cast; + +namespace dataset { + +#define EXPECT_OK ARROW_EXPECT_OK + +Expression cast(Expression argument, std::shared_ptr to_type) { + return call("cast", {std::move(argument)}, + compute::CastOptions::Safe(std::move(to_type))); +} + +TEST(Expression, ToString) { + EXPECT_EQ(field_ref("alpha").ToString(), "alpha"); + + EXPECT_EQ(literal(3).ToString(), "3"); + EXPECT_EQ(literal("a").ToString(), "\"a\""); + EXPECT_EQ(literal("a\nb").ToString(), "\"a\\nb\""); + EXPECT_EQ(literal(std::make_shared()).ToString(), "null"); + EXPECT_EQ(literal(std::make_shared(Buffer::FromString("az"))).ToString(), + "\"617A\""); + + auto ts = *MakeScalar("1990-10-23 10:23:33")->CastTo(timestamp(TimeUnit::NANO)); + EXPECT_EQ(literal(ts).ToString(), "656677413000000000"); + + EXPECT_EQ(call("add", {literal(3), field_ref("beta")}).ToString(), "add(3, beta)"); + + auto in_12 = call("index_in", {field_ref("beta")}, + compute::SetLookupOptions{ArrayFromJSON(int32(), "[1,2]")}); + + EXPECT_EQ(in_12.ToString(), "index_in(beta, value_set=[\n 1,\n 2\n])"); + + EXPECT_EQ(and_(field_ref("a"), field_ref("b")).ToString(), "(a and b)"); + EXPECT_EQ(or_(field_ref("a"), field_ref("b")).ToString(), "(a or b)"); + EXPECT_EQ(not_(field_ref("a")).ToString(), "invert(a)"); + + EXPECT_EQ(cast(field_ref("a"), int32()).ToString(), "cast(a, to_type=int32)"); + EXPECT_EQ(cast(field_ref("a"), nullptr).ToString(), + "cast(a, to_type=)"); + + struct WidgetifyOptions : compute::FunctionOptions { + bool really; + }; + + // NB: corrupted for nullary functions but we don't have any of those + EXPECT_EQ(call("widgetify", {}).ToString(), "widgetif)"); + EXPECT_EQ( + call("widgetify", {literal(1)}, std::make_shared()).ToString(), + "widgetify(1, {NON-REPRESENTABLE OPTIONS})"); + + EXPECT_EQ(equal(field_ref("a"), literal(1)).ToString(), "(a == 1)"); + EXPECT_EQ(less(field_ref("a"), literal(2)).ToString(), "(a < 2)"); + EXPECT_EQ(greater(field_ref("a"), literal(3)).ToString(), "(a > 3)"); + EXPECT_EQ(not_equal(field_ref("a"), literal("a")).ToString(), "(a != \"a\")"); + EXPECT_EQ(less_equal(field_ref("a"), literal("b")).ToString(), "(a <= \"b\")"); + EXPECT_EQ(greater_equal(field_ref("a"), literal("c")).ToString(), "(a >= \"c\")"); + + EXPECT_EQ(project( + { + field_ref("a"), + field_ref("a"), + literal(3), + in_12, + }, + { + "a", + "renamed_a", + "three", + "b", + }) + .ToString(), + "{a=a, renamed_a=a, three=3, b=" + in_12.ToString() + "}"); +} + +TEST(Expression, Equality) { + EXPECT_EQ(literal(1), literal(1)); + EXPECT_NE(literal(1), literal(2)); + + EXPECT_EQ(field_ref("a"), field_ref("a")); + EXPECT_NE(field_ref("a"), field_ref("b")); + EXPECT_NE(field_ref("a"), literal(2)); + + EXPECT_EQ(call("add", {literal(3), field_ref("a")}), + call("add", {literal(3), field_ref("a")})); + EXPECT_NE(call("add", {literal(3), field_ref("a")}), + call("add", {literal(2), field_ref("a")})); + EXPECT_NE(call("add", {field_ref("a"), literal(3)}), + call("add", {literal(3), field_ref("a")})); + + auto in_123 = compute::SetLookupOptions{ArrayFromJSON(int32(), "[1,2,3]")}; + EXPECT_EQ(call("add", {literal(3), call("index_in", {field_ref("beta")}, in_123)}), + call("add", {literal(3), call("index_in", {field_ref("beta")}, in_123)})); + + auto in_12 = compute::SetLookupOptions{ArrayFromJSON(int32(), "[1,2]")}; + EXPECT_NE(call("add", {literal(3), call("index_in", {field_ref("beta")}, in_12)}), + call("add", {literal(3), call("index_in", {field_ref("beta")}, in_123)})); + + EXPECT_EQ(cast(field_ref("a"), int32()), cast(field_ref("a"), int32())); + EXPECT_NE(cast(field_ref("a"), int32()), cast(field_ref("a"), int64())); + EXPECT_NE(cast(field_ref("a"), int32()), + call("cast", {field_ref("a")}, compute::CastOptions::Unsafe(int32()))); +} + +TEST(Expression, Hash) { + std::unordered_set set; + + EXPECT_TRUE(set.emplace(field_ref("alpha")).second); + EXPECT_TRUE(set.emplace(field_ref("beta")).second); + EXPECT_FALSE(set.emplace(field_ref("beta")).second) << "already inserted"; + EXPECT_TRUE(set.emplace(literal(1)).second); + EXPECT_FALSE(set.emplace(literal(1)).second) << "already inserted"; + EXPECT_TRUE(set.emplace(literal(3)).second); + + // NB: no validation on construction; we couldn't execute + // add with zero arguments + EXPECT_TRUE(set.emplace(call("add", {})).second); + EXPECT_FALSE(set.emplace(call("add", {})).second) << "already inserted"; + + // NB: unbound expressions don't check for availability in any registry + EXPECT_TRUE(set.emplace(call("widgetify", {})).second); + + EXPECT_EQ(set.size(), 6); +} + +TEST(Expression, IsScalarExpression) { + EXPECT_TRUE(literal(true).IsScalarExpression()); + + auto arr = ArrayFromJSON(int8(), "[]"); + EXPECT_FALSE(literal(arr).IsScalarExpression()); + + EXPECT_TRUE(field_ref("a").IsScalarExpression()); + + EXPECT_TRUE(equal(field_ref("a"), literal(1)).IsScalarExpression()); + + EXPECT_FALSE(equal(field_ref("a"), literal(arr)).IsScalarExpression()); + + EXPECT_TRUE(call("is_in", {field_ref("a")}, compute::SetLookupOptions{arr, true}) + .IsScalarExpression()); + + // non scalar function + EXPECT_FALSE(call("take", {field_ref("a"), literal(arr)}).IsScalarExpression()); +} + +TEST(Expression, IsSatisfiable) { + EXPECT_TRUE(literal(true).IsSatisfiable()); + EXPECT_FALSE(literal(false).IsSatisfiable()); + + auto null = std::make_shared(); + EXPECT_FALSE(literal(null).IsSatisfiable()); + + EXPECT_TRUE(field_ref("a").IsSatisfiable()); + + EXPECT_TRUE(equal(field_ref("a"), literal(1)).IsSatisfiable()); + + // NB: no constant folding here + EXPECT_TRUE(equal(literal(0), literal(1)).IsSatisfiable()); + + // When a top level conjunction contains an Expression which is certain to evaluate to + // null, it can only evaluate to null or false. + auto null_or_false = and_(literal(null), field_ref("a")); + // This may appear in satisfiable filters if coalesced + EXPECT_TRUE(call("is_null", {null_or_false}).IsSatisfiable()); + // ... but at the top level it is not satisfiable. + // This special case arises when (for example) an absent column has made + // one member of the conjunction always-null. This is fairly common and + // would be a worthwhile optimization to support. + // EXPECT_FALSE(null_or_false).IsSatisfiable()); +} + +TEST(Expression, FieldsInExpression) { + auto ExpectFieldsAre = [](Expression expr, std::vector expected) { + EXPECT_THAT(FieldsInExpression(expr), testing::ContainerEq(expected)); + }; + + ExpectFieldsAre(literal(true), {}); + + ExpectFieldsAre(field_ref("a"), {"a"}); + + ExpectFieldsAre(equal(field_ref("a"), literal(1)), {"a"}); + + ExpectFieldsAre(equal(field_ref("a"), field_ref("b")), {"a", "b"}); + + ExpectFieldsAre( + or_(equal(field_ref("a"), literal(1)), equal(field_ref("a"), literal(2))), + {"a", "a"}); + + ExpectFieldsAre( + or_(equal(field_ref("a"), literal(1)), equal(field_ref("b"), literal(2))), + {"a", "b"}); + + ExpectFieldsAre(or_(and_(not_(equal(field_ref("a"), literal(1))), + equal(field_ref("b"), literal(2))), + not_(less(field_ref("c"), literal(3)))), + {"a", "b", "c"}); +} + +TEST(Expression, BindLiteral) { + for (Datum dat : { + Datum(3), + Datum(3.5), + Datum(ArrayFromJSON(int32(), "[1,2,3]")), + }) { + // literals are always considered bound + auto expr = literal(dat); + EXPECT_EQ(expr.descr(), dat.descr()); + EXPECT_TRUE(expr.IsBound()); + } +} + +void ExpectBindsTo(Expression expr, util::optional expected, + Expression* bound_out = nullptr) { + if (!expected) { + expected = expr; + } + + ASSERT_OK_AND_ASSIGN(auto bound, expr.Bind(*kBoringSchema)); + EXPECT_TRUE(bound.IsBound()); + + ASSERT_OK_AND_ASSIGN(expected, expected->Bind(*kBoringSchema)); + EXPECT_EQ(bound, *expected) << " unbound: " << expr.ToString(); + + if (bound_out) { + *bound_out = bound; + } +} + +const auto no_change = util::nullopt; + +TEST(Expression, BindFieldRef) { + // an unbound field_ref does not have the output ValueDescr set + auto expr = field_ref("alpha"); + EXPECT_EQ(expr.descr(), ValueDescr{}); + EXPECT_FALSE(expr.IsBound()); + + ExpectBindsTo(field_ref("i32"), no_change, &expr); + EXPECT_EQ(expr.descr(), ValueDescr::Array(int32())); + + // if the field is not found, a null scalar will be emitted + ExpectBindsTo(field_ref("no such field"), no_change, &expr); + EXPECT_EQ(expr.descr(), ValueDescr::Scalar(null())); + + // referencing a field by name is not supported if that name is not unique + // in the input schema + ASSERT_RAISES(Invalid, field_ref("alpha").Bind(Schema( + {field("alpha", int32()), field("alpha", float32())}))); + + // referencing nested fields is supported + ASSERT_OK_AND_ASSIGN(expr, + field_ref(FieldRef("a", "b")) + .Bind(Schema({field("a", struct_({field("b", int32())}))}))); + EXPECT_TRUE(expr.IsBound()); + EXPECT_EQ(expr.descr(), ValueDescr::Array(int32())); +} + +TEST(Expression, BindCall) { + auto expr = call("add", {field_ref("i32"), field_ref("i32_req")}); + EXPECT_FALSE(expr.IsBound()); + + ExpectBindsTo(expr, no_change, &expr); + EXPECT_EQ(expr.descr(), ValueDescr::Array(int32())); + + // literal(3) may be safely cast to float32, so binding this expr casts that literal: + ExpectBindsTo(call("add", {field_ref("f32"), literal(3)}), + call("add", {field_ref("f32"), literal(3.0F)})); + + // literal(3.5) may not be safely cast to int32, so binding this expr fails: + ASSERT_RAISES(Invalid, + call("add", {field_ref("i32"), literal(3.5)}).Bind(*kBoringSchema)); +} + +TEST(Expression, BindWithImplicitCasts) { + for (auto cmp : {equal, not_equal, less, less_equal, greater, greater_equal}) { + // cast arguments to same type + ExpectBindsTo(cmp(field_ref("i32"), field_ref("i64")), + cmp(field_ref("i32"), cast(field_ref("i64"), int32()))); + // NB: RHS is cast unless LHS is scalar. + + // cast dictionary to value type + ExpectBindsTo(cmp(field_ref("dict_str"), field_ref("str")), + cmp(cast(field_ref("dict_str"), utf8()), field_ref("str"))); + } + + // scalars are directly cast when possible: + auto ts_scalar = MakeScalar("1990-10-23")->CastTo(timestamp(TimeUnit::NANO)); + ExpectBindsTo(equal(field_ref("ts_ns"), literal("1990-10-23")), + equal(field_ref("ts_ns"), literal(*ts_scalar))); + + // cast value_set to argument type + auto Opts = [](std::shared_ptr type) { + return compute::SetLookupOptions{ArrayFromJSON(type, R"(["a"])")}; + }; + ExpectBindsTo(call("is_in", {field_ref("str")}, Opts(binary())), + call("is_in", {field_ref("str")}, Opts(utf8()))); + + // dictionary decode set then cast to argument type + ExpectBindsTo(call("is_in", {field_ref("str")}, Opts(dictionary(int32(), binary()))), + call("is_in", {field_ref("str")}, Opts(utf8()))); +} + +TEST(Expression, BindNestedCall) { + auto expr = + call("add", {field_ref("a"), + call("subtract", {call("multiply", {field_ref("b"), field_ref("c")}), + field_ref("d")})}); + EXPECT_FALSE(expr.IsBound()); + + ASSERT_OK_AND_ASSIGN(expr, + expr.Bind(Schema({field("a", int32()), field("b", int32()), + field("c", int32()), field("d", int32())}))); + EXPECT_EQ(expr.descr(), ValueDescr::Array(int32())); + EXPECT_TRUE(expr.IsBound()); +} + +TEST(Expression, ExecuteFieldRef) { + auto AssertRefIs = [](FieldRef ref, Datum in, Datum expected) { + auto expr = field_ref(ref); + + ASSERT_OK_AND_ASSIGN(expr, expr.Bind(in.descr())); + ASSERT_OK_AND_ASSIGN(Datum actual, ExecuteScalarExpression(expr, in)); + + AssertDatumsEqual(actual, expected, /*verbose=*/true); + }; + + AssertRefIs("a", ArrayFromJSON(struct_({field("a", float64())}), R"([ + {"a": 6.125}, + {"a": 0.0}, + {"a": -1} + ])"), + ArrayFromJSON(float64(), R"([6.125, 0.0, -1])")); + + // more nested: + AssertRefIs(FieldRef{"a", "a"}, + ArrayFromJSON(struct_({field("a", struct_({field("a", float64())}))}), R"([ + {"a": {"a": 6.125}}, + {"a": {"a": 0.0}}, + {"a": {"a": -1}} + ])"), + ArrayFromJSON(float64(), R"([6.125, 0.0, -1])")); + + // absent fields are resolved as a null scalar: + AssertRefIs(FieldRef{"b"}, ArrayFromJSON(struct_({field("a", float64())}), R"([ + {"a": 6.125}, + {"a": 0.0}, + {"a": -1} + ])"), + MakeNullScalar(null())); +} + +Result NaiveExecuteScalarExpression(const Expression& expr, const Datum& input) { + auto call = expr.call(); + if (call == nullptr) { + // already tested execution of field_ref, execution of literal is trivial + return ExecuteScalarExpression(expr, input); + } + + std::vector arguments(call->arguments.size()); + for (size_t i = 0; i < arguments.size(); ++i) { + ARROW_ASSIGN_OR_RAISE(arguments[i], + NaiveExecuteScalarExpression(call->arguments[i], input)); + } + + compute::ExecContext exec_context; + ARROW_ASSIGN_OR_RAISE(auto function, GetFunction(*call, &exec_context)); + + auto descrs = GetDescriptors(call->arguments); + ARROW_ASSIGN_OR_RAISE(auto expected_kernel, function->DispatchExact(descrs)); + + EXPECT_EQ(call->kernel, expected_kernel); + return function->Execute(arguments, call->options.get(), &exec_context); +} + +void AssertExecute(Expression expr, Datum in, Datum* actual_out = NULLPTR) { + if (in.is_value()) { + ASSERT_OK_AND_ASSIGN(expr, expr.Bind(in.descr())); + } else { + ASSERT_OK_AND_ASSIGN(expr, expr.Bind(*in.record_batch()->schema())); + } + + ASSERT_OK_AND_ASSIGN(Datum actual, ExecuteScalarExpression(expr, in)); + + ASSERT_OK_AND_ASSIGN(Datum expected, NaiveExecuteScalarExpression(expr, in)); + + AssertDatumsEqual(actual, expected, /*verbose=*/true); + + if (actual_out) { + *actual_out = actual; + } +} + +TEST(Expression, ExecuteCall) { + AssertExecute(call("add", {field_ref("a"), literal(3.5)}), + ArrayFromJSON(struct_({field("a", float64())}), R"([ + {"a": 6.125}, + {"a": 0.0}, + {"a": -1} + ])")); + + AssertExecute( + call("add", {field_ref("a"), call("subtract", {literal(3.5), field_ref("b")})}), + ArrayFromJSON(struct_({field("a", float64()), field("b", float64())}), R"([ + {"a": 6.125, "b": 3.375}, + {"a": 0.0, "b": 1}, + {"a": -1, "b": 4.75} + ])")); + + AssertExecute(call("strptime", {field_ref("a")}, + compute::StrptimeOptions("%m/%d/%Y", TimeUnit::MICRO)), + ArrayFromJSON(struct_({field("a", utf8())}), R"([ + {"a": "5/1/2020"}, + {"a": null}, + {"a": "12/11/1900"} + ])")); + + AssertExecute(project({call("add", {field_ref("a"), literal(3.5)})}, {"a + 3.5"}), + ArrayFromJSON(struct_({field("a", float64())}), R"([ + {"a": 6.125}, + {"a": 0.0}, + {"a": -1} + ])")); +} + +TEST(Expression, ExecuteDictionaryTransparent) { + AssertExecute( + equal(field_ref("a"), field_ref("b")), + ArrayFromJSON( + struct_({field("a", dictionary(int32(), utf8())), field("b", utf8())}), R"([ + {"a": "hi", "b": "hi"}, + {"a": "", "b": ""}, + {"a": "hi", "b": "hello"} + ])")); + + Datum dict_set = ArrayFromJSON(dictionary(int32(), utf8()), R"(["a"])"); + AssertExecute(call("is_in", {field_ref("a")}, + compute::SetLookupOptions{dict_set, + /*skip_nulls=*/false}), + ArrayFromJSON(struct_({field("a", utf8())}), R"([ + {"a": "a"}, + {"a": "good"}, + {"a": null} + ])")); +} + +void ExpectIdenticalIfUnchanged(Expression modified, Expression original) { + if (modified == original) { + // no change -> must be identical + EXPECT_TRUE(Identical(modified, original)) << " " << original.ToString(); + } +} + +struct { + void operator()(Expression expr, Expression expected) { + ASSERT_OK_AND_ASSIGN(expr, expr.Bind(*kBoringSchema)); + ASSERT_OK_AND_ASSIGN(expected, expected.Bind(*kBoringSchema)); + + ASSERT_OK_AND_ASSIGN(auto folded, FoldConstants(expr)); + + EXPECT_EQ(folded, expected); + ExpectIdenticalIfUnchanged(folded, expr); + } +} ExpectFoldsTo; + +TEST(Expression, FoldConstants) { + // literals are unchanged + ExpectFoldsTo(literal(3), literal(3)); + + // field_refs are unchanged + ExpectFoldsTo(field_ref("i32"), field_ref("i32")); + + // call against literals (3 + 2 == 5) + ExpectFoldsTo(call("add", {literal(3), literal(2)}), literal(5)); + + // call against literal and field_ref + ExpectFoldsTo(call("add", {literal(3), field_ref("i32")}), + call("add", {literal(3), field_ref("i32")})); + + // nested call against literals ((8 - (2 * 3)) + 2 == 4) + ExpectFoldsTo(call("add", + { + call("subtract", + { + literal(8), + call("multiply", {literal(2), literal(3)}), + }), + literal(2), + }), + literal(4)); + + // nested call against literals with one field_ref + // (i32 - (2 * 3)) + 2 == (i32 - 6) + 2 + // NB this could be improved further by using associativity of addition; another pass + ExpectFoldsTo(call("add", + { + call("subtract", + { + field_ref("i32"), + call("multiply", {literal(2), literal(3)}), + }), + literal(2), + }), + call("add", { + call("subtract", + { + field_ref("i32"), + literal(6), + }), + literal(2), + })); + + compute::SetLookupOptions in_123(ArrayFromJSON(int32(), "[1,2,3]")); + + ExpectFoldsTo(call("is_in", {literal(2)}, in_123), literal(true)); + + ExpectFoldsTo( + call("is_in", + {call("add", {field_ref("i32"), call("multiply", {literal(2), literal(3)})})}, + in_123), + call("is_in", {call("add", {field_ref("i32"), literal(6)})}, in_123)); +} + +TEST(Expression, FoldConstantsBoolean) { + // test and_kleene/or_kleene-specific optimizations + auto one = literal(1); + auto two = literal(2); + auto whatever = equal(call("add", {one, field_ref("i32")}), two); + + auto true_ = literal(true); + auto false_ = literal(false); + + ExpectFoldsTo(and_(false_, whatever), false_); + ExpectFoldsTo(and_(true_, whatever), whatever); + ExpectFoldsTo(and_(whatever, whatever), whatever); + + ExpectFoldsTo(or_(true_, whatever), true_); + ExpectFoldsTo(or_(false_, whatever), whatever); + ExpectFoldsTo(or_(whatever, whatever), whatever); +} + +TEST(Expression, ExtractKnownFieldValues) { + struct { + void operator()(Expression guarantee, + std::unordered_map expected) { + ASSERT_OK_AND_ASSIGN(auto actual, ExtractKnownFieldValues(guarantee)); + EXPECT_THAT(actual, UnorderedElementsAreArray(expected)) + << " guarantee: " << guarantee.ToString(); + } + } ExpectKnown; + + ExpectKnown(equal(field_ref("i32"), literal(3)), {{"i32", Datum(3)}}); + + ExpectKnown(greater(field_ref("i32"), literal(3)), {}); + + // FIXME known null should be expressed with is_null rather than equality + auto null_int32 = std::make_shared(); + ExpectKnown(equal(field_ref("i32"), literal(null_int32)), {{"i32", Datum(null_int32)}}); + + ExpectKnown( + and_({equal(field_ref("i32"), literal(3)), equal(field_ref("f32"), literal(1.5F))}), + {{"i32", Datum(3)}, {"f32", Datum(1.5F)}}); + + // NB: guarantees are *not* automatically canonicalized + ExpectKnown( + and_({equal(field_ref("i32"), literal(3)), equal(literal(1.5F), field_ref("f32"))}), + {{"i32", Datum(3)}}); + + // NB: guarantees are *not* automatically simplified + // (the below could be constant folded to a usable guarantee) + ExpectKnown(or_({equal(field_ref("i32"), literal(3)), literal(false)}), {}); + + // NB: guarantees are unbound; applying them may require casts + ExpectKnown(equal(field_ref("i32"), literal("1234324")), {{"i32", Datum("1234324")}}); + + ExpectKnown( + and_({equal(field_ref("i32"), literal(3)), equal(field_ref("f32"), literal(2.F)), + equal(field_ref("i32_req"), literal(1))}), + {{"i32", Datum(3)}, {"f32", Datum(2.F)}, {"i32_req", Datum(1)}}); + + ExpectKnown( + and_(or_(equal(field_ref("i32"), literal(3)), equal(field_ref("i32"), literal(4))), + equal(field_ref("f32"), literal(2.F))), + {{"f32", Datum(2.F)}}); + + ExpectKnown(and_({equal(field_ref("i32"), literal(3)), + equal(field_ref("f32"), field_ref("f32_req")), + equal(field_ref("i32_req"), literal(1))}), + {{"i32", Datum(3)}, {"i32_req", Datum(1)}}); +} + +TEST(Expression, ReplaceFieldsWithKnownValues) { + auto ExpectReplacesTo = + [](Expression expr, + std::unordered_map known_values, + Expression unbound_expected) { + ASSERT_OK_AND_ASSIGN(expr, expr.Bind(*kBoringSchema)); + ASSERT_OK_AND_ASSIGN(auto expected, unbound_expected.Bind(*kBoringSchema)); + ASSERT_OK_AND_ASSIGN(auto replaced, + ReplaceFieldsWithKnownValues(known_values, expr)); + + EXPECT_EQ(replaced, expected); + ExpectIdenticalIfUnchanged(replaced, expr); + }; + + std::unordered_map i32_is_3{{"i32", Datum(3)}}; + + ExpectReplacesTo(literal(1), i32_is_3, literal(1)); + + ExpectReplacesTo(field_ref("i32"), i32_is_3, literal(3)); + + // NB: known_values will be cast + ExpectReplacesTo(field_ref("i32"), {{"i32", Datum("3")}}, literal(3)); + + ExpectReplacesTo(field_ref("b"), i32_is_3, field_ref("b")); + + ExpectReplacesTo(equal(field_ref("i32"), literal(1)), i32_is_3, + equal(literal(3), literal(1))); + + ExpectReplacesTo(call("add", + { + call("subtract", + { + field_ref("i32"), + call("multiply", {literal(2), literal(3)}), + }), + literal(2), + }), + i32_is_3, + call("add", { + call("subtract", + { + literal(3), + call("multiply", {literal(2), literal(3)}), + }), + literal(2), + })); +} + +struct { + void operator()(Expression expr, Expression unbound_expected) const { + ASSERT_OK_AND_ASSIGN(auto bound, expr.Bind(*kBoringSchema)); + ASSERT_OK_AND_ASSIGN(auto expected, unbound_expected.Bind(*kBoringSchema)); + ASSERT_OK_AND_ASSIGN(auto actual, Canonicalize(bound)); + + EXPECT_EQ(actual, expected); + ExpectIdenticalIfUnchanged(actual, bound); + } +} ExpectCanonicalizesTo; + +TEST(Expression, CanonicalizeTrivial) { + ExpectCanonicalizesTo(literal(1), literal(1)); + + ExpectCanonicalizesTo(field_ref("i32"), field_ref("i32")); + + ExpectCanonicalizesTo(equal(field_ref("i32"), field_ref("i32_req")), + equal(field_ref("i32"), field_ref("i32_req"))); +} + +TEST(Expression, CanonicalizeAnd) { + // some aliases for brevity: + auto true_ = literal(true); + auto null_ = literal(std::make_shared()); + + auto b = field_ref("bool"); + auto c = equal(literal(1), literal(2)); + + // no change possible: + ExpectCanonicalizesTo(and_(b, c), and_(b, c)); + + // literals are placed innermost + ExpectCanonicalizesTo(and_(b, true_), and_(true_, b)); + ExpectCanonicalizesTo(and_(true_, b), and_(true_, b)); + + ExpectCanonicalizesTo(and_(b, and_(true_, c)), and_(and_(true_, b), c)); + ExpectCanonicalizesTo(and_(b, and_(and_(true_, true_), c)), + and_(and_(and_(true_, true_), b), c)); + ExpectCanonicalizesTo(and_(b, and_(and_(true_, null_), c)), + and_(and_(and_(null_, true_), b), c)); + ExpectCanonicalizesTo(and_(b, and_(and_(true_, null_), and_(c, null_))), + and_(and_(and_(and_(null_, null_), true_), b), c)); + + // catches and_kleene even when it's a subexpression + ExpectCanonicalizesTo(call("is_valid", {and_(b, true_)}), + call("is_valid", {and_(true_, b)})); +} + +TEST(Expression, CanonicalizeComparison) { + ExpectCanonicalizesTo(equal(literal(1), field_ref("i32")), + equal(field_ref("i32"), literal(1))); + + ExpectCanonicalizesTo(equal(field_ref("i32"), literal(1)), + equal(field_ref("i32"), literal(1))); + + ExpectCanonicalizesTo(less(literal(1), field_ref("i32")), + greater(field_ref("i32"), literal(1))); + + ExpectCanonicalizesTo(less(field_ref("i32"), literal(1)), + less(field_ref("i32"), literal(1))); +} + +struct Simplify { + Expression expr; + + struct Expectable { + Expression expr, guarantee; + + void Expect(Expression unbound_expected) { + ASSERT_OK_AND_ASSIGN(auto bound, expr.Bind(*kBoringSchema)); + + ASSERT_OK_AND_ASSIGN(auto simplified, SimplifyWithGuarantee(bound, guarantee)); + + ASSERT_OK_AND_ASSIGN(auto expected, unbound_expected.Bind(*kBoringSchema)); + EXPECT_EQ(simplified, expected) << " original: " << expr.ToString() << "\n" + << " guarantee: " << guarantee.ToString() << "\n" + << (simplified == bound ? " (no change)\n" : ""); + + ExpectIdenticalIfUnchanged(simplified, bound); + } + void ExpectUnchanged() { Expect(expr); } + void Expect(bool constant) { Expect(literal(constant)); } + }; + + Expectable WithGuarantee(Expression guarantee) { return {expr, guarantee}; } +}; + +TEST(Expression, SingleComparisonGuarantees) { + auto i32 = field_ref("i32"); + + // i32 is guaranteed equal to 3, so the projection can just materialize that constant + // and need not incur IO + Simplify{project({call("add", {i32, literal(1)})}, {"i32 + 1"})} + .WithGuarantee(equal(i32, literal(3))) + .Expect(literal( + std::make_shared(ScalarVector{std::make_shared(4)}, + struct_({field("i32 + 1", int32())})))); + + // i32 is guaranteed equal to 5 everywhere, so filtering i32==5 is redundant and the + // filter can be simplified to true (== select everything) + Simplify{ + equal(i32, literal(5)), + } + .WithGuarantee(equal(i32, literal(5))) + .Expect(true); + + Simplify{ + equal(i32, literal(5)), + } + .WithGuarantee(equal(i32, literal(5))) + .Expect(true); + + Simplify{ + less_equal(i32, literal(5)), + } + .WithGuarantee(equal(i32, literal(5))) + .Expect(true); + + Simplify{ + less(i32, literal(5)), + } + .WithGuarantee(equal(i32, literal(3))) + .Expect(true); + + Simplify{ + greater_equal(i32, literal(5)), + } + .WithGuarantee(greater(i32, literal(5))) + .Expect(true); + + // i32 is guaranteed less than 3 everywhere, so filtering i32==5 is redundant and the + // filter can be simplified to false (== select nothing) + Simplify{ + equal(i32, literal(5)), + } + .WithGuarantee(less(i32, literal(3))) + .Expect(false); + + Simplify{ + less(i32, literal(5)), + } + .WithGuarantee(equal(i32, literal(5))) + .Expect(false); + + Simplify{ + less_equal(i32, literal(3)), + } + .WithGuarantee(equal(i32, literal(5))) + .Expect(false); + + // no simplification possible: + Simplify{ + not_equal(i32, literal(3)), + } + .WithGuarantee(less(i32, literal(5))) + .ExpectUnchanged(); + + // exhaustive coverage of all single comparison simplifications + for (std::string filter_op : + {"equal", "not_equal", "less", "less_equal", "greater", "greater_equal"}) { + for (auto filter_rhs : {literal(5), literal(3), literal(7)}) { + auto filter = call(filter_op, {i32, filter_rhs}); + for (std::string guarantee_op : + {"equal", "less", "less_equal", "greater", "greater_equal"}) { + auto guarantee = call(guarantee_op, {i32, literal(5)}); + + // generate data which satisfies the guarantee + static std::unordered_map satisfying_i32{ + {"equal", "[5]"}, + {"less", "[4, 3, 2, 1]"}, + {"less_equal", "[5, 4, 3, 2, 1]"}, + {"greater", "[6, 7, 8, 9]"}, + {"greater_equal", "[5, 6, 7, 8, 9]"}, + }; + + ASSERT_OK_AND_ASSIGN( + Datum input, + StructArray::Make({ArrayFromJSON(int32(), satisfying_i32[guarantee_op])}, + {"i32"})); + + ASSERT_OK_AND_ASSIGN(filter, filter.Bind(*kBoringSchema)); + ASSERT_OK_AND_ASSIGN(Datum evaluated, ExecuteScalarExpression(filter, input)); + + // ensure that the simplified filter is as simplified as it could be + // (this is always possible for single comparisons) + bool all = true, none = true; + for (int64_t i = 0; i < input.length(); ++i) { + if (evaluated.array_as()->Value(i)) { + none = false; + } else { + all = false; + } + } + Simplify{filter}.WithGuarantee(guarantee).Expect( + all ? literal(true) : none ? literal(false) : filter); + } + } + } +} + +TEST(Expression, SimplifyWithGuarantee) { + // drop both members of a conjunctive filter + Simplify{ + and_(equal(field_ref("i32"), literal(2)), equal(field_ref("f32"), literal(3.5F)))} + .WithGuarantee(and_(greater_equal(field_ref("i32"), literal(0)), + less_equal(field_ref("i32"), literal(1)))) + .Expect(false); + + // drop one member of a conjunctive filter + Simplify{ + and_(equal(field_ref("i32"), literal(0)), equal(field_ref("f32"), literal(3.5F)))} + .WithGuarantee(equal(field_ref("i32"), literal(0))) + .Expect(equal(field_ref("f32"), literal(3.5F))); + + // drop both members of a disjunctive filter + Simplify{ + or_(equal(field_ref("i32"), literal(0)), equal(field_ref("f32"), literal(3.5F)))} + .WithGuarantee(equal(field_ref("i32"), literal(0))) + .Expect(true); + + // drop one member of a disjunctive filter + Simplify{or_(equal(field_ref("i32"), literal(0)), equal(field_ref("i32"), literal(3)))} + .WithGuarantee(and_(greater_equal(field_ref("i32"), literal(0)), + less_equal(field_ref("i32"), literal(1)))) + .Expect(equal(field_ref("i32"), literal(0))); + Simplify{ + or_(equal(field_ref("f32"), literal("0")), equal(field_ref("i32"), literal(3)))} + .WithGuarantee(greater(field_ref("f32"), literal(0.0))) + .Expect(equal(field_ref("i32"), literal(3))); + + // simplification can see through implicit casts + Simplify{or_({equal(field_ref("f32"), literal("0")), + call("is_in", {field_ref("i64")}, + compute::SetLookupOptions{ + ArrayFromJSON(dictionary(int32(), int32()), "[1,2,3]"), true})})} + .WithGuarantee(greater(field_ref("f32"), literal(0.0))) + .Expect(call("is_in", {field_ref("i64")}, + compute::SetLookupOptions{ArrayFromJSON(int64(), "[1,2,3]"), true})); +} + +TEST(Expression, SimplifyThenExecute) { + auto filter = + or_({equal(field_ref("f32"), literal("0")), + call("is_in", {field_ref("i64")}, + compute::SetLookupOptions{ + ArrayFromJSON(dictionary(int32(), int32()), "[1,2,3]"), true})}); + + ASSERT_OK_AND_ASSIGN(filter, filter.Bind(*kBoringSchema)); + auto guarantee = greater(field_ref("f32"), literal(0.0)); + + ASSERT_OK_AND_ASSIGN(auto simplified, SimplifyWithGuarantee(filter, guarantee)); + + auto input = RecordBatchFromJSON(kBoringSchema, R"([ + {"i64": 0, "f32": 0.1}, + {"i64": 0, "f32": 0.3}, + {"i64": 1, "f32": 0.5}, + {"i64": 2, "f32": 0.1}, + {"i64": 0, "f32": 0.1}, + {"i64": 0, "f32": 0.4}, + {"i64": 0, "f32": 1.0} + ])"); + + Datum evaluated, simplified_evaluated; + AssertExecute(filter, input, &evaluated); + AssertExecute(simplified, input, &simplified_evaluated); + AssertDatumsEqual(evaluated, simplified_evaluated, /*verbose=*/true); +} + +TEST(Expression, Filter) { + auto ExpectFilter = [](Expression filter, std::string batch_json) { + ASSERT_OK_AND_ASSIGN(auto s, kBoringSchema->AddField(0, field("in", boolean()))); + auto batch = RecordBatchFromJSON(s, batch_json); + auto expected_mask = batch->column(0); + + ASSERT_OK_AND_ASSIGN(filter, filter.Bind(*kBoringSchema)); + ASSERT_OK_AND_ASSIGN(Datum mask, ExecuteScalarExpression(filter, batch)); + + AssertDatumsEqual(expected_mask, mask); + }; + + ExpectFilter(equal(field_ref("i32"), literal(0)), R"([ + {"i32": 0, "f32": -0.1, "in": 1}, + {"i32": 0, "f32": 0.3, "in": 1}, + {"i32": 1, "f32": 0.2, "in": 0}, + {"i32": 2, "f32": -0.1, "in": 0}, + {"i32": 0, "f32": 0.1, "in": 1}, + {"i32": 0, "f32": null, "in": 1}, + {"i32": 0, "f32": 1.0, "in": 1} + ])"); + + ExpectFilter( + greater(call("multiply", {field_ref("f32"), field_ref("f64")}), literal(0)), R"([ + {"f64": 0.3, "f32": 0.1, "in": 1}, + {"f64": -0.1, "f32": 0.3, "in": 0}, + {"f64": 0.1, "f32": 0.2, "in": 1}, + {"f64": 0.0, "f32": -0.1, "in": 0}, + {"f64": 1.0, "f32": 0.1, "in": 1}, + {"f64": -2.0, "f32": null, "in": null}, + {"f64": 3.0, "f32": 1.0, "in": 1} + ])"); +} + +TEST(Expression, SerializationRoundTrips) { + auto ExpectRoundTrips = [](const Expression& expr) { + ASSERT_OK_AND_ASSIGN(auto serialized, Serialize(expr)); + ASSERT_OK_AND_ASSIGN(Expression roundtripped, Deserialize(*serialized)); + EXPECT_EQ(expr, roundtripped); + }; + + ExpectRoundTrips(literal(MakeNullScalar(null()))); + + ExpectRoundTrips(literal(MakeNullScalar(int32()))); + + ExpectRoundTrips( + literal(MakeNullScalar(struct_({field("i", int32()), field("s", utf8())})))); + + ExpectRoundTrips(literal(true)); + + ExpectRoundTrips(literal(false)); + + ExpectRoundTrips(literal(1)); + + ExpectRoundTrips(literal(1.125)); + + ExpectRoundTrips(literal("stringy strings")); + + ExpectRoundTrips(field_ref("field")); + + ExpectRoundTrips(greater(field_ref("a"), literal(0.25))); + + ExpectRoundTrips( + or_({equal(field_ref("a"), literal(1)), not_equal(field_ref("b"), literal("hello")), + equal(field_ref("b"), literal("foo bar"))})); + + ExpectRoundTrips(not_(field_ref("alpha"))); + + ExpectRoundTrips(call("is_in", {literal(1)}, + compute::SetLookupOptions{ArrayFromJSON(int32(), "[1, 2, 3]")})); + + ExpectRoundTrips( + call("is_in", + {call("cast", {field_ref("version")}, compute::CastOptions::Safe(float64()))}, + compute::SetLookupOptions{ArrayFromJSON(float64(), "[0.5, 1.0, 2.0]"), true})); + + ExpectRoundTrips(call("is_valid", {field_ref("validity")})); + + ExpectRoundTrips(and_({and_(greater_equal(field_ref("x"), literal(-1.5)), + less(field_ref("x"), literal(0.0))), + and_(greater_equal(field_ref("y"), literal(0.0)), + less(field_ref("y"), literal(1.5))), + and_(greater(field_ref("z"), literal(1.5)), + less_equal(field_ref("z"), literal(3.0)))})); + + ExpectRoundTrips(and_({equal(field_ref("year"), literal(int16_t(1999))), + equal(field_ref("month"), literal(int8_t(12))), + equal(field_ref("day"), literal(int8_t(31))), + equal(field_ref("hour"), literal(int8_t(0))), + equal(field_ref("alpha"), literal(int32_t(0))), + equal(field_ref("beta"), literal(3.25f))})); +} + +} // namespace dataset +} // namespace arrow diff --git a/cpp/src/arrow/dataset/file_base.cc b/cpp/src/arrow/dataset/file_base.cc index ef044ea3a03..2c437ce8eec 100644 --- a/cpp/src/arrow/dataset/file_base.cc +++ b/cpp/src/arrow/dataset/file_base.cc @@ -24,7 +24,6 @@ #include #include "arrow/dataset/dataset_internal.h" -#include "arrow/dataset/filter.h" #include "arrow/dataset/scanner.h" #include "arrow/dataset/scanner_internal.h" #include "arrow/filesystem/filesystem.h" @@ -57,16 +56,16 @@ Result> FileSource::Open() const { Result> FileFormat::MakeFragment( FileSource source, std::shared_ptr physical_schema) { - return MakeFragment(std::move(source), scalar(true), std::move(physical_schema)); + return MakeFragment(std::move(source), literal(true), std::move(physical_schema)); } Result> FileFormat::MakeFragment( - FileSource source, std::shared_ptr partition_expression) { + FileSource source, Expression partition_expression) { return MakeFragment(std::move(source), std::move(partition_expression), nullptr); } Result> FileFormat::MakeFragment( - FileSource source, std::shared_ptr partition_expression, + FileSource source, Expression partition_expression, std::shared_ptr physical_schema) { return std::shared_ptr( new FileFragment(std::move(source), shared_from_this(), @@ -83,7 +82,7 @@ Result FileFragment::Scan(std::shared_ptr options } FileSystemDataset::FileSystemDataset(std::shared_ptr schema, - std::shared_ptr root_partition, + Expression root_partition, std::shared_ptr format, std::shared_ptr filesystem, std::vector> fragments) @@ -93,7 +92,7 @@ FileSystemDataset::FileSystemDataset(std::shared_ptr schema, fragments_(std::move(fragments)) {} Result> FileSystemDataset::Make( - std::shared_ptr schema, std::shared_ptr root_partition, + std::shared_ptr schema, Expression root_partition, std::shared_ptr format, std::shared_ptr filesystem, std::vector> fragments) { return std::shared_ptr(new FileSystemDataset( @@ -129,20 +128,22 @@ std::string FileSystemDataset::ToString() const { repr += "\n" + fragment->source().path(); const auto& partition = fragment->partition_expression(); - if (!partition->Equals(true)) { - repr += ": " + partition->ToString(); + if (partition != literal(true)) { + repr += ": " + partition.ToString(); } } return repr; } -FragmentIterator FileSystemDataset::GetFragmentsImpl( - std::shared_ptr predicate) { +Result FileSystemDataset::GetFragmentsImpl(Expression predicate) { FragmentVector fragments; for (const auto& fragment : fragments_) { - if (predicate->IsSatisfiableWith(fragment->partition_expression())) { + ARROW_ASSIGN_OR_RAISE( + auto simplified, + SimplifyWithGuarantee(predicate, fragment->partition_expression())); + if (simplified.IsSatisfiable()) { fragments.push_back(fragment); } } @@ -273,7 +274,8 @@ Status FileSystemDataset::Write(const FileSystemDatasetWriteOptions& write_optio // // NB: neither of these will have any impact whatsoever on the common case of writing // an in-memory table to disk. - ARROW_ASSIGN_OR_RAISE(FragmentVector fragments, scanner->GetFragments().ToVector()); + ARROW_ASSIGN_OR_RAISE(auto fragment_it, scanner->GetFragments()); + ARROW_ASSIGN_OR_RAISE(FragmentVector fragments, fragment_it.ToVector()); ScanTaskVector scan_tasks; std::vector fragment_for_task; @@ -313,8 +315,10 @@ Status FileSystemDataset::Write(const FileSystemDatasetWriteOptions& write_optio std::unordered_set need_flushed; for (size_t i = 0; i < groups.batches.size(); ++i) { - AndExpression partition_expression(std::move(groups.expressions[i]), - fragment->partition_expression()); + ARROW_ASSIGN_OR_RAISE( + auto partition_expression, + and_(std::move(groups.expressions[i]), fragment->partition_expression()) + .Bind(*scanner->schema())); auto batch = std::move(groups.batches[i]); ARROW_ASSIGN_OR_RAISE(auto part, diff --git a/cpp/src/arrow/dataset/file_base.h b/cpp/src/arrow/dataset/file_base.h index 192921e7cf0..bb2aa86ba9b 100644 --- a/cpp/src/arrow/dataset/file_base.h +++ b/cpp/src/arrow/dataset/file_base.h @@ -143,11 +143,11 @@ class ARROW_DS_EXPORT FileFormat : public std::enable_shared_from_this> MakeFragment( - FileSource source, std::shared_ptr partition_expression, + FileSource source, Expression partition_expression, std::shared_ptr physical_schema); - Result> MakeFragment( - FileSource source, std::shared_ptr partition_expression); + Result> MakeFragment(FileSource source, + Expression partition_expression); Result> MakeFragment( FileSource source, std::shared_ptr physical_schema = NULLPTR); @@ -173,8 +173,7 @@ class ARROW_DS_EXPORT FileFragment : public Fragment { protected: FileFragment(FileSource source, std::shared_ptr format, - std::shared_ptr partition_expression, - std::shared_ptr physical_schema) + Expression partition_expression, std::shared_ptr physical_schema) : Fragment(std::move(partition_expression), std::move(physical_schema)), source_(std::move(source)), format_(std::move(format)) {} @@ -207,7 +206,7 @@ class ARROW_DS_EXPORT FileSystemDataset : public Dataset { /// /// \return A constructed dataset. static Result> Make( - std::shared_ptr schema, std::shared_ptr root_partition, + std::shared_ptr schema, Expression root_partition, std::shared_ptr format, std::shared_ptr filesystem, std::vector> fragments); @@ -234,10 +233,9 @@ class ARROW_DS_EXPORT FileSystemDataset : public Dataset { std::string ToString() const; protected: - FragmentIterator GetFragmentsImpl(std::shared_ptr predicate) override; + Result GetFragmentsImpl(Expression predicate) override; - FileSystemDataset(std::shared_ptr schema, - std::shared_ptr root_partition, + FileSystemDataset(std::shared_ptr schema, Expression root_partition, std::shared_ptr format, std::shared_ptr filesystem, std::vector> fragments); diff --git a/cpp/src/arrow/dataset/file_csv.cc b/cpp/src/arrow/dataset/file_csv.cc index f889b12fddd..bc1a69066f7 100644 --- a/cpp/src/arrow/dataset/file_csv.cc +++ b/cpp/src/arrow/dataset/file_csv.cc @@ -28,12 +28,12 @@ #include "arrow/csv/reader.h" #include "arrow/dataset/dataset_internal.h" #include "arrow/dataset/file_base.h" -#include "arrow/dataset/filter.h" #include "arrow/dataset/type_fwd.h" #include "arrow/dataset/visibility.h" #include "arrow/result.h" #include "arrow/type.h" #include "arrow/util/iterator.h" +#include "arrow/util/logging.h" namespace arrow { namespace dataset { @@ -90,13 +90,16 @@ static inline Result GetConvertOptions( } // FIXME(bkietz) also acquire types of fields materialized but not projected. - for (auto&& name : FieldsInExpression(scan_options->filter)) { - ARROW_ASSIGN_OR_RAISE(auto match, - FieldRef(name).FindOneOrNone(*scan_options->schema())); - if (match.indices().empty()) { - convert_options.include_columns.push_back(std::move(name)); + // This requires that scan_options include the full dataset schema (not just + // the projected schema). + for (const FieldRef& ref : FieldsInExpression(scan_options->filter2)) { + DCHECK(ref.name()); + ARROW_ASSIGN_OR_RAISE(auto match, ref.FindOneOrNone(*scan_options->schema())); + if (!match) { + convert_options.include_columns.push_back(*ref.name()); } } + return convert_options; } diff --git a/cpp/src/arrow/dataset/file_csv_test.cc b/cpp/src/arrow/dataset/file_csv_test.cc index 8e73be0ee8f..eb0e1bf9395 100644 --- a/cpp/src/arrow/dataset/file_csv_test.cc +++ b/cpp/src/arrow/dataset/file_csv_test.cc @@ -23,7 +23,6 @@ #include "arrow/dataset/dataset_internal.h" #include "arrow/dataset/file_base.h" -#include "arrow/dataset/filter.h" #include "arrow/dataset/partition.h" #include "arrow/dataset/test_util.h" #include "arrow/io/memory.h" @@ -148,7 +147,7 @@ N/A,bar schema_ = schema({field("f64", utf8()), field("str", utf8())}); ScannerBuilder builder(schema_, fragment, ctx_); // filter expression validated against declared schema - ASSERT_OK(builder.Filter("f64"_ == "str"_)); + ASSERT_OK(builder.Filter(equal(field_ref("f64"), field_ref("str")))); // project only "str" ASSERT_OK(builder.Project({"str"})); ASSERT_OK_AND_ASSIGN(auto scanner, builder.Finish()); diff --git a/cpp/src/arrow/dataset/file_ipc_test.cc b/cpp/src/arrow/dataset/file_ipc_test.cc index 25bbe559cf4..f49337b362e 100644 --- a/cpp/src/arrow/dataset/file_ipc_test.cc +++ b/cpp/src/arrow/dataset/file_ipc_test.cc @@ -24,7 +24,6 @@ #include "arrow/dataset/dataset_internal.h" #include "arrow/dataset/discovery.h" #include "arrow/dataset/file_base.h" -#include "arrow/dataset/filter.h" #include "arrow/dataset/partition.h" #include "arrow/dataset/test_util.h" #include "arrow/io/memory.h" @@ -244,7 +243,7 @@ TEST_F(TestIpcFileFormat, ScanRecordBatchReaderProjected) { opts_ = ScanOptions::Make(schema_); opts_->projector = RecordBatchProjector(SchemaFromColumnNames(schema_, {"f64"})); - opts_->filter = equal(field_ref("i32"), scalar(0)); + opts_->filter2 = equal(field_ref("i32"), literal(0)); // NB: projector is applied by the scanner; FileFragment does not evaluate it so // we will not drop "i32" even though it is not in the projector's schema @@ -280,7 +279,7 @@ TEST_F(TestIpcFileFormat, ScanRecordBatchReaderProjectedMissingCols) { schema_ = reader->schema(); opts_ = ScanOptions::Make(schema_); opts_->projector = RecordBatchProjector(SchemaFromColumnNames(schema_, {"f64"})); - opts_->filter = equal(field_ref("i32"), scalar(0)); + opts_->filter2 = equal(field_ref("i32"), literal(0)); auto readers = {reader.get(), reader_without_i32.get(), reader_without_f64.get()}; for (auto reader : readers) { diff --git a/cpp/src/arrow/dataset/file_parquet.cc b/cpp/src/arrow/dataset/file_parquet.cc index e06cee649ef..94d2115358f 100644 --- a/cpp/src/arrow/dataset/file_parquet.cc +++ b/cpp/src/arrow/dataset/file_parquet.cc @@ -24,7 +24,6 @@ #include #include "arrow/dataset/dataset_internal.h" -#include "arrow/dataset/filter.h" #include "arrow/dataset/scanner.h" #include "arrow/filesystem/path_util.h" #include "arrow/table.h" @@ -125,7 +124,7 @@ static Result> GetSchemaManifest( return manifest; } -static std::shared_ptr ColumnChunkStatisticsAsExpression( +static util::optional ColumnChunkStatisticsAsExpression( const SchemaField& schema_field, const parquet::RowGroupMetaData& metadata) { // For the remaining of this function, failure to extract/parse statistics // are ignored by returning nullptr. The goal is two fold. First @@ -134,13 +133,13 @@ static std::shared_ptr ColumnChunkStatisticsAsExpression( // For now, only leaf (primitive) types are supported. if (!schema_field.is_leaf()) { - return nullptr; + return util::nullopt; } auto column_metadata = metadata.ColumnChunk(schema_field.column_index); auto statistics = column_metadata->statistics(); if (statistics == nullptr) { - return nullptr; + return util::nullopt; } const auto& field = schema_field.field; @@ -148,12 +147,12 @@ static std::shared_ptr ColumnChunkStatisticsAsExpression( // Optimize for corner case where all values are nulls if (statistics->num_values() == statistics->null_count()) { - return equal(std::move(field_expr), scalar(MakeNullScalar(field->type()))); + return equal(std::move(field_expr), literal(MakeNullScalar(field->type()))); } std::shared_ptr min, max; if (!StatisticsAsScalars(*statistics, &min, &max).ok()) { - return nullptr; + return util::nullopt; } auto maybe_min = min->CastTo(field->type()); @@ -161,11 +160,11 @@ static std::shared_ptr ColumnChunkStatisticsAsExpression( if (maybe_min.ok() && maybe_max.ok()) { min = maybe_min.MoveValueUnsafe(); max = maybe_max.MoveValueUnsafe(); - return and_(greater_equal(field_expr, scalar(min)), - less_equal(field_expr, scalar(max))); + return and_(greater_equal(field_expr, literal(min)), + less_equal(field_expr, literal(max))); } - return nullptr; + return util::nullopt; } static void AddColumnIndices(const SchemaField& schema_field, @@ -291,7 +290,7 @@ Result ParquetFileFormat::ScanFile(std::shared_ptr row_groups; bool pre_filtered = false; - auto empty = [] { return MakeEmptyIterator>(); }; + auto MakeEmpty = [] { return MakeEmptyIterator>(); }; // If RowGroup metadata is cached completely we can pre-filter RowGroups before opening // a FileReader, potentially avoiding IO altogether if all RowGroups are excluded due to @@ -299,10 +298,10 @@ Result ParquetFileFormat::ScanFile(std::shared_ptrmetadata() != nullptr) { ARROW_ASSIGN_OR_RAISE(row_groups, - parquet_fragment->FilterRowGroups(*options->filter)); + parquet_fragment->FilterRowGroups(options->filter2)); pre_filtered = true; - if (row_groups.empty()) empty(); + if (row_groups.empty()) MakeEmpty(); } // Open the reader and pay the real IO cost. @@ -315,9 +314,9 @@ Result ParquetFileFormat::ScanFile(std::shared_ptrFilterRowGroups(*options->filter)); + parquet_fragment->FilterRowGroups(options->filter2)); - if (row_groups.empty()) empty(); + if (row_groups.empty()) MakeEmpty(); } auto column_projection = InferColumnProjection(*reader, *options); @@ -332,7 +331,7 @@ Result ParquetFileFormat::ScanFile(std::shared_ptr> ParquetFileFormat::MakeFragment( - FileSource source, std::shared_ptr partition_expression, + FileSource source, Expression partition_expression, std::shared_ptr physical_schema, std::vector row_groups) { return std::shared_ptr(new ParquetFileFragment( std::move(source), shared_from_this(), std::move(partition_expression), @@ -340,7 +339,7 @@ Result> ParquetFileFormat::MakeFragment( } Result> ParquetFileFormat::MakeFragment( - FileSource source, std::shared_ptr partition_expression, + FileSource source, Expression partition_expression, std::shared_ptr physical_schema) { return std::shared_ptr(new ParquetFileFragment( std::move(source), shared_from_this(), std::move(partition_expression), @@ -395,7 +394,7 @@ Status ParquetFileWriter::Finish() { return parquet_writer_->Close(); } ParquetFileFragment::ParquetFileFragment(FileSource source, std::shared_ptr format, - std::shared_ptr partition_expression, + Expression partition_expression, std::shared_ptr physical_schema, util::optional> row_groups) : FileFragment(std::move(source), std::move(format), std::move(partition_expression), @@ -442,7 +441,7 @@ Status ParquetFileFragment::SetMetadata( metadata_ = std::move(metadata); manifest_ = std::move(manifest); - statistics_expressions_.resize(row_groups_->size(), scalar(true)); + statistics_expressions_.resize(row_groups_->size(), literal(true)); statistics_expressions_complete_.resize(physical_schema_->num_fields(), false); for (int row_group : *row_groups_) { @@ -457,10 +456,9 @@ Status ParquetFileFragment::SetMetadata( return Status::OK(); } -Result ParquetFileFragment::SplitByRowGroup( - const std::shared_ptr& predicate) { +Result ParquetFileFragment::SplitByRowGroup(Expression predicate) { RETURN_NOT_OK(EnsureCompleteMetadata()); - ARROW_ASSIGN_OR_RAISE(auto row_groups, FilterRowGroups(*predicate)); + ARROW_ASSIGN_OR_RAISE(auto row_groups, FilterRowGroups(predicate)); FragmentVector fragments(row_groups.size()); int i = 0; @@ -476,10 +474,9 @@ Result ParquetFileFragment::SplitByRowGroup( return fragments; } -Result> ParquetFileFragment::Subset( - const std::shared_ptr& predicate) { +Result> ParquetFileFragment::Subset(Expression predicate) { RETURN_NOT_OK(EnsureCompleteMetadata()); - ARROW_ASSIGN_OR_RAISE(auto row_groups, FilterRowGroups(*predicate)); + ARROW_ASSIGN_OR_RAISE(auto row_groups, FilterRowGroups(predicate)); return Subset(std::move(row_groups)); } @@ -494,22 +491,26 @@ Result> ParquetFileFragment::Subset( return new_fragment; } -inline void FoldingAnd(std::shared_ptr* l, std::shared_ptr r) { - if ((*l)->Equals(true)) { +inline void FoldingAnd(Expression* l, Expression r) { + if (*l == literal(true)) { *l = std::move(r); } else { *l = and_(std::move(*l), std::move(r)); } } -Result> ParquetFileFragment::FilterRowGroups( - const Expression& predicate) { +Result> ParquetFileFragment::FilterRowGroups(Expression predicate) { auto lock = physical_schema_mutex_.Lock(); DCHECK_NE(metadata_, nullptr); - RETURN_NOT_OK(predicate.Validate(*physical_schema_)); + ARROW_ASSIGN_OR_RAISE( + predicate, SimplifyWithGuarantee(std::move(predicate), partition_expression_)); + + if (!predicate.IsSatisfiable()) { + return std::vector{}; + } - for (FieldRef ref : FieldsInExpression(predicate)) { + for (const FieldRef& ref : FieldsInExpression(predicate)) { ARROW_ASSIGN_OR_RAISE(auto path, ref.FindOneOrNone(*physical_schema_)); if (!path) continue; @@ -523,21 +524,20 @@ Result> ParquetFileFragment::FilterRowGroups( if (auto minmax = ColumnChunkStatisticsAsExpression(schema_field, *row_group_metadata)) { - FoldingAnd(&statistics_expressions_[i], std::move(minmax)); + FoldingAnd(&statistics_expressions_[i], std::move(*minmax)); + ARROW_ASSIGN_OR_RAISE(statistics_expressions_[i], + statistics_expressions_[i].Bind(*physical_schema_)); } ++i; } } - auto simplified_predicate = predicate.Assume(partition_expression_); - if (!simplified_predicate->IsSatisfiable()) { - return std::vector{}; - } - std::vector row_groups; for (size_t i = 0; i < row_groups_->size(); ++i) { - if (simplified_predicate->IsSatisfiableWith(statistics_expressions_[i])) { + ARROW_ASSIGN_OR_RAISE(auto row_group_predicate, + SimplifyWithGuarantee(predicate, statistics_expressions_[i])); + if (row_group_predicate.IsSatisfiable()) { row_groups.push_back(row_groups_->at(i)); } } @@ -661,7 +661,7 @@ ParquetDatasetFactory::CollectParquetFragments(const Partitioning& partitioning) auto partition_expression = partitioning.Parse(StripPrefixAndFilename(path, options_.partition_base_dir)) - .ValueOr(scalar(true)); + .ValueOr(literal(true)); ARROW_ASSIGN_OR_RAISE( auto fragment, @@ -712,7 +712,7 @@ Result> ParquetDatasetFactory::Finish(FinishOptions opt } ARROW_ASSIGN_OR_RAISE(auto fragments, CollectParquetFragments(*partitioning)); - return FileSystemDataset::Make(std::move(schema), scalar(true), format_, filesystem_, + return FileSystemDataset::Make(std::move(schema), literal(true), format_, filesystem_, std::move(fragments)); } diff --git a/cpp/src/arrow/dataset/file_parquet.h b/cpp/src/arrow/dataset/file_parquet.h index b82241a574b..ae0337994a0 100644 --- a/cpp/src/arrow/dataset/file_parquet.h +++ b/cpp/src/arrow/dataset/file_parquet.h @@ -117,12 +117,12 @@ class ARROW_DS_EXPORT ParquetFileFormat : public FileFormat { /// \brief Create a Fragment targeting all RowGroups. Result> MakeFragment( - FileSource source, std::shared_ptr partition_expression, + FileSource source, Expression partition_expression, std::shared_ptr physical_schema) override; /// \brief Create a Fragment, restricted to the specified row groups. Result> MakeFragment( - FileSource source, std::shared_ptr partition_expression, + FileSource source, Expression partition_expression, std::shared_ptr physical_schema, std::vector row_groups); /// \brief Return a FileReader on the given source. @@ -150,7 +150,7 @@ class ARROW_DS_EXPORT ParquetFileFormat : public FileFormat { /// significant performance boost when scanning high latency file systems. class ARROW_DS_EXPORT ParquetFileFragment : public FileFragment { public: - Result SplitByRowGroup(const std::shared_ptr& predicate); + Result SplitByRowGroup(Expression predicate); /// \brief Return the RowGroups selected by this fragment. const std::vector& row_groups() const { @@ -166,12 +166,12 @@ class ARROW_DS_EXPORT ParquetFileFragment : public FileFragment { Status EnsureCompleteMetadata(parquet::arrow::FileReader* reader = NULLPTR); /// \brief Return fragment which selects a filtered subset of this fragment's RowGroups. - Result> Subset(const std::shared_ptr& predicate); + Result> Subset(Expression predicate); Result> Subset(std::vector row_group_ids); private: ParquetFileFragment(FileSource source, std::shared_ptr format, - std::shared_ptr partition_expression, + Expression partition_expression, std::shared_ptr physical_schema, util::optional> row_groups); @@ -185,7 +185,7 @@ class ARROW_DS_EXPORT ParquetFileFragment : public FileFragment { } // Return a filtered subset of row group indices. - Result> FilterRowGroups(const Expression& predicate); + Result> FilterRowGroups(Expression predicate); ParquetFileFormat& parquet_format_; @@ -193,7 +193,7 @@ class ARROW_DS_EXPORT ParquetFileFragment : public FileFragment { // or util::nullopt if all row groups are selected. util::optional> row_groups_; - ExpressionVector statistics_expressions_; + std::vector statistics_expressions_; std::vector statistics_expressions_complete_; std::shared_ptr metadata_; std::shared_ptr manifest_; diff --git a/cpp/src/arrow/dataset/file_parquet_test.cc b/cpp/src/arrow/dataset/file_parquet_test.cc index c2e0d6b632d..67d1fb17120 100644 --- a/cpp/src/arrow/dataset/file_parquet_test.cc +++ b/cpp/src/arrow/dataset/file_parquet_test.cc @@ -22,7 +22,6 @@ #include #include "arrow/dataset/dataset_internal.h" -#include "arrow/dataset/filter.h" #include "arrow/dataset/test_util.h" #include "arrow/record_batch.h" #include "arrow/table.h" @@ -160,6 +159,10 @@ class TestParquetFileFormat : public ArrowParquetWriterMixin { return Batches(std::move(scan_task_it)); } + void SetFilter(Expression filter) { + ASSERT_OK_AND_ASSIGN(opts_->filter2, filter.Bind(*schema_)); + } + std::shared_ptr SingleBatch(Fragment* fragment) { auto batches = IteratorToVector(Batches(fragment)); EXPECT_EQ(batches.size(), 1); @@ -187,10 +190,11 @@ class TestParquetFileFormat : public ArrowParquetWriterMixin { } void CountRowGroupsInFragment(const std::shared_ptr& fragment, - std::vector expected_row_groups, - const Expression& filter) { + std::vector expected_row_groups, Expression filter) { + schema_ = opts_->schema(); + ASSERT_OK_AND_ASSIGN(auto bound, filter.Bind(*schema_)); auto parquet_fragment = checked_pointer_cast(fragment); - ASSERT_OK_AND_ASSIGN(auto fragments, parquet_fragment->SplitByRowGroup(filter.Copy())) + ASSERT_OK_AND_ASSIGN(auto fragments, parquet_fragment->SplitByRowGroup(bound)) EXPECT_EQ(fragments.size(), expected_row_groups.size()); for (size_t i = 0; i < fragments.size(); i++) { @@ -214,6 +218,7 @@ TEST_F(TestParquetFileFormat, ScanRecordBatchReader) { auto source = GetFileSource(reader.get()); opts_ = ScanOptions::Make(reader->schema()); + SetFilter(literal(true)); ASSERT_OK_AND_ASSIGN(auto fragment, format_->MakeFragment(*source)); int64_t row_count = 0; @@ -233,6 +238,7 @@ TEST_F(TestParquetFileFormat, ScanRecordBatchReaderDictEncoded) { auto source = GetFileSource(reader.get()); opts_ = ScanOptions::Make(reader->schema()); + SetFilter(literal(true)); format_->reader_options.dict_columns = {"utf8"}; ASSERT_OK_AND_ASSIGN(auto fragment, format_->MakeFragment(*source)); @@ -275,7 +281,7 @@ TEST_F(TestParquetFileFormat, ScanRecordBatchReaderProjected) { opts_ = ScanOptions::Make(schema_); opts_->projector = RecordBatchProjector(SchemaFromColumnNames(schema_, {"f64"})); - opts_->filter = equal(field_ref("i32"), scalar(0)); + SetFilter(equal(field_ref("i32"), literal(0))); // NB: projector is applied by the scanner; FileFragment does not evaluate it so // we will not drop "i32" even though it is not in the projector's schema @@ -311,7 +317,7 @@ TEST_F(TestParquetFileFormat, ScanRecordBatchReaderProjectedMissingCols) { schema_ = reader->schema(); opts_ = ScanOptions::Make(schema_); opts_->projector = RecordBatchProjector(SchemaFromColumnNames(schema_, {"f64"})); - opts_->filter = equal(field_ref("i32"), scalar(0)); + SetFilter(equal(field_ref("i32"), literal(0))); auto readers = {reader.get(), reader_without_i32.get(), reader_without_f64.get()}; for (auto reader : readers) { @@ -404,34 +410,37 @@ TEST_F(TestParquetFileFormat, PredicatePushdown) { auto source = GetFileSource(reader.get()); opts_ = ScanOptions::Make(reader->schema()); + schema_ = reader->schema(); ASSERT_OK_AND_ASSIGN(auto fragment, format_->MakeFragment(*source)); - opts_->filter = scalar(true); + SetFilter(literal(true)); CountRowsAndBatchesInScan(fragment, kTotalNumRows, kNumRowGroups); for (int64_t i = 1; i <= kNumRowGroups; i++) { - opts_->filter = ("i64"_ == int64_t(i)).Copy(); + SetFilter(equal(field_ref("i64"), literal(i))); CountRowsAndBatchesInScan(fragment, i, 1); } // Out of bound filters should skip all RowGroups. - opts_->filter = scalar(false); + SetFilter(literal(false)); CountRowsAndBatchesInScan(fragment, 0, 0); - opts_->filter = ("i64"_ == int64_t(kNumRowGroups + 1)).Copy(); + SetFilter(equal(field_ref("i64"), literal(kNumRowGroups + 1))); CountRowsAndBatchesInScan(fragment, 0, 0); - opts_->filter = ("i64"_ == int64_t(-1)).Copy(); + SetFilter(equal(field_ref("i64"), literal(-1))); CountRowsAndBatchesInScan(fragment, 0, 0); // No rows match 1 and 2. - opts_->filter = ("i64"_ == int64_t(1) and "u8"_ == uint8_t(2)).Copy(); + SetFilter(and_(equal(field_ref("i64"), literal(1)), + equal(field_ref("u8"), literal(2)))); CountRowsAndBatchesInScan(fragment, 0, 0); - opts_->filter = ("i64"_ == int64_t(2) or "i64"_ == int64_t(4)).Copy(); + SetFilter(or_(equal(field_ref("i64"), literal(2)), + equal(field_ref("i64"), literal(4)))); CountRowsAndBatchesInScan(fragment, 2 + 4, 2); - opts_->filter = ("i64"_ < int64_t(6)).Copy(); + SetFilter(less(field_ref("i64"), literal(6))); CountRowsAndBatchesInScan(fragment, 5 * (5 + 1) / 2, 5); - opts_->filter = ("i64"_ >= int64_t(6)).Copy(); + SetFilter(greater_equal(field_ref("i64"), literal(6))); CountRowsAndBatchesInScan(fragment, kTotalNumRows - (5 * (5 + 1) / 2), kNumRowGroups - 5); } @@ -446,36 +455,45 @@ TEST_F(TestParquetFileFormat, PredicatePushdownRowGroupFragments) { ASSERT_OK_AND_ASSIGN(auto fragment, format_->MakeFragment(*source)); auto all_row_groups = internal::Iota(static_cast(kNumRowGroups)); - CountRowGroupsInFragment(fragment, all_row_groups, *scalar(true)); - CountRowGroupsInFragment(fragment, all_row_groups, "not here"_ == 0); + CountRowGroupsInFragment(fragment, all_row_groups, literal(true)); + + // FIXME this is only meaningful if "not here" is a virtual column + // CountRowGroupsInFragment(fragment, all_row_groups, "not here"_ == 0); for (int i = 0; i < kNumRowGroups; ++i) { - CountRowGroupsInFragment(fragment, {i}, "i64"_ == int64_t(i + 1)); + CountRowGroupsInFragment(fragment, {i}, equal(field_ref("i64"), literal(i + 1))); } // Out of bound filters should skip all RowGroups. - CountRowGroupsInFragment(fragment, {}, *scalar(false)); - CountRowGroupsInFragment(fragment, {}, "i64"_ == int64_t(kNumRowGroups + 1)); - CountRowGroupsInFragment(fragment, {}, "i64"_ == int64_t(-1)); + CountRowGroupsInFragment(fragment, {}, literal(false)); + CountRowGroupsInFragment(fragment, {}, + equal(field_ref("i64"), literal(kNumRowGroups + 1))); + CountRowGroupsInFragment(fragment, {}, equal(field_ref("i64"), literal(-1))); // No rows match 1 and 2. - CountRowGroupsInFragment(fragment, {}, "i64"_ == int64_t(1) and "u8"_ == uint8_t(2)); - CountRowGroupsInFragment(fragment, {}, "i64"_ == int64_t(2) and "i64"_ == int64_t(4)); + CountRowGroupsInFragment( + fragment, {}, + and_(equal(field_ref("i64"), literal(1)), equal(field_ref("u8"), literal(2)))); + CountRowGroupsInFragment( + fragment, {}, + and_(equal(field_ref("i64"), literal(2)), equal(field_ref("i64"), literal(4)))); - CountRowGroupsInFragment(fragment, {1, 3}, - "i64"_ == int64_t(2) or "i64"_ == int64_t(4)); + CountRowGroupsInFragment( + fragment, {1, 3}, + or_(equal(field_ref("i64"), literal(2)), equal(field_ref("i64"), literal(4)))); // TODO(bkietz): better Assume support for InExpression // auto set = ArrayFromJSON(int64(), "[2, 4]"); - // CountRowGroupsInFragment(fragment, {1, 3}, "i64"_.In(set)); + // CountRowGroupsInFragment(fragment, {1, 3}, field_ref("i64").In(set)); - CountRowGroupsInFragment(fragment, {0, 1, 2, 3, 4}, "i64"_ < int64_t(6)); + CountRowGroupsInFragment(fragment, {0, 1, 2, 3, 4}, less(field_ref("i64"), literal(6))); CountRowGroupsInFragment(fragment, internal::Iota(5, static_cast(kNumRowGroups)), - "i64"_ >= int64_t(6)); + greater_equal(field_ref("i64"), literal(6))); CountRowGroupsInFragment(fragment, {5, 6}, - "i64"_ >= int64_t(6) and "i64"_ < int64_t(8)); + and_(greater_equal(field_ref("i64"), literal(6)), + less(field_ref("i64"), literal(8)))); } TEST_F(TestParquetFileFormat, PredicatePushdownRowGroupFragmentsUsingStringColumn) { @@ -492,7 +510,7 @@ TEST_F(TestParquetFileFormat, PredicatePushdownRowGroupFragmentsUsingStringColum opts_ = ScanOptions::Make(reader.schema()); ASSERT_OK_AND_ASSIGN(auto fragment, format_->MakeFragment(*source)); - CountRowGroupsInFragment(fragment, {0, 3}, "x"_ == "a"); + CountRowGroupsInFragment(fragment, {0, 3}, equal(field_ref("x"), literal("a"))); } TEST_F(TestParquetFileFormat, ExplicitRowGroupSelection) { @@ -503,10 +521,12 @@ TEST_F(TestParquetFileFormat, ExplicitRowGroupSelection) { auto source = GetFileSource(reader.get()); opts_ = ScanOptions::Make(reader->schema()); + schema_ = reader->schema(); + SetFilter(literal(true)); auto row_groups_fragment = [&](std::vector row_groups) { EXPECT_OK_AND_ASSIGN(auto fragment, - format_->MakeFragment(*source, scalar(true), + format_->MakeFragment(*source, literal(true), /*physical_schema=*/nullptr, row_groups)); return fragment; }; @@ -514,7 +534,7 @@ TEST_F(TestParquetFileFormat, ExplicitRowGroupSelection) { // select all row groups EXPECT_OK_AND_ASSIGN( auto all_row_groups_fragment, - format_->MakeFragment(*source, scalar(true)) + format_->MakeFragment(*source, literal(true)) .Map([](std::shared_ptr f) { return internal::checked_pointer_cast(f); })); @@ -532,17 +552,17 @@ TEST_F(TestParquetFileFormat, ExplicitRowGroupSelection) { for (int i = 0; i < kNumRowGroups; ++i) { // conflicting selection/filter - opts_->filter = ("i64"_ == int64_t(i)).Copy(); + SetFilter(equal(field_ref("i64"), literal(i))); CountRowsAndBatchesInScan(row_groups_fragment({i}), 0, 0); } for (int i = 0; i < kNumRowGroups; ++i) { // identical selection/filter - opts_->filter = ("i64"_ == int64_t(i + 1)).Copy(); + SetFilter(equal(field_ref("i64"), literal(i + 1))); CountRowsAndBatchesInScan(row_groups_fragment({i}), i + 1, 1); } - opts_->filter = ("i64"_ > int64_t(3)).Copy(); + SetFilter(greater(field_ref("i64"), literal(3))); CountRowsAndBatchesInScan(row_groups_fragment({2, 3, 4, 5}), 4 + 5 + 6, 3); EXPECT_RAISES_WITH_MESSAGE_THAT( diff --git a/cpp/src/arrow/dataset/file_test.cc b/cpp/src/arrow/dataset/file_test.cc index 8ee2613576f..f0799e07a3a 100644 --- a/cpp/src/arrow/dataset/file_test.cc +++ b/cpp/src/arrow/dataset/file_test.cc @@ -82,15 +82,15 @@ TEST(FileSource, BufferBased) { TEST_F(TestFileSystemDataset, Basic) { MakeDataset({}); - AssertFragmentsAreFromPath(dataset_->GetFragments(), {}); + AssertFragmentsAreFromPath(*dataset_->GetFragments(), {}); MakeDataset({fs::File("a"), fs::File("b"), fs::File("c")}); - AssertFragmentsAreFromPath(dataset_->GetFragments(), {"a", "b", "c"}); + AssertFragmentsAreFromPath(*dataset_->GetFragments(), {"a", "b", "c"}); AssertFilesAre(dataset_, {"a", "b", "c"}); // Should not create fragment from directories. MakeDataset({fs::Dir("A"), fs::Dir("A/B"), fs::File("A/a"), fs::File("A/B/b")}); - AssertFragmentsAreFromPath(dataset_->GetFragments(), {"A/a", "A/B/b"}); + AssertFragmentsAreFromPath(*dataset_->GetFragments(), {"A/a", "A/B/b"}); AssertFilesAre(dataset_, {"A/a", "A/B/b"}); } @@ -98,7 +98,7 @@ TEST_F(TestFileSystemDataset, ReplaceSchema) { auto schm = schema({field("i32", int32()), field("f64", float64())}); auto format = std::make_shared(schm); ASSERT_OK_AND_ASSIGN(auto dataset, - FileSystemDataset::Make(schm, scalar(true), format, nullptr, {})); + FileSystemDataset::Make(schm, literal(true), format, nullptr, {})); // drop field ASSERT_OK(dataset->ReplaceSchema(schema({field("i32", int32())})).status()); @@ -119,45 +119,64 @@ TEST_F(TestFileSystemDataset, ReplaceSchema) { } TEST_F(TestFileSystemDataset, RootPartitionPruning) { - auto root_partition = ("a"_ == 5).Copy(); + auto root_partition = equal(field_ref("i32"), literal(5)); MakeDataset({fs::File("a"), fs::File("b")}, root_partition); + auto GetFragments = [&](Expression filter) { + return *dataset_->GetFragments(*filter.Bind(*dataset_->schema())); + }; + // Default filter should always return all data. - AssertFragmentsAreFromPath(dataset_->GetFragments(), {"a", "b"}); + AssertFragmentsAreFromPath(*dataset_->GetFragments(), {"a", "b"}); // filter == partition - AssertFragmentsAreFromPath(dataset_->GetFragments(root_partition), {"a", "b"}); + AssertFragmentsAreFromPath(GetFragments(root_partition), {"a", "b"}); // Same partition key, but non matching filter - AssertFragmentsAreFromPath(dataset_->GetFragments(("a"_ == 6).Copy()), {}); + AssertFragmentsAreFromPath(GetFragments(equal(field_ref("i32"), literal(6))), {}); - AssertFragmentsAreFromPath(dataset_->GetFragments(("a"_ > 1).Copy()), {"a", "b"}); + AssertFragmentsAreFromPath(GetFragments(greater(field_ref("i32"), literal(1))), + {"a", "b"}); // different key shouldn't prune - AssertFragmentsAreFromPath(dataset_->GetFragments(("b"_ == 6).Copy()), {"a", "b"}); + AssertFragmentsAreFromPath(GetFragments(equal(field_ref("f32"), literal(3.F))), + {"a", "b"}); // No partition should match MakeDataset({fs::File("a"), fs::File("b")}); - AssertFragmentsAreFromPath(dataset_->GetFragments(("b"_ == 6).Copy()), {"a", "b"}); + AssertFragmentsAreFromPath(GetFragments(equal(field_ref("f32"), literal(3.F))), + {"a", "b"}); } TEST_F(TestFileSystemDataset, TreePartitionPruning) { - auto root_partition = ("country"_ == "US").Copy(); + auto root_partition = equal(field_ref("country"), literal("US")); + std::vector regions = { fs::Dir("NY"), fs::File("NY/New York"), fs::File("NY/Franklin"), fs::Dir("CA"), fs::File("CA/San Francisco"), fs::File("CA/Franklin"), }; - ExpressionVector partitions = { - ("state"_ == "NY").Copy(), - ("state"_ == "NY" and "city"_ == "New York").Copy(), - ("state"_ == "NY" and "city"_ == "Franklin").Copy(), - ("state"_ == "CA").Copy(), - ("state"_ == "CA" and "city"_ == "San Francisco").Copy(), - ("state"_ == "CA" and "city"_ == "Franklin").Copy(), + std::vector partitions = { + equal(field_ref("state"), literal("NY")), + + and_(equal(field_ref("state"), literal("NY")), + equal(field_ref("city"), literal("New York"))), + + and_(equal(field_ref("state"), literal("NY")), + equal(field_ref("city"), literal("Franklin"))), + + equal(field_ref("state"), literal("CA")), + + and_(equal(field_ref("state"), literal("CA")), + equal(field_ref("city"), literal("San Francisco"))), + + and_(equal(field_ref("state"), literal("CA")), + equal(field_ref("city"), literal("Franklin"))), }; - MakeDataset(regions, root_partition, partitions); + MakeDataset( + regions, root_partition, partitions, + schema({field("country", utf8()), field("state", utf8()), field("city", utf8())})); std::vector all_cities = {"CA/San Francisco", "CA/Franklin", "NY/New York", "NY/Franklin"}; @@ -165,52 +184,67 @@ TEST_F(TestFileSystemDataset, TreePartitionPruning) { std::vector franklins = {"CA/Franklin", "NY/Franklin"}; // Default filter should always return all data. - AssertFragmentsAreFromPath(dataset_->GetFragments(), all_cities); + AssertFragmentsAreFromPath(*dataset_->GetFragments(), all_cities); + + auto GetFragments = [&](Expression filter) { + return *dataset_->GetFragments(*filter.Bind(*dataset_->schema())); + }; // Dataset's partitions are respected - AssertFragmentsAreFromPath(dataset_->GetFragments(("country"_ == "US").Copy()), + AssertFragmentsAreFromPath(GetFragments(equal(field_ref("country"), literal("US"))), all_cities); - AssertFragmentsAreFromPath(dataset_->GetFragments(("country"_ == "FR").Copy()), {}); + AssertFragmentsAreFromPath(GetFragments(equal(field_ref("country"), literal("FR"))), + {}); - AssertFragmentsAreFromPath(dataset_->GetFragments(("state"_ == "CA").Copy()), + AssertFragmentsAreFromPath(GetFragments(equal(field_ref("state"), literal("CA"))), ca_cities); // Filter where no decisions can be made on inner nodes when filter don't // apply to inner partitions. - AssertFragmentsAreFromPath(dataset_->GetFragments(("city"_ == "Franklin").Copy()), + AssertFragmentsAreFromPath(GetFragments(equal(field_ref("city"), literal("Franklin"))), franklins); } TEST_F(TestFileSystemDataset, FragmentPartitions) { - auto root_partition = ("country"_ == "US").Copy(); + auto root_partition = equal(field_ref("country"), literal("US")); std::vector regions = { fs::Dir("NY"), fs::File("NY/New York"), fs::File("NY/Franklin"), fs::Dir("CA"), fs::File("CA/San Francisco"), fs::File("CA/Franklin"), }; - ExpressionVector partitions = { - ("state"_ == "NY").Copy(), - ("state"_ == "NY" and "city"_ == "New York").Copy(), - ("state"_ == "NY" and "city"_ == "Franklin").Copy(), - ("state"_ == "CA").Copy(), - ("state"_ == "CA" and "city"_ == "San Francisco").Copy(), - ("state"_ == "CA" and "city"_ == "Franklin").Copy(), - }; + std::vector partitions = { + equal(field_ref("state"), literal("NY")), + + and_(equal(field_ref("state"), literal("NY")), + equal(field_ref("city"), literal("New York"))), - MakeDataset(regions, root_partition, partitions); + and_(equal(field_ref("state"), literal("NY")), + equal(field_ref("city"), literal("Franklin"))), - auto with_root = [&](const Expression& state, const Expression& city) { - return and_(state.Copy(), city.Copy()); + equal(field_ref("state"), literal("CA")), + + and_(equal(field_ref("state"), literal("CA")), + equal(field_ref("city"), literal("San Francisco"))), + + and_(equal(field_ref("state"), literal("CA")), + equal(field_ref("city"), literal("Franklin"))), }; + MakeDataset( + regions, root_partition, partitions, + schema({field("country", utf8()), field("state", utf8()), field("city", utf8())})); + AssertFragmentsHavePartitionExpressions( - dataset_->GetFragments(), - { - with_root("state"_ == "CA", "city"_ == "San Francisco"), - with_root("state"_ == "CA", "city"_ == "Franklin"), - with_root("state"_ == "NY", "city"_ == "New York"), - with_root("state"_ == "NY", "city"_ == "Franklin"), - }); + dataset_, { + and_(equal(field_ref("state"), literal("CA")), + equal(field_ref("city"), literal("San Francisco"))), + and_(equal(field_ref("state"), literal("CA")), + equal(field_ref("city"), literal("Franklin"))), + and_(equal(field_ref("state"), literal("NY")), + equal(field_ref("city"), literal("New York"))), + and_(equal(field_ref("state"), literal("NY")), + equal(field_ref("city"), literal("Franklin"))), + }); } } // namespace dataset diff --git a/cpp/src/arrow/dataset/filter.cc b/cpp/src/arrow/dataset/filter.cc index caab28e3f7c..2357896dd7a 100644 --- a/cpp/src/arrow/dataset/filter.cc +++ b/cpp/src/arrow/dataset/filter.cc @@ -17,1756 +17,4 @@ #include "arrow/dataset/filter.h" -#include -#include -#include -#include -#include -#include -#include - -#include "arrow/array/builder_primitive.h" -#include "arrow/buffer.h" -#include "arrow/compute/api.h" -#include "arrow/dataset/dataset.h" -#include "arrow/io/memory.h" -#include "arrow/ipc/reader.h" -#include "arrow/ipc/writer.h" -#include "arrow/record_batch.h" -#include "arrow/result.h" -#include "arrow/scalar.h" -#include "arrow/type_fwd.h" -#include "arrow/util/checked_cast.h" -#include "arrow/util/int_util_internal.h" -#include "arrow/util/iterator.h" -#include "arrow/util/logging.h" -#include "arrow/util/string.h" -#include "arrow/visitor_inline.h" - -namespace arrow { - -using compute::CompareOperator; -using compute::ExecContext; - -namespace dataset { - -using arrow::internal::checked_cast; -using arrow::internal::checked_pointer_cast; - -inline std::shared_ptr NullExpression() { - return std::make_shared(std::make_shared()); -} - -inline Datum NullDatum() { return Datum(std::make_shared()); } - -bool IsNullDatum(const Datum& datum) { - if (datum.is_scalar()) { - auto scalar = datum.scalar(); - return !scalar->is_valid; - } - - auto array_data = datum.array(); - return array_data->GetNullCount() == array_data->length; -} - -struct Comparison { - enum type { - LESS, - EQUAL, - GREATER, - NULL_, - }; -}; - -Result> EnsureNotDictionary( - const std::shared_ptr& scalar) { - if (scalar->type->id() == Type::DICTIONARY) { - return checked_cast(*scalar).GetEncodedValue(); - } - return scalar; -} - -Result Compare(const Scalar& lhs, const Scalar& rhs); - -struct CompareVisitor { - template - using ScalarType = typename TypeTraits::ScalarType; - - Status Visit(const NullType&) { - result_ = Comparison::NULL_; - return Status::OK(); - } - - Status Visit(const BooleanType&) { return CompareValues(); } - - template - enable_if_physical_floating_point Visit(const T&) { - return CompareValues(); - } - - template - enable_if_physical_signed_integer Visit(const T&) { - return CompareValues(); - } - - template - enable_if_physical_unsigned_integer Visit(const T&) { - return CompareValues(); - } - - template - enable_if_nested Visit(const T&) { - return Status::NotImplemented("comparison of scalars of type ", *lhs_.type); - } - - template - enable_if_binary_like Visit(const T&) { - auto lhs = checked_cast&>(lhs_).value; - auto rhs = checked_cast&>(rhs_).value; - auto cmp = std::memcmp(lhs->data(), rhs->data(), std::min(lhs->size(), rhs->size())); - if (cmp == 0) { - return CompareValues(lhs->size(), rhs->size()); - } - return CompareValues(cmp, 0); - } - - template - enable_if_string_like Visit(const T&) { - auto lhs = checked_cast&>(lhs_).value; - auto rhs = checked_cast&>(rhs_).value; - auto cmp = std::memcmp(lhs->data(), rhs->data(), std::min(lhs->size(), rhs->size())); - if (cmp == 0) { - return CompareValues(lhs->size(), rhs->size()); - } - return CompareValues(cmp, 0); - } - - Status Visit(const Decimal128Type&) { return CompareValues(); } - Status Visit(const Decimal256Type&) { return CompareValues(); } - - // Explicit because it falls under `physical_unsigned_integer`. - // TODO(bkietz) whenever we vendor a float16, this can be implemented - Status Visit(const HalfFloatType&) { - return Status::NotImplemented("comparison of scalars of type ", *lhs_.type); - } - - Status Visit(const ExtensionType&) { - return Status::NotImplemented("comparison of scalars of type ", *lhs_.type); - } - - Status Visit(const DictionaryType&) { - return Status::NotImplemented("comparison of scalars of type ", *lhs_.type); - } - - // defer comparison to ScalarType::value - template - Status CompareValues() { - auto lhs = checked_cast&>(lhs_).value; - auto rhs = checked_cast&>(rhs_).value; - return CompareValues(lhs, rhs); - } - - // defer comparison to explicit values - template - Status CompareValues(Value lhs, Value rhs) { - result_ = lhs < rhs ? Comparison::LESS - : lhs == rhs ? Comparison::EQUAL : Comparison::GREATER; - return Status::OK(); - } - - Comparison::type result_; - const Scalar& lhs_; - const Scalar& rhs_; -}; - -// Compare two scalars -// if either is null, return is null -// TODO(bkietz) extract this to the scalar comparison kernels -Result Compare(const Scalar& lhs, const Scalar& rhs) { - if (!lhs.type->Equals(*rhs.type)) { - return Status::TypeError("Cannot compare scalars of differing type: ", *lhs.type, - " vs ", *rhs.type); - } - if (!lhs.is_valid || !rhs.is_valid) { - return Comparison::NULL_; - } - CompareVisitor vis{Comparison::NULL_, lhs, rhs}; - RETURN_NOT_OK(VisitTypeInline(*lhs.type, &vis)); - return vis.result_; -} - -CompareOperator InvertCompareOperator(CompareOperator op) { - switch (op) { - case CompareOperator::EQUAL: - return CompareOperator::NOT_EQUAL; - - case CompareOperator::NOT_EQUAL: - return CompareOperator::EQUAL; - - case CompareOperator::GREATER: - return CompareOperator::LESS_EQUAL; - - case CompareOperator::GREATER_EQUAL: - return CompareOperator::LESS; - - case CompareOperator::LESS: - return CompareOperator::GREATER_EQUAL; - - case CompareOperator::LESS_EQUAL: - return CompareOperator::GREATER; - - default: - break; - } - - DCHECK(false); - return CompareOperator::EQUAL; -} - -template -std::shared_ptr InvertBoolean(const Boolean& expr) { - auto lhs = Invert(*expr.left_operand()); - auto rhs = Invert(*expr.right_operand()); - - if (std::is_same::value) { - return std::make_shared(std::move(lhs), std::move(rhs)); - } - - if (std::is_same::value) { - return std::make_shared(std::move(lhs), std::move(rhs)); - } - - return nullptr; -} - -std::shared_ptr Invert(const Expression& expr) { - switch (expr.type()) { - case ExpressionType::NOT: - return checked_cast(expr).operand(); - - case ExpressionType::AND: - return InvertBoolean(checked_cast(expr)); - - case ExpressionType::OR: - return InvertBoolean(checked_cast(expr)); - - case ExpressionType::COMPARISON: { - const auto& comparison = checked_cast(expr); - auto inverted_op = InvertCompareOperator(comparison.op()); - return std::make_shared( - inverted_op, comparison.left_operand(), comparison.right_operand()); - } - - default: - break; - } - return nullptr; -} - -std::shared_ptr Expression::Assume(const Expression& given) const { - std::shared_ptr out; - - DCHECK_OK(VisitConjunctionMembers(given, [&](const Expression& given) { - if (out != nullptr) { - return Status::OK(); - } - - if (given.type() != ExpressionType::COMPARISON) { - return Status::OK(); - } - - const auto& given_cmp = checked_cast(given); - if (given_cmp.op() != CompareOperator::EQUAL) { - return Status::OK(); - } - - if (this->Equals(given_cmp.left_operand())) { - out = given_cmp.right_operand(); - return Status::OK(); - } - - if (this->Equals(given_cmp.right_operand())) { - out = given_cmp.left_operand(); - return Status::OK(); - } - - return Status::OK(); - })); - - return out ? out : Copy(); -} - -std::shared_ptr ComparisonExpression::Assume(const Expression& given) const { - switch (given.type()) { - case ExpressionType::COMPARISON: { - return AssumeGivenComparison(checked_cast(given)); - } - - case ExpressionType::NOT: { - const auto& given_not = checked_cast(given); - if (auto inverted = Invert(*given_not.operand())) { - return Assume(*inverted); - } - return Copy(); - } - - case ExpressionType::OR: { - const auto& given_or = checked_cast(given); - - auto left_simplified = Assume(*given_or.left_operand()); - auto right_simplified = Assume(*given_or.right_operand()); - - // The result of simplification against the operands of an OrExpression - // cannot be used unless they are identical - if (left_simplified->Equals(right_simplified)) { - return left_simplified; - } - - return Copy(); - } - - case ExpressionType::AND: { - const auto& given_and = checked_cast(given); - - auto simplified = Copy(); - simplified = simplified->Assume(*given_and.left_operand()); - simplified = simplified->Assume(*given_and.right_operand()); - return simplified; - } - - // TODO(bkietz) we should be able to use ExpressionType::IN here - - default: - break; - } - - return Copy(); -} - -// Try to simplify one comparison against another comparison. -// For example, -// ("x"_ > 3) is a subset of ("x"_ > 2), so ("x"_ > 2).Assume("x"_ > 3) == (true) -// ("x"_ < 0) is disjoint with ("x"_ > 2), so ("x"_ > 2).Assume("x"_ < 0) == (false) -// If simplification to (true) or (false) is not possible, pass e through unchanged. -std::shared_ptr ComparisonExpression::AssumeGivenComparison( - const ComparisonExpression& given) const { - if (!left_operand_->Equals(given.left_operand_)) { - return Copy(); - } - - for (auto rhs : {right_operand_, given.right_operand_}) { - if (rhs->type() != ExpressionType::SCALAR) { - return Copy(); - } - } - - auto this_rhs = - EnsureNotDictionary(checked_cast(*right_operand_).value()) - .ValueOr(nullptr); - auto given_rhs = - EnsureNotDictionary( - checked_cast(*given.right_operand_).value()) - .ValueOr(nullptr); - - if (!this_rhs || !given_rhs) { - return Copy(); - } - - auto cmp = Compare(*this_rhs, *given_rhs).ValueOrDie(); - - if (cmp == Comparison::NULL_) { - // the RHS of e or given was null - return NullExpression(); - } - - static auto always = scalar(true); - static auto never = scalar(false); - - if (cmp == Comparison::GREATER) { - // the rhs of e is greater than that of given - switch (op()) { - case CompareOperator::EQUAL: - case CompareOperator::GREATER: - case CompareOperator::GREATER_EQUAL: - switch (given.op()) { - case CompareOperator::EQUAL: - case CompareOperator::LESS: - case CompareOperator::LESS_EQUAL: - return never; - default: - return Copy(); - } - case CompareOperator::NOT_EQUAL: - case CompareOperator::LESS: - case CompareOperator::LESS_EQUAL: - switch (given.op()) { - case CompareOperator::EQUAL: - case CompareOperator::LESS: - case CompareOperator::LESS_EQUAL: - return always; - default: - return Copy(); - } - default: - return Copy(); - } - } - - if (cmp == Comparison::LESS) { - // the rhs of e is less than that of given - switch (op()) { - case CompareOperator::EQUAL: - case CompareOperator::LESS: - case CompareOperator::LESS_EQUAL: - switch (given.op()) { - case CompareOperator::EQUAL: - case CompareOperator::GREATER: - case CompareOperator::GREATER_EQUAL: - return never; - default: - return Copy(); - } - case CompareOperator::NOT_EQUAL: - case CompareOperator::GREATER: - case CompareOperator::GREATER_EQUAL: - switch (given.op()) { - case CompareOperator::EQUAL: - case CompareOperator::GREATER: - case CompareOperator::GREATER_EQUAL: - return always; - default: - return Copy(); - } - default: - return Copy(); - } - } - - DCHECK_EQ(cmp, Comparison::EQUAL); - - // the rhs of the comparisons are equal - switch (op_) { - case CompareOperator::EQUAL: - switch (given.op()) { - case CompareOperator::NOT_EQUAL: - case CompareOperator::GREATER: - case CompareOperator::LESS: - return never; - case CompareOperator::EQUAL: - return always; - default: - return Copy(); - } - case CompareOperator::NOT_EQUAL: - switch (given.op()) { - case CompareOperator::EQUAL: - return never; - case CompareOperator::NOT_EQUAL: - case CompareOperator::GREATER: - case CompareOperator::LESS: - return always; - default: - return Copy(); - } - case CompareOperator::GREATER: - switch (given.op()) { - case CompareOperator::EQUAL: - case CompareOperator::LESS_EQUAL: - case CompareOperator::LESS: - return never; - case CompareOperator::GREATER: - return always; - default: - return Copy(); - } - case CompareOperator::GREATER_EQUAL: - switch (given.op()) { - case CompareOperator::LESS: - return never; - case CompareOperator::EQUAL: - case CompareOperator::GREATER: - case CompareOperator::GREATER_EQUAL: - return always; - default: - return Copy(); - } - case CompareOperator::LESS: - switch (given.op()) { - case CompareOperator::EQUAL: - case CompareOperator::GREATER: - case CompareOperator::GREATER_EQUAL: - return never; - case CompareOperator::LESS: - return always; - default: - return Copy(); - } - case CompareOperator::LESS_EQUAL: - switch (given.op()) { - case CompareOperator::GREATER: - return never; - case CompareOperator::EQUAL: - case CompareOperator::LESS: - case CompareOperator::LESS_EQUAL: - return always; - default: - return Copy(); - } - default: - return Copy(); - } - return Copy(); -} - -std::shared_ptr AndExpression::Assume(const Expression& given) const { - auto left_operand = left_operand_->Assume(given); - auto right_operand = right_operand_->Assume(given); - - // if either of the operands is trivially false then so is this AND - if (left_operand->Equals(false) || right_operand->Equals(false)) { - return scalar(false); - } - - // if either operand is trivially null then so is this AND - if (left_operand->IsNull() || right_operand->IsNull()) { - return NullExpression(); - } - - // if one of the operands is trivially true then drop it - if (left_operand->Equals(true)) { - return right_operand; - } - if (right_operand->Equals(true)) { - return left_operand; - } - - // if neither of the operands is trivial, simply construct a new AND - return and_(std::move(left_operand), std::move(right_operand)); -} - -std::shared_ptr OrExpression::Assume(const Expression& given) const { - auto left_operand = left_operand_->Assume(given); - auto right_operand = right_operand_->Assume(given); - - // if either of the operands is trivially true then so is this OR - if (left_operand->Equals(true) || right_operand->Equals(true)) { - return scalar(true); - } - - // if either operand is trivially null then so is this OR - if (left_operand->IsNull() || right_operand->IsNull()) { - return NullExpression(); - } - - // if one of the operands is trivially false then drop it - if (left_operand->Equals(false)) { - return right_operand; - } - if (right_operand->Equals(false)) { - return left_operand; - } - - // if neither of the operands is trivial, simply construct a new OR - return or_(std::move(left_operand), std::move(right_operand)); -} - -std::shared_ptr NotExpression::Assume(const Expression& given) const { - auto operand = operand_->Assume(given); - - if (operand->IsNull()) { - return NullExpression(); - } - if (operand->Equals(true)) { - return scalar(false); - } - if (operand->Equals(false)) { - return scalar(true); - } - - return Copy(); -} - -std::shared_ptr InExpression::Assume(const Expression& given) const { - auto operand = operand_->Assume(given); - if (operand->type() != ExpressionType::SCALAR) { - return std::make_shared(std::move(operand), set_); - } - - if (operand->IsNull()) { - return scalar(set_->null_count() > 0); - } - - Datum set, value; - if (set_->type_id() == Type::DICTIONARY) { - const auto& dict_set = checked_cast(*set_); - auto maybe_decoded = compute::Take(dict_set.dictionary(), dict_set.indices()); - auto maybe_value = checked_cast( - *checked_cast(*operand).value()) - .GetEncodedValue(); - if (!maybe_decoded.ok() || !maybe_value.ok()) { - return std::make_shared(std::move(operand), set_); - } - set = *maybe_decoded; - value = *maybe_value; - } else { - set = set_; - value = checked_cast(*operand).value(); - } - - compute::CompareOptions eq(CompareOperator::EQUAL); - Result maybe_out = compute::Compare(set, value, eq); - if (!maybe_out.ok()) { - return std::make_shared(std::move(operand), set_); - } - - Datum out = maybe_out.ValueOrDie(); - - DCHECK(out.is_array()); - DCHECK_EQ(out.type()->id(), Type::BOOL); - auto out_array = checked_pointer_cast(out.make_array()); - - for (int64_t i = 0; i < out_array->length(); ++i) { - if (out_array->IsValid(i) && out_array->Value(i)) { - return scalar(true); - } - } - return scalar(false); -} - -std::shared_ptr IsValidExpression::Assume(const Expression& given) const { - auto operand = operand_->Assume(given); - if (operand->type() == ExpressionType::SCALAR) { - return scalar(!operand->IsNull()); - } - - return std::make_shared(std::move(operand)); -} - -std::shared_ptr CastExpression::Assume(const Expression& given) const { - auto operand = operand_->Assume(given); - if (arrow::util::holds_alternative>(to_)) { - auto to_type = arrow::util::get>(to_); - return std::make_shared(std::move(operand), std::move(to_type), - options_); - } - auto like = arrow::util::get>(to_)->Assume(given); - return std::make_shared(std::move(operand), std::move(like), options_); -} - -const std::shared_ptr& CastExpression::to_type() const { - if (arrow::util::holds_alternative>(to_)) { - return arrow::util::get>(to_); - } - static std::shared_ptr null; - return null; -} - -const std::shared_ptr& CastExpression::like_expr() const { - if (arrow::util::holds_alternative>(to_)) { - return arrow::util::get>(to_); - } - static std::shared_ptr null; - return null; -} - -std::string FieldExpression::ToString() const { return name_; } - -std::string OperatorName(compute::CompareOperator op) { - switch (op) { - case CompareOperator::EQUAL: - return "=="; - case CompareOperator::NOT_EQUAL: - return "!="; - case CompareOperator::LESS: - return "<"; - case CompareOperator::LESS_EQUAL: - return "<="; - case CompareOperator::GREATER: - return ">"; - case CompareOperator::GREATER_EQUAL: - return ">="; - default: - DCHECK(false); - } - return ""; -} - -std::string ScalarExpression::ToString() const { - auto type_repr = value_->type->ToString(); - if (!value_->is_valid) { - return "null:" + type_repr; - } - - return value_->ToString() + ":" + type_repr; -} - -using arrow::internal::JoinStrings; - -std::string AndExpression::ToString() const { - return JoinStrings( - {"(", left_operand_->ToString(), " and ", right_operand_->ToString(), ")"}, ""); -} - -std::string OrExpression::ToString() const { - return JoinStrings( - {"(", left_operand_->ToString(), " or ", right_operand_->ToString(), ")"}, ""); -} - -std::string NotExpression::ToString() const { - if (operand_->type() == ExpressionType::IS_VALID) { - const auto& is_valid = checked_cast(*operand_); - return JoinStrings({"(", is_valid.operand()->ToString(), " is null)"}, ""); - } - return JoinStrings({"(not ", operand_->ToString(), ")"}, ""); -} - -std::string IsValidExpression::ToString() const { - return JoinStrings({"(", operand_->ToString(), " is not null)"}, ""); -} - -std::string InExpression::ToString() const { - return JoinStrings({"(", operand_->ToString(), " is in ", set_->ToString(), ")"}, ""); -} - -std::string CastExpression::ToString() const { - std::string to; - if (arrow::util::holds_alternative>(to_)) { - auto to_type = arrow::util::get>(to_); - to = " to " + to_type->ToString(); - } else { - auto like = arrow::util::get>(to_); - to = " like " + like->ToString(); - } - return JoinStrings({"(cast ", operand_->ToString(), std::move(to), ")"}, ""); -} - -std::string ComparisonExpression::ToString() const { - return JoinStrings({"(", left_operand_->ToString(), " ", OperatorName(op()), " ", - right_operand_->ToString(), ")"}, - ""); -} - -bool UnaryExpression::Equals(const Expression& other) const { - return type_ == other.type() && - operand_->Equals(checked_cast(other).operand_); -} - -bool BinaryExpression::Equals(const Expression& other) const { - return type_ == other.type() && - left_operand_->Equals( - checked_cast(other).left_operand_) && - right_operand_->Equals( - checked_cast(other).right_operand_); -} - -bool ComparisonExpression::Equals(const Expression& other) const { - return BinaryExpression::Equals(other) && - op_ == checked_cast(other).op_; -} - -bool ScalarExpression::Equals(const Expression& other) const { - return other.type() == ExpressionType::SCALAR && - value_->Equals(*checked_cast(other).value_); -} - -bool FieldExpression::Equals(const Expression& other) const { - return other.type() == ExpressionType::FIELD && - name_ == checked_cast(other).name_; -} - -bool Expression::Equals(const std::shared_ptr& other) const { - if (other == nullptr) { - return false; - } - return Equals(*other); -} - -bool Expression::IsNull() const { - if (type_ != ExpressionType::SCALAR) { - return false; - } - - const auto& scalar = checked_cast(*this).value(); - if (!scalar->is_valid) { - return true; - } - - return false; -} - -InExpression Expression::In(std::shared_ptr set) const { - return InExpression(Copy(), std::move(set)); -} - -IsValidExpression Expression::IsValid() const { return IsValidExpression(Copy()); } - -std::shared_ptr FieldExpression::Copy() const { - return std::make_shared(*this); -} - -std::shared_ptr ScalarExpression::Copy() const { - return std::make_shared(*this); -} - -std::shared_ptr and_(std::shared_ptr lhs, - std::shared_ptr rhs) { - return std::make_shared(std::move(lhs), std::move(rhs)); -} - -std::shared_ptr and_(const ExpressionVector& subexpressions) { - auto acc = scalar(true); - for (const auto& next : subexpressions) { - if (next->Equals(false)) return next; - acc = acc->Equals(true) ? next : and_(std::move(acc), next); - } - return acc; -} - -std::shared_ptr or_(std::shared_ptr lhs, - std::shared_ptr rhs) { - return std::make_shared(std::move(lhs), std::move(rhs)); -} - -std::shared_ptr or_(const ExpressionVector& subexpressions) { - auto acc = scalar(false); - for (const auto& next : subexpressions) { - if (next->Equals(true)) return next; - acc = acc->Equals(false) ? next : or_(std::move(acc), next); - } - return acc; -} - -std::shared_ptr not_(std::shared_ptr operand) { - return std::make_shared(std::move(operand)); -} - -AndExpression operator&&(const Expression& lhs, const Expression& rhs) { - return AndExpression(lhs.Copy(), rhs.Copy()); -} - -OrExpression operator||(const Expression& lhs, const Expression& rhs) { - return OrExpression(lhs.Copy(), rhs.Copy()); -} - -NotExpression operator!(const Expression& rhs) { return NotExpression(rhs.Copy()); } - -CastExpression Expression::CastTo(std::shared_ptr type, - compute::CastOptions options) const { - return CastExpression(Copy(), type, std::move(options)); -} - -CastExpression Expression::CastLike(std::shared_ptr expr, - compute::CastOptions options) const { - return CastExpression(Copy(), std::move(expr), std::move(options)); -} - -CastExpression Expression::CastLike(const Expression& expr, - compute::CastOptions options) const { - return CastLike(expr.Copy(), std::move(options)); -} - -Result> ComparisonExpression::Validate( - const Schema& schema) const { - ARROW_ASSIGN_OR_RAISE(auto lhs_type, left_operand_->Validate(schema)); - ARROW_ASSIGN_OR_RAISE(auto rhs_type, right_operand_->Validate(schema)); - - if (lhs_type->id() == Type::NA || rhs_type->id() == Type::NA) { - return boolean(); - } - - if (!lhs_type->Equals(rhs_type)) { - return Status::TypeError("cannot compare expressions of differing type, ", *lhs_type, - " vs ", *rhs_type); - } - - return boolean(); -} - -Status EnsureNullOrBool(const std::string& msg_prefix, - const std::shared_ptr& type) { - if (type->id() == Type::BOOL || type->id() == Type::NA) { - return Status::OK(); - } - return Status::TypeError(msg_prefix, *type); -} - -Result> ValidateBoolean(const ExpressionVector& operands, - const Schema& schema) { - for (const auto& operand : operands) { - ARROW_ASSIGN_OR_RAISE(auto type, operand->Validate(schema)); - RETURN_NOT_OK( - EnsureNullOrBool("cannot combine expressions including one of type ", type)); - } - return boolean(); -} - -Result> AndExpression::Validate(const Schema& schema) const { - return ValidateBoolean({left_operand_, right_operand_}, schema); -} - -Result> OrExpression::Validate(const Schema& schema) const { - return ValidateBoolean({left_operand_, right_operand_}, schema); -} - -Result> NotExpression::Validate(const Schema& schema) const { - return ValidateBoolean({operand_}, schema); -} - -Result> InExpression::Validate(const Schema& schema) const { - ARROW_ASSIGN_OR_RAISE(auto operand_type, operand_->Validate(schema)); - if (operand_type->id() == Type::NA || set_->type()->id() == Type::NA) { - return boolean(); - } - - if (!operand_type->Equals(set_->type())) { - return Status::TypeError("mismatch: set type ", *set_->type(), " vs operand type ", - *operand_type); - } - // TODO(bkietz) check if IsIn supports operand_type - return boolean(); -} - -Result> IsValidExpression::Validate( - const Schema& schema) const { - ARROW_ASSIGN_OR_RAISE(std::ignore, operand_->Validate(schema)); - return boolean(); -} - -Result> CastExpression::Validate(const Schema& schema) const { - ARROW_ASSIGN_OR_RAISE(auto operand_type, operand_->Validate(schema)); - std::shared_ptr to_type; - if (arrow::util::holds_alternative>(to_)) { - to_type = arrow::util::get>(to_); - } else { - auto like = arrow::util::get>(to_); - ARROW_ASSIGN_OR_RAISE(to_type, like->Validate(schema)); - } - - // Until expressions carry a shape, detect scalar and try to cast it. Works - // if the operand is a scalar leaf. - if (operand_->type() == ExpressionType::SCALAR) { - auto scalar_expr = checked_pointer_cast(operand_); - ARROW_ASSIGN_OR_RAISE(std::ignore, scalar_expr->value()->CastTo(to_type)); - return to_type; - } - - if (!compute::CanCast(*operand_type, *to_type)) { - return Status::Invalid("Cannot cast to ", to_type->ToString()); - } - - return to_type; -} - -Result> ScalarExpression::Validate(const Schema& schema) const { - return value_->type; -} - -Result> FieldExpression::Validate(const Schema& schema) const { - ARROW_ASSIGN_OR_RAISE(auto field, FieldRef(name_).GetOneOrNone(schema)); - if (field != nullptr) { - return field->type(); - } - return null(); -} - -Result CastOrDictionaryEncode(const Datum& arr, - const std::shared_ptr& type, - const compute::CastOptions opts) { - if (type->id() == Type::DICTIONARY) { - const auto& dict_type = checked_cast(*type); - if (dict_type.index_type()->id() != Type::INT32) { - return Status::TypeError("cannot DictionaryEncode to index type ", - *dict_type.index_type()); - } - ARROW_ASSIGN_OR_RAISE(auto dense, compute::Cast(arr, dict_type.value_type(), opts)); - return compute::DictionaryEncode(dense); - } - - return compute::Cast(arr, type, opts); -} - -struct InsertImplicitCastsImpl { - struct ValidatedAndCast { - std::shared_ptr expr; - std::shared_ptr type; - }; - - Result InsertCastsAndValidate(const Expression& expr) { - ValidatedAndCast out; - ARROW_ASSIGN_OR_RAISE(out.expr, InsertImplicitCasts(expr, schema_)); - ARROW_ASSIGN_OR_RAISE(out.type, out.expr->Validate(schema_)); - return std::move(out); - } - - Result> Cast(std::shared_ptr type, - const Expression& expr) { - if (expr.type() != ExpressionType::SCALAR) { - return expr.CastTo(type).Copy(); - } - - // cast the scalar directly - const auto& value = checked_cast(expr).value(); - ARROW_ASSIGN_OR_RAISE(auto cast_value, value->CastTo(std::move(type))); - return scalar(cast_value); - } - - Result> operator()(const InExpression& expr) { - ARROW_ASSIGN_OR_RAISE(auto op, InsertCastsAndValidate(*expr.operand())); - auto set = expr.set(); - - if (!op.type->Equals(set->type())) { - // cast the set (which we assume to be small) to match op.type - ARROW_ASSIGN_OR_RAISE(auto encoded_set, CastOrDictionaryEncode(*set, op.type, {})); - set = encoded_set.make_array(); - } - - return std::make_shared(std::move(op.expr), std::move(set)); - } - - Result> operator()(const NotExpression& expr) { - ARROW_ASSIGN_OR_RAISE(auto op, InsertCastsAndValidate(*expr.operand())); - - if (op.type->id() != Type::BOOL) { - ARROW_ASSIGN_OR_RAISE(op.expr, Cast(boolean(), *op.expr)); - } - return not_(std::move(op.expr)); - } - - Result> operator()(const AndExpression& expr) { - ARROW_ASSIGN_OR_RAISE(auto lhs, InsertCastsAndValidate(*expr.left_operand())); - ARROW_ASSIGN_OR_RAISE(auto rhs, InsertCastsAndValidate(*expr.right_operand())); - - if (lhs.type->id() != Type::BOOL) { - ARROW_ASSIGN_OR_RAISE(lhs.expr, Cast(boolean(), *lhs.expr)); - } - if (rhs.type->id() != Type::BOOL) { - ARROW_ASSIGN_OR_RAISE(rhs.expr, Cast(boolean(), *rhs.expr)); - } - return and_(std::move(lhs.expr), std::move(rhs.expr)); - } - - Result> operator()(const OrExpression& expr) { - ARROW_ASSIGN_OR_RAISE(auto lhs, InsertCastsAndValidate(*expr.left_operand())); - ARROW_ASSIGN_OR_RAISE(auto rhs, InsertCastsAndValidate(*expr.right_operand())); - - if (lhs.type->id() != Type::BOOL) { - ARROW_ASSIGN_OR_RAISE(lhs.expr, Cast(boolean(), *lhs.expr)); - } - if (rhs.type->id() != Type::BOOL) { - ARROW_ASSIGN_OR_RAISE(rhs.expr, Cast(boolean(), *rhs.expr)); - } - return or_(std::move(lhs.expr), std::move(rhs.expr)); - } - - Result> operator()(const ComparisonExpression& expr) { - ARROW_ASSIGN_OR_RAISE(auto lhs, InsertCastsAndValidate(*expr.left_operand())); - ARROW_ASSIGN_OR_RAISE(auto rhs, InsertCastsAndValidate(*expr.right_operand())); - - if (lhs.type->Equals(rhs.type)) { - return expr.Copy(); - } - - if (lhs.expr->type() == ExpressionType::SCALAR) { - ARROW_ASSIGN_OR_RAISE(lhs.expr, Cast(rhs.type, *lhs.expr)); - } else { - ARROW_ASSIGN_OR_RAISE(rhs.expr, Cast(lhs.type, *rhs.expr)); - } - return std::make_shared(expr.op(), std::move(lhs.expr), - std::move(rhs.expr)); - } - - Result> operator()(const Expression& expr) const { - return expr.Copy(); - } - - const Schema& schema_; -}; - -Result> InsertImplicitCasts(const Expression& expr, - const Schema& schema) { - RETURN_NOT_OK(schema.CanReferenceFieldsByNames(FieldsInExpression(expr))); - return VisitExpression(expr, InsertImplicitCastsImpl{schema}); -} - -Status VisitConjunctionMembers(const Expression& expr, - const std::function& visitor) { - if (expr.type() == ExpressionType::AND) { - const auto& and_ = checked_cast(expr); - RETURN_NOT_OK(VisitConjunctionMembers(*and_.left_operand(), visitor)); - RETURN_NOT_OK(VisitConjunctionMembers(*and_.right_operand(), visitor)); - return Status::OK(); - } - - return visitor(expr); -} - -std::vector FieldsInExpression(const Expression& expr) { - struct { - void operator()(const FieldExpression& expr) { fields.push_back(expr.name()); } - - void operator()(const UnaryExpression& expr) { - VisitExpression(*expr.operand(), *this); - } - - void operator()(const BinaryExpression& expr) { - VisitExpression(*expr.left_operand(), *this); - VisitExpression(*expr.right_operand(), *this); - } - - void operator()(const Expression&) const {} - - std::vector fields; - } visitor; - - VisitExpression(expr, visitor); - return std::move(visitor.fields); -} - -std::vector FieldsInExpression(const std::shared_ptr& expr) { - DCHECK_NE(expr, nullptr); - if (expr == nullptr) { - return {}; - } - - return FieldsInExpression(*expr); -} - -RecordBatchIterator ExpressionEvaluator::FilterBatches(RecordBatchIterator unfiltered, - std::shared_ptr filter, - MemoryPool* pool) { - auto filter_batches = [filter, pool, this](std::shared_ptr unfiltered) { - auto filtered = Evaluate(*filter, *unfiltered, pool).Map([&](Datum selection) { - return Filter(selection, unfiltered, pool); - }); - - if (filtered.ok() && (*filtered)->num_rows() == 0) { - // drop empty batches - return FilterIterator::Reject>(); - } - - return FilterIterator::MaybeAccept(std::move(filtered)); - }; - - return MakeFilterIterator(std::move(filter_batches), std::move(unfiltered)); -} - -std::shared_ptr ExpressionEvaluator::Null() { - struct Impl : ExpressionEvaluator { - Result Evaluate(const Expression& expr, const RecordBatch& batch, - MemoryPool* pool) const override { - ARROW_ASSIGN_OR_RAISE(auto type, expr.Validate(*batch.schema())); - return Datum(MakeNullScalar(type)); - } - - Result> Filter(const Datum& selection, - const std::shared_ptr& batch, - MemoryPool* pool) const override { - return batch; - } - }; - - return std::make_shared(); -} - -struct TreeEvaluator::Impl { - Result operator()(const ScalarExpression& expr) const { - return Datum(expr.value()); - } - - Result operator()(const FieldExpression& expr) const { - if (auto column = batch_.GetColumnByName(expr.name())) { - return std::move(column); - } - return NullDatum(); - } - - Result operator()(const AndExpression& expr) const { - return EvaluateBoolean(expr, compute::KleeneAnd); - } - - Result operator()(const OrExpression& expr) const { - return EvaluateBoolean(expr, compute::KleeneOr); - } - - Result EvaluateBoolean(const BinaryExpression& expr, - Result kernel(const Datum& left, - const Datum& right, - ExecContext* ctx)) const { - ARROW_ASSIGN_OR_RAISE(Datum lhs, Evaluate(*expr.left_operand())); - ARROW_ASSIGN_OR_RAISE(Datum rhs, Evaluate(*expr.right_operand())); - - if (lhs.is_scalar()) { - ARROW_ASSIGN_OR_RAISE( - auto lhs_array, - MakeArrayFromScalar(*lhs.scalar(), batch_.num_rows(), ctx_.memory_pool())); - lhs = Datum(std::move(lhs_array)); - } - - if (rhs.is_scalar()) { - ARROW_ASSIGN_OR_RAISE( - auto rhs_array, - MakeArrayFromScalar(*rhs.scalar(), batch_.num_rows(), ctx_.memory_pool())); - rhs = Datum(std::move(rhs_array)); - } - - return kernel(lhs, rhs, &ctx_); - } - - Result operator()(const NotExpression& expr) const { - ARROW_ASSIGN_OR_RAISE(Datum to_invert, Evaluate(*expr.operand())); - if (IsNullDatum(to_invert)) { - return NullDatum(); - } - - if (to_invert.is_scalar()) { - bool trivial_condition = - checked_cast(*to_invert.scalar()).value; - return Datum(std::make_shared(!trivial_condition)); - } - return compute::Invert(to_invert, &ctx_); - } - - Result operator()(const InExpression& expr) const { - ARROW_ASSIGN_OR_RAISE(Datum operand_values, Evaluate(*expr.operand())); - if (IsNullDatum(operand_values)) { - return Datum(expr.set()->null_count() != 0); - } - - DCHECK(operand_values.is_array()); - return compute::IsIn(operand_values, expr.set(), &ctx_); - } - - Result operator()(const IsValidExpression& expr) const { - ARROW_ASSIGN_OR_RAISE(Datum operand_values, Evaluate(*expr.operand())); - if (IsNullDatum(operand_values)) { - return Datum(false); - } - - if (operand_values.is_scalar()) { - return Datum(true); - } - - DCHECK(operand_values.is_array()); - if (operand_values.array()->GetNullCount() == 0) { - return Datum(true); - } - - return Datum(std::make_shared(operand_values.array()->length, - operand_values.array()->buffers[0])); - } - - Result operator()(const CastExpression& expr) const { - ARROW_ASSIGN_OR_RAISE(auto to_type, expr.Validate(*batch_.schema())); - - ARROW_ASSIGN_OR_RAISE(auto to_cast, Evaluate(*expr.operand())); - if (to_cast.is_scalar()) { - return to_cast.scalar()->CastTo(to_type); - } - - DCHECK(to_cast.is_array()); - return CastOrDictionaryEncode(to_cast, to_type, expr.options()); - } - - Result operator()(const ComparisonExpression& expr) const { - ARROW_ASSIGN_OR_RAISE(Datum lhs, Evaluate(*expr.left_operand())); - ARROW_ASSIGN_OR_RAISE(Datum rhs, Evaluate(*expr.right_operand())); - - if (IsNullDatum(lhs) || IsNullDatum(rhs)) { - return Datum(std::make_shared()); - } - - if (lhs.type()->id() == Type::DICTIONARY && rhs.type()->id() == Type::DICTIONARY) { - if (lhs.is_array() && rhs.is_array()) { - // decode dictionary arrays - for (Datum* arg : {&lhs, &rhs}) { - auto dict = checked_pointer_cast(arg->make_array()); - ARROW_ASSIGN_OR_RAISE(*arg, compute::Take(dict->dictionary(), dict->indices(), - compute::TakeOptions::Defaults())); - } - } else if (lhs.is_array() || rhs.is_array()) { - auto dict = checked_pointer_cast( - (lhs.is_array() ? lhs : rhs).make_array()); - - ARROW_ASSIGN_OR_RAISE(auto scalar, checked_cast( - *(lhs.is_scalar() ? lhs : rhs).scalar()) - .GetEncodedValue()); - if (lhs.is_array()) { - lhs = dict->dictionary(); - rhs = std::move(scalar); - } else { - lhs = std::move(scalar); - rhs = dict->dictionary(); - } - ARROW_ASSIGN_OR_RAISE( - Datum out_dict, - compute::Compare(lhs, rhs, compute::CompareOptions(expr.op()), &ctx_)); - - return compute::Take(out_dict, dict->indices(), compute::TakeOptions::Defaults()); - } - } - - return compute::Compare(lhs, rhs, compute::CompareOptions(expr.op()), &ctx_); - } - - Result operator()(const Expression& expr) const { - return Status::NotImplemented("evaluation of ", expr.ToString()); - } - - Result Evaluate(const Expression& expr) const { - return this_->Evaluate(expr, batch_, ctx_.memory_pool()); - } - - const TreeEvaluator* this_; - const RecordBatch& batch_; - mutable compute::ExecContext ctx_; -}; - -Result TreeEvaluator::Evaluate(const Expression& expr, const RecordBatch& batch, - MemoryPool* pool) const { - return VisitExpression(expr, Impl{this, batch, compute::ExecContext{pool}}); -} - -Result> TreeEvaluator::Filter( - const Datum& selection, const std::shared_ptr& batch, - MemoryPool* pool) const { - if (selection.is_array()) { - auto selection_array = selection.make_array(); - compute::ExecContext ctx(pool); - ARROW_ASSIGN_OR_RAISE(Datum filtered, - compute::Filter(batch, selection_array, - compute::FilterOptions::Defaults(), &ctx)); - return filtered.record_batch(); - } - - if (!selection.is_scalar() || selection.type()->id() != Type::BOOL) { - return Status::NotImplemented("Filtering batches against DatumKind::", - selection.kind(), " of type ", *selection.type()); - } - - if (BooleanScalar(true).Equals(*selection.scalar())) { - return batch; - } - - return batch->Slice(0, 0); -} - -const std::shared_ptr& scalar(bool value) { - static auto true_ = scalar(MakeScalar(true)); - static auto false_ = scalar(MakeScalar(false)); - return value ? true_ : false_; -} - -// Serialization is accomplished by converting expressions to single element StructArrays -// then writing that to an IPC file. The last field is always an int32 column containing -// ExpressionType, the rest store the Expression's members. -struct SerializeImpl { - Result> ToArray(const Expression& expr) const { - return VisitExpression(expr, *this); - } - - Result> TaggedWithChildren(const Expression& expr, - ArrayVector children) const { - children.emplace_back(); - ARROW_ASSIGN_OR_RAISE(children.back(), - MakeArrayFromScalar(Int32Scalar(expr.type()), 1)); - - return StructArray::Make(children, std::vector(children.size(), "")); - } - - Result> operator()(const FieldExpression& expr) const { - // store the field's name in a StringArray - ARROW_ASSIGN_OR_RAISE(auto name, MakeArrayFromScalar(StringScalar(expr.name()), 1)); - return TaggedWithChildren(expr, {name}); - } - - Result> operator()(const ScalarExpression& expr) const { - // store the scalar's value in a single element Array - ARROW_ASSIGN_OR_RAISE(auto value, MakeArrayFromScalar(*expr.value(), 1)); - return TaggedWithChildren(expr, {value}); - } - - Result> operator()(const UnaryExpression& expr) const { - // recurse to store the operand in a single element StructArray - ARROW_ASSIGN_OR_RAISE(auto operand, ToArray(*expr.operand())); - return TaggedWithChildren(expr, {operand}); - } - - Result> operator()(const CastExpression& expr) const { - // recurse to store the operand in a single element StructArray - ARROW_ASSIGN_OR_RAISE(auto operand, ToArray(*expr.operand())); - - // store the cast target and a discriminant - std::shared_ptr is_like_expr, to; - if (const auto& to_type = expr.to_type()) { - ARROW_ASSIGN_OR_RAISE(is_like_expr, MakeArrayFromScalar(BooleanScalar(false), 1)); - ARROW_ASSIGN_OR_RAISE(to, MakeArrayOfNull(to_type, 1)); - } - if (const auto& like_expr = expr.like_expr()) { - ARROW_ASSIGN_OR_RAISE(is_like_expr, MakeArrayFromScalar(BooleanScalar(true), 1)); - ARROW_ASSIGN_OR_RAISE(to, ToArray(*like_expr)); - } - - return TaggedWithChildren(expr, {operand, is_like_expr, to}); - } - - Result> operator()(const BinaryExpression& expr) const { - // recurse to store the operands in single element StructArrays - ARROW_ASSIGN_OR_RAISE(auto left_operand, ToArray(*expr.left_operand())); - ARROW_ASSIGN_OR_RAISE(auto right_operand, ToArray(*expr.right_operand())); - return TaggedWithChildren(expr, {left_operand, right_operand}); - } - - Result> operator()( - const ComparisonExpression& expr) const { - // recurse to store the operands in single element StructArrays - ARROW_ASSIGN_OR_RAISE(auto left_operand, ToArray(*expr.left_operand())); - ARROW_ASSIGN_OR_RAISE(auto right_operand, ToArray(*expr.right_operand())); - // store the CompareOperator in a single element Int32Array - ARROW_ASSIGN_OR_RAISE(auto op, MakeArrayFromScalar(Int32Scalar(expr.op()), 1)); - return TaggedWithChildren(expr, {left_operand, right_operand, op}); - } - - Result> operator()(const InExpression& expr) const { - // recurse to store the operand in a single element StructArray - ARROW_ASSIGN_OR_RAISE(auto operand, ToArray(*expr.operand())); - - // store the set as a single element ListArray - auto set_type = list(expr.set()->type()); - - ARROW_ASSIGN_OR_RAISE(auto set_offsets, AllocateBuffer(sizeof(int32_t) * 2)); - reinterpret_cast(set_offsets->mutable_data())[0] = 0; - reinterpret_cast(set_offsets->mutable_data())[1] = - static_cast(expr.set()->length()); - - auto set_values = expr.set(); - - auto set = std::make_shared(std::move(set_type), 1, std::move(set_offsets), - std::move(set_values)); - return TaggedWithChildren(expr, {operand, set}); - } - - Result> operator()(const Expression& expr) const { - return Status::NotImplemented("serialization of ", expr.ToString()); - } - - Result> ToBuffer(const Expression& expr) const { - ARROW_ASSIGN_OR_RAISE(auto array, SerializeImpl{}.ToArray(expr)); - ARROW_ASSIGN_OR_RAISE(auto batch, RecordBatch::FromStructArray(array)); - ARROW_ASSIGN_OR_RAISE(auto stream, io::BufferOutputStream::Create()); - ARROW_ASSIGN_OR_RAISE(auto writer, ipc::MakeFileWriter(stream, batch->schema())); - RETURN_NOT_OK(writer->WriteRecordBatch(*batch)); - RETURN_NOT_OK(writer->Close()); - return stream->Finish(); - } -}; - -Result> Expression::Serialize() const { - return SerializeImpl{}.ToBuffer(*this); -} - -struct DeserializeImpl { - Result> FromArray(const Array& array) const { - if (array.type_id() != Type::STRUCT || array.length() != 1) { - return Status::Invalid("can only deserialize expressions from unit-length", - " StructArray, got ", array); - } - const auto& struct_array = checked_cast(array); - - ARROW_ASSIGN_OR_RAISE(auto expression_type, GetExpressionType(struct_array)); - switch (expression_type) { - case ExpressionType::FIELD: { - ARROW_ASSIGN_OR_RAISE(auto name, GetView(struct_array, 0)); - return field_ref(std::string(name)); - } - - case ExpressionType::SCALAR: { - ARROW_ASSIGN_OR_RAISE(auto value, struct_array.field(0)->GetScalar(0)); - return scalar(std::move(value)); - } - - case ExpressionType::NOT: { - ARROW_ASSIGN_OR_RAISE(auto operand, FromArray(*struct_array.field(0))); - return not_(std::move(operand)); - } - - case ExpressionType::CAST: { - ARROW_ASSIGN_OR_RAISE(auto operand, FromArray(*struct_array.field(0))); - ARROW_ASSIGN_OR_RAISE(auto is_like_expr, GetView(struct_array, 1)); - if (is_like_expr) { - ARROW_ASSIGN_OR_RAISE(auto like_expr, FromArray(*struct_array.field(2))); - return operand->CastLike(std::move(like_expr)).Copy(); - } - return operand->CastTo(struct_array.field(2)->type()).Copy(); - } - - case ExpressionType::AND: { - ARROW_ASSIGN_OR_RAISE(auto left_operand, FromArray(*struct_array.field(0))); - ARROW_ASSIGN_OR_RAISE(auto right_operand, FromArray(*struct_array.field(1))); - return and_(std::move(left_operand), std::move(right_operand)); - } - - case ExpressionType::OR: { - ARROW_ASSIGN_OR_RAISE(auto left_operand, FromArray(*struct_array.field(0))); - ARROW_ASSIGN_OR_RAISE(auto right_operand, FromArray(*struct_array.field(1))); - return or_(std::move(left_operand), std::move(right_operand)); - } - - case ExpressionType::COMPARISON: { - ARROW_ASSIGN_OR_RAISE(auto left_operand, FromArray(*struct_array.field(0))); - ARROW_ASSIGN_OR_RAISE(auto right_operand, FromArray(*struct_array.field(1))); - ARROW_ASSIGN_OR_RAISE(auto op, GetView(struct_array, 2)); - return std::make_shared(static_cast(op), - std::move(left_operand), - std::move(right_operand)); - } - - case ExpressionType::IS_VALID: { - ARROW_ASSIGN_OR_RAISE(auto operand, FromArray(*struct_array.field(0))); - return std::make_shared(std::move(operand)); - } - - case ExpressionType::IN: { - ARROW_ASSIGN_OR_RAISE(auto operand, FromArray(*struct_array.field(0))); - if (struct_array.field(1)->type_id() != Type::LIST) { - return Status::TypeError("expected field 1 of ", struct_array, - " to have list type"); - } - auto set = checked_cast(*struct_array.field(1)).values(); - return std::make_shared(std::move(operand), std::move(set)); - } - - default: - break; - } - - return Status::Invalid("non-deserializable ExpressionType ", expression_type); - } - - template ::ArrayType> - static Result().GetView(0))> GetView(const StructArray& array, - int index) { - if (index >= array.num_fields()) { - return Status::IndexError("expected ", array, " to have a child at index ", index); - } - - const auto& child = *array.field(index); - if (child.type_id() != T::type_id) { - return Status::TypeError("expected child ", index, " of ", array, " to have type ", - T::type_id); - } - - return checked_cast(child).GetView(0); - } - - static Result GetExpressionType(const StructArray& array) { - if (array.struct_type()->num_fields() < 1) { - return Status::Invalid("StructArray didn't contain ExpressionType member"); - } - - ARROW_ASSIGN_OR_RAISE(auto expression_type, - GetView(array, array.num_fields() - 1)); - return static_cast(expression_type); - } - - Result> FromBuffer(const Buffer& serialized) { - io::BufferReader stream(serialized); - ARROW_ASSIGN_OR_RAISE(auto reader, ipc::RecordBatchFileReader::Open(&stream)); - ARROW_ASSIGN_OR_RAISE(auto batch, reader->ReadRecordBatch(0)); - ARROW_ASSIGN_OR_RAISE(auto array, batch->ToStructArray()); - return FromArray(*array); - } -}; - -Result> Expression::Deserialize(const Buffer& serialized) { - return DeserializeImpl{}.FromBuffer(serialized); -} - -// Transform an array of counts to offsets which will divide a ListArray -// into an equal number of slices with corresponding lengths. -inline Result> CountsToOffsets( - std::shared_ptr counts) { - Int32Builder offset_builder; - RETURN_NOT_OK(offset_builder.Resize(counts->length() + 1)); - offset_builder.UnsafeAppend(0); - - for (int64_t i = 0; i < counts->length(); ++i) { - DCHECK_NE(counts->Value(i), 0); - auto next_offset = static_cast(offset_builder[i] + counts->Value(i)); - offset_builder.UnsafeAppend(next_offset); - } - - std::shared_ptr offsets; - RETURN_NOT_OK(offset_builder.Finish(&offsets)); - return offsets; -} - -// Helper for simultaneous dictionary encoding of multiple arrays. -// -// The fused dictionary is the Cartesian product of the individual dictionaries. -// For example given two arrays A, B where A has unique values ["ex", "why"] -// and B has unique values [0, 1] the fused dictionary is the set of tuples -// [["ex", 0], ["ex", 1], ["why", 0], ["ex", 1]]. -// -// TODO(bkietz) this capability belongs in an Action of the hash kernels, where -// it can be used to group aggregates without materializing a grouped batch. -// For the purposes of writing we need the materialized grouped batch anyway -// since no Writers accept a selection vector. -class StructDictionary { - public: - struct Encoded { - std::shared_ptr indices; - std::shared_ptr dictionary; - }; - - static Result Encode(const ArrayVector& columns) { - Encoded out{nullptr, std::make_shared()}; - - for (const auto& column : columns) { - if (column->null_count() != 0) { - return Status::NotImplemented("Grouping on a field with nulls"); - } - - RETURN_NOT_OK(out.dictionary->AddOne(column, &out.indices)); - } - - return out; - } - - Result> Decode(std::shared_ptr fused_indices, - FieldVector fields) { - std::vector builders(dictionaries_.size()); - for (Int32Builder& b : builders) { - RETURN_NOT_OK(b.Resize(fused_indices->length())); - } - - std::vector codes(dictionaries_.size()); - for (int64_t i = 0; i < fused_indices->length(); ++i) { - Expand(fused_indices->Value(i), codes.data()); - - auto builder_it = builders.begin(); - for (int32_t index : codes) { - builder_it++->UnsafeAppend(index); - } - } - - ArrayVector columns(dictionaries_.size()); - for (size_t i = 0; i < dictionaries_.size(); ++i) { - std::shared_ptr indices; - RETURN_NOT_OK(builders[i].FinishInternal(&indices)); - - ARROW_ASSIGN_OR_RAISE(Datum column, compute::Take(dictionaries_[i], indices)); - columns[i] = column.make_array(); - } - - return StructArray::Make(std::move(columns), std::move(fields)); - } - - private: - Status AddOne(Datum column, std::shared_ptr* fused_indices) { - ArrayData* encoded; - if (column.type()->id() != Type::DICTIONARY) { - ARROW_ASSIGN_OR_RAISE(column, compute::DictionaryEncode(column)); - } - encoded = column.mutable_array(); - - auto indices = - std::make_shared(encoded->length, std::move(encoded->buffers[1])); - - dictionaries_.push_back(MakeArray(std::move(encoded->dictionary))); - auto dictionary_size = static_cast(dictionaries_.back()->length()); - - if (*fused_indices == nullptr) { - *fused_indices = std::move(indices); - size_ = dictionary_size; - return Status::OK(); - } - - // It's useful to think about the case where each of dictionaries_ has size 10. - // In this case the decimal digit in the ones place is the code in dictionaries_[0], - // the tens place corresponds to dictionaries_[1], etc. - // The incumbent indices must be shifted to the hundreds place so as not to collide. - ARROW_ASSIGN_OR_RAISE(Datum new_fused_indices, - compute::Multiply(indices, MakeScalar(size_))); - - ARROW_ASSIGN_OR_RAISE(new_fused_indices, - compute::Add(new_fused_indices, *fused_indices)); - - *fused_indices = checked_pointer_cast(new_fused_indices.make_array()); - - // XXX should probably cap this at 2**15 or so - ARROW_CHECK(!internal::MultiplyWithOverflow(size_, dictionary_size, &size_)); - return Status::OK(); - } - - // expand a fused code into component dict codes, order is in order of addition - void Expand(int32_t fused_code, int32_t* codes) { - for (size_t i = 0; i < dictionaries_.size(); ++i) { - auto dictionary_size = static_cast(dictionaries_[i]->length()); - codes[i] = fused_code % dictionary_size; - fused_code /= dictionary_size; - } - } - - int32_t size_; - ArrayVector dictionaries_; -}; - -Result> MakeGroupings(const StructArray& by) { - if (by.num_fields() == 0) { - return Status::NotImplemented("Grouping with no criteria"); - } - - ARROW_ASSIGN_OR_RAISE(auto fused, StructDictionary::Encode(by.fields())); - - ARROW_ASSIGN_OR_RAISE(auto sort_indices, compute::SortIndices(*fused.indices)); - ARROW_ASSIGN_OR_RAISE(Datum sorted, compute::Take(fused.indices, *sort_indices)); - fused.indices = checked_pointer_cast(sorted.make_array()); - - ARROW_ASSIGN_OR_RAISE(auto fused_counts_and_values, - compute::ValueCounts(fused.indices)); - fused.indices.reset(); - - auto unique_fused_indices = - checked_pointer_cast(fused_counts_and_values->GetFieldByName("values")); - ARROW_ASSIGN_OR_RAISE( - auto unique_rows, - fused.dictionary->Decode(std::move(unique_fused_indices), by.type()->fields())); - - auto counts = - checked_pointer_cast(fused_counts_and_values->GetFieldByName("counts")); - ARROW_ASSIGN_OR_RAISE(auto offsets, CountsToOffsets(std::move(counts))); - - ARROW_ASSIGN_OR_RAISE(auto grouped_sort_indices, - ListArray::FromArrays(*offsets, *sort_indices)); - - return StructArray::Make( - ArrayVector{std::move(unique_rows), std::move(grouped_sort_indices)}, - std::vector{"values", "groupings"}); -} - -Result> ApplyGroupings(const ListArray& groupings, - const Array& array) { - ARROW_ASSIGN_OR_RAISE(Datum sorted, - compute::Take(array, groupings.data()->child_data[0])); - - return std::make_shared(list(array.type()), groupings.length(), - groupings.value_offsets(), sorted.make_array()); -} - -Result ApplyGroupings(const ListArray& groupings, - const std::shared_ptr& batch) { - ARROW_ASSIGN_OR_RAISE(Datum sorted, - compute::Take(batch, groupings.data()->child_data[0])); - - const auto& sorted_batch = *sorted.record_batch(); - - RecordBatchVector out(static_cast(groupings.length())); - for (size_t i = 0; i < out.size(); ++i) { - out[i] = sorted_batch.Slice(groupings.value_offset(i), groupings.value_length(i)); - } - - return out; -} - -} // namespace dataset -} // namespace arrow +// FIXME remove diff --git a/cpp/src/arrow/dataset/filter.h b/cpp/src/arrow/dataset/filter.h index c6b55b419ff..9852f8c0808 100644 --- a/cpp/src/arrow/dataset/filter.h +++ b/cpp/src/arrow/dataset/filter.h @@ -15,658 +15,4 @@ // specific language governing permissions and limitations // under the License. -// This API is EXPERIMENTAL. - -#pragma once - -#include -#include -#include -#include -#include - -#include "arrow/compute/api_scalar.h" -#include "arrow/compute/cast.h" -#include "arrow/dataset/type_fwd.h" -#include "arrow/dataset/visibility.h" -#include "arrow/datum.h" -#include "arrow/result.h" -#include "arrow/scalar.h" -#include "arrow/type_fwd.h" -#include "arrow/util/checked_cast.h" -#include "arrow/util/variant.h" - -namespace arrow { -namespace dataset { - -using compute::CastOptions; -using compute::CompareOperator; - -struct ExpressionType { - enum type { - /// a reference to a column within a record batch, will evaluate to an array - FIELD, - - /// a literal singular value encapsulated in a Scalar - SCALAR, - - /// a literal Array - // TODO(bkietz) ARRAY, - - /// an inversion of another expression - NOT, - - /// cast an expression to a given DataType - CAST, - - /// a conjunction of multiple expressions (true if all operands are true) - AND, - - /// a disjunction of multiple expressions (true if any operand is true) - OR, - - /// a comparison of two other expressions - COMPARISON, - - /// replace nulls with other expressions - /// currently only boolean expressions may be coalesced - // TODO(bkietz) COALESCE, - - /// extract validity as a boolean expression - IS_VALID, - - /// check each element for membership in a set - IN, - - /// custom user defined expression - CUSTOM, - }; -}; - -class InExpression; -class CastExpression; -class IsValidExpression; - -/// Represents an expression tree -class ARROW_DS_EXPORT Expression { - public: - explicit Expression(ExpressionType::type type) : type_(type) {} - - virtual ~Expression() = default; - - /// Returns true iff the expressions are identical; does not check for equivalence. - /// For example, (A and B) is not equal to (B and A) nor is (A and not A) equal to - /// (false). - virtual bool Equals(const Expression& other) const = 0; - - bool Equals(const std::shared_ptr& other) const; - - /// Overload for the common case of checking for equality to a specific scalar. - template ()))> - bool Equals(T&& t) const; - - /// If true, this Expression is a ScalarExpression wrapping a null scalar. - bool IsNull() const; - - /// Validate this expression for execution against a schema. This will check that all - /// reference fields are present (fields not in the schema will be replaced with null) - /// and all subexpressions are executable. Returns the type to which this expression - /// will evaluate. - virtual Result> Validate(const Schema& schema) const = 0; - - /// \brief Simplify to an equivalent Expression given assumed constraints on input. - /// This can be used to do less filtering work using predicate push down. - /// - /// Both expressions must pass validation against a schema before Assume may be used. - /// - /// Two expressions can be considered equivalent for a given subset of possible inputs - /// if they yield identical results. Formally, if given.Evaluate(input).Equals(input) - /// then Assume guarantees that: - /// expr.Assume(given).Evaluate(input).Equals(expr.Evaluate(input)) - /// - /// For example if we are given that all inputs will - /// satisfy ("a"_ == 1) then the expression ("a"_ > 0 and "b"_ > 0) is equivalent to - /// ("b"_ > 0). It is impossible that the comparison ("a"_ > 0) will evaluate false - /// given ("a"_ == 1), so both expressions will yield identical results. Thus we can - /// write: - /// ("a"_ > 0 and "b"_ > 0).Assume("a"_ == 1).Equals("b"_ > 0) - /// - /// filter.Assume(partition) is trivial if filter and partition are disjoint or if - /// partition is a subset of filter. FIXME(bkietz) write this better - /// - If the two are disjoint, then (false) may be substituted for filter. - /// - If partition is a subset of filter then (true) may be substituted for filter. - /// - /// filter.Assume(partition) is straightforward if both filter and partition are simple - /// comparisons. - /// - filter may be a superset of partition, in which case the filter is - /// satisfied by all inputs: - /// ("a"_ > 0).Assume("a"_ == 1).Equals(true) - /// - filter may be disjoint with partition, in which case there are no inputs which - /// satisfy filter: - /// ("a"_ < 0).Assume("a"_ == 1).Equals(false) - /// - If neither of these is the case, partition provides no information which can - /// simplify filter: - /// ("a"_ == 1).Assume("a"_ > 0).Equals("a"_ == 1) - /// ("a"_ == 1).Assume("b"_ == 1).Equals("a"_ == 1) - /// - /// If filter is compound, Assume can be distributed across the boolean operator. To - /// prove this is valid, we again demonstrate that the simplified expression will yield - /// identical results. For conjunction of filters lhs and rhs: - /// (lhs.Assume(p) and rhs.Assume(p)).Evaluate(input) - /// == Intersection(lhs.Assume(p).Evaluate(input), rhs.Assume(p).Evaluate(input)) - /// == Intersection(lhs.Evaluate(input), rhs.Evaluate(input)) - /// == (lhs and rhs).Evaluate(input) - /// - The proof for disjunction is symmetric; just replace Intersection with Union. Thus - /// we can write: - /// (lhs and rhs).Assume(p).Equals(lhs.Assume(p) and rhs.Assume(p)) - /// (lhs or rhs).Assume(p).Equals(lhs.Assume(p) or rhs.Assume(p)) - /// - For negation: - /// (not e.Assume(p)).Evaluate(input) - /// == Difference(input, e.Assume(p).Evaluate(input)) - /// == Difference(input, e.Evaluate(input)) - /// == (not e).Evaluate(input) - /// - Thus we can write: - /// (not e).Assume(p).Equals(not e.Assume(p)) - /// - /// If the partition expression is a conjunction then each of its subexpressions is - /// true for all input and can be used independently: - /// filter.Assume(lhs).Assume(rhs).Evaluate(input) - /// == filter.Assume(lhs).Evaluate(input) - /// == filter.Evaluate(input) - /// - Thus we can write: - /// filter.Assume(lhs and rhs).Equals(filter.Assume(lhs).Assume(rhs)) - /// - /// FIXME(bkietz) disjunction proof - /// filter.Assume(lhs or rhs).Equals(filter.Assume(lhs) and filter.Assume(rhs)) - /// - This may not result in a simpler expression so it is only used when - /// filter.Assume(lhs).Equals(filter.Assume(rhs)) - /// - /// If the partition expression is a negation then we can use the above relations by - /// replacing comparisons with their complements and using the properties: - /// (not (a and b)).Equals(not a or not b) - /// (not (a or b)).Equals(not a and not b) - virtual std::shared_ptr Assume(const Expression& given) const; - - std::shared_ptr Assume(const std::shared_ptr& given) const { - return Assume(*given); - } - - /// Indicates if the expression is satisfiable. - /// - /// This is a shortcut to check if the expression is neither null nor false. - bool IsSatisfiable() const { return !IsNull() && !Equals(false); } - - /// Indicates if the expression is satisfiable given an other expression. - /// - /// This behaves like IsSatisfiable, but it simplifies the current expression - /// with the given `other` information. - bool IsSatisfiableWith(const Expression& other) const { - return Assume(other)->IsSatisfiable(); - } - - bool IsSatisfiableWith(const std::shared_ptr& other) const { - return Assume(other)->IsSatisfiable(); - } - - /// returns a debug string representing this expression - virtual std::string ToString() const = 0; - - /// serialize/deserialize an Expression. - Result> Serialize() const; - static Result> Deserialize(const Buffer&); - - /// \brief Return the expression's type identifier - ExpressionType::type type() const { return type_; } - - /// Copy this expression into a shared pointer. - virtual std::shared_ptr Copy() const = 0; - - InExpression In(std::shared_ptr set) const; - - IsValidExpression IsValid() const; - - CastExpression CastTo(std::shared_ptr type, - CastOptions options = CastOptions()) const; - - CastExpression CastLike(const Expression& expr, - CastOptions options = CastOptions()) const; - - CastExpression CastLike(std::shared_ptr expr, - CastOptions options = CastOptions()) const; - - protected: - ExpressionType::type type_; -}; - -/// Helper class which implements Copy and forwards construction -template -class ExpressionImpl : public Base { - public: - static constexpr ExpressionType::type expression_type = E; - - template - explicit ExpressionImpl(A0&& arg0, A&&... args) - : Base(expression_type, std::forward(arg0), std::forward(args)...) {} - - std::shared_ptr Copy() const override { - return std::make_shared(internal::checked_cast(*this)); - } -}; - -/// Base class for an expression with exactly one operand -class ARROW_DS_EXPORT UnaryExpression : public Expression { - public: - const std::shared_ptr& operand() const { return operand_; } - - bool Equals(const Expression& other) const override; - - protected: - UnaryExpression(ExpressionType::type type, std::shared_ptr operand) - : Expression(type), operand_(std::move(operand)) {} - - std::shared_ptr operand_; -}; - -/// Base class for an expression with exactly two operands -class ARROW_DS_EXPORT BinaryExpression : public Expression { - public: - const std::shared_ptr& left_operand() const { return left_operand_; } - - const std::shared_ptr& right_operand() const { return right_operand_; } - - bool Equals(const Expression& other) const override; - - protected: - BinaryExpression(ExpressionType::type type, std::shared_ptr left_operand, - std::shared_ptr right_operand) - : Expression(type), - left_operand_(std::move(left_operand)), - right_operand_(std::move(right_operand)) {} - - std::shared_ptr left_operand_, right_operand_; -}; - -class ARROW_DS_EXPORT ComparisonExpression final - : public ExpressionImpl { - public: - ComparisonExpression(CompareOperator op, std::shared_ptr left_operand, - std::shared_ptr right_operand) - : ExpressionImpl(std::move(left_operand), std::move(right_operand)), op_(op) {} - - std::string ToString() const override; - - bool Equals(const Expression& other) const override; - - std::shared_ptr Assume(const Expression& given) const override; - - CompareOperator op() const { return op_; } - - Result> Validate(const Schema& schema) const override; - - private: - std::shared_ptr AssumeGivenComparison( - const ComparisonExpression& given) const; - - CompareOperator op_; -}; - -class ARROW_DS_EXPORT AndExpression final - : public ExpressionImpl { - public: - using ExpressionImpl::ExpressionImpl; - - std::string ToString() const override; - - std::shared_ptr Assume(const Expression& given) const override; - - Result> Validate(const Schema& schema) const override; -}; - -class ARROW_DS_EXPORT OrExpression final - : public ExpressionImpl { - public: - using ExpressionImpl::ExpressionImpl; - - std::string ToString() const override; - - std::shared_ptr Assume(const Expression& given) const override; - - Result> Validate(const Schema& schema) const override; -}; - -class ARROW_DS_EXPORT NotExpression final - : public ExpressionImpl { - public: - using ExpressionImpl::ExpressionImpl; - - std::string ToString() const override; - - std::shared_ptr Assume(const Expression& given) const override; - - Result> Validate(const Schema& schema) const override; -}; - -class ARROW_DS_EXPORT IsValidExpression final - : public ExpressionImpl { - public: - using ExpressionImpl::ExpressionImpl; - - std::string ToString() const override; - - Result> Validate(const Schema& schema) const override; - - std::shared_ptr Assume(const Expression& given) const override; -}; - -class ARROW_DS_EXPORT InExpression final - : public ExpressionImpl { - public: - InExpression(std::shared_ptr operand, std::shared_ptr set) - : ExpressionImpl(std::move(operand)), set_(std::move(set)) {} - - std::string ToString() const override; - - Result> Validate(const Schema& schema) const override; - - std::shared_ptr Assume(const Expression& given) const override; - - /// The set against which the operand will be compared - const std::shared_ptr& set() const { return set_; } - - private: - std::shared_ptr set_; -}; - -/// Explicitly cast an expression to a different type -class ARROW_DS_EXPORT CastExpression final - : public ExpressionImpl { - public: - CastExpression(std::shared_ptr operand, std::shared_ptr to, - CastOptions options) - : ExpressionImpl(std::move(operand)), - to_(std::move(to)), - options_(std::move(options)) {} - - /// The operand will be cast to whatever type `like` would evaluate to, given the same - /// schema. - CastExpression(std::shared_ptr operand, std::shared_ptr like, - CastOptions options) - : ExpressionImpl(std::move(operand)), - to_(std::move(like)), - options_(std::move(options)) {} - - std::string ToString() const override; - - std::shared_ptr Assume(const Expression& given) const override; - - Result> Validate(const Schema& schema) const override; - - const CastOptions& options() const { return options_; } - - /// Return the target type of this CastTo expression, or nullptr if this is a - /// CastLike expression. - const std::shared_ptr& to_type() const; - - /// Return the target expression of this CastLike expression, or nullptr if - /// this is a CastTo expression. - const std::shared_ptr& like_expr() const; - - private: - util::Variant, std::shared_ptr> to_; - CastOptions options_; -}; - -/// Represents a scalar value; thin wrapper around arrow::Scalar -class ARROW_DS_EXPORT ScalarExpression final : public Expression { - public: - explicit ScalarExpression(const std::shared_ptr& value) - : Expression(ExpressionType::SCALAR), value_(std::move(value)) {} - - const std::shared_ptr& value() const { return value_; } - - std::string ToString() const override; - - bool Equals(const Expression& other) const override; - - Result> Validate(const Schema& schema) const override; - - std::shared_ptr Copy() const override; - - private: - std::shared_ptr value_; -}; - -/// Represents a reference to a field. Stores only the field's name (type and other -/// information is known only when a Schema is provided) -class ARROW_DS_EXPORT FieldExpression final : public Expression { - public: - explicit FieldExpression(std::string name) - : Expression(ExpressionType::FIELD), name_(std::move(name)) {} - - std::string name() const { return name_; } - - std::string ToString() const override; - - bool Equals(const Expression& other) const override; - - Result> Validate(const Schema& schema) const override; - - std::shared_ptr Copy() const override; - - private: - std::string name_; -}; - -class ARROW_DS_EXPORT CustomExpression : public Expression { - protected: - CustomExpression() : Expression(ExpressionType::CUSTOM) {} -}; - -ARROW_DS_EXPORT std::shared_ptr and_(std::shared_ptr lhs, - std::shared_ptr rhs); - -ARROW_DS_EXPORT std::shared_ptr and_(const ExpressionVector& subexpressions); - -ARROW_DS_EXPORT AndExpression operator&&(const Expression& lhs, const Expression& rhs); - -ARROW_DS_EXPORT std::shared_ptr or_(std::shared_ptr lhs, - std::shared_ptr rhs); - -ARROW_DS_EXPORT std::shared_ptr or_(const ExpressionVector& subexpressions); - -ARROW_DS_EXPORT OrExpression operator||(const Expression& lhs, const Expression& rhs); - -ARROW_DS_EXPORT std::shared_ptr not_(std::shared_ptr operand); - -ARROW_DS_EXPORT NotExpression operator!(const Expression& rhs); - -inline std::shared_ptr scalar(std::shared_ptr value) { - return std::make_shared(std::move(value)); -} - -template -auto scalar(T&& value) -> decltype(scalar(MakeScalar(std::forward(value)))) { - return scalar(MakeScalar(std::forward(value))); -} - -#define COMPARISON_FACTORY(NAME, FACTORY_NAME, OP) \ - inline std::shared_ptr FACTORY_NAME( \ - const std::shared_ptr& lhs, const std::shared_ptr& rhs) { \ - return std::make_shared(CompareOperator::NAME, lhs, rhs); \ - } \ - \ - template ::type>::value>::type> \ - ComparisonExpression operator OP(const Expression& lhs, T&& rhs) { \ - return ComparisonExpression(CompareOperator::NAME, lhs.Copy(), \ - scalar(std::forward(rhs))); \ - } \ - \ - inline ComparisonExpression operator OP(const Expression& lhs, \ - const Expression& rhs) { \ - return ComparisonExpression(CompareOperator::NAME, lhs.Copy(), rhs.Copy()); \ - } -COMPARISON_FACTORY(EQUAL, equal, ==) -COMPARISON_FACTORY(NOT_EQUAL, not_equal, !=) -COMPARISON_FACTORY(GREATER, greater, >) -COMPARISON_FACTORY(GREATER_EQUAL, greater_equal, >=) -COMPARISON_FACTORY(LESS, less, <) -COMPARISON_FACTORY(LESS_EQUAL, less_equal, <=) -#undef COMPARISON_FACTORY - -inline std::shared_ptr field_ref(std::string name) { - return std::make_shared(std::move(name)); -} - -inline namespace string_literals { -// clang-format off -inline FieldExpression operator"" _(const char* name, size_t name_length) { - // clang-format on - return FieldExpression({name, name_length}); -} -} // namespace string_literals - -template -bool Expression::Equals(T&& t) const { - if (type_ != ExpressionType::SCALAR) { - return false; - } - auto s = MakeScalar(std::forward(t)); - return internal::checked_cast(*this).value()->Equals(*s); -} - -template -auto VisitExpression(const Expression& expr, Visitor&& visitor) - -> decltype(visitor(expr)) { - switch (expr.type()) { - case ExpressionType::FIELD: - return visitor(internal::checked_cast(expr)); - - case ExpressionType::SCALAR: - return visitor(internal::checked_cast(expr)); - - case ExpressionType::IN: - return visitor(internal::checked_cast(expr)); - - case ExpressionType::IS_VALID: - return visitor(internal::checked_cast(expr)); - - case ExpressionType::AND: - return visitor(internal::checked_cast(expr)); - - case ExpressionType::OR: - return visitor(internal::checked_cast(expr)); - - case ExpressionType::NOT: - return visitor(internal::checked_cast(expr)); - - case ExpressionType::CAST: - return visitor(internal::checked_cast(expr)); - - case ExpressionType::COMPARISON: - return visitor(internal::checked_cast(expr)); - - case ExpressionType::CUSTOM: - default: - break; - } - return visitor(internal::checked_cast(expr)); -} - -/// \brief Visit each subexpression of an arbitrarily nested conjunction. -/// -/// | given | visit | -/// |--------------------------------|---------------------------------------------| -/// | a and b | visit(a), visit(b) | -/// | c | visit(c) | -/// | (a and b) and ((c or d) and e) | visit(a), visit(b), visit(c or d), visit(e) | -ARROW_DS_EXPORT Status VisitConjunctionMembers( - const Expression& expr, const std::function& visitor); - -/// \brief Insert CastExpressions where necessary to make a valid expression. -ARROW_DS_EXPORT Result> InsertImplicitCasts( - const Expression& expr, const Schema& schema); - -/// \brief Returns field names referenced in the expression. -ARROW_DS_EXPORT std::vector FieldsInExpression(const Expression& expr); - -ARROW_DS_EXPORT std::vector FieldsInExpression( - const std::shared_ptr& expr); - -/// Interface for evaluation of expressions against record batches. -class ARROW_DS_EXPORT ExpressionEvaluator { - public: - virtual ~ExpressionEvaluator() = default; - - /// Evaluate expr against each row of a RecordBatch. - /// Returned Datum will be of either SCALAR or ARRAY kind. - /// A return value of ARRAY kind will have length == batch.num_rows() - /// An return value of SCALAR kind is equivalent to an array of the same type whose - /// slots contain a single repeated value. - /// - /// expr must be validated against the schema of batch before calling this method. - virtual Result Evaluate(const Expression& expr, const RecordBatch& batch, - MemoryPool* pool) const = 0; - - Result Evaluate(const Expression& expr, const RecordBatch& batch) const { - return Evaluate(expr, batch, default_memory_pool()); - } - - virtual Result> Filter( - const Datum& selection, const std::shared_ptr& batch, - MemoryPool* pool) const = 0; - - Result> Filter( - const Datum& selection, const std::shared_ptr& batch) const { - return Filter(selection, batch, default_memory_pool()); - } - - /// \brief Wrap an iterator of record batches with a filter expression. The resulting - /// iterator will yield record batches filtered by the given expression. - /// - /// \note The ExpressionEvaluator must outlive the returned iterator. - RecordBatchIterator FilterBatches(RecordBatchIterator unfiltered, - std::shared_ptr filter, - MemoryPool* pool = default_memory_pool()); - - /// construct an Evaluator which evaluates all expressions to null and does no - /// filtering - static std::shared_ptr Null(); -}; - -/// construct an Evaluator which uses compute kernels to evaluate expressions and -/// filter record batches in depth first order -class ARROW_DS_EXPORT TreeEvaluator : public ExpressionEvaluator { - public: - Result Evaluate(const Expression& expr, const RecordBatch& batch, - MemoryPool* pool) const override; - - Result> Filter(const Datum& selection, - const std::shared_ptr& batch, - MemoryPool* pool) const override; - - protected: - struct Impl; -}; - -/// \brief Assemble lists of indices of identical rows. -/// -/// \param[in] by A StructArray whose columns will be used as grouping criteria. -/// \return A StructArray mapping unique rows (in field "values", represented as a -/// StructArray with the same fields as `by`) to lists of indices where -/// that row appears (in field "groupings"). -ARROW_DS_EXPORT -Result> MakeGroupings(const StructArray& by); - -/// \brief Produce slices of an Array which correspond to the provided groupings. -ARROW_DS_EXPORT -Result> ApplyGroupings(const ListArray& groupings, - const Array& array); -ARROW_DS_EXPORT -Result ApplyGroupings(const ListArray& groupings, - const std::shared_ptr& batch); - -} // namespace dataset -} // namespace arrow +// FIXME remove diff --git a/cpp/src/arrow/dataset/filter_test.cc b/cpp/src/arrow/dataset/filter_test.cc index 7723912eeab..ef32b8a7bb6 100644 --- a/cpp/src/arrow/dataset/filter_test.cc +++ b/cpp/src/arrow/dataset/filter_test.cc @@ -15,645 +15,4 @@ // specific language governing permissions and limitations // under the License. -#include "arrow/dataset/filter.h" - -#include -#include -#include -#include - -#include -#include - -#include "arrow/compute/api.h" -#include "arrow/dataset/test_util.h" -#include "arrow/record_batch.h" -#include "arrow/status.h" -#include "arrow/testing/gtest_util.h" -#include "arrow/type.h" -#include "arrow/type_fwd.h" -#include "arrow/util/checked_cast.h" -#include "arrow/util/logging.h" - -namespace arrow { -namespace dataset { - -// clang-format off -using string_literals::operator"" _; -// clang-format on - -using internal::checked_cast; -using internal::checked_pointer_cast; - -using E = TestExpression; - -class ExpressionsTest : public ::testing::Test { - public: - void AssertSimplifiesTo(const Expression& expr, const Expression& given, - const Expression& expected) { - ASSERT_OK_AND_ASSIGN(auto expr_type, expr.Validate(*schema_)); - ASSERT_OK_AND_ASSIGN(auto given_type, given.Validate(*schema_)); - ASSERT_OK_AND_ASSIGN(auto expected_type, expected.Validate(*schema_)); - - EXPECT_TRUE(expr_type->Equals(expected_type)); - EXPECT_TRUE(given_type->Equals(boolean())); - - auto simplified = expr.Assume(given); - ASSERT_EQ(E{simplified}, E{expected}) - << " simplification of: " << expr.ToString() << std::endl - << " given: " << given.ToString() << std::endl; - } - - void AssertSimplifiesTo(const Expression& expr, const Expression& given, - const std::shared_ptr& expected) { - AssertSimplifiesTo(expr, given, *expected); - } - - std::shared_ptr ns = timestamp(TimeUnit::NANO); - std::shared_ptr schema_ = - schema({field("a", int32()), field("b", int32()), field("f", float64()), - field("s", utf8()), field("ts", ns), - field("dict_b", dictionary(int32(), int32()))}); - std::shared_ptr always = scalar(true); - std::shared_ptr never = scalar(false); -}; - -TEST_F(ExpressionsTest, StringRepresentation) { - ASSERT_EQ("a"_.ToString(), "a"); - ASSERT_EQ(("a"_ > int32_t(3)).ToString(), "(a > 3:int32)"); - ASSERT_EQ(("a"_ > int32_t(3) and "a"_ < int32_t(4)).ToString(), - "((a > 3:int32) and (a < 4:int32))"); - ASSERT_EQ(("f"_ > double(4)).ToString(), "(f > 4:double)"); - ASSERT_EQ("f"_.CastTo(float64()).ToString(), "(cast f to double)"); - ASSERT_EQ("f"_.CastLike("a"_).ToString(), "(cast f like a)"); -} - -TEST_F(ExpressionsTest, Equality) { - ASSERT_EQ(E{"a"_}, E{"a"_}); - ASSERT_NE(E{"a"_}, E{"b"_}); - - ASSERT_EQ(E{"b"_ == 3}, E{"b"_ == 3}); - ASSERT_NE(E{"b"_ == 3}, E{"b"_ < 3}); - ASSERT_NE(E{"b"_ == 3}, E{"b"_}); - - // ordering matters - ASSERT_EQ(E{"b"_ == 3}, E{"b"_ == 3}); - ASSERT_NE(E{"b"_ == 3}, E{"b"_ < 3}); - ASSERT_NE(E{"b"_ == 3}, E{"b"_}); - - ASSERT_EQ(E("b"_ > 2 and "b"_ < 3), E("b"_ > 2 and "b"_ < 3)); - ASSERT_NE(E("b"_ > 2 and "b"_ < 3), E("b"_ < 3 and "b"_ > 2)); -} - -TEST_F(ExpressionsTest, SimplificationOfCompoundQuery) { - AssertSimplifiesTo("b"_ > 5 and "b"_ < 10, "b"_ == 3, *never); - AssertSimplifiesTo("b"_ > 5 and "b"_ < 10, "b"_ == 6, *always); - - AssertSimplifiesTo("b"_ == 3 or "b"_ == 4, "b"_ > 6, *never); - AssertSimplifiesTo("b"_ == 3 or "b"_ == 4, "b"_ == 3, *always); - AssertSimplifiesTo("b"_ == 3 or "b"_ == 4, "b"_ > 3, "b"_ == 4); - AssertSimplifiesTo("b"_ == 3 or "b"_ == 4, "b"_ >= 3, "b"_ == 3 or "b"_ == 4); - - AssertSimplifiesTo("f"_ > 0.5 and "f"_ < 1.5, not("f"_ < 0.0 or "f"_ > 1.0), - "f"_ > 0.5); - - AssertSimplifiesTo("b"_ == 4, "a"_ == 0, "b"_ == 4); - - AssertSimplifiesTo("a"_ == 3 or "b"_ == 4, "a"_ == 0, "b"_ == 4); - - auto set_123 = ArrayFromJSON(int32(), R"([1, 2, 3])"); - AssertSimplifiesTo("a"_ == 3 and "b"_.In(set_123), "b"_ == 3, "a"_ == 3); - AssertSimplifiesTo("a"_ == 3 and "b"_.In(set_123), "b"_ == 0, *never); - - AssertSimplifiesTo("a"_ == 0 or not"b"_.IsValid(), "b"_ == 3, "a"_ == 0); -} - -TEST_F(ExpressionsTest, SimplificationAgainstCompoundCondition) { - AssertSimplifiesTo("b"_ > 5, "b"_ == 3 or "b"_ == 6, "b"_ > 5); - AssertSimplifiesTo("b"_ > 7, "b"_ == 3 or "b"_ == 6, *never); - AssertSimplifiesTo("b"_ > 5 and "b"_ < 10, "b"_ > 6 and "b"_ < 13, "b"_ < 10); - - auto set_123 = ArrayFromJSON(int32(), R"([1, 2, 3])"); - AssertSimplifiesTo("b"_.In(set_123), "a"_ == 3 and "b"_ == 3, *always); - AssertSimplifiesTo("b"_.In(set_123), "a"_ == 3 and "b"_ == 5, *never); - - auto dict_set_123 = - DictArrayFromJSON(dictionary(int32(), int32()), R"([1,2,0])", R"([1,2,3])"); - ASSERT_OK_AND_ASSIGN(auto b_dict, dict_set_123->GetScalar(0)); - AssertSimplifiesTo("b_dict"_.In(dict_set_123), "a"_ == 3 and "b_dict"_ == b_dict, - *always); -} - -TEST_F(ExpressionsTest, SimplificationToNull) { - auto null = scalar(std::make_shared()); - auto null32 = scalar(std::make_shared()); - - AssertSimplifiesTo(*equal(field_ref("b"), null32), "b"_ == 3, *null); - AssertSimplifiesTo(*not_equal(field_ref("b"), null32), "b"_ == 3, *null); - - // Kleene logic applies here - AssertSimplifiesTo(*not_equal(field_ref("b"), null32) and "b"_ > 3, "b"_ == 3, *never); - AssertSimplifiesTo(*not_equal(field_ref("b"), null32) and "b"_ > 2, "b"_ == 3, *null); - AssertSimplifiesTo(*not_equal(field_ref("b"), null32) or "b"_ > 3, "b"_ == 3, *null); - AssertSimplifiesTo(*not_equal(field_ref("b"), null32) or "b"_ > 2, "b"_ == 3, *always); -} - -class FilterTest : public ::testing::Test { - public: - FilterTest() { evaluator_ = std::make_shared(); } - - Result DoFilter(const Expression& expr, - std::vector> fields, - std::string batch_json, - std::shared_ptr* expected_mask = nullptr) { - // expected filter result is in the "in" field - fields.push_back(field("in", boolean())); - auto batch = RecordBatchFromJSON(schema(fields), batch_json); - if (expected_mask) { - *expected_mask = checked_pointer_cast(batch->GetColumnByName("in")); - } - - ARROW_ASSIGN_OR_RAISE(auto expr_type, expr.Validate(*batch->schema())); - EXPECT_TRUE(expr_type->Equals(boolean())); - - return evaluator_->Evaluate(expr, *batch); - } - - void AssertFilter(const std::shared_ptr& expr, - std::vector> fields, - const std::string& batch_json) { - AssertFilter(*expr, std::move(fields), batch_json); - } - - void AssertFilter(const Expression& expr, std::vector> fields, - const std::string& batch_json) { - std::shared_ptr expected_mask; - - ASSERT_OK_AND_ASSIGN(Datum mask, DoFilter(expr, std::move(fields), - std::move(batch_json), &expected_mask)); - ASSERT_TRUE(mask.type()->Equals(null()) || mask.type()->Equals(boolean())); - - if (mask.is_array()) { - AssertArraysEqual(*expected_mask, *mask.make_array(), /*verbose=*/true); - return; - } - - ASSERT_TRUE(mask.is_scalar()); - auto mask_scalar = mask.scalar(); - if (!mask_scalar->is_valid) { - ASSERT_EQ(expected_mask->null_count(), expected_mask->length()); - return; - } - - TypedBufferBuilder builder; - ASSERT_OK(builder.Append(expected_mask->length(), - checked_cast(*mask_scalar).value)); - - std::shared_ptr values; - ASSERT_OK(builder.Finish(&values)); - - ASSERT_ARRAYS_EQUAL(*expected_mask, BooleanArray(expected_mask->length(), values)); - } - - std::shared_ptr evaluator_; -}; - -TEST_F(FilterTest, Trivial) { - // Note that we should expect these trivial expressions will never be evaluated against - // record batches; since they're trivial, evaluation is not necessary. - AssertFilter(scalar(true), {field("a", int32()), field("b", float64())}, R"([ - {"a": 0, "b": -0.1, "in": 1}, - {"a": 0, "b": 0.3, "in": 1}, - {"a": 1, "b": 0.2, "in": 1}, - {"a": 2, "b": -0.1, "in": 1}, - {"a": 0, "b": 0.1, "in": 1}, - {"a": 0, "b": null, "in": 1}, - {"a": 0, "b": 1.0, "in": 1} - ])"); - - AssertFilter(scalar(false), {field("a", int32()), field("b", float64())}, R"([ - {"a": 0, "b": -0.1, "in": 0}, - {"a": 0, "b": 0.3, "in": 0}, - {"a": 1, "b": 0.2, "in": 0}, - {"a": 2, "b": -0.1, "in": 0}, - {"a": 0, "b": 0.1, "in": 0}, - {"a": 0, "b": null, "in": 0}, - {"a": 0, "b": 1.0, "in": 0} - ])"); - - AssertFilter(*scalar(std::shared_ptr(new BooleanScalar)), - {field("a", int32()), field("b", float64())}, R"([ - {"a": 0, "b": -0.1, "in": null}, - {"a": 0, "b": 0.3, "in": null}, - {"a": 1, "b": 0.2, "in": null}, - {"a": 2, "b": -0.1, "in": null}, - {"a": 0, "b": 0.1, "in": null}, - {"a": 0, "b": null, "in": null}, - {"a": 0, "b": 1.0, "in": null} - ])"); -} - -TEST_F(FilterTest, Basics) { - AssertFilter("a"_ == 0 and "b"_ > 0.0 and "b"_ < 1.0, - {field("a", int32()), field("b", float64())}, R"([ - {"a": 0, "b": -0.1, "in": 0}, - {"a": 0, "b": 0.3, "in": 1}, - {"a": 1, "b": 0.2, "in": 0}, - {"a": 2, "b": -0.1, "in": 0}, - {"a": 0, "b": 0.1, "in": 1}, - {"a": 0, "b": null, "in": null}, - {"a": 0, "b": 1.0, "in": 0} - ])"); - - AssertFilter("a"_ != 0 and "b"_ > 0.1, {field("a", int32()), field("b", float64())}, - R"([ - {"a": 0, "b": -0.1, "in": 0}, - {"a": 0, "b": 0.3, "in": 0}, - {"a": 1, "b": 0.2, "in": 1}, - {"a": 2, "b": -0.1, "in": 0}, - {"a": 0, "b": 0.1, "in": 0}, - {"a": 0, "b": null, "in": 0}, - {"a": 0, "b": 1.0, "in": 0} - ])"); -} - -TEST_F(FilterTest, InExpression) { - auto hello_world = ArrayFromJSON(utf8(), R"(["hello", "world"])"); - - AssertFilter("s"_.In(hello_world), {field("s", utf8())}, R"([ - {"s": "hello", "in": 1}, - {"s": "world", "in": 1}, - {"s": "", "in": 0}, - {"s": null, "in": null}, - {"s": "foo", "in": 0}, - {"s": "hello", "in": 1}, - {"s": "bar", "in": 0} - ])"); -} - -TEST_F(FilterTest, IsValidExpression) { - AssertFilter("s"_.IsValid(), {field("s", utf8())}, R"([ - {"s": "hello", "in": 1}, - {"s": null, "in": 0}, - {"s": "", "in": 1}, - {"s": null, "in": 0}, - {"s": "foo", "in": 1}, - {"s": "hello", "in": 1}, - {"s": null, "in": 0} - ])"); -} - -TEST_F(FilterTest, Cast) { - ASSERT_RAISES(TypeError, ("a"_ == double(1.0)).Validate(Schema({field("a", int32())}))); - - AssertFilter("a"_.CastTo(float64()) == double(1.0), - {field("a", int32()), field("b", float64())}, R"([ - {"a": 0, "b": -0.1, "in": 0}, - {"a": 0, "b": 0.3, "in": 0}, - {"a": 1, "b": 0.2, "in": 1}, - {"a": 2, "b": -0.1, "in": 0}, - {"a": 0, "b": 0.1, "in": 0}, - {"a": 0, "b": null, "in": 0}, - {"a": 1, "b": 1.0, "in": 1} - ])"); - - AssertFilter("a"_ == scalar(0.6)->CastLike("a"_), - {field("a", int32()), field("b", float64())}, R"([ - {"a": 0, "b": -0.1, "in": 1}, - {"a": 0, "b": 0.3, "in": 1}, - {"a": 1, "b": 0.2, "in": 0}, - {"a": 2, "b": -0.1, "in": 0}, - {"a": 0, "b": 0.1, "in": 1}, - {"a": 0, "b": null, "in": 1}, - {"a": 1, "b": 1.0, "in": 0} - ])"); - - AssertFilter("a"_.CastLike("b"_) == "b"_, {field("a", int32()), field("b", float64())}, - R"([ - {"a": 0, "b": -0.1, "in": 0}, - {"a": 0, "b": 0.0, "in": 1}, - {"a": 1, "b": 1.0, "in": 1}, - {"a": 2, "b": -0.1, "in": 0}, - {"a": 0, "b": 0.1, "in": 0}, - {"a": 2, "b": null, "in": null}, - {"a": 1, "b": 1.0, "in": 1} - ])"); -} - -TEST_F(ExpressionsTest, ImplicitCast) { - ASSERT_OK_AND_ASSIGN(auto filter, InsertImplicitCasts("a"_ == 0.0, *schema_)); - ASSERT_EQ(E{filter}, E{"a"_ == 0}); - - auto ns = timestamp(TimeUnit::NANO); - auto date = "1990-10-23 10:23:33"; - ASSERT_OK_AND_ASSIGN(filter, InsertImplicitCasts("ts"_ == date, *schema_)); - ASSERT_EQ(E{filter}, E{"ts"_ == *MakeScalar(date)->CastTo(ns)}); - - ASSERT_OK_AND_ASSIGN(filter, - InsertImplicitCasts("ts"_ == date and "b"_ == "3", *schema_)); - ASSERT_EQ(E{filter}, E{"ts"_ == *MakeScalar(date)->CastTo(ns) and "b"_ == 3}); - AssertSimplifiesTo(*filter, "b"_ == 2, *never); - AssertSimplifiesTo(*filter, "b"_ == 3, "ts"_ == *MakeScalar(date)->CastTo(ns)); - - // set is double but "a"_ is int32 - auto set_double = ArrayFromJSON(float64(), R"([1, 2, 3])"); - ASSERT_OK_AND_ASSIGN(filter, InsertImplicitCasts("a"_.In(set_double), *schema_)); - auto set_int32 = ArrayFromJSON(int32(), R"([1, 2, 3])"); - ASSERT_EQ(E{filter}, E{"a"_.In(set_int32)}); - - EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, - testing::HasSubstr("Field named 'nope' not found"), - InsertImplicitCasts("nope"_ == 0.0, *schema_)); -} - -TEST_F(ExpressionsTest, ImplicitCastToDict) { - auto dict_type = dictionary(int32(), float64()); - ASSERT_OK_AND_ASSIGN(auto filter, - InsertImplicitCasts("a"_ == 1.5, Schema({field("a", dict_type)}))); - - auto encoded_scalar = std::make_shared( - DictionaryScalar::ValueType{MakeScalar(0), - ArrayFromJSON(float64(), "[1.5]")}, - dict_type); - - ASSERT_EQ(E{filter}, E{"a"_ == encoded_scalar}); - - for (int32_t i = 0; i < 5; ++i) { - auto partition_scalar = std::make_shared( - DictionaryScalar::ValueType{ - MakeScalar(i), ArrayFromJSON(float64(), "[0.0, 0.5, 1.0, 1.5, 2.0]")}, - dict_type); - ASSERT_EQ(E{filter->Assume("a"_ == partition_scalar)}, E{scalar(i == 3)}); - } - - auto set_f64 = ArrayFromJSON(float64(), "[0.0, 0.5, 1.0, 1.5, 2.0]"); - ASSERT_OK_AND_ASSIGN( - filter, InsertImplicitCasts("a"_.In(set_f64), Schema({field("a", dict_type)}))); -} - -TEST_F(FilterTest, ImplicitCast) { - ASSERT_OK_AND_ASSIGN(auto filter, - InsertImplicitCasts("a"_ >= "1", Schema({field("a", int32())}))); - - AssertFilter(*filter, {field("a", int32()), field("b", float64())}, - R"([ - {"a": 0, "b": -0.1, "in": 0}, - {"a": 0, "b": 0.0, "in": 0}, - {"a": 1, "b": 1.0, "in": 1}, - {"a": 2, "b": -0.1, "in": 1}, - {"a": 0, "b": 0.1, "in": 0}, - {"a": 2, "b": null, "in": 1}, - {"a": 1, "b": 1.0, "in": 1} - ])"); -} - -TEST_F(FilterTest, ConditionOnAbsentColumn) { - AssertFilter("a"_ == 0 and "b"_ > 0.0 and "b"_ < 1.0 and "absent"_ == 0, - {field("a", int32()), field("b", float64())}, R"([ - {"a": 0, "b": -0.1, "in": false}, - {"a": 0, "b": 0.3, "in": null}, - {"a": 1, "b": 0.2, "in": false}, - {"a": 2, "b": -0.1, "in": false}, - {"a": 0, "b": 0.1, "in": null}, - {"a": 0, "b": null, "in": null}, - {"a": 0, "b": 1.0, "in": false} - ])"); -} - -TEST_F(FilterTest, KleeneTruthTables) { - AssertFilter("a"_ and "b"_, {field("a", boolean()), field("b", boolean())}, R"([ - {"a":null, "b":null, "in":null}, - {"a":null, "b":true, "in":null}, - {"a":null, "b":false, "in":false}, - - {"a":true, "b":true, "in":true}, - {"a":true, "b":false, "in":false}, - - {"a":false, "b":false, "in":false} - ])"); - - AssertFilter("a"_ or "b"_, {field("a", boolean()), field("b", boolean())}, R"([ - {"a":null, "b":null, "in":null}, - {"a":null, "b":true, "in":true}, - {"a":null, "b":false, "in":null}, - - {"a":true, "b":true, "in":true}, - {"a":true, "b":false, "in":true}, - - {"a":false, "b":false, "in":false} - ])"); -} - -class TakeExpression : public CustomExpression { - public: - TakeExpression(std::shared_ptr operand, std::shared_ptr dictionary) - : operand_(std::move(operand)), dictionary_(std::move(dictionary)) {} - - std::string ToString() const override { - return dictionary_->ToString() + "[" + operand_->ToString() + "]"; - } - - std::shared_ptr Copy() const override { - return std::make_shared(*this); - } - - bool Equals(const Expression& other) const override { - // in a real CustomExpression this would need to be more sophisticated - return other.type() == ExpressionType::CUSTOM && ToString() == other.ToString(); - } - - Result> Validate(const Schema& schema) const override { - ARROW_ASSIGN_OR_RAISE(auto operand_type, operand_->Validate(schema)); - if (!is_integer(operand_type->id())) { - return Status::TypeError("Take indices must be integral, not ", *operand_type); - } - return dictionary_->type(); - } - - class Evaluator : public TreeEvaluator { - public: - using TreeEvaluator::TreeEvaluator; - - using TreeEvaluator::Evaluate; - - Result Evaluate(const Expression& expr, const RecordBatch& batch, - MemoryPool* pool) const override { - if (expr.type() == ExpressionType::CUSTOM) { - const auto& take_expr = checked_cast(expr); - return EvaluateTake(take_expr, batch, pool); - } - return TreeEvaluator::Evaluate(expr, batch, pool); - } - - Result EvaluateTake(const TakeExpression& take_expr, const RecordBatch& batch, - MemoryPool* pool) const { - ARROW_ASSIGN_OR_RAISE(auto indices, Evaluate(*take_expr.operand_, batch, pool)); - - if (indices.kind() == Datum::SCALAR) { - ARROW_ASSIGN_OR_RAISE(auto indices_array, - MakeArrayFromScalar(*indices.scalar(), batch.num_rows(), - default_memory_pool())); - indices = Datum(indices_array->data()); - } - - DCHECK_EQ(indices.kind(), Datum::ARRAY); - compute::ExecContext ctx(pool); - ARROW_ASSIGN_OR_RAISE(Datum out, - compute::Take(take_expr.dictionary_->data(), indices, - compute::TakeOptions(), &ctx)); - return std::move(out); - } - }; - - private: - std::shared_ptr operand_; - std::shared_ptr dictionary_; -}; - -TEST_F(ExpressionsTest, TakeAssumeYieldsNothing) { - auto dict = ArrayFromJSON(float64(), "[0.0, 0.25, 0.5, 0.75, 1.0]"); - auto take_b_is_half = (TakeExpression(field_ref("b"), dict) == 0.5); - - // no special Assume logic was provided for TakeExpression so we should just ignore it - // (logically the below *could* be simplified to false but we haven't implemented that) - AssertSimplifiesTo(take_b_is_half, "b"_ == 3, take_b_is_half); - - // custom expressions will not interfere with simplification of other subexpressions and - // can be dropped if other subexpressions simplify trivially - - // ("b"_ > 5).Assume("b"_ == 3) simplifies to false regardless of take, so the and will - // be false - AssertSimplifiesTo("b"_ > 5 and take_b_is_half, "b"_ == 3, *never); - - // ("b"_ > 5).Assume("b"_ == 6) simplifies to true regardless of take, so it can be - // dropped - AssertSimplifiesTo("b"_ > 5 and take_b_is_half, "b"_ == 6, take_b_is_half); - - // ("b"_ > 5).Assume("b"_ == 6) simplifies to true regardless of take, so the or will be - // true - AssertSimplifiesTo(take_b_is_half or "b"_ > 5, "b"_ == 6, *always); - - // ("b"_ > 5).Assume("b"_ == 3) simplifies to true regardless of take, so it can be - // dropped - AssertSimplifiesTo(take_b_is_half or "b"_ > 5, "b"_ == 3, take_b_is_half); -} - -TEST_F(FilterTest, EvaluateTakeExpression) { - evaluator_ = std::make_shared(); - - auto dict = ArrayFromJSON(float64(), "[0.0, 0.25, 0.5, 0.75, 1.0]"); - - AssertFilter(TakeExpression(field_ref("b"), dict) == 0.5, - {field("b", int32()), field("f", float64())}, R"([ - {"b": 3, "f": -0.1, "in": 0}, - {"b": 2, "f": 0.3, "in": 1}, - {"b": 1, "f": 0.2, "in": 0}, - {"b": 2, "f": -0.1, "in": 1}, - {"b": 4, "f": 0.1, "in": 0}, - {"b": null, "f": 0.0, "in": null}, - {"b": 0, "f": 1.0, "in": 0} - ])"); -} - -void AssertFieldsInExpression(std::shared_ptr expr, - std::vector expected) { - EXPECT_THAT(FieldsInExpression(expr), testing::ContainerEq(expected)); -} - -TEST(FieldsInExpressionTest, Basic) { - AssertFieldsInExpression(scalar(true), {}); - - AssertFieldsInExpression(("a"_).Copy(), {"a"}); - AssertFieldsInExpression(("a"_ == 1).Copy(), {"a"}); - AssertFieldsInExpression(("a"_ == "b"_).Copy(), {"a", "b"}); - - AssertFieldsInExpression(("a"_ == 1 || "a"_ == 2).Copy(), {"a", "a"}); - AssertFieldsInExpression(("a"_ == 1 || "b"_ == 2).Copy(), {"a", "b"}); - AssertFieldsInExpression((not("a"_ == 1) && ("b"_ == 2 || not("c"_ < 3))).Copy(), - {"a", "b", "c"}); -} - -TEST(ExpressionSerializationTest, RoundTrips) { - std::vector exprs{ - scalar(MakeNullScalar(null())), - scalar(MakeNullScalar(int32())), - scalar(MakeNullScalar(struct_({field("i", int32()), field("s", utf8())}))), - scalar(true), - scalar(false), - scalar(1), - scalar(1.125), - scalar("stringy strings"), - "field"_, - "a"_ > 0.25, - "a"_ == 1 or "b"_ != "hello" or "b"_ == "foo bar", - not"alpha"_, - "valid"_ and "a"_.CastLike("b"_) >= "b"_, - "version"_.CastTo(float64()).In(ArrayFromJSON(float64(), "[0.5, 1.0, 2.0]")), - "validity"_.IsValid(), - ("x"_ >= -1.5 and "x"_ < 0.0) and ("y"_ >= 0.0 and "y"_ < 1.5) and - ("z"_ > 1.5 and "z"_ <= 3.0), - "year"_ == int16_t(1999) and "month"_ == int8_t(12) and "day"_ == int8_t(31) and - "hour"_ == int8_t(0) and "alpha"_ == int32_t(0) and "beta"_ == 3.25f, - }; - - for (const auto& expr : exprs) { - ASSERT_OK_AND_ASSIGN(auto serialized, expr.expression->Serialize()); - ASSERT_OK_AND_ASSIGN(E roundtripped, Expression::Deserialize(*serialized)); - ASSERT_EQ(expr, roundtripped); - } -} - -void AssertGrouping(const FieldVector& by_fields, const std::string& batch_json, - const std::string& expected_json) { - FieldVector fields_with_ids = by_fields; - fields_with_ids.push_back(field("ids", list(int32()))); - auto expected = ArrayFromJSON(struct_(fields_with_ids), expected_json); - - FieldVector fields_with_id = by_fields; - fields_with_id.push_back(field("id", int32())); - auto batch = RecordBatchFromJSON(schema(fields_with_id), batch_json); - - ASSERT_OK_AND_ASSIGN(auto by, batch->RemoveColumn(batch->num_columns() - 1) - .Map([](std::shared_ptr by) { - return by->ToStructArray(); - })); - - ASSERT_OK_AND_ASSIGN(auto groupings_and_values, MakeGroupings(*by)); - - auto groupings = - checked_pointer_cast(groupings_and_values->GetFieldByName("groupings")); - - ASSERT_OK_AND_ASSIGN(std::shared_ptr grouped_ids, - ApplyGroupings(*groupings, *batch->GetColumnByName("id"))); - - ArrayVector columns = - checked_cast(*groupings_and_values->GetFieldByName("values")) - .fields(); - columns.push_back(grouped_ids); - - ASSERT_OK_AND_ASSIGN(auto actual, StructArray::Make(columns, fields_with_ids)); - - AssertArraysEqual(*expected, *actual, /*verbose=*/true); -} - -TEST(GroupTest, Basics) { - AssertGrouping({field("a", utf8()), field("b", int32())}, R"([ - {"a": "ex", "b": 0, "id": 0}, - {"a": "ex", "b": 0, "id": 1}, - {"a": "why", "b": 0, "id": 2}, - {"a": "ex", "b": 1, "id": 3}, - {"a": "why", "b": 0, "id": 4}, - {"a": "ex", "b": 1, "id": 5}, - {"a": "ex", "b": 0, "id": 6}, - {"a": "why", "b": 1, "id": 7} - ])", - R"([ - {"a": "ex", "b": 0, "ids": [0, 1, 6]}, - {"a": "why", "b": 0, "ids": [2, 4]}, - {"a": "ex", "b": 1, "ids": [3, 5]}, - {"a": "why", "b": 1, "ids": [7]} - ])"); -} - -} // namespace dataset -} // namespace arrow +// FIXME delete this diff --git a/cpp/src/arrow/dataset/partition.cc b/cpp/src/arrow/dataset/partition.cc index 159e0ac0331..2822c4a15b6 100644 --- a/cpp/src/arrow/dataset/partition.cc +++ b/cpp/src/arrow/dataset/partition.cc @@ -18,31 +18,22 @@ #include "arrow/dataset/partition.h" #include -#include #include -#include -#include #include #include #include "arrow/array/array_base.h" #include "arrow/array/array_nested.h" -#include "arrow/array/builder_binary.h" #include "arrow/array/builder_dict.h" #include "arrow/compute/api_scalar.h" +#include "arrow/compute/api_vector.h" +#include "arrow/compute/cast.h" #include "arrow/dataset/dataset_internal.h" -#include "arrow/dataset/file_base.h" -#include "arrow/dataset/filter.h" -#include "arrow/dataset/scanner.h" -#include "arrow/dataset/scanner_internal.h" -#include "arrow/filesystem/filesystem.h" #include "arrow/filesystem/path_util.h" #include "arrow/scalar.h" -#include "arrow/util/iterator.h" +#include "arrow/util/int_util_internal.h" #include "arrow/util/logging.h" #include "arrow/util/make_unique.h" -#include "arrow/util/range.h" -#include "arrow/util/sort.h" #include "arrow/util/string_view.h" namespace arrow { @@ -60,8 +51,8 @@ std::shared_ptr Partitioning::Default() { std::string type_name() const override { return "default"; } - Result> Parse(const std::string& path) const override { - return scalar(true); + Result Parse(const std::string& path) const override { + return literal(true); } Result Format(const Expression& expr) const override { @@ -71,70 +62,36 @@ std::shared_ptr Partitioning::Default() { Result Partition( const std::shared_ptr& batch) const override { - return PartitionedBatches{{batch}, {scalar(true)}}; + return PartitionedBatches{{batch}, {literal(true)}}; } }; return std::make_shared(); } -Status KeyValuePartitioning::VisitKeys( - const Expression& expr, - const std::function& value)>& visitor) { - return VisitConjunctionMembers(expr, [visitor](const Expression& expr) { - if (expr.type() != ExpressionType::COMPARISON) { - return Status::OK(); - } - - const auto& cmp = checked_cast(expr); - if (cmp.op() != compute::CompareOperator::EQUAL) { - return Status::OK(); - } - - auto lhs = cmp.left_operand().get(); - auto rhs = cmp.right_operand().get(); - if (lhs->type() != ExpressionType::FIELD) std::swap(lhs, rhs); - - if (lhs->type() != ExpressionType::FIELD || rhs->type() != ExpressionType::SCALAR) { - return Status::OK(); +Status KeyValuePartitioning::SetDefaultValuesFromKeys(const Expression& expr, + RecordBatchProjector* projector) { + ARROW_ASSIGN_OR_RAISE(auto known_values, ExtractKnownFieldValues(expr)); + for (const auto& ref_value : known_values) { + if (!ref_value.second.is_scalar()) { + return Status::Invalid("non-scalar partition key ", ref_value.second.ToString()); } - return visitor(checked_cast(lhs)->name(), - checked_cast(rhs)->value()); - }); -} - -Result>> -KeyValuePartitioning::GetKeys(const Expression& expr) { - std::unordered_map> keys; - RETURN_NOT_OK( - VisitKeys(expr, [&](const std::string& name, const std::shared_ptr& value) { - keys.emplace(name, value); - return Status::OK(); - })); - return keys; -} + ARROW_ASSIGN_OR_RAISE(auto match, + ref_value.first.FindOneOrNone(*projector->schema())); -Status KeyValuePartitioning::SetDefaultValuesFromKeys(const Expression& expr, - RecordBatchProjector* projector) { - return KeyValuePartitioning::VisitKeys( - expr, [projector](const std::string& name, const std::shared_ptr& value) { - ARROW_ASSIGN_OR_RAISE(auto match, - FieldRef(name).FindOneOrNone(*projector->schema())); - if (!match) { - return Status::OK(); - } - return projector->SetDefaultValue(match, value); - }); + if (!match) continue; + RETURN_NOT_OK(projector->SetDefaultValue(match, ref_value.second.scalar())); + } + return Status::OK(); } -inline std::shared_ptr ConjunctionFromGroupingRow(Scalar* row) { +inline Expression ConjunctionFromGroupingRow(Scalar* row) { ScalarVector* values = &checked_cast(row)->value; - ExpressionVector equality_expressions(values->size()); + std::vector equality_expressions(values->size()); for (size_t i = 0; i < values->size(); ++i) { const std::string& name = row->type->field(static_cast(i))->name(); - equality_expressions[i] = equal(field_ref(name), scalar(std::move(values->at(i)))); + equality_expressions[i] = equal(field_ref(name), literal(std::move(values->at(i)))); } return and_(std::move(equality_expressions)); } @@ -158,7 +115,7 @@ Result KeyValuePartitioning::Partition( if (by_fields.empty()) { // no fields to group by; return the whole batch - return PartitionedBatches{{batch}, {scalar(true)}}; + return PartitionedBatches{{batch}, {literal(true)}}; } ARROW_ASSIGN_OR_RAISE(auto by, @@ -179,11 +136,10 @@ Result KeyValuePartitioning::Partition( return out; } -Result> KeyValuePartitioning::ConvertKey( - const Key& key) const { +Result KeyValuePartitioning::ConvertKey(const Key& key) const { ARROW_ASSIGN_OR_RAISE(auto match, FieldRef(key.name).FindOneOrNone(*schema_)); if (!match) { - return scalar(true); + return literal(true); } auto field_index = match[0]; @@ -220,18 +176,15 @@ Result> KeyValuePartitioning::ConvertKey( ARROW_ASSIGN_OR_RAISE(converted, Scalar::Parse(field->type(), key.value)); } - return equal(field_ref(field->name()), scalar(std::move(converted))); + return equal(field_ref(field->name()), literal(std::move(converted))); } -Result> KeyValuePartitioning::Parse( - const std::string& path) const { - ExpressionVector expressions; +Result KeyValuePartitioning::Parse(const std::string& path) const { + std::vector expressions; for (const Key& key : ParseKeys(path)) { ARROW_ASSIGN_OR_RAISE(auto expr, ConvertKey(key)); - - if (expr->Equals(true)) continue; - + if (expr == literal(true)) continue; expressions.push_back(std::move(expr)); } @@ -241,20 +194,25 @@ Result> KeyValuePartitioning::Parse( Result KeyValuePartitioning::Format(const Expression& expr) const { std::vector values{static_cast(schema_->num_fields()), nullptr}; - RETURN_NOT_OK(VisitKeys(expr, [&](const std::string& name, - const std::shared_ptr& value) { - ARROW_ASSIGN_OR_RAISE(auto match, FieldRef(name).FindOneOrNone(*schema_)); - if (match) { - const auto& field = schema_->field(match[0]); - if (!value->type->Equals(field->type())) { - return Status::TypeError("scalar ", value->ToString(), " (of type ", *value->type, - ") is invalid for ", field->ToString()); - } + ARROW_ASSIGN_OR_RAISE(auto known_values, ExtractKnownFieldValues(expr)); + for (const auto& ref_value : known_values) { + if (!ref_value.second.is_scalar()) { + return Status::Invalid("non-scalar partition key ", ref_value.second.ToString()); + } - values[match[0]] = value.get(); + ARROW_ASSIGN_OR_RAISE(auto match, ref_value.first.FindOneOrNone(*schema_)); + if (!match) continue; + + const auto& value = ref_value.second.scalar(); + + const auto& field = schema_->field(match[0]); + if (!value->type->Equals(field->type())) { + return Status::TypeError("scalar ", value->ToString(), " (of type ", *value->type, + ") is invalid for ", field->ToString()); } - return Status::OK(); - })); + + values[match[0]] = value.get(); + } return FormatValues(values); } @@ -573,5 +531,192 @@ Result> PartitioningOrFactory::GetOrInferSchema( return factory()->Inspect(paths); } +// Transform an array of counts to offsets which will divide a ListArray +// into an equal number of slices with corresponding lengths. +inline Result> CountsToOffsets( + std::shared_ptr counts) { + Int32Builder offset_builder; + RETURN_NOT_OK(offset_builder.Resize(counts->length() + 1)); + offset_builder.UnsafeAppend(0); + + for (int64_t i = 0; i < counts->length(); ++i) { + DCHECK_NE(counts->Value(i), 0); + auto next_offset = static_cast(offset_builder[i] + counts->Value(i)); + offset_builder.UnsafeAppend(next_offset); + } + + std::shared_ptr offsets; + RETURN_NOT_OK(offset_builder.Finish(&offsets)); + return offsets; +} + +// Helper for simultaneous dictionary encoding of multiple arrays. +// +// The fused dictionary is the Cartesian product of the individual dictionaries. +// For example given two arrays A, B where A has unique values ["ex", "why"] +// and B has unique values [0, 1] the fused dictionary is the set of tuples +// [["ex", 0], ["ex", 1], ["why", 0], ["ex", 1]]. +// +// TODO(bkietz) this capability belongs in an Action of the hash kernels, where +// it can be used to group aggregates without materializing a grouped batch. +// For the purposes of writing we need the materialized grouped batch anyway +// since no Writers accept a selection vector. +class StructDictionary { + public: + struct Encoded { + std::shared_ptr indices; + std::shared_ptr dictionary; + }; + + static Result Encode(const ArrayVector& columns) { + Encoded out{nullptr, std::make_shared()}; + + for (const auto& column : columns) { + if (column->null_count() != 0) { + return Status::NotImplemented("Grouping on a field with nulls"); + } + + RETURN_NOT_OK(out.dictionary->AddOne(column, &out.indices)); + } + + return out; + } + + Result> Decode(std::shared_ptr fused_indices, + FieldVector fields) { + std::vector builders(dictionaries_.size()); + for (Int32Builder& b : builders) { + RETURN_NOT_OK(b.Resize(fused_indices->length())); + } + + std::vector codes(dictionaries_.size()); + for (int64_t i = 0; i < fused_indices->length(); ++i) { + Expand(fused_indices->Value(i), codes.data()); + + auto builder_it = builders.begin(); + for (int32_t index : codes) { + builder_it++->UnsafeAppend(index); + } + } + + ArrayVector columns(dictionaries_.size()); + for (size_t i = 0; i < dictionaries_.size(); ++i) { + std::shared_ptr indices; + RETURN_NOT_OK(builders[i].FinishInternal(&indices)); + + ARROW_ASSIGN_OR_RAISE(Datum column, compute::Take(dictionaries_[i], indices)); + columns[i] = column.make_array(); + } + + return StructArray::Make(std::move(columns), std::move(fields)); + } + + private: + Status AddOne(Datum column, std::shared_ptr* fused_indices) { + ArrayData* encoded; + if (column.type()->id() != Type::DICTIONARY) { + ARROW_ASSIGN_OR_RAISE(column, compute::DictionaryEncode(column)); + } + encoded = column.mutable_array(); + + auto indices = + std::make_shared(encoded->length, std::move(encoded->buffers[1])); + + dictionaries_.push_back(MakeArray(std::move(encoded->dictionary))); + auto dictionary_size = static_cast(dictionaries_.back()->length()); + + if (*fused_indices == nullptr) { + *fused_indices = std::move(indices); + size_ = dictionary_size; + return Status::OK(); + } + + // It's useful to think about the case where each of dictionaries_ has size 10. + // In this case the decimal digit in the ones place is the code in dictionaries_[0], + // the tens place corresponds to dictionaries_[1], etc. + // The incumbent indices must be shifted to the hundreds place so as not to collide. + ARROW_ASSIGN_OR_RAISE(Datum new_fused_indices, + compute::Multiply(indices, MakeScalar(size_))); + + ARROW_ASSIGN_OR_RAISE(new_fused_indices, + compute::Add(new_fused_indices, *fused_indices)); + + *fused_indices = checked_pointer_cast(new_fused_indices.make_array()); + + // XXX should probably cap this at 2**15 or so + ARROW_CHECK(!internal::MultiplyWithOverflow(size_, dictionary_size, &size_)); + return Status::OK(); + } + + // expand a fused code into component dict codes, order is in order of addition + void Expand(int32_t fused_code, int32_t* codes) { + for (size_t i = 0; i < dictionaries_.size(); ++i) { + auto dictionary_size = static_cast(dictionaries_[i]->length()); + codes[i] = fused_code % dictionary_size; + fused_code /= dictionary_size; + } + } + + int32_t size_; + ArrayVector dictionaries_; +}; + +Result> MakeGroupings(const StructArray& by) { + if (by.num_fields() == 0) { + return Status::NotImplemented("Grouping with no criteria"); + } + + ARROW_ASSIGN_OR_RAISE(auto fused, StructDictionary::Encode(by.fields())); + + ARROW_ASSIGN_OR_RAISE(auto sort_indices, compute::SortIndices(*fused.indices)); + ARROW_ASSIGN_OR_RAISE(Datum sorted, compute::Take(fused.indices, *sort_indices)); + fused.indices = checked_pointer_cast(sorted.make_array()); + + ARROW_ASSIGN_OR_RAISE(auto fused_counts_and_values, + compute::ValueCounts(fused.indices)); + fused.indices.reset(); + + auto unique_fused_indices = + checked_pointer_cast(fused_counts_and_values->GetFieldByName("values")); + ARROW_ASSIGN_OR_RAISE( + auto unique_rows, + fused.dictionary->Decode(std::move(unique_fused_indices), by.type()->fields())); + + auto counts = + checked_pointer_cast(fused_counts_and_values->GetFieldByName("counts")); + ARROW_ASSIGN_OR_RAISE(auto offsets, CountsToOffsets(std::move(counts))); + + ARROW_ASSIGN_OR_RAISE(auto grouped_sort_indices, + ListArray::FromArrays(*offsets, *sort_indices)); + + return StructArray::Make( + ArrayVector{std::move(unique_rows), std::move(grouped_sort_indices)}, + std::vector{"values", "groupings"}); +} + +Result> ApplyGroupings(const ListArray& groupings, + const Array& array) { + ARROW_ASSIGN_OR_RAISE(Datum sorted, + compute::Take(array, groupings.data()->child_data[0])); + + return std::make_shared(list(array.type()), groupings.length(), + groupings.value_offsets(), sorted.make_array()); +} + +Result ApplyGroupings(const ListArray& groupings, + const std::shared_ptr& batch) { + ARROW_ASSIGN_OR_RAISE(Datum sorted, + compute::Take(batch, groupings.data()->child_data[0])); + + const auto& sorted_batch = *sorted.record_batch(); + + RecordBatchVector out(static_cast(groupings.length())); + for (size_t i = 0; i < out.size(); ++i) { + out[i] = sorted_batch.Slice(groupings.value_offset(i), groupings.value_length(i)); + } + + return out; +} + } // namespace dataset } // namespace arrow diff --git a/cpp/src/arrow/dataset/partition.h b/cpp/src/arrow/dataset/partition.h index 165fcfb5248..8975f565b19 100644 --- a/cpp/src/arrow/dataset/partition.h +++ b/cpp/src/arrow/dataset/partition.h @@ -26,7 +26,7 @@ #include #include -#include "arrow/dataset/filter.h" +#include "arrow/dataset/expression.h" #include "arrow/dataset/type_fwd.h" #include "arrow/dataset/visibility.h" #include "arrow/util/optional.h" @@ -63,13 +63,13 @@ class ARROW_DS_EXPORT Partitioning { /// produce sub-batches which satisfy mutually exclusive Expressions. struct PartitionedBatches { RecordBatchVector batches; - ExpressionVector expressions; + std::vector expressions; }; virtual Result Partition( const std::shared_ptr& batch) const = 0; /// \brief Parse a path into a partition expression - virtual Result> Parse(const std::string& path) const = 0; + virtual Result Parse(const std::string& path) const = 0; virtual Result Format(const Expression& expr) const = 0; @@ -122,21 +122,13 @@ class ARROW_DS_EXPORT KeyValuePartitioning : public Partitioning { std::string name, value; }; - static Status VisitKeys( - const Expression& expr, - const std::function& value)>& visitor); - - static Result>> GetKeys( - const Expression& expr); - static Status SetDefaultValuesFromKeys(const Expression& expr, RecordBatchProjector* projector); Result Partition( const std::shared_ptr& batch) const override; - Result> Parse(const std::string& path) const override; + Result Parse(const std::string& path) const override; Result Format(const Expression& expr) const override; @@ -153,7 +145,7 @@ class ARROW_DS_EXPORT KeyValuePartitioning : public Partitioning { virtual Result FormatValues(const std::vector& values) const = 0; /// Convert a Key to a full expression. - Result> ConvertKey(const Key& key) const; + Result ConvertKey(const Key& key) const; ArrayVector dictionaries_; }; @@ -215,8 +207,7 @@ class ARROW_DS_EXPORT HivePartitioning : public KeyValuePartitioning { /// \brief Implementation provided by lambda or other callable class ARROW_DS_EXPORT FunctionPartitioning : public Partitioning { public: - using ParseImpl = - std::function>(const std::string&)>; + using ParseImpl = std::function(const std::string&)>; using FormatImpl = std::function(const Expression&)>; @@ -229,7 +220,7 @@ class ARROW_DS_EXPORT FunctionPartitioning : public Partitioning { std::string type_name() const override { return name_; } - Result> Parse(const std::string& path) const override { + Result Parse(const std::string& path) const override { return parse_impl_(path); } @@ -294,5 +285,22 @@ class ARROW_DS_EXPORT PartitioningOrFactory { std::shared_ptr partitioning_; }; +/// \brief Assemble lists of indices of identical rows. +/// +/// \param[in] by A StructArray whose columns will be used as grouping criteria. +/// \return A StructArray mapping unique rows (in field "values", represented as a +/// StructArray with the same fields as `by`) to lists of indices where +/// that row appears (in field "groupings"). +ARROW_DS_EXPORT +Result> MakeGroupings(const StructArray& by); + +/// \brief Produce slices of an Array which correspond to the provided groupings. +ARROW_DS_EXPORT +Result> ApplyGroupings(const ListArray& groupings, + const Array& array); +ARROW_DS_EXPORT +Result ApplyGroupings(const ListArray& groupings, + const std::shared_ptr& batch); + } // namespace dataset } // namespace arrow diff --git a/cpp/src/arrow/dataset/partition_test.cc b/cpp/src/arrow/dataset/partition_test.cc index f49103a585a..2260eb219da 100644 --- a/cpp/src/arrow/dataset/partition_test.cc +++ b/cpp/src/arrow/dataset/partition_test.cc @@ -21,52 +21,52 @@ #include #include -#include #include #include #include #include -#include "arrow/dataset/file_base.h" +#include "arrow/compute/api_scalar.h" #include "arrow/dataset/scanner_internal.h" #include "arrow/dataset/test_util.h" -#include "arrow/filesystem/localfs.h" #include "arrow/filesystem/path_util.h" #include "arrow/status.h" #include "arrow/testing/gtest_util.h" -#include "arrow/util/io_util.h" namespace arrow { using internal::checked_pointer_cast; namespace dataset { -using E = TestExpression; - class TestPartitioning : public ::testing::Test { public: void AssertParseError(const std::string& path) { ASSERT_RAISES(Invalid, partitioning_->Parse(path)); } - void AssertParse(const std::string& path, E expected) { + void AssertParse(const std::string& path, Expression expected) { ASSERT_OK_AND_ASSIGN(auto parsed, partitioning_->Parse(path)); - ASSERT_EQ(E{parsed}, expected); + ASSERT_EQ(parsed, expected); } template - void AssertFormatError(E expr) { - ASSERT_EQ(partitioning_->Format(*expr.expression).status().code(), code); + void AssertFormatError(Expression expr) { + ASSERT_EQ(partitioning_->Format(expr).status().code(), code); } - void AssertFormat(E expr, const std::string& expected) { - ASSERT_OK_AND_ASSIGN(auto formatted, partitioning_->Format(*expr.expression)); + void AssertFormat(Expression expr, const std::string& expected) { + // formatted partition expressions are bound to the schema of the dataset being + // written + ASSERT_OK_AND_ASSIGN(auto formatted, partitioning_->Format(expr)); ASSERT_EQ(formatted, expected); // ensure the formatted path round trips the relevant components of the partition // expression: roundtripped should be a subset of expr - ASSERT_OK_AND_ASSIGN(auto roundtripped, partitioning_->Parse(formatted)); - ASSERT_EQ(E{roundtripped->Assume(*expr.expression)}, E{scalar(true)}); + ASSERT_OK_AND_ASSIGN(Expression roundtripped, partitioning_->Parse(formatted)); + + ASSERT_OK_AND_ASSIGN(roundtripped, roundtripped.Bind(*written_schema_)); + ASSERT_OK_AND_ASSIGN(auto simplified, SimplifyWithGuarantee(roundtripped, expr)); + ASSERT_EQ(simplified, literal(true)); } void AssertInspect(const std::vector& paths, @@ -99,38 +99,57 @@ class TestPartitioning : public ::testing::Test { std::shared_ptr partitioning_; std::shared_ptr factory_; + std::shared_ptr written_schema_; }; TEST_F(TestPartitioning, DirectoryPartitioning) { partitioning_ = std::make_shared( schema({field("alpha", int32()), field("beta", utf8())})); - AssertParse("/0/hello", "alpha"_ == int32_t(0) and "beta"_ == "hello"); - AssertParse("/3", "alpha"_ == int32_t(3)); - AssertParseError("/world/0"); // reversed order - AssertParseError("/0.0/foo"); // invalid alpha - AssertParseError("/3.25"); // invalid alpha with missing beta - AssertParse("", scalar(true)); // no segments to parse + AssertParse("/0/hello", and_(equal(field_ref("alpha"), literal(0)), + equal(field_ref("beta"), literal("hello")))); + AssertParse("/3", equal(field_ref("alpha"), literal(3))); + AssertParseError("/world/0"); // reversed order + AssertParseError("/0.0/foo"); // invalid alpha + AssertParseError("/3.25"); // invalid alpha with missing beta + AssertParse("", literal(true)); // no segments to parse // gotcha someday: - AssertParse("/0/dat.parquet", "alpha"_ == int32_t(0) and "beta"_ == "dat.parquet"); + AssertParse("/0/dat.parquet", and_(equal(field_ref("alpha"), literal(0)), + equal(field_ref("beta"), literal("dat.parquet")))); - AssertParse("/0/foo/ignored=2341", "alpha"_ == int32_t(0) and "beta"_ == "foo"); + AssertParse("/0/foo/ignored=2341", and_(equal(field_ref("alpha"), literal(0)), + equal(field_ref("beta"), literal("foo")))); } TEST_F(TestPartitioning, DirectoryPartitioningFormat) { partitioning_ = std::make_shared( schema({field("alpha", int32()), field("beta", utf8())})); - AssertFormat("alpha"_ == int32_t(0) and "beta"_ == "hello", "0/hello"); - AssertFormat("beta"_ == "hello" and "alpha"_ == int32_t(0), "0/hello"); - AssertFormat("alpha"_ == int32_t(0), "0"); - AssertFormatError("beta"_ == "hello"); - AssertFormat(scalar(true), ""); + written_schema_ = partitioning_->schema(); - AssertFormatError("alpha"_ == 0.0 and "beta"_ == "hello"); - AssertFormat("gamma"_ == "yo" and "alpha"_ == int32_t(0) and "beta"_ == "hello", + AssertFormat(and_(equal(field_ref("alpha"), literal(0)), + equal(field_ref("beta"), literal("hello"))), + "0/hello"); + AssertFormat(and_(equal(field_ref("beta"), literal("hello")), + equal(field_ref("alpha"), literal(0))), + "0/hello"); + AssertFormat(equal(field_ref("alpha"), literal(0)), "0"); + AssertFormatError(equal(field_ref("beta"), literal("hello"))); + AssertFormat(literal(true), ""); + + ASSERT_OK_AND_ASSIGN(written_schema_, + written_schema_->AddField(0, field("gamma", utf8()))); + AssertFormat(and_({equal(field_ref("gamma"), literal("yo")), + equal(field_ref("alpha"), literal(0)), + equal(field_ref("beta"), literal("hello"))}), "0/hello"); + + // written_schema_ is incompatible with partitioning_'s schema + written_schema_ = schema({field("alpha", utf8()), field("beta", utf8())}); + AssertFormatError( + and_(equal(field_ref("alpha"), literal("0.0")), + equal(field_ref("beta"), literal("hello")))); } TEST_F(TestPartitioning, DirectoryPartitioningWithTemporal) { @@ -140,7 +159,9 @@ TEST_F(TestPartitioning, DirectoryPartitioningWithTemporal) { ASSERT_OK_AND_ASSIGN(auto day, StringScalar("2020-06-08").CastTo(temporal)); AssertParse("/2020/06/2020-06-08", - "year"_ == int32_t(2020) and "month"_ == int8_t(6) and "day"_ == day); + and_({equal(field_ref("year"), literal(2020)), + equal(field_ref("month"), literal(6)), + equal(field_ref("day"), literal(day))})); } } @@ -200,7 +221,7 @@ TEST_F(TestPartitioning, DictionaryHasUniqueValues) { std::make_shared(index_and_dictionary, alpha->type()); auto path = "/" + expected_dictionary->GetString(i); - AssertParse(path, "alpha"_ == dictionary_scalar); + AssertParse(path, equal(field_ref("alpha"), literal(dictionary_scalar))); } AssertParseError("/yosemite"); // not in inspected dictionary @@ -216,19 +237,23 @@ TEST_F(TestPartitioning, HivePartitioning) { partitioning_ = std::make_shared( schema({field("alpha", int32()), field("beta", float32())})); - AssertParse("/alpha=0/beta=3.25", "alpha"_ == int32_t(0) and "beta"_ == 3.25f); - AssertParse("/beta=3.25/alpha=0", "beta"_ == 3.25f and "alpha"_ == int32_t(0)); - AssertParse("/alpha=0", "alpha"_ == int32_t(0)); - AssertParse("/beta=3.25", "beta"_ == 3.25f); - AssertParse("", scalar(true)); + AssertParse("/alpha=0/beta=3.25", and_(equal(field_ref("alpha"), literal(0)), + equal(field_ref("beta"), literal(3.25f)))); + AssertParse("/beta=3.25/alpha=0", and_(equal(field_ref("beta"), literal(3.25f)), + equal(field_ref("alpha"), literal(0)))); + AssertParse("/alpha=0", equal(field_ref("alpha"), literal(0))); + AssertParse("/beta=3.25", equal(field_ref("beta"), literal(3.25f))); + AssertParse("", literal(true)); AssertParse("/alpha=0/unexpected/beta=3.25", - "alpha"_ == int32_t(0) and "beta"_ == 3.25f); + and_(equal(field_ref("alpha"), literal(0)), + equal(field_ref("beta"), literal(3.25f)))); AssertParse("/alpha=0/beta=3.25/ignored=2341", - "alpha"_ == int32_t(0) and "beta"_ == 3.25f); + and_(equal(field_ref("alpha"), literal(0)), + equal(field_ref("beta"), literal(3.25f)))); - AssertParse("/ignored=2341", scalar(true)); + AssertParse("/ignored=2341", literal(true)); AssertParseError("/alpha=0.0/beta=3.25"); // conversion of "0.0" to int32 fails } @@ -237,15 +262,30 @@ TEST_F(TestPartitioning, HivePartitioningFormat) { partitioning_ = std::make_shared( schema({field("alpha", int32()), field("beta", float32())})); - AssertFormat("alpha"_ == int32_t(0) and "beta"_ == 3.25f, "alpha=0/beta=3.25"); - AssertFormat("beta"_ == 3.25f and "alpha"_ == int32_t(0), "alpha=0/beta=3.25"); - AssertFormat("alpha"_ == int32_t(0), "alpha=0"); - AssertFormat("beta"_ == 3.25f, "alpha/beta=3.25"); - AssertFormat(scalar(true), ""); + written_schema_ = partitioning_->schema(); - AssertFormatError("alpha"_ == "yo" and "beta"_ == 3.25f); - AssertFormat("gamma"_ == "yo" and "alpha"_ == int32_t(0) and "beta"_ == 3.25f, + AssertFormat(and_(equal(field_ref("alpha"), literal(0)), + equal(field_ref("beta"), literal(3.25f))), + "alpha=0/beta=3.25"); + AssertFormat(and_(equal(field_ref("beta"), literal(3.25f)), + equal(field_ref("alpha"), literal(0))), + "alpha=0/beta=3.25"); + AssertFormat(equal(field_ref("alpha"), literal(0)), "alpha=0"); + AssertFormat(equal(field_ref("beta"), literal(3.25f)), "alpha/beta=3.25"); + AssertFormat(literal(true), ""); + + ASSERT_OK_AND_ASSIGN(written_schema_, + written_schema_->AddField(0, field("gamma", utf8()))); + AssertFormat(and_({equal(field_ref("gamma"), literal("yo")), + equal(field_ref("alpha"), literal(0)), + equal(field_ref("beta"), literal(3.25f))}), "alpha=0/beta=3.25"); + + // written_schema_ is incompatible with partitioning_'s schema + written_schema_ = schema({field("alpha", utf8()), field("beta", utf8())}); + AssertFormatError( + and_(equal(field_ref("alpha"), literal("0.0")), + equal(field_ref("beta"), literal("hello")))); } TEST_F(TestPartitioning, DiscoverHiveSchema) { @@ -312,7 +352,7 @@ TEST_F(TestPartitioning, HiveDictionaryHasUniqueValues) { std::make_shared(index_and_dictionary, alpha->type()); auto path = "/alpha=" + expected_dictionary->GetString(i); - AssertParse(path, "alpha"_ == dictionary_scalar); + AssertParse(path, equal(field_ref("alpha"), literal(dictionary_scalar))); } AssertParseError("/alpha=yosemite"); // not in inspected dictionary @@ -326,10 +366,12 @@ TEST_F(TestPartitioning, EtlThenHive) { FieldVector alphabeta_fields{field("alpha", int32()), field("beta", float32())}; HivePartitioning alphabeta_part(schema(alphabeta_fields)); - partitioning_ = std::make_shared( + auto schm = schema({field("year", int16()), field("month", int8()), field("day", int8()), - field("hour", int8()), field("alpha", int32()), field("beta", float32())}), - [&](const std::string& path) -> Result> { + field("hour", int8()), field("alpha", int32()), field("beta", float32())}); + + partitioning_ = std::make_shared( + schm, [&](const std::string& path) -> Result { auto segments = fs::internal::SplitAbstractPath(path); if (segments.size() < etl_fields.size() + alphabeta_fields.size()) { return Status::Invalid("path ", path, " can't be parsed"); @@ -349,12 +391,15 @@ TEST_F(TestPartitioning, EtlThenHive) { }); AssertParse("/1999/12/31/00/alpha=0/beta=3.25", - "year"_ == int16_t(1999) and "month"_ == int8_t(12) and - "day"_ == int8_t(31) and "hour"_ == int8_t(0) and - ("alpha"_ == int32_t(0) and "beta"_ == 3.25f)); + and_({equal(field_ref("year"), literal(1999)), + equal(field_ref("month"), literal(12)), + equal(field_ref("day"), literal(31)), + equal(field_ref("hour"), literal(0)), + and_(equal(field_ref("alpha"), literal(0)), + equal(field_ref("beta"), literal(3.25f)))})); AssertParseError("/20X6/03/21/05/alpha=0/beta=3.25"); -} // namespace dataset +} TEST_F(TestPartitioning, Set) { auto ints = [](std::vector ints) { @@ -363,12 +408,13 @@ TEST_F(TestPartitioning, Set) { return out; }; + auto schm = schema({field("x", int32())}); + // An adhoc partitioning which parses segments like "/x in [1 4 5]" - // into ("x"_ == 1 or "x"_ == 4 or "x"_ == 5) + // into (field_ref("x") == 1 or field_ref("x") == 4 or field_ref("x") == 5) partitioning_ = std::make_shared( - schema({field("x", int32())}), - [&](const std::string& path) -> Result> { - ExpressionVector subexpressions; + schm, [&](const std::string& path) -> Result { + std::vector subexpressions; for (auto segment : fs::internal::SplitAbstractPath(path)) { std::smatch matches; @@ -384,26 +430,30 @@ TEST_F(TestPartitioning, Set) { set.push_back(checked_cast(*s).value); } - subexpressions.push_back(field_ref(matches[1])->In(ints(set)).Copy()); + subexpressions.push_back(call("is_in", {field_ref(std::string(matches[1]))}, + compute::SetLookupOptions{ints(set)})); } return and_(std::move(subexpressions)); }); - AssertParse("/x in [1]", "x"_.In(ints({1}))); - AssertParse("/x in [1 4 5]", "x"_.In(ints({1, 4, 5}))); - AssertParse("/x in []", "x"_.In(ints({}))); + auto x_in = [&](std::vector set) { + return call("is_in", {field_ref("x")}, compute::SetLookupOptions{ints(set)}); + }; + AssertParse("/x in [1]", x_in({1})); + AssertParse("/x in [1 4 5]", x_in({1, 4, 5})); + AssertParse("/x in []", x_in({})); } // An adhoc partitioning which parses segments like "/x=[-3.25, 0.0)" -// into ("x"_ >= -3.25 and "x" < 0.0) +// into (field_ref("x") >= -3.25 and "x" < 0.0) class RangePartitioning : public Partitioning { public: explicit RangePartitioning(std::shared_ptr s) : Partitioning(std::move(s)) {} std::string type_name() const override { return "range"; } - Result> Parse(const std::string& path) const override { - ExpressionVector ranges; + Result Parse(const std::string& path) const override { + std::vector ranges; for (auto segment : fs::internal::SplitAbstractPath(path)) { auto key = HivePartitioning::ParseKey(segment); @@ -423,9 +473,10 @@ class RangePartitioning : public Partitioning { ARROW_ASSIGN_OR_RAISE(auto min, Scalar::Parse(type, min_repr)); ARROW_ASSIGN_OR_RAISE(auto max, Scalar::Parse(type, max_repr)); - ranges.push_back(and_(min_cmp(field_ref(key->name), scalar(min)), - max_cmp(field_ref(key->name), scalar(max)))); + ranges.push_back(and_(min_cmp(field_ref(key->name), literal(min)), + max_cmp(field_ref(key->name), literal(max)))); } + return and_(ranges); } @@ -458,8 +509,12 @@ TEST_F(TestPartitioning, Range) { schema({field("x", float64()), field("y", float64()), field("z", float64())})); AssertParse("/x=[-1.5 0.0)/y=[0.0 1.5)/z=(1.5 3.0]", - ("x"_ >= -1.5 and "x"_ < 0.0) and ("y"_ >= 0.0 and "y"_ < 1.5) and - ("z"_ > 1.5 and "z"_ <= 3.0)); + and_({and_(greater_equal(field_ref("x"), literal(-1.5)), + less(field_ref("x"), literal(0.0))), + and_(greater_equal(field_ref("y"), literal(0.0)), + less(field_ref("y"), literal(1.5))), + and_(greater(field_ref("z"), literal(1.5)), + less_equal(field_ref("z"), literal(3.0)))})); } TEST(TestStripPrefixAndFilename, Basic) { @@ -478,5 +533,57 @@ TEST(TestStripPrefixAndFilename, Basic) { "year=2019/month=12/day=01")); } +void AssertGrouping(const FieldVector& by_fields, const std::string& batch_json, + const std::string& expected_json) { + FieldVector fields_with_ids = by_fields; + fields_with_ids.push_back(field("ids", list(int32()))); + auto expected = ArrayFromJSON(struct_(fields_with_ids), expected_json); + + FieldVector fields_with_id = by_fields; + fields_with_id.push_back(field("id", int32())); + auto batch = RecordBatchFromJSON(schema(fields_with_id), batch_json); + + ASSERT_OK_AND_ASSIGN(auto by, batch->RemoveColumn(batch->num_columns() - 1) + .Map([](std::shared_ptr by) { + return by->ToStructArray(); + })); + + ASSERT_OK_AND_ASSIGN(auto groupings_and_values, MakeGroupings(*by)); + + auto groupings = + checked_pointer_cast(groupings_and_values->GetFieldByName("groupings")); + + ASSERT_OK_AND_ASSIGN(std::shared_ptr grouped_ids, + ApplyGroupings(*groupings, *batch->GetColumnByName("id"))); + + ArrayVector columns = + checked_cast(*groupings_and_values->GetFieldByName("values")) + .fields(); + columns.push_back(grouped_ids); + + ASSERT_OK_AND_ASSIGN(auto actual, StructArray::Make(columns, fields_with_ids)); + + AssertArraysEqual(*expected, *actual, /*verbose=*/true); +} + +TEST(GroupTest, Basics) { + AssertGrouping({field("a", utf8()), field("b", int32())}, R"([ + {"a": "ex", "b": 0, "id": 0}, + {"a": "ex", "b": 0, "id": 1}, + {"a": "why", "b": 0, "id": 2}, + {"a": "ex", "b": 1, "id": 3}, + {"a": "why", "b": 0, "id": 4}, + {"a": "ex", "b": 1, "id": 5}, + {"a": "ex", "b": 0, "id": 6}, + {"a": "why", "b": 1, "id": 7} + ])", + R"([ + {"a": "ex", "b": 0, "ids": [0, 1, 6]}, + {"a": "why", "b": 0, "ids": [2, 4]}, + {"a": "ex", "b": 1, "ids": [3, 5]}, + {"a": "why", "b": 1, "ids": [7]} + ])"); +} + } // namespace dataset } // namespace arrow diff --git a/cpp/src/arrow/dataset/scanner.cc b/cpp/src/arrow/dataset/scanner.cc index 019416041aa..0c501c9f5b3 100644 --- a/cpp/src/arrow/dataset/scanner.cc +++ b/cpp/src/arrow/dataset/scanner.cc @@ -23,10 +23,10 @@ #include "arrow/dataset/dataset.h" #include "arrow/dataset/dataset_internal.h" -#include "arrow/dataset/filter.h" #include "arrow/dataset/scanner_internal.h" #include "arrow/table.h" #include "arrow/util/iterator.h" +#include "arrow/util/logging.h" #include "arrow/util/task_group.h" #include "arrow/util/thread_pool.h" @@ -34,14 +34,12 @@ namespace arrow { namespace dataset { ScanOptions::ScanOptions(std::shared_ptr schema) - : evaluator(ExpressionEvaluator::Null()), - projector(RecordBatchProjector(std::move(schema))) {} + : projector(RecordBatchProjector(std::move(schema))) {} std::shared_ptr ScanOptions::ReplaceSchema( std::shared_ptr schema) const { auto copy = ScanOptions::Make(std::move(schema)); - copy->filter = filter; - copy->evaluator = evaluator; + copy->filter2 = filter2; copy->batch_size = batch_size; return copy; } @@ -53,8 +51,9 @@ std::vector ScanOptions::MaterializedFields() const { fields.push_back(f->name()); } - for (auto&& name : FieldsInExpression(filter)) { - fields.push_back(std::move(name)); + for (const FieldRef& ref : FieldsInExpression(filter2)) { + DCHECK(ref.name()); + fields.push_back(*ref.name()); } return fields; @@ -64,7 +63,7 @@ Result InMemoryScanTask::Execute() { return MakeVectorIterator(record_batches_); } -FragmentIterator Scanner::GetFragments() { +Result Scanner::GetFragments() { if (fragment_ != nullptr) { return MakeVectorIterator(FragmentVector{fragment_}); } @@ -72,14 +71,15 @@ FragmentIterator Scanner::GetFragments() { // Transform Datasets in a flat Iterator. This // iterator is lazily constructed, i.e. Dataset::GetFragments is // not invoked until a Fragment is requested. - return GetFragmentsFromDatasets({dataset_}, scan_options_->filter); + return GetFragmentsFromDatasets({dataset_}, scan_options_->filter2); } Result Scanner::Scan() { // Transforms Iterator into a unified // Iterator. The first Iterator::Next invocation is going to do // all the work of unwinding the chained iterators. - return GetScanTaskIterator(GetFragments(), scan_options_, scan_context_); + ARROW_ASSIGN_OR_RAISE(auto fragment_it, GetFragments()); + return GetScanTaskIterator(std::move(fragment_it), scan_options_, scan_context_); } Result ScanTaskIteratorFromRecordBatch( @@ -95,15 +95,24 @@ ScannerBuilder::ScannerBuilder(std::shared_ptr dataset, : dataset_(std::move(dataset)), fragment_(nullptr), scan_options_(ScanOptions::Make(dataset_->schema())), - scan_context_(std::move(scan_context)) {} + scan_context_(std::move(scan_context)) { + DCHECK_OK(Filter(literal(true))); +} ScannerBuilder::ScannerBuilder(std::shared_ptr schema, std::shared_ptr fragment, std::shared_ptr scan_context) : dataset_(nullptr), fragment_(std::move(fragment)), - scan_options_(ScanOptions::Make(schema)), - scan_context_(std::move(scan_context)) {} + fragment_schema_(schema), + scan_options_(ScanOptions::Make(std::move(schema))), + scan_context_(std::move(scan_context)) { + DCHECK_OK(Filter(literal(true))); +} + +const std::shared_ptr& ScannerBuilder::schema() const { + return fragment_ ? fragment_schema_ : dataset_->schema(); +} Status ScannerBuilder::Project(std::vector columns) { RETURN_NOT_OK(schema()->CanReferenceFieldsByNames(columns)); @@ -112,15 +121,14 @@ Status ScannerBuilder::Project(std::vector columns) { return Status::OK(); } -Status ScannerBuilder::Filter(std::shared_ptr filter) { - RETURN_NOT_OK(schema()->CanReferenceFieldsByNames(FieldsInExpression(*filter))); - RETURN_NOT_OK(filter->Validate(*schema())); - scan_options_->filter = std::move(filter); +Status ScannerBuilder::Filter(const Expression& filter) { + for (const auto& ref : FieldsInExpression(filter)) { + RETURN_NOT_OK(ref.FindOne(*schema())); + } + ARROW_ASSIGN_OR_RAISE(scan_options_->filter2, filter.Bind(*schema())); return Status::OK(); } -Status ScannerBuilder::Filter(const Expression& filter) { return Filter(filter.Copy()); } - Status ScannerBuilder::UseThreads(bool use_threads) { scan_context_->use_threads = use_threads; return Status::OK(); @@ -143,10 +151,6 @@ Result> ScannerBuilder::Finish() const { scan_options = std::make_shared(*scan_options_); } - if (!scan_options->filter->Equals(true)) { - scan_options->evaluator = std::make_shared(); - } - if (dataset_ == nullptr) { return std::make_shared(fragment_, std::move(scan_options), scan_context_); } diff --git a/cpp/src/arrow/dataset/scanner.h b/cpp/src/arrow/dataset/scanner.h index 950d416f615..5902f759ec3 100644 --- a/cpp/src/arrow/dataset/scanner.h +++ b/cpp/src/arrow/dataset/scanner.h @@ -25,6 +25,7 @@ #include #include "arrow/dataset/dataset.h" +#include "arrow/dataset/expression.h" #include "arrow/dataset/projector.h" #include "arrow/dataset/type_fwd.h" #include "arrow/dataset/visibility.h" @@ -62,10 +63,7 @@ class ARROW_DS_EXPORT ScanOptions { std::shared_ptr ReplaceSchema(std::shared_ptr schema) const; // Filter - std::shared_ptr filter = scalar(true); - - // Evaluator for Filter - std::shared_ptr evaluator; + Expression filter2 = literal(true); // Schema to which record batches will be reconciled const std::shared_ptr& schema() const { return projector.schema(); } @@ -172,7 +170,7 @@ class ARROW_DS_EXPORT Scanner { Result> ToTable(); /// \brief GetFragments returns an iterator over all Fragments in this scan. - FragmentIterator GetFragments(); + Result GetFragments(); const std::shared_ptr& schema() const { return scan_options_->schema(); } @@ -222,7 +220,6 @@ class ARROW_DS_EXPORT ScannerBuilder { /// /// \return Failure if any referenced columns does not exist in the dataset's /// Schema. - Status Filter(std::shared_ptr filter); Status Filter(const Expression& filter); /// \brief Indicate if the Scanner should make use of the available @@ -240,11 +237,12 @@ class ARROW_DS_EXPORT ScannerBuilder { /// \brief Return the constructed now-immutable Scanner object Result> Finish() const; - std::shared_ptr schema() const { return scan_options_->schema(); } + const std::shared_ptr& schema() const; private: std::shared_ptr dataset_; std::shared_ptr fragment_; + std::shared_ptr fragment_schema_; std::shared_ptr scan_options_; std::shared_ptr scan_context_; bool has_projection_ = false; diff --git a/cpp/src/arrow/dataset/scanner_internal.h b/cpp/src/arrow/dataset/scanner_internal.h index 94df94470fa..cd8fffd2a71 100644 --- a/cpp/src/arrow/dataset/scanner_internal.h +++ b/cpp/src/arrow/dataset/scanner_internal.h @@ -20,22 +20,36 @@ #include #include +#include "arrow/array/array_nested.h" +#include "arrow/compute/api_vector.h" +#include "arrow/compute/exec.h" #include "arrow/dataset/dataset_internal.h" -#include "arrow/dataset/filter.h" #include "arrow/dataset/partition.h" #include "arrow/dataset/scanner.h" namespace arrow { namespace dataset { -inline RecordBatchIterator FilterRecordBatch(RecordBatchIterator it, - const ExpressionEvaluator& evaluator, - const Expression& filter, MemoryPool* pool) { +inline RecordBatchIterator FilterRecordBatch(RecordBatchIterator it, Expression filter, + MemoryPool* pool) { return MakeMaybeMapIterator( - [&filter, &evaluator, pool](std::shared_ptr in) { - return evaluator.Evaluate(filter, *in, pool).Map([&](Datum selection) { - return evaluator.Filter(selection, in); - }); + [=](std::shared_ptr in) -> Result> { + compute::ExecContext exec_context{pool}; + ARROW_ASSIGN_OR_RAISE(Datum mask, + ExecuteScalarExpression(filter, Datum(in), &exec_context)); + + if (mask.is_scalar()) { + const auto& mask_scalar = mask.scalar_as(); + if (mask_scalar.is_valid && mask_scalar.value) { + return std::move(in); + } + return in->Slice(0, 0); + } + + ARROW_ASSIGN_OR_RAISE( + Datum filtered, + compute::Filter(in, mask, compute::FilterOptions::Defaults(), &exec_context)); + return filtered.record_batch(); }, std::move(it)); } @@ -56,39 +70,40 @@ inline RecordBatchIterator ProjectRecordBatch(RecordBatchIterator it, class FilterAndProjectScanTask : public ScanTask { public: - explicit FilterAndProjectScanTask(std::shared_ptr task, - std::shared_ptr partition) + explicit FilterAndProjectScanTask(std::shared_ptr task, Expression partition) : ScanTask(task->options(), task->context()), task_(std::move(task)), partition_(std::move(partition)), - filter_(options()->filter->Assume(partition_)), + filter_(options()->filter2), projector_(options()->projector) {} Result Execute() override { ARROW_ASSIGN_OR_RAISE(auto it, task_->Execute()); - auto filter_it = - FilterRecordBatch(std::move(it), *options_->evaluator, *filter_, context_->pool); + ARROW_ASSIGN_OR_RAISE(Expression simplified_filter, + SimplifyWithGuarantee(filter_, partition_)); + + RecordBatchIterator filter_it = + FilterRecordBatch(std::move(it), simplified_filter, context_->pool); + + RETURN_NOT_OK( + KeyValuePartitioning::SetDefaultValuesFromKeys(partition_, &projector_)); - if (partition_) { - RETURN_NOT_OK( - KeyValuePartitioning::SetDefaultValuesFromKeys(*partition_, &projector_)); - } return ProjectRecordBatch(std::move(filter_it), &projector_, context_->pool); } private: std::shared_ptr task_; - std::shared_ptr partition_; - std::shared_ptr filter_; + Expression partition_; + Expression filter_; RecordBatchProjector projector_; }; /// \brief GetScanTaskIterator transforms an Iterator in a /// flattened Iterator. -inline ScanTaskIterator GetScanTaskIterator(FragmentIterator fragments, - std::shared_ptr options, - std::shared_ptr context) { +inline Result GetScanTaskIterator( + FragmentIterator fragments, std::shared_ptr options, + std::shared_ptr context) { // Fragment -> ScanTaskIterator auto fn = [options, context](std::shared_ptr fragment) -> Result { @@ -112,42 +127,5 @@ inline ScanTaskIterator GetScanTaskIterator(FragmentIterator fragments, return MakeFlattenIterator(std::move(maybe_scantask_it)); } -struct FragmentRecordBatchReader : RecordBatchReader { - public: - std::shared_ptr schema() const override { return options_->schema(); } - - Status ReadNext(std::shared_ptr* batch) override { - return iterator_.Next().Value(batch); - } - - static Result> Make( - std::shared_ptr fragment, std::shared_ptr schema, - std::shared_ptr context) { - // ensure schema is cached in fragment - auto options = ScanOptions::Make(std::move(schema)); - RETURN_NOT_OK(KeyValuePartitioning::SetDefaultValuesFromKeys( - *fragment->partition_expression(), &options->projector)); - - auto pool = context->pool; - ARROW_ASSIGN_OR_RAISE(auto scan_tasks, fragment->Scan(options, std::move(context))); - - auto reader = std::make_shared(); - reader->options_ = std::move(options); - reader->fragment_ = std::move(fragment); - reader->iterator_ = ProjectRecordBatch( - MakeFlattenIterator(MakeMaybeMapIterator( - [](std::shared_ptr task) { return task->Execute(); }, - std::move(scan_tasks))), - &reader->options_->projector, pool); - - return reader; - } - - private: - std::shared_ptr options_; - std::shared_ptr fragment_; - RecordBatchIterator iterator_; -}; - } // namespace dataset } // namespace arrow diff --git a/cpp/src/arrow/dataset/scanner_test.cc b/cpp/src/arrow/dataset/scanner_test.cc index 36966760346..f8f959e3a28 100644 --- a/cpp/src/arrow/dataset/scanner_test.cc +++ b/cpp/src/arrow/dataset/scanner_test.cc @@ -55,7 +55,7 @@ class TestScanner : public DatasetFixtureMixin { // structures of the scanner, i.e. Scanner[Dataset[ScanTask[RecordBatch]]] AssertScannerEquals(expected.get(), &scanner); } -}; // namespace dataset +}; constexpr int64_t TestScanner::kNumberChildDatasets; constexpr int64_t TestScanner::kNumberBatches; @@ -88,8 +88,7 @@ TEST_F(TestScanner, FilteredScan) { value += 1.0; })); - options_->filter = ("f64"_ > 0.0).Copy(); - options_->evaluator = std::make_shared(); + SetFilter(greater(field_ref("f64"), literal(0.0))); auto batch = RecordBatch::Make(schema_, f64->length(), {f64}); @@ -148,7 +147,7 @@ TEST_F(TestScanner, ToTable) { } class TestScannerBuilder : public ::testing::Test { - void SetUp() { + void SetUp() override { DatasetVector sources; schema_ = schema({ @@ -163,7 +162,7 @@ class TestScannerBuilder : public ::testing::Test { } protected: - std::shared_ptr ctx_; + std::shared_ptr ctx_ = std::make_shared(); std::shared_ptr schema_; std::shared_ptr dataset_; }; @@ -184,14 +183,18 @@ TEST_F(TestScannerBuilder, TestProject) { TEST_F(TestScannerBuilder, TestFilter) { ScannerBuilder builder(dataset_, ctx_); - ASSERT_OK(builder.Filter(scalar(true))); - ASSERT_OK(builder.Filter("i64"_ == int64_t(10))); - ASSERT_OK(builder.Filter("i64"_ == int64_t(10) || "b"_ == true)); + ASSERT_OK(builder.Filter(literal(true))); + ASSERT_OK(builder.Filter(equal(field_ref("i64"), literal(10)))); + ASSERT_OK(builder.Filter(or_(equal(field_ref("i64"), literal(10)), + equal(field_ref("b"), literal(true))))); + + ASSERT_OK(builder.Filter(equal(field_ref("i64"), literal(10)))); + + ASSERT_RAISES(Invalid, builder.Filter(equal(field_ref("not_a_column"), literal(true)))); - ASSERT_RAISES(TypeError, builder.Filter("i64"_ == int32_t(10))); - ASSERT_RAISES(Invalid, builder.Filter("not_a_column"_ == true)); ASSERT_RAISES(Invalid, - builder.Filter("i64"_ == int64_t(10) || "not_a_column"_ == true)); + builder.Filter(or_(equal(field_ref("i64"), literal(10)), + equal(field_ref("not_a_column"), literal(true))))); } using testing::ElementsAre; @@ -204,7 +207,7 @@ TEST(ScanOptions, TestMaterializedFields) { auto opts = ScanOptions::Make(schema({})); EXPECT_THAT(opts->MaterializedFields(), IsEmpty()); - opts->filter = ("i32"_ == 10).Copy(); + opts->filter2 = equal(field_ref("i32"), literal(10)); EXPECT_THAT(opts->MaterializedFields(), ElementsAre("i32")); opts = ScanOptions::Make(schema({i32, i64})); @@ -213,10 +216,10 @@ TEST(ScanOptions, TestMaterializedFields) { opts = opts->ReplaceSchema(schema({i32})); EXPECT_THAT(opts->MaterializedFields(), ElementsAre("i32")); - opts->filter = ("i32"_ == 10).Copy(); + opts->filter2 = equal(field_ref("i32"), literal(10)); EXPECT_THAT(opts->MaterializedFields(), ElementsAre("i32", "i32")); - opts->filter = ("i64"_ == 10).Copy(); + opts->filter2 = equal(field_ref("i64"), literal(10)); EXPECT_THAT(opts->MaterializedFields(), ElementsAre("i32", "i64")); } diff --git a/cpp/src/arrow/dataset/test_util.h b/cpp/src/arrow/dataset/test_util.h index f504305a996..1c7c471d3ca 100644 --- a/cpp/src/arrow/dataset/test_util.h +++ b/cpp/src/arrow/dataset/test_util.h @@ -22,6 +22,7 @@ #include #include #include +#include #include #include #include @@ -32,7 +33,6 @@ #include "arrow/dataset/dataset_internal.h" #include "arrow/dataset/discovery.h" #include "arrow/dataset/file_base.h" -#include "arrow/dataset/filter.h" #include "arrow/filesystem/localfs.h" #include "arrow/filesystem/mockfs.h" #include "arrow/filesystem/path_util.h" @@ -49,6 +49,20 @@ namespace arrow { namespace dataset { +const std::shared_ptr kBoringSchema = schema({ + field("i32", int32()), + field("i32_req", int32(), /*nullable=*/false), + field("i64", int64()), + field("date64", date64()), + field("f32", float32()), + field("f32_req", float32(), /*nullable=*/false), + field("f64", float64()), + field("bool", boolean()), + field("str", utf8()), + field("dict_str", dictionary(int32(), utf8())), + field("ts_ns", timestamp(TimeUnit::NANO)), +}); + using fs::internal::GetAbstractPathExtension; using internal::checked_cast; using internal::checked_pointer_cast; @@ -65,7 +79,7 @@ template class GeneratedRecordBatch : public RecordBatchReader { public: GeneratedRecordBatch(std::shared_ptr schema, Gen gen) - : schema_(schema), gen_(gen) {} + : schema_(std::move(schema)), gen_(gen) {} std::shared_ptr schema() const override { return schema_; } @@ -100,8 +114,6 @@ void EnsureRecordBatchReaderDrained(RecordBatchReader* reader) { class DatasetFixtureMixin : public ::testing::Test { public: - DatasetFixtureMixin() : ctx_(std::make_shared()) {} - /// \brief Ensure that record batches found in reader are equals to the /// record batches yielded by the data fragment. void AssertScanTaskEquals(RecordBatchReader* expected, ScanTask* task, @@ -140,7 +152,8 @@ class DatasetFixtureMixin : public ::testing::Test { /// record batches yielded by the data fragments of a dataset. void AssertDatasetFragmentsEqual(RecordBatchReader* expected, Dataset* dataset, bool ensure_drained = true) { - auto it = dataset->GetFragments(options_->filter); + ASSERT_OK_AND_ASSIGN(auto predicate, options_->filter2.Bind(*dataset->schema())); + ASSERT_OK_AND_ASSIGN(auto it, dataset->GetFragments(predicate)); ARROW_EXPECT_OK(it.Visit([&](std::shared_ptr fragment) -> Status { AssertFragmentEquals(expected, fragment.get(), false); @@ -185,11 +198,16 @@ class DatasetFixtureMixin : public ::testing::Test { void SetSchema(std::vector> fields) { schema_ = schema(std::move(fields)); options_ = ScanOptions::Make(schema_); + SetFilter(literal(true)); + } + + void SetFilter(Expression filter) { + ASSERT_OK_AND_ASSIGN(options_->filter2, filter.Bind(*schema_)); } std::shared_ptr schema_; std::shared_ptr options_; - std::shared_ptr ctx_; + std::shared_ptr ctx_ = std::make_shared(); }; /// \brief A dummy FileFormat implementation @@ -320,15 +338,14 @@ struct MakeFileSystemDatasetMixin { } void MakeDataset(const std::vector& infos, - std::shared_ptr root_partition = scalar(true), - ExpressionVector partitions = {}) { + Expression root_partition = literal(true), + std::vector partitions = {}, + std::shared_ptr s = kBoringSchema) { auto n_fragments = infos.size(); if (partitions.empty()) { - partitions.resize(n_fragments, scalar(true)); + partitions.resize(n_fragments, literal(true)); } - auto s = schema({}); - MakeFileSystem(infos); auto format = std::make_shared(s); @@ -339,21 +356,17 @@ struct MakeFileSystemDatasetMixin { continue; } + ASSERT_OK_AND_ASSIGN(partitions[i], partitions[i].Bind(*s)); ASSERT_OK_AND_ASSIGN(auto fragment, format->MakeFragment({info, fs_}, partitions[i])); fragments.push_back(std::move(fragment)); } + ASSERT_OK_AND_ASSIGN(root_partition, root_partition.Bind(*s)); ASSERT_OK_AND_ASSIGN(dataset_, FileSystemDataset::Make(s, root_partition, format, fs_, std::move(fragments))); } - void MakeDatasetFromPathlist(const std::string& pathlist, - std::shared_ptr root_partition = scalar(true), - ExpressionVector partitions = {}) { - MakeDataset(ParsePathList(pathlist), root_partition, partitions); - } - std::shared_ptr fs_; std::shared_ptr dataset_; std::shared_ptr options_ = ScanOptions::Make(schema({})); @@ -386,49 +399,24 @@ void AssertFragmentsAreFromPath(FragmentIterator it, std::vector ex testing::UnorderedElementsAreArray(expected)); } -// A frozen shared_ptr with behavior expected by GTest -struct TestExpression : util::EqualityComparable, - util::ToStringOstreamable { - // NOLINTNEXTLINE runtime/explicit - TestExpression(std::shared_ptr e) : expression(std::move(e)) {} - - // NOLINTNEXTLINE runtime/explicit - TestExpression(const Expression& e) : expression(e.Copy()) {} - - std::shared_ptr expression; - - using util::EqualityComparable::operator==; - bool Equals(const TestExpression& other) const { - return expression->Equals(other.expression); - } - - std::string ToString() const { return expression->ToString(); } - - friend bool operator==(const std::shared_ptr& lhs, - const TestExpression& rhs) { - return TestExpression(lhs) == rhs; - } - - friend void PrintTo(const TestExpression& expr, std::ostream* os) { - *os << expr.ToString(); - } -}; - -static std::vector PartitionExpressionsOf( - const FragmentVector& fragments) { - std::vector partition_expressions; +static std::vector PartitionExpressionsOf(const FragmentVector& fragments) { + std::vector partition_expressions; std::transform(fragments.begin(), fragments.end(), std::back_inserter(partition_expressions), [](const std::shared_ptr& fragment) { - return TestExpression(fragment->partition_expression()); + return fragment->partition_expression(); }); return partition_expressions; } -void AssertFragmentsHavePartitionExpressions(FragmentIterator it, - ExpressionVector expected) { +void AssertFragmentsHavePartitionExpressions(std::shared_ptr dataset, + std::vector expected) { + ASSERT_OK_AND_ASSIGN(auto fragment_it, dataset->GetFragments()); + for (auto& expr : expected) { + ASSERT_OK_AND_ASSIGN(expr, expr.Bind(*dataset->schema())); + } // Ordering is not guaranteed. - EXPECT_THAT(PartitionExpressionsOf(IteratorToVector(std::move(it))), + EXPECT_THAT(PartitionExpressionsOf(IteratorToVector(std::move(fragment_it))), testing::UnorderedElementsAreArray(expected)); } @@ -455,7 +443,7 @@ struct ArithmeticDatasetFixture { /// 2}, static std::string JSONRecordFor(int64_t n) { std::stringstream ss; - int32_t n_i32 = static_cast(n); + auto n_i32 = static_cast(n); ss << "{"; ss << "\"i64\": " << n << ", "; @@ -740,7 +728,8 @@ class WriteFileSystemDatasetMixin : public MakeFileSystemDatasetMixin { } EXPECT_THAT(actual_paths, testing::UnorderedElementsAreArray(expected_paths)); - for (auto maybe_fragment : written_->GetFragments()) { + ASSERT_OK_AND_ASSIGN(auto written_fragments_it, written_->GetFragments()); + for (auto maybe_fragment : written_fragments_it) { ASSERT_OK_AND_ASSIGN(auto fragment, maybe_fragment); ASSERT_OK_AND_ASSIGN(auto actual_physical_schema, fragment->ReadPhysicalSchema()); diff --git a/cpp/src/arrow/dataset/type_fwd.h b/cpp/src/arrow/dataset/type_fwd.h index 0ff77de0102..66fed352d0f 100644 --- a/cpp/src/arrow/dataset/type_fwd.h +++ b/cpp/src/arrow/dataset/type_fwd.h @@ -69,12 +69,6 @@ class ParquetFileWriter; class ParquetFileWriteOptions; class Expression; -using ExpressionVector = std::vector>; -class ExpressionEvaluator; - -/// forward declared to facilitate scalar(true) as a default for Expression parameters -ARROW_DS_EXPORT -const std::shared_ptr& scalar(bool); class Partitioning; class PartitioningFactory; diff --git a/cpp/src/arrow/datum.cc b/cpp/src/arrow/datum.cc index 5feed556207..786110996dc 100644 --- a/cpp/src/arrow/datum.cc +++ b/cpp/src/arrow/datum.cc @@ -68,6 +68,9 @@ Datum::Datum(int64_t value) : value(std::make_shared(value)) {} Datum::Datum(uint64_t value) : value(std::make_shared(value)) {} Datum::Datum(float value) : value(std::make_shared(value)) {} Datum::Datum(double value) : value(std::make_shared(value)) {} +Datum::Datum(std::string value) + : value(std::make_shared(std::move(value))) {} +Datum::Datum(const char* value) : value(std::make_shared(value)) {} Datum::Datum(const ChunkedArray& value) : value(std::make_shared(value.chunks(), value.type())) {} @@ -86,12 +89,24 @@ std::shared_ptr Datum::make_array() const { std::shared_ptr Datum::type() const { if (this->kind() == Datum::ARRAY) { return util::get>(this->value)->type; - } else if (this->kind() == Datum::CHUNKED_ARRAY) { + } + if (this->kind() == Datum::CHUNKED_ARRAY) { return util::get>(this->value)->type(); - } else if (this->kind() == Datum::SCALAR) { + } + if (this->kind() == Datum::SCALAR) { return util::get>(this->value)->type; } - return NULLPTR; + return nullptr; +} + +std::shared_ptr Datum::schema() const { + if (this->kind() == Datum::RECORD_BATCH) { + return util::get>(this->value)->schema(); + } + if (this->kind() == Datum::TABLE) { + return util::get>(this->value)->schema(); + } + return nullptr; } int64_t Datum::length() const { @@ -196,6 +211,8 @@ static std::string FormatValueDescr(const ValueDescr& descr) { std::string ValueDescr::ToString() const { return FormatValueDescr(*this); } +void PrintTo(const ValueDescr& descr, std::ostream* os) { *os << descr.ToString(); } + std::string Datum::ToString() const { switch (this->kind()) { case Datum::NONE: @@ -238,4 +255,17 @@ ValueDescr::Shape GetBroadcastShape(const std::vector& args) { return ValueDescr::SCALAR; } +void PrintTo(const Datum& datum, std::ostream* os) { + switch (datum.kind()) { + case Datum::SCALAR: + *os << datum.scalar()->ToString(); + break; + case Datum::ARRAY: + *os << datum.make_array()->ToString(); + break; + default: + *os << datum.ToString(); + } +} + } // namespace arrow diff --git a/cpp/src/arrow/datum.h b/cpp/src/arrow/datum.h index 09dc870f687..fb783ea5261 100644 --- a/cpp/src/arrow/datum.h +++ b/cpp/src/arrow/datum.h @@ -81,12 +81,16 @@ struct ARROW_EXPORT ValueDescr { } bool operator==(const ValueDescr& other) const { - return this->shape == other.shape && this->type->Equals(*other.type); + if (shape != other.shape) return false; + if (type == other.type) return true; + return type && type->Equals(other.type); } bool operator!=(const ValueDescr& other) const { return !(*this == other); } std::string ToString() const; + + ARROW_EXPORT friend void PrintTo(const ValueDescr&, std::ostream*); }; /// \brief For use with scalar functions, returns the broadcasted Value::Shape @@ -114,6 +118,11 @@ struct ARROW_EXPORT Datum { /// \brief Empty datum, to be populated elsewhere Datum() = default; + Datum(const Datum& other) = default; + Datum& operator=(const Datum& other) = default; + Datum(Datum&& other) = default; + Datum& operator=(Datum&& other) = default; + Datum(std::shared_ptr value) // NOLINT implicit conversion : value(std::move(value)) {} @@ -153,21 +162,8 @@ struct ARROW_EXPORT Datum { explicit Datum(uint64_t value); explicit Datum(float value); explicit Datum(double value); - - Datum(const Datum& other) noexcept { this->value = other.value; } - - Datum& operator=(const Datum& other) noexcept { - value = other.value; - return *this; - } - - // Define move constructor and move assignment, for better performance - Datum(Datum&& other) noexcept : value(std::move(other.value)) {} - - Datum& operator=(Datum&& other) noexcept { - value = std::move(other.value); - return *this; - } + explicit Datum(std::string value); + explicit Datum(const char* value); Datum::Kind kind() const { switch (this->value.index()) { @@ -218,6 +214,11 @@ struct ARROW_EXPORT Datum { return util::get>(this->value); } + template + std::shared_ptr array_as() const { + return internal::checked_pointer_cast(this->make_array()); + } + template const ExactType& scalar_as() const { return internal::checked_cast(*this->scalar()); @@ -251,6 +252,11 @@ struct ARROW_EXPORT Datum { /// \return nullptr if no type std::shared_ptr type() const; + /// \brief The schema of the variant, if any + /// + /// \return nullptr if no schema + std::shared_ptr schema() const; + /// \brief The value length of the variant, if any /// /// \return kUnknownLength if no type @@ -267,6 +273,8 @@ struct ARROW_EXPORT Datum { bool operator!=(const Datum& other) const { return !Equals(other); } std::string ToString() const; + + ARROW_EXPORT friend void PrintTo(const Datum&, std::ostream*); }; } // namespace arrow diff --git a/cpp/src/arrow/result.h b/cpp/src/arrow/result.h index eac08919286..09dfd59c8d2 100644 --- a/cpp/src/arrow/result.h +++ b/cpp/src/arrow/result.h @@ -401,6 +401,28 @@ class ARROW_MUST_USE_TYPE Result : public util::EqualityComparable> { return std::forward(m)(ValueUnsafe()); } + /// Cast the internally stored value to produce a new result or propagate the stored + /// error. + template ::value>::type> + Result As() && { + if (!ok()) { + return status(); + } + return U(MoveValueUnsafe()); + } + + /// Cast the internally stored value to produce a new result or propagate the stored + /// error. + template ::value>::type> + Result As() const& { + if (!ok()) { + return status(); + } + return U(ValueUnsafe()); + } + const T& ValueUnsafe() const& { return *internal::launder(reinterpret_cast(&data_)); } diff --git a/cpp/src/arrow/type.cc b/cpp/src/arrow/type.cc index 5e6fec82cb3..ad889b3eb24 100644 --- a/cpp/src/arrow/type.cc +++ b/cpp/src/arrow/type.cc @@ -30,12 +30,10 @@ #include #include "arrow/array.h" -#include "arrow/chunked_array.h" #include "arrow/compare.h" #include "arrow/record_batch.h" #include "arrow/result.h" #include "arrow/status.h" -#include "arrow/table.h" #include "arrow/util/checked_cast.h" #include "arrow/util/hash_util.h" #include "arrow/util/hashing.h" @@ -885,20 +883,21 @@ size_t FieldPath::hash() const { } std::string FieldPath::ToString() const { + if (this->indices().empty()) { + return "FieldPath(empty)"; + } + std::string repr = "FieldPath("; for (auto index : this->indices()) { repr += std::to_string(index) + " "; } - repr.resize(repr.size() - 1); - repr += ")"; + repr.back() = ')'; return repr; } struct FieldPathGetImpl { static const DataType& GetType(const ArrayData& data) { return *data.type; } - static const DataType& GetType(const ChunkedArray& array) { return *array.type(); } - static void Summarize(const FieldVector& fields, std::stringstream* ss) { *ss << "{ "; for (const auto& field : fields) { @@ -993,32 +992,6 @@ struct FieldPathGetImpl { path, &child_data, [](const std::shared_ptr& data) { return &data->child_data; }); } - - static Result> Get( - const FieldPath* path, const ChunkedArrayVector& columns_arg) { - ChunkedArrayVector columns = columns_arg; - - return FieldPathGetImpl::Get( - path, &columns, [&](const std::shared_ptr& a) { - columns.clear(); - - for (int i = 0; i < a->type()->num_fields(); ++i) { - ArrayVector child_chunks; - - for (const auto& chunk : a->chunks()) { - auto child_chunk = MakeArray(chunk->data()->child_data[i]); - child_chunks.push_back(std::move(child_chunk)); - } - - auto child_column = std::make_shared( - std::move(child_chunks), a->type()->field(i)->type()); - - columns.emplace_back(std::move(child_column)); - } - - return &columns; - }); - } }; Result> FieldPath::Get(const Schema& schema) const { @@ -1039,11 +1012,16 @@ Result> FieldPath::Get(const FieldVector& fields) const { Result> FieldPath::Get(const RecordBatch& batch) const { ARROW_ASSIGN_OR_RAISE(auto data, FieldPathGetImpl::Get(this, batch.column_data())); - return MakeArray(data); + return MakeArray(std::move(data)); +} + +Result> FieldPath::Get(const Array& array) const { + ARROW_ASSIGN_OR_RAISE(auto data, Get(*array.data())); + return MakeArray(std::move(data)); } -Result> FieldPath::Get(const Table& table) const { - return FieldPathGetImpl::Get(this, table.columns()); +Result> FieldPath::Get(const ArrayData& data) const { + return FieldPathGetImpl::Get(this, data.child_data); } FieldRef::FieldRef(FieldPath indices) : impl_(std::move(indices)) { @@ -1291,11 +1269,11 @@ std::vector FieldRef::FindAll(const FieldVector& fields) const { return util::visit(Visitor{fields}, impl_); } -std::vector FieldRef::FindAll(const Array& array) const { - return FindAll(*array.type()); +std::vector FieldRef::FindAll(const ArrayData& array) const { + return FindAll(*array.type); } -std::vector FieldRef::FindAll(const ChunkedArray& array) const { +std::vector FieldRef::FindAll(const Array& array) const { return FindAll(*array.type()); } @@ -1303,10 +1281,6 @@ std::vector FieldRef::FindAll(const RecordBatch& batch) const { return FindAll(*batch.schema()); } -std::vector FieldRef::FindAll(const Table& table) const { - return FindAll(*table.schema()); -} - void PrintTo(const FieldRef& ref, std::ostream* os) { *os << ref.ToString(); } // ---------------------------------------------------------------------- @@ -1333,7 +1307,7 @@ Schema::Schema(std::vector> fields, Schema::Schema(const Schema& schema) : detail::Fingerprintable(), impl_(new Impl(*schema.impl_)) {} -Schema::~Schema() {} +Schema::~Schema() = default; int Schema::num_fields() const { return static_cast(impl_->fields_.size()); } @@ -1476,7 +1450,7 @@ std::shared_ptr Schema::WithMetadata( return std::make_shared(impl_->fields_, metadata); } -std::shared_ptr Schema::metadata() const { +const std::shared_ptr& Schema::metadata() const { return impl_->metadata_; } diff --git a/cpp/src/arrow/type.h b/cpp/src/arrow/type.h index ed504c3afe5..f0fa04f40be 100644 --- a/cpp/src/arrow/type.h +++ b/cpp/src/arrow/type.h @@ -1402,6 +1402,9 @@ class ARROW_EXPORT FieldPath { std::string ToString() const; size_t hash() const; + struct Hash { + size_t operator()(const FieldPath& path) const { return path.hash(); } + }; explicit operator bool() const { return !indices_.empty(); } bool operator!() const { return indices_.empty(); } @@ -1421,11 +1424,10 @@ class ARROW_EXPORT FieldPath { /// \brief Retrieve the referenced column from a RecordBatch or Table Result> Get(const RecordBatch& batch) const; - Result> Get(const Table& table) const; - /// \brief Retrieve the referenced child Array from an Array or ChunkedArray + /// \brief Retrieve the referenced child from an Array, ArrayData, or ChunkedArray Result> Get(const Array& array) const; - Result> Get(const ChunkedArray& array) const; + Result> Get(const ArrayData& data) const; private: std::vector indices_; @@ -1517,6 +1519,12 @@ class ARROW_EXPORT FieldRef { std::string ToString() const; size_t hash() const; + struct Hash { + size_t operator()(const FieldRef& ref) const { return ref.hash(); } + }; + + explicit operator bool() const { return Equals(FieldPath{}); } + bool operator!() const { return !Equals(FieldPath{}); } bool IsFieldPath() const { return util::holds_alternative(impl_); } bool IsName() const { return util::holds_alternative(impl_); } @@ -1526,6 +1534,13 @@ class ARROW_EXPORT FieldRef { return true; } + const FieldPath* field_path() const { + return IsFieldPath() ? &util::get(impl_) : NULLPTR; + } + const std::string* name() const { + return IsName() ? &util::get(impl_) : NULLPTR; + } + /// \brief Retrieve FieldPath of every child field which matches this FieldRef. std::vector FindAll(const Schema& schema) const; std::vector FindAll(const Field& field) const; @@ -1533,6 +1548,7 @@ class ARROW_EXPORT FieldRef { std::vector FindAll(const FieldVector& fields) const; /// \brief Convenience function which applies FindAll to arg's type or schema. + std::vector FindAll(const ArrayData& array) const; std::vector FindAll(const Array& array) const; std::vector FindAll(const ChunkedArray& array) const; std::vector FindAll(const RecordBatch& batch) const; @@ -1609,7 +1625,7 @@ class ARROW_EXPORT FieldRef { if (match) { return match.Get(root).ValueOrDie(); } - return NULLPTR; + return GetType(NULLPTR); } private: @@ -1669,7 +1685,7 @@ class ARROW_EXPORT Schema : public detail::Fingerprintable, /// \brief The custom key-value metadata, if any /// /// \return metadata may be null - std::shared_ptr metadata() const; + const std::shared_ptr& metadata() const; /// \brief Render a string representation of the schema suitable for debugging /// \param[in] show_metadata when true, if KeyValueMetadata is non-empty, diff --git a/cpp/src/arrow/type_fwd.h b/cpp/src/arrow/type_fwd.h index cf90953527e..f1000d1fe7f 100644 --- a/cpp/src/arrow/type_fwd.h +++ b/cpp/src/arrow/type_fwd.h @@ -71,6 +71,8 @@ class RecordBatch; class RecordBatchReader; class Table; +struct Datum; + using ChunkedArrayVector = std::vector>; using RecordBatchVector = std::vector>; using RecordBatchIterator = Iterator>; diff --git a/cpp/src/arrow/util/variant.h b/cpp/src/arrow/util/variant.h index 4713976d485..89f39ab8917 100644 --- a/cpp/src/arrow/util/variant.h +++ b/cpp/src/arrow/util/variant.h @@ -389,6 +389,22 @@ U& get(Variant& v) { return *v.template get(); } +/// \brief Get a const pointer to the value held by the variant +/// +/// If the type given as template argument doesn't match, a nullptr is returned. +template +const U* get_if(const Variant* v) { + return v->template get(); +} + +/// \brief Get a pointer to the value held by the variant +/// +/// If the type given as template argument doesn't match, a nullptr is returned. +template +U* get_if(Variant* v) { + return v->template get(); +} + namespace detail { template diff --git a/python/pyarrow/_dataset.pyx b/python/pyarrow/_dataset.pyx index bbe0485fa69..710e3c9c2c4 100644 --- a/python/pyarrow/_dataset.pyx +++ b/python/pyarrow/_dataset.pyx @@ -86,24 +86,22 @@ cdef CFileSource _make_file_source(object file, FileSystem filesystem=None): cdef class Expression(_Weakrefable): cdef: - shared_ptr[CExpression] wrapped - CExpression* expr + CExpression expr def __init__(self): _forbid_instantiation(self.__class__) - cdef void init(self, const shared_ptr[CExpression]& sp): - self.wrapped = sp - self.expr = sp.get() + cdef void init(self, const CExpression& sp): + self.expr = sp @staticmethod - cdef wrap(const shared_ptr[CExpression]& sp): + cdef wrap(const CExpression& sp): cdef Expression self = Expression.__new__(Expression) self.init(sp) return self - cdef inline shared_ptr[CExpression] unwrap(self): - return self.wrapped + cdef inline CExpression unwrap(self): + return self.expr def equals(self, Expression other): return self.expr.Equals(other.unwrap()) @@ -118,104 +116,79 @@ cdef class Expression(_Weakrefable): @staticmethod def _deserialize(Buffer buffer not None): - c_buffer = pyarrow_unwrap_buffer(buffer) - c_expr = GetResultValue(CExpression.Deserialize(deref(c_buffer))) - return Expression.wrap(move(c_expr)) + return Expression.wrap(GetResultValue(CDeserializeExpression( + deref(pyarrow_unwrap_buffer(buffer))))) def __reduce__(self): - buffer = pyarrow_wrap_buffer(GetResultValue(self.expr.Serialize())) + buffer = pyarrow_wrap_buffer(GetResultValue( + CSerializeExpression(self.expr))) return Expression._deserialize, (buffer,) - def validate(self, Schema schema not None): - """Validate this expression for execution against a schema. - - This will check that all reference fields are present (fields not in - the schema will be replaced with null) and all subexpressions are - executable. Returns the type to which this expression will evaluate. - - Parameters - ---------- - schema : Schema - Schema to execute the expression on. + @staticmethod + cdef Expression _expr_or_scalar(object expr): + if isinstance(expr, Expression): + return ( expr) + return ( Expression._scalar(expr)) - Returns - ------- - type : DataType - """ + @staticmethod + cdef Expression _call(str function_name, list arguments, + shared_ptr[CFunctionOptions] options=( + nullptr)): cdef: - shared_ptr[CSchema] sp_schema - CResult[shared_ptr[CDataType]] result - sp_schema = pyarrow_unwrap_schema(schema) - result = self.expr.Validate(deref(sp_schema)) - return pyarrow_wrap_data_type(GetResultValue(result)) + vector[CExpression] c_arguments - def assume(self, Expression given): - """Simplify to an equivalent Expression given assumed constraints.""" - return Expression.wrap(self.expr.Assume(given.unwrap())) - - def __invert__(self): - return Expression.wrap(CMakeNotExpression(self.unwrap())) + for argument in arguments: + c_arguments.push_back(( argument).expr) - @staticmethod - cdef shared_ptr[CExpression] _expr_or_scalar(object expr) except *: - if isinstance(expr, Expression): - return ( expr).unwrap() - return ( Expression._scalar(expr)).unwrap() + return Expression.wrap(CMakeCallExpression(tobytes(function_name), + move(c_arguments), options)) def __richcmp__(self, other, int op): - cdef: - shared_ptr[CExpression] c_expr - shared_ptr[CExpression] c_left - shared_ptr[CExpression] c_right - - c_left = self.unwrap() - c_right = Expression._expr_or_scalar(other) - - if op == Py_EQ: - c_expr = CMakeEqualExpression(move(c_left), move(c_right)) - elif op == Py_NE: - c_expr = CMakeNotEqualExpression(move(c_left), move(c_right)) - elif op == Py_GT: - c_expr = CMakeGreaterExpression(move(c_left), move(c_right)) - elif op == Py_GE: - c_expr = CMakeGreaterEqualExpression(move(c_left), move(c_right)) - elif op == Py_LT: - c_expr = CMakeLessExpression(move(c_left), move(c_right)) - elif op == Py_LE: - c_expr = CMakeLessEqualExpression(move(c_left), move(c_right)) - - return Expression.wrap(c_expr) + other = Expression._expr_or_scalar(other) + return Expression._call({ + Py_EQ: "equal", + Py_NE: "not_equal", + Py_GT: "greater", + Py_GE: "greater_equal", + Py_LT: "less", + Py_LE: "less_equal", + }[op], [self, other]) + + def __invert__(self): + return Expression._call("invert", [self]) def __and__(Expression self, other): - c_other = Expression._expr_or_scalar(other) - return Expression.wrap(CMakeAndExpression(self.wrapped, - move(c_other))) + other = Expression._expr_or_scalar(other) + return Expression._call("and_kleene", [self, other]) def __or__(Expression self, other): - c_other = Expression._expr_or_scalar(other) - return Expression.wrap(CMakeOrExpression(self.wrapped, - move(c_other))) + other = Expression._expr_or_scalar(other) + return Expression._call("or_kleene", [self, other]) def is_valid(self): """Checks whether the expression is not-null (valid)""" - return Expression.wrap(self.expr.IsValid().Copy()) + return Expression._call("is_valid", [self]) def cast(self, type, bint safe=True): """Explicitly change the expression's data type""" - cdef CCastOptions options - if safe: - options = CCastOptions.Safe() - else: - options = CCastOptions.Unsafe() - c_type = pyarrow_unwrap_data_type(ensure_type(type)) - return Expression.wrap(self.expr.CastTo(c_type, options).Copy()) + cdef shared_ptr[CCastOptions] c_options + c_options.reset(new CCastOptions(safe)) + c_options.get().to_type = pyarrow_unwrap_data_type(ensure_type(type)) + return Expression._call("cast", [self], + c_options) def isin(self, values): """Checks whether the expression is contained in values""" + cdef: + shared_ptr[CFunctionOptions] c_options + CDatum c_values + if not isinstance(values, pa.Array): values = pa.array(values) - c_values = pyarrow_unwrap_array(values) - return Expression.wrap(self.expr.In(c_values).Copy()) + + c_values = CDatum(pyarrow_unwrap_array(values)) + c_options.reset(new CSetLookupOptions(c_values, True)) + return Expression._call("is_in", [self], c_options) @staticmethod def _field(str name not None): @@ -231,11 +204,7 @@ cdef class Expression(_Weakrefable): else: scalar = pa.scalar(value) - return Expression.wrap( - shared_ptr[CExpression]( - new CScalarExpression(move(scalar.unwrap())) - ) - ) + return Expression.wrap(CMakeScalarExpression(scalar.unwrap())) _deserialize = Expression._deserialize @@ -288,12 +257,7 @@ cdef class Dataset(_Weakrefable): An Expression which evaluates to true for all data viewed by this Dataset. """ - cdef shared_ptr[CExpression] expr - expr = self.dataset.partition_expression() - if expr.get() == nullptr: - return None - else: - return Expression.wrap(expr) + return Expression.wrap(self.dataset.partition_expression()) def replace_schema(self, Schema schema not None): """ @@ -322,14 +286,15 @@ cdef class Dataset(_Weakrefable): fragments : iterator of Fragment """ cdef: - shared_ptr[CExpression] c_filter + CExpression c_filter CFragmentIterator c_iterator if filter is None: - c_fragments = self.dataset.GetFragments() + c_fragments = move(GetResultValue(self.dataset.GetFragments())) else: - c_filter = _insert_implicit_casts(filter, self.schema) - c_fragments = self.dataset.GetFragments(c_filter) + c_filter = _bind(filter, self.schema) + c_fragments = move(GetResultValue( + self.dataset.GetFragments(c_filter))) for maybe_fragment in c_fragments: yield Fragment.wrap(GetResultValue(move(maybe_fragment))) @@ -593,19 +558,14 @@ cdef class FileSystemDataset(Dataset): return FileFormat.wrap(self.filesystem_dataset.format()) -cdef shared_ptr[CExpression] _insert_implicit_casts(Expression filter, - Schema schema) except *: +cdef CExpression _bind(Expression filter, Schema schema) except *: assert schema is not None if filter is None: return _true.unwrap() - return GetResultValue( - CInsertImplicitCasts( - deref(filter.unwrap().get()), - deref(pyarrow_unwrap_schema(schema).get()) - ) - ) + return GetResultValue(filter.unwrap().Bind( + deref(pyarrow_unwrap_schema(schema).get()))) cdef class FileWriteOptions(_Weakrefable): @@ -1035,11 +995,11 @@ cdef class ParquetFileFragment(FileFragment): """ cdef: vector[shared_ptr[CFragment]] c_fragments - shared_ptr[CExpression] c_filter + CExpression c_filter shared_ptr[CFragment] c_fragment schema = schema or self.physical_schema - c_filter = _insert_implicit_casts(filter, schema) + c_filter = _bind(filter, schema) with nogil: c_fragments = move(GetResultValue( self.parquet_file_fragment.SplitByRowGroup(move(c_filter)))) @@ -1072,7 +1032,7 @@ cdef class ParquetFileFragment(FileFragment): ParquetFileFragment """ cdef: - shared_ptr[CExpression] c_filter + CExpression c_filter vector[int] c_row_group_ids shared_ptr[CFragment] c_fragment @@ -1083,7 +1043,7 @@ cdef class ParquetFileFragment(FileFragment): if filter is not None: schema = schema or self.physical_schema - c_filter = _insert_implicit_casts(filter, schema) + c_filter = _bind(filter, schema) with nogil: c_fragment = move(GetResultValue( self.parquet_file_fragment.SubsetWithFilter( @@ -1408,7 +1368,7 @@ cdef class Partitioning(_Weakrefable): return self.wrapped def parse(self, path): - cdef CResult[shared_ptr[CExpression]] result + cdef CResult[CExpression] result result = self.partitioning.Parse(tobytes(path)) return Expression.wrap(GetResultValue(result)) @@ -1643,11 +1603,7 @@ cdef class DatasetFactory(_Weakrefable): @property def root_partition(self): - cdef shared_ptr[CExpression] expr = self.factory.root_partition() - if expr.get() == nullptr: - return None - else: - return Expression.wrap(expr) + return Expression.wrap(self.factory.root_partition()) @root_partition.setter def root_partition(self, Expression expr): @@ -2121,14 +2077,14 @@ cdef void _populate_builder(const shared_ptr[CScannerBuilder]& ptr, int batch_size=_DEFAULT_BATCH_SIZE) except *: cdef: CScannerBuilder *builder - builder = ptr.get() - if columns is not None: - check_status(builder.Project([tobytes(c) for c in columns])) - check_status(builder.Filter(_insert_implicit_casts( + check_status(builder.Filter(_bind( filter, pyarrow_wrap_schema(builder.schema())))) + if columns is not None: + check_status(builder.Project([tobytes(c) for c in columns])) + check_status(builder.BatchSize(batch_size)) @@ -2277,7 +2233,8 @@ cdef class Scanner(_Weakrefable): def get_fragments(self): """Returns an iterator over the fragments in this scan. """ - cdef CFragmentIterator c_fragments = self.scanner.GetFragments() + cdef CFragmentIterator c_fragments = move(GetResultValue( + self.scanner.GetFragments())) for maybe_fragment in c_fragments: yield Fragment.wrap(GetResultValue(move(maybe_fragment))) @@ -2296,13 +2253,16 @@ def _get_partition_keys(Expression partition_expression): is converted to {'part': 'a', 'year': 2016} """ cdef: - shared_ptr[CExpression] expr = partition_expression.unwrap() - pair[c_string, shared_ptr[CScalar]] name_val - - return { - frombytes(name_val.first): pyarrow_wrap_scalar(name_val.second).as_py() - for name_val in GetResultValue(CGetPartitionKeys(deref(expr.get()))) - } + CExpression expr = partition_expression.unwrap() + pair[CFieldRef, CDatum] ref_val + + out = {} + for ref_val in GetResultValue(CExtractKnownFieldValues(expr)): + assert ref_val.first.name() != nullptr + assert ref_val.second.kind() == DatumType_SCALAR + val = pyarrow_wrap_scalar(ref_val.second.scalar()) + out[frombytes(deref(ref_val.first.name()))] = val.as_py() + return out def _filesystemdataset_write( diff --git a/python/pyarrow/compute.py b/python/pyarrow/compute.py index ddfd8057db2..8650a38345b 100644 --- a/python/pyarrow/compute.py +++ b/python/pyarrow/compute.py @@ -62,7 +62,7 @@ def _get_arg_names(func): arg_names = ["left", "right"] else: raise NotImplementedError( - "unsupported arity: {}".format(func.arity)) + f"unsupported arity: {func.arity} (function: {func.name})") return arg_names @@ -116,7 +116,7 @@ def _decorate_compute_function(wrapper, exposed_name, func, option_class): doc_pieces.append("""\ options : pyarrow.compute.{0}, optional Parameters altering compute function semantics - **kwargs: optional + **kwargs : optional Parameters for {0} constructor. Either `options` or `**kwargs` can be passed, but not both at the same time. """.format(option_class.__name__)) @@ -160,14 +160,14 @@ def _handle_options(name, option_class, options, kwargs): _wrapper_template = dedent("""\ def make_wrapper(func, option_class): - def {func_name}({args_sig}, *, memory_pool=None): + def {func_name}({args_sig}{kwonly}, memory_pool=None): return func.call([{args_sig}], None, memory_pool) return {func_name} """) _wrapper_options_template = dedent("""\ def make_wrapper(func, option_class): - def {func_name}({args_sig}, *, options=None, memory_pool=None, + def {func_name}({args_sig}{kwonly}, options=None, memory_pool=None, **kwargs): options = _handle_options({func_name!r}, option_class, options, kwargs) @@ -180,6 +180,7 @@ def _wrap_function(name, func): option_class = _get_options_class(func) arg_names = _get_arg_names(func) args_sig = ', '.join(arg_names) + kwonly = '' if arg_names[-1].startswith('*') else ', *' # Generate templated wrapper, so that the signature matches # the documented argument names. @@ -188,7 +189,8 @@ def _wrap_function(name, func): template = _wrapper_options_template else: template = _wrapper_template - exec(template.format(func_name=name, args_sig=args_sig), globals(), ns) + exec(template.format(func_name=name, args_sig=args_sig, kwonly=kwonly), + globals(), ns) wrapper = ns['make_wrapper'](func, option_class) return _decorate_compute_function(wrapper, name, func, option_class) diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index 45a39061919..fcdf8ed4179 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -405,6 +405,10 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil: CFieldRef() CFieldRef(c_string name) CFieldRef(int index) + const c_string* name() const + + cdef cppclass CFieldRefHash" arrow::FieldRef::Hash": + pass cdef cppclass CStructType" arrow::StructType"(CDataType): CStructType(const vector[shared_ptr[CField]]& fields) @@ -1826,13 +1830,14 @@ cdef extern from "arrow/compute/api.h" namespace "arrow::compute" nogil: CDatum(const shared_ptr[CRecordBatch]& value) CDatum(const shared_ptr[CTable]& value) - DatumType kind() + DatumType kind() const + c_string ToString() const - shared_ptr[CArrayData] array() - shared_ptr[CChunkedArray] chunked_array() - shared_ptr[CRecordBatch] record_batch() - shared_ptr[CTable] table() - shared_ptr[CScalar] scalar() + const shared_ptr[CArrayData]& array() const + const shared_ptr[CChunkedArray]& chunked_array() const + const shared_ptr[CRecordBatch]& record_batch() const + const shared_ptr[CTable]& table() const + const shared_ptr[CScalar]& scalar() const cdef cppclass CSetLookupOptions \ "arrow::compute::SetLookupOptions"(CFunctionOptions): diff --git a/python/pyarrow/includes/libarrow_dataset.pxd b/python/pyarrow/includes/libarrow_dataset.pxd index a81042920d5..98e6c20bf23 100644 --- a/python/pyarrow/includes/libarrow_dataset.pxd +++ b/python/pyarrow/includes/libarrow_dataset.pxd @@ -36,63 +36,24 @@ cdef extern from "arrow/dataset/api.h" namespace "arrow::dataset" nogil: cdef cppclass CExpression "arrow::dataset::Expression": c_bool Equals(const CExpression& other) const - c_bool Equals(const shared_ptr[CExpression]& other) const - CResult[shared_ptr[CDataType]] Validate(const CSchema& schema) const - shared_ptr[CExpression] Assume(const CExpression& given) const - shared_ptr[CExpression] Assume( - const shared_ptr[CExpression]& given) const c_string ToString() const - shared_ptr[CExpression] Copy() const + CResult[CExpression] Bind(const CSchema&) - const CExpression& In(shared_ptr[CArray]) const - const CExpression& IsValid() const - const CExpression& CastTo(shared_ptr[CDataType], CCastOptions) const - const CExpression& CastLike(shared_ptr[CExpression], - CCastOptions) const + cdef CExpression CMakeScalarExpression \ + "arrow::dataset::literal"(shared_ptr[CScalar] value) - @staticmethod - CResult[shared_ptr[CExpression]] Deserialize(const CBuffer& buffer) - CResult[shared_ptr[CBuffer]] Serialize() const - - ctypedef vector[shared_ptr[CExpression]] CExpressionVector \ - "arrow::dataset::ExpressionVector" + cdef CExpression CMakeFieldExpression \ + "arrow::dataset::field_ref"(c_string name) - cdef cppclass CScalarExpression \ - "arrow::dataset::ScalarExpression"(CExpression): - CScalarExpression(const shared_ptr[CScalar]& value) + cdef CExpression CMakeCallExpression \ + "arrow::dataset::call"(c_string function, + vector[CExpression] arguments, + shared_ptr[CFunctionOptions] options) - cdef shared_ptr[CExpression] CMakeFieldExpression \ - "arrow::dataset::field_ref"(c_string name) - cdef shared_ptr[CExpression] CMakeNotExpression \ - "arrow::dataset::not_"(shared_ptr[CExpression] operand) - cdef shared_ptr[CExpression] CMakeAndExpression \ - "arrow::dataset::and_"(shared_ptr[CExpression], - shared_ptr[CExpression]) - cdef shared_ptr[CExpression] CMakeOrExpression \ - "arrow::dataset::or_"(shared_ptr[CExpression], - shared_ptr[CExpression]) - cdef shared_ptr[CExpression] CMakeEqualExpression \ - "arrow::dataset::equal"(shared_ptr[CExpression], - shared_ptr[CExpression]) - cdef shared_ptr[CExpression] CMakeNotEqualExpression \ - "arrow::dataset::not_equal"(shared_ptr[CExpression], - shared_ptr[CExpression]) - cdef shared_ptr[CExpression] CMakeGreaterExpression \ - "arrow::dataset::greater"(shared_ptr[CExpression], - shared_ptr[CExpression]) - cdef shared_ptr[CExpression] CMakeGreaterEqualExpression \ - "arrow::dataset::greater_equal"(shared_ptr[CExpression], - shared_ptr[CExpression]) - cdef shared_ptr[CExpression] CMakeLessExpression \ - "arrow::dataset::less"(shared_ptr[CExpression], - shared_ptr[CExpression]) - cdef shared_ptr[CExpression] CMakeLessEqualExpression \ - "arrow::dataset::less_equal"(shared_ptr[CExpression], - shared_ptr[CExpression]) - - cdef CResult[shared_ptr[CExpression]] CInsertImplicitCasts \ - "arrow::dataset::InsertImplicitCasts"( - const CExpression &, const CSchema&) + cdef CResult[shared_ptr[CBuffer]] CSerializeExpression \ + "arrow::dataset::Serialize"(const CExpression&) + cdef CResult[CExpression] CDeserializeExpression \ + "arrow::dataset::Deserialize"(const CBuffer&) cdef cppclass CRecordBatchProjector "arrow::dataset::RecordBatchProjector": pass @@ -119,7 +80,7 @@ cdef extern from "arrow/dataset/api.h" namespace "arrow::dataset" nogil: shared_ptr[CScanOptions] options, shared_ptr[CScanContext] context) c_bool splittable() const c_string type_name() const - const shared_ptr[CExpression]& partition_expression() const + const CExpression& partition_expression() const ctypedef vector[shared_ptr[CFragment]] CFragmentVector \ "arrow::dataset::FragmentVector" @@ -130,7 +91,7 @@ cdef extern from "arrow/dataset/api.h" namespace "arrow::dataset" nogil: cdef cppclass CInMemoryFragment "arrow::dataset::InMemoryFragment"( CFragment): CInMemoryFragment(vector[shared_ptr[CRecordBatch]] record_batches, - shared_ptr[CExpression] partition_expression) + CExpression partition_expression) cdef cppclass CScanner "arrow::dataset::Scanner": CScanner(shared_ptr[CDataset], shared_ptr[CScanOptions], @@ -139,7 +100,7 @@ cdef extern from "arrow/dataset/api.h" namespace "arrow::dataset" nogil: shared_ptr[CScanContext]) CResult[CScanTaskIterator] Scan() CResult[shared_ptr[CTable]] ToTable() - CFragmentIterator GetFragments() + CResult[CFragmentIterator] GetFragments() const shared_ptr[CScanOptions]& options() cdef cppclass CScannerBuilder "arrow::dataset::ScannerBuilder": @@ -148,8 +109,7 @@ cdef extern from "arrow/dataset/api.h" namespace "arrow::dataset" nogil: CScannerBuilder(shared_ptr[CSchema], shared_ptr[CFragment], shared_ptr[CScanContext] scan_context) CStatus Project(const vector[c_string]& columns) - CStatus Filter(const CExpression& filter) - CStatus Filter(shared_ptr[CExpression] filter) + CStatus Filter(CExpression filter) CStatus UseThreads(c_bool use_threads) CStatus BatchSize(int64_t batch_size) CResult[shared_ptr[CScanner]] Finish() @@ -160,9 +120,9 @@ cdef extern from "arrow/dataset/api.h" namespace "arrow::dataset" nogil: cdef cppclass CDataset "arrow::dataset::Dataset": const shared_ptr[CSchema] & schema() - CFragmentIterator GetFragments() - CFragmentIterator GetFragments(shared_ptr[CExpression] predicate) - const shared_ptr[CExpression] & partition_expression() + CResult[CFragmentIterator] GetFragments() + CResult[CFragmentIterator] GetFragments(CExpression predicate) + const CExpression & partition_expression() c_string type_name() CResult[shared_ptr[CDataset]] ReplaceSchema(shared_ptr[CSchema]) @@ -193,8 +153,8 @@ cdef extern from "arrow/dataset/api.h" namespace "arrow::dataset" nogil: CResult[shared_ptr[CDataset]] FinishWithSchema "Finish"( const shared_ptr[CSchema]& schema) CResult[shared_ptr[CDataset]] Finish() - const shared_ptr[CExpression]& root_partition() - CStatus SetRootPartition(shared_ptr[CExpression] partition) + const CExpression& root_partition() + CStatus SetRootPartition(CExpression partition) cdef cppclass CUnionDatasetFactory "arrow::dataset::UnionDatasetFactory": @staticmethod @@ -221,7 +181,7 @@ cdef extern from "arrow/dataset/api.h" namespace "arrow::dataset" nogil: CResult[shared_ptr[CSchema]] Inspect(const CFileSource&) const CResult[shared_ptr[CFileFragment]] MakeFragment( CFileSource source, - shared_ptr[CExpression] partition_expression, + CExpression partition_expression, shared_ptr[CSchema] physical_schema) shared_ptr[CFileWriteOptions] DefaultWriteOptions() @@ -240,9 +200,9 @@ cdef extern from "arrow/dataset/api.h" namespace "arrow::dataset" nogil: const vector[int]& row_groups() const shared_ptr[CFileMetaData] metadata() const CResult[vector[shared_ptr[CFragment]]] SplitByRowGroup( - shared_ptr[CExpression] predicate) + CExpression predicate) CResult[shared_ptr[CFragment]] SubsetWithFilter "Subset"( - shared_ptr[CExpression] predicate) + CExpression predicate) CResult[shared_ptr[CFragment]] SubsetWithIds "Subset"( vector[int] row_group_ids) CStatus EnsureCompleteMetadata() @@ -260,7 +220,7 @@ cdef extern from "arrow/dataset/api.h" namespace "arrow::dataset" nogil: @staticmethod CResult[shared_ptr[CDataset]] Make( shared_ptr[CSchema] schema, - shared_ptr[CExpression] source_partition, + CExpression source_partition, shared_ptr[CFileFormat] format, shared_ptr[CFileSystem] filesystem, vector[shared_ptr[CFileFragment]] fragments) @@ -287,7 +247,7 @@ cdef extern from "arrow/dataset/api.h" namespace "arrow::dataset" nogil: CParquetFileFormatReaderOptions reader_options CResult[shared_ptr[CFileFragment]] MakeFragment( CFileSource source, - shared_ptr[CExpression] partition_expression, + CExpression partition_expression, shared_ptr[CSchema] physical_schema, vector[int] row_groups) @@ -305,7 +265,7 @@ cdef extern from "arrow/dataset/api.h" namespace "arrow::dataset" nogil: cdef cppclass CPartitioning "arrow::dataset::Partitioning": c_string type_name() const - CResult[shared_ptr[CExpression]] Parse(const c_string & path) const + CResult[CExpression] Parse(const c_string & path) const const shared_ptr[CSchema] & schema() cdef cppclass CPartitioningFactoryOptions \ @@ -346,9 +306,9 @@ cdef extern from "arrow/dataset/api.h" namespace "arrow::dataset" nogil: const CExpression& partition_expression, CRecordBatchProjector* projector) - cdef CResult[unordered_map[c_string, shared_ptr[CScalar]]] \ - CGetPartitionKeys "arrow::dataset::KeyValuePartitioning::GetKeys"( - const CExpression& partition_expression) + cdef CResult[unordered_map[CFieldRef, CDatum, CFieldRefHash]] \ + CExtractKnownFieldValues "arrow::dataset::ExtractKnownFieldValues"( + const CExpression& partition_expression) cdef cppclass CFileSystemFactoryOptions \ "arrow::dataset::FileSystemFactoryOptions": diff --git a/python/pyarrow/tests/test_compute.py b/python/pyarrow/tests/test_compute.py index 1afa0ef931c..b74dfcdec7b 100644 --- a/python/pyarrow/tests/test_compute.py +++ b/python/pyarrow/tests/test_compute.py @@ -83,7 +83,11 @@ def test_exported_functions(): functions = exported_functions assert len(functions) >= 10 for func in functions: - args = [object()] * func.__arrow_compute_function__['arity'] + arity = func.__arrow_compute_function__['arity'] + if arity is Ellipsis: + args = [object()] * 3 + else: + args = [object()] * arity with pytest.raises(TypeError, match="Got unexpected argument type " " for compute function"): @@ -172,7 +176,8 @@ def test_function_attributes(): kernels = func.kernels assert func.num_kernels == len(kernels) assert all(isinstance(ker, pc.Kernel) for ker in kernels) - assert func.arity >= 1 # no varargs functions for now + if func.arity is not Ellipsis: + assert func.arity >= 1 repr(func) for ker in kernels: repr(ker) @@ -402,7 +407,7 @@ def test_generated_docstrings(): If not passed, will allocate memory from the default memory pool. options : pyarrow.compute.MinMaxOptions, optional Parameters altering compute function semantics - **kwargs: optional + **kwargs : optional Parameters for MinMaxOptions constructor. Either `options` or `**kwargs` can be passed, but not both at the same time. """) diff --git a/python/pyarrow/tests/test_dataset.py b/python/pyarrow/tests/test_dataset.py index 3bda5692128..0ab9d95398d 100644 --- a/python/pyarrow/tests/test_dataset.py +++ b/python/pyarrow/tests/test_dataset.py @@ -408,19 +408,6 @@ def test_expression_serialization(): f = ds.scalar({'a': 1}) g = ds.scalar(pa.scalar(1)) - condition = ds.field('i64') > 5 - schema = pa.schema([ - pa.field('i64', pa.int64()), - pa.field('f64', pa.float64()) - ]) - assert condition.validate(schema) == pa.bool_() - - assert condition.assume(ds.field('i64') == 5).equals( - ds.scalar(False)) - - assert condition.assume(ds.field('i64') == 7).equals( - ds.scalar(True)) - all_exprs = [a, b, c, d, e, f, g, a == b, a > b, a & b, a | b, ~c, d.is_valid(), a.cast(pa.int32(), safe=False), a.cast(pa.int32(), safe=False), a.isin([1, 2, 3]), @@ -781,7 +768,9 @@ def assert_yields_projected(fragment, row_slice, # Fragments don't contain the partition's columns if not provided to the # `to_table(schema=...)` method. - with pytest.raises(ValueError, match="Field named 'part' not found"): + pattern = (r'No match for FieldRef.Name\(part\) in ' + + fragment.physical_schema.to_string(False, False, False)) + with pytest.raises(ValueError, match=pattern): new_fragment = parquet_format.make_fragment( fragment.path, fragment.filesystem, partition_expression=fragment.partition_expression) diff --git a/python/pyarrow/tests/test_parquet.py b/python/pyarrow/tests/test_parquet.py index 9422594caff..6fa001cc758 100644 --- a/python/pyarrow/tests/test_parquet.py +++ b/python/pyarrow/tests/test_parquet.py @@ -2144,7 +2144,7 @@ def test_filters_invalid_column(tempdir, use_legacy_dataset): _generate_partition_directories(fs, base_path, partition_spec, df) - msg = "Field named 'non_existent_column' not found" + msg = r"No match for FieldRef.Name\(non_existent_column\)" with pytest.raises(ValueError, match=msg): pq.ParquetDataset(base_path, filesystem=fs, filters=[('non_existent_column', '<', 3), ], diff --git a/r/R/arrowExports.R b/r/R/arrowExports.R index 5e519080cee..7407e7c23de 100644 --- a/r/R/arrowExports.R +++ b/r/R/arrowExports.R @@ -736,52 +736,12 @@ FixedSizeListType__list_size <- function(type){ .Call(`_arrow_FixedSizeListType__list_size` , type) } -dataset___expr__field_ref <- function(name){ - .Call(`_arrow_dataset___expr__field_ref` , name) -} - -dataset___expr__equal <- function(lhs, rhs){ - .Call(`_arrow_dataset___expr__equal` , lhs, rhs) -} - -dataset___expr__not_equal <- function(lhs, rhs){ - .Call(`_arrow_dataset___expr__not_equal` , lhs, rhs) -} - -dataset___expr__greater <- function(lhs, rhs){ - .Call(`_arrow_dataset___expr__greater` , lhs, rhs) -} - -dataset___expr__greater_equal <- function(lhs, rhs){ - .Call(`_arrow_dataset___expr__greater_equal` , lhs, rhs) -} - -dataset___expr__less <- function(lhs, rhs){ - .Call(`_arrow_dataset___expr__less` , lhs, rhs) -} - -dataset___expr__less_equal <- function(lhs, rhs){ - .Call(`_arrow_dataset___expr__less_equal` , lhs, rhs) -} - -dataset___expr__in <- function(lhs, rhs){ - .Call(`_arrow_dataset___expr__in` , lhs, rhs) -} - -dataset___expr__and <- function(lhs, rhs){ - .Call(`_arrow_dataset___expr__and` , lhs, rhs) -} - -dataset___expr__or <- function(lhs, rhs){ - .Call(`_arrow_dataset___expr__or` , lhs, rhs) +dataset___expr__call <- function(func_name, argument_list, options){ + .Call(`_arrow_dataset___expr__call` , func_name, argument_list, options) } -dataset___expr__not <- function(lhs){ - .Call(`_arrow_dataset___expr__not` , lhs) -} - -dataset___expr__is_valid <- function(lhs){ - .Call(`_arrow_dataset___expr__is_valid` , lhs) +dataset___expr__field_ref <- function(name){ + .Call(`_arrow_dataset___expr__field_ref` , name) } dataset___expr__scalar <- function(x){ @@ -1448,10 +1408,6 @@ Scalar__ToString <- function(s){ .Call(`_arrow_Scalar__ToString` , s) } -Scalar__CastTo <- function(s, t){ - .Call(`_arrow_Scalar__CastTo` , s, t) -} - StructScalar__field <- function(s, i){ .Call(`_arrow_StructScalar__field` , s, i) } diff --git a/r/R/dplyr.R b/r/R/dplyr.R index f3dd078c952..b4b6fc4dab1 100644 --- a/r/R/dplyr.R +++ b/r/R/dplyr.R @@ -248,7 +248,7 @@ filter_mask <- function(.data) { if (query_on_dataset(.data)) { comp_func <- function(operator) { force(operator) - function(e1, e2) make_expression(operator, e1, e2) + function(e1, e2) build_dataset_expression(operator, e1, e2) } var_binder <- function(x) Expression$field_ref(x) } else { diff --git a/r/R/expression.R b/r/R/expression.R index d5623fb7786..9a5e575183d 100644 --- a/r/R/expression.R +++ b/r/R/expression.R @@ -57,21 +57,53 @@ build_array_expression <- function(.Generic, e1, e2, ...) { if (.Generic %in% names(.unary_function_map)) { expr <- array_expression(.unary_function_map[[.Generic]], e1) } else { - e1 <- .wrap_arrow(e1, .Generic, e2$type) - e2 <- .wrap_arrow(e2, .Generic, e1$type) + e1 <- .wrap_arrow(e1, .Generic) + e2 <- .wrap_arrow(e2, .Generic) + + # In Arrow, "divide" is one function, which does integer division on + # integer inputs and floating-point division on floats + if (.Generic == "/") { + # TODO: omg so many ways it's wrong to assume these types + e1 <- cast_array_expression(e1, float64()) + e2 <- cast_array_expression(e2, float64()) + } else if (.Generic == "%/%") { + # In R, integer division works like floor(float division) + out <- build_array_expression("/", e1, e2) + return(cast_array_expression(out, int32(), allow_float_truncate = TRUE)) + } else if (.Generic == "%%") { + # {e1 - e2 * ( e1 %/% e2 )} + # ^^^ form doesn't work because Ops.Array evaluates eagerly, + # but we can build that up + quotient <- build_array_expression("%/%", e1, e2) + # this cast is to ensure that the result of this and e1 are the same + # (autocasting only applies to scalars) + base <- cast_array_expression(quotient * e2, e1$type) + return(build_array_expression("-", e1, base)) + } + expr <- array_expression(.binary_function_map[[.Generic]], e1, e2, ...) } expr } -.wrap_arrow <- function(arg, fun, type) { +cast_array_expression <- function(x, to_type, safe = TRUE, ...) { + opts <- list( + to_type = to_type, + allow_int_overflow = !safe, + allow_time_truncate = !safe, + allow_float_truncate = !safe + ) + array_expression("cast", x, options = modifyList(opts, list(...))) +} + +.wrap_arrow <- function(arg, fun) { if (!inherits(arg, c("ArrowObject", "array_expression"))) { # TODO: Array$create if lengths are equal? # TODO: these kernels should autocast like the dataset ones do (e.g. int vs. float) if (fun == "%in%") { - arg <- Array$create(arg, type = type) + arg <- Array$create(arg) } else { - arg <- Scalar$create(arg, type = type) + arg <- Scalar$create(arg) } } arg @@ -91,6 +123,15 @@ build_array_expression <- function(.Generic, e1, e2, ...) { "<=" = "less_equal", "&" = "and_kleene", "|" = "or_kleene", + "+" = "add_checked", + "-" = "subtract_checked", + "*" = "multiply_checked", + "/" = "divide_checked", + "%/%" = "divide_checked", + # we don't actually use divide_checked with `%%`, rather it is rewritten to + # use %/% above. + "%%" = "divide_checked", + # TODO: "^" (ARROW-11070) "%in%" = "is_in_meta_binary" ) @@ -104,6 +145,16 @@ eval_array_expression <- function(x) { a } }) + if (length(x$args) == 2L) { + # Insert implicit casts + if (inherits(x$args[[1]], "Scalar")) { + x$args[[1]] <- x$args[[1]]$cast(x$args[[2]]$type) + } else if (inherits(x$args[[2]], "Scalar")) { + x$args[[2]] <- x$args[[2]]$cast(x$args[[1]]$type) + } else if (x$fun == "is_in_meta_binary" && inherits(x$args[[2]], "Array")) { + x$args[[2]] <- x$args[[2]]$cast(x$args[[1]]$type) + } + } call_function(x$fun, args = x$args, options = x$options %||% empty_named_list()) } @@ -153,104 +204,86 @@ print.array_expression <- function(x, ...) { #' `Expression$field_ref(name)` is used to construct an `Expression` which #' evaluates to the named column in the `Dataset` against which it is evaluated. #' -#' `Expression$compare(OP, e1, e2)` takes two `Expression` operands, constructing -#' an `Expression` which will evaluate these operands then compare them with the -#' relation specified by OP (e.g. "==", "!=", ">", etc.) For example, to filter -#' down to rows where the column named "alpha" is less than 5: -#' `Expression$compare("<", Expression$field_ref("alpha"), Expression$scalar(5))` -#' -#' `Expression$and(e1, e2)`, `Expression$or(e1, e2)`, and `Expression$not(e1)` -#' construct an `Expression` combining their arguments with Boolean operators. -#' -#' `Expression$is_valid(x)` is essentially (an inversion of) `is.na()` for `Expression`s. -#' -#' `Expression$in_(x, set)` evaluates x and returns whether or not it is a member of the set. +#' `Expression$create(function_name, ..., options)` builds a function-call +#' `Expression` containing one or more `Expression`s. #' @name Expression #' @rdname Expression #' @export Expression <- R6Class("Expression", inherit = ArrowObject, public = list( - ToString = function() dataset___expr__ToString(self) + ToString = function() dataset___expr__ToString(self), + cast = function(to_type, safe = TRUE, ...) { + opts <- list( + to_type = to_type, + allow_int_overflow = !safe, + allow_time_truncate = !safe, + allow_float_truncate = !safe + ) + Expression$create("cast", self, options = modifyList(opts, list(...))) + } ) ) - +Expression$create <- function(function_name, + ..., + args = list(...), + options = empty_named_list()) { + assert_that(is.string(function_name)) + dataset___expr__call(function_name, args, options) +} Expression$field_ref <- function(name) { - assert_is(name, "character") - assert_that(length(name) == 1) + assert_that(is.string(name)) dataset___expr__field_ref(name) } Expression$scalar <- function(x) { dataset___expr__scalar(Scalar$create(x)) } -Expression$compare <- function(OP, e1, e2) { - comp_func <- comparison_function_map[[OP]] - if (is.null(comp_func)) { - stop(OP, " is not a supported comparison function", call. = FALSE) - } - comp_func(e1, e2) -} -comparison_function_map <- list( - "==" = dataset___expr__equal, - "!=" = dataset___expr__not_equal, - ">" = dataset___expr__greater, - ">=" = dataset___expr__greater_equal, - "<" = dataset___expr__less, - "<=" = dataset___expr__less_equal -) -Expression$in_ <- function(x, set) { - dataset___expr__in(x, Array$create(set)) -} -Expression$and <- function(e1, e2) { - dataset___expr__and(e1, e2) -} -Expression$or <- function(e1, e2) { - dataset___expr__or(e1, e2) -} -Expression$not <- function(e1) { - dataset___expr__not(e1) -} -Expression$is_valid <- function(e1) { - dataset___expr__is_valid(e1) +build_dataset_expression <- function(.Generic, e1, e2, ...) { + if (.Generic %in% names(.unary_function_map)) { + expr <- Expression$create(.unary_function_map[[.Generic]], e1) + } else if (.Generic == "%in%") { + # Special-case %in%, which is different from the Array function name + expr <- Expression$create("is_in", e1, + options = list( + value_set = Array$create(e2), + skip_nulls = TRUE + ) + ) + } else { + if (!inherits(e1, "Expression")) { + e1 <- Expression$scalar(e1) + } + if (!inherits(e2, "Expression")) { + e2 <- Expression$scalar(e2) + } + + # In Arrow, "divide" is one function, which does integer division on + # integer inputs and floating-point division on floats + if (.Generic == "/") { + # TODO: omg so many ways it's wrong to assume these types + e1 <- e1$cast(float64()) + e2 <- e2$cast(float64()) + } else if (.Generic == "%/%") { + # In R, integer division works like floor(float division) + out <- build_dataset_expression("/", e1, e2) + return(out$cast(int32(), allow_float_truncate = TRUE)) + } else if (.Generic == "%%") { + return(e1 - e2 * ( e1 %/% e2 )) + } + + expr <- Expression$create(.binary_function_map[[.Generic]], e1, e2, ...) + } + expr } #' @export Ops.Expression <- function(e1, e2) { if (.Generic == "!") { - return(Expression$not(e1)) - } - make_expression(.Generic, e1, e2) -} - -make_expression <- function(operator, e1, e2) { - if (operator == "%in%") { - # In doesn't take Scalar, it takes Array - return(Expression$in_(e1, e2)) - } - - # Handle unary functions before touching e2 - if (operator == "is.na") { - return(is.na(e1)) - } - if (operator == "!") { - return(Expression$not(e1)) - } - - # Check for non-expressions and convert to Expressions - if (!inherits(e1, "Expression")) { - e1 <- Expression$scalar(e1) - } - if (!inherits(e2, "Expression")) { - e2 <- Expression$scalar(e2) - } - if (operator == "&") { - Expression$and(e1, e2) - } else if (operator == "|") { - Expression$or(e1, e2) + build_dataset_expression(.Generic, e1) } else { - Expression$compare(operator, e1, e2) + build_dataset_expression(.Generic, e1, e2) } } #' @export -is.na.Expression <- function(x) !Expression$is_valid(x) +is.na.Expression <- function(x) Expression$create("is_null", x) diff --git a/r/R/scalar.R b/r/R/scalar.R index 12f29990e0a..774fe571145 100644 --- a/r/R/scalar.R +++ b/r/R/scalar.R @@ -32,8 +32,14 @@ Scalar <- R6Class("Scalar", # TODO: document the methods public = list( ToString = function() Scalar__ToString(self), - cast = function(target_type) { - Scalar__CastTo(self, as_type(target_type)) + cast = function(target_type, safe = TRUE, ...) { + opts <- list( + to_type = as_type(target_type), + allow_int_overflow = !safe, + allow_time_truncate = !safe, + allow_float_truncate = !safe + ) + call_function("cast", self, options = modifyList(opts, list(...))) }, as_vector = function() Scalar__as_vector(self) ), diff --git a/r/man/Expression.Rd b/r/man/Expression.Rd index 14570eba892..58a6a44c0c0 100644 --- a/r/man/Expression.Rd +++ b/r/man/Expression.Rd @@ -13,16 +13,6 @@ the provided scalar (length-1) R value. \code{Expression$field_ref(name)} is used to construct an \code{Expression} which evaluates to the named column in the \code{Dataset} against which it is evaluated. -\code{Expression$compare(OP, e1, e2)} takes two \code{Expression} operands, constructing -an \code{Expression} which will evaluate these operands then compare them with the -relation specified by OP (e.g. "==", "!=", ">", etc.) For example, to filter -down to rows where the column named "alpha" is less than 5: -\code{Expression$compare("<", Expression$field_ref("alpha"), Expression$scalar(5))} - -\code{Expression$and(e1, e2)}, \code{Expression$or(e1, e2)}, and \code{Expression$not(e1)} -construct an \code{Expression} combining their arguments with Boolean operators. - -\code{Expression$is_valid(x)} is essentially (an inversion of) \code{is.na()} for \code{Expression}s. - -\code{Expression$in_(x, set)} evaluates x and returns whether or not it is a member of the set. +\code{Expression$create(function_name, ..., options)} builds a function-call +\code{Expression} containing one or more \code{Expression}s. } diff --git a/r/src/arrowExports.cpp b/r/src/arrowExports.cpp index 975ff72f7f1..0a73b8681c4 100644 --- a/r/src/arrowExports.cpp +++ b/r/src/arrowExports.cpp @@ -2846,190 +2846,33 @@ extern "C" SEXP _arrow_FixedSizeListType__list_size(SEXP type_sexp){ // expression.cpp #if defined(ARROW_R_WITH_ARROW) -std::shared_ptr dataset___expr__field_ref(std::string name); -extern "C" SEXP _arrow_dataset___expr__field_ref(SEXP name_sexp){ -BEGIN_CPP11 - arrow::r::Input::type name(name_sexp); - return cpp11::as_sexp(dataset___expr__field_ref(name)); -END_CPP11 -} -#else -extern "C" SEXP _arrow_dataset___expr__field_ref(SEXP name_sexp){ - Rf_error("Cannot call dataset___expr__field_ref(). Please use arrow::install_arrow() to install required runtime libraries. "); -} -#endif - -// expression.cpp -#if defined(ARROW_R_WITH_ARROW) -std::shared_ptr dataset___expr__equal(const std::shared_ptr& lhs, const std::shared_ptr& rhs); -extern "C" SEXP _arrow_dataset___expr__equal(SEXP lhs_sexp, SEXP rhs_sexp){ -BEGIN_CPP11 - arrow::r::Input&>::type lhs(lhs_sexp); - arrow::r::Input&>::type rhs(rhs_sexp); - return cpp11::as_sexp(dataset___expr__equal(lhs, rhs)); -END_CPP11 -} -#else -extern "C" SEXP _arrow_dataset___expr__equal(SEXP lhs_sexp, SEXP rhs_sexp){ - Rf_error("Cannot call dataset___expr__equal(). Please use arrow::install_arrow() to install required runtime libraries. "); -} -#endif - -// expression.cpp -#if defined(ARROW_R_WITH_ARROW) -std::shared_ptr dataset___expr__not_equal(const std::shared_ptr& lhs, const std::shared_ptr& rhs); -extern "C" SEXP _arrow_dataset___expr__not_equal(SEXP lhs_sexp, SEXP rhs_sexp){ -BEGIN_CPP11 - arrow::r::Input&>::type lhs(lhs_sexp); - arrow::r::Input&>::type rhs(rhs_sexp); - return cpp11::as_sexp(dataset___expr__not_equal(lhs, rhs)); -END_CPP11 -} -#else -extern "C" SEXP _arrow_dataset___expr__not_equal(SEXP lhs_sexp, SEXP rhs_sexp){ - Rf_error("Cannot call dataset___expr__not_equal(). Please use arrow::install_arrow() to install required runtime libraries. "); -} -#endif - -// expression.cpp -#if defined(ARROW_R_WITH_ARROW) -std::shared_ptr dataset___expr__greater(const std::shared_ptr& lhs, const std::shared_ptr& rhs); -extern "C" SEXP _arrow_dataset___expr__greater(SEXP lhs_sexp, SEXP rhs_sexp){ +std::shared_ptr dataset___expr__call(std::string func_name, cpp11::list argument_list, cpp11::list options); +extern "C" SEXP _arrow_dataset___expr__call(SEXP func_name_sexp, SEXP argument_list_sexp, SEXP options_sexp){ BEGIN_CPP11 - arrow::r::Input&>::type lhs(lhs_sexp); - arrow::r::Input&>::type rhs(rhs_sexp); - return cpp11::as_sexp(dataset___expr__greater(lhs, rhs)); -END_CPP11 -} -#else -extern "C" SEXP _arrow_dataset___expr__greater(SEXP lhs_sexp, SEXP rhs_sexp){ - Rf_error("Cannot call dataset___expr__greater(). Please use arrow::install_arrow() to install required runtime libraries. "); -} -#endif - -// expression.cpp -#if defined(ARROW_R_WITH_ARROW) -std::shared_ptr dataset___expr__greater_equal(const std::shared_ptr& lhs, const std::shared_ptr& rhs); -extern "C" SEXP _arrow_dataset___expr__greater_equal(SEXP lhs_sexp, SEXP rhs_sexp){ -BEGIN_CPP11 - arrow::r::Input&>::type lhs(lhs_sexp); - arrow::r::Input&>::type rhs(rhs_sexp); - return cpp11::as_sexp(dataset___expr__greater_equal(lhs, rhs)); -END_CPP11 -} -#else -extern "C" SEXP _arrow_dataset___expr__greater_equal(SEXP lhs_sexp, SEXP rhs_sexp){ - Rf_error("Cannot call dataset___expr__greater_equal(). Please use arrow::install_arrow() to install required runtime libraries. "); -} -#endif - -// expression.cpp -#if defined(ARROW_R_WITH_ARROW) -std::shared_ptr dataset___expr__less(const std::shared_ptr& lhs, const std::shared_ptr& rhs); -extern "C" SEXP _arrow_dataset___expr__less(SEXP lhs_sexp, SEXP rhs_sexp){ -BEGIN_CPP11 - arrow::r::Input&>::type lhs(lhs_sexp); - arrow::r::Input&>::type rhs(rhs_sexp); - return cpp11::as_sexp(dataset___expr__less(lhs, rhs)); -END_CPP11 -} -#else -extern "C" SEXP _arrow_dataset___expr__less(SEXP lhs_sexp, SEXP rhs_sexp){ - Rf_error("Cannot call dataset___expr__less(). Please use arrow::install_arrow() to install required runtime libraries. "); -} -#endif - -// expression.cpp -#if defined(ARROW_R_WITH_ARROW) -std::shared_ptr dataset___expr__less_equal(const std::shared_ptr& lhs, const std::shared_ptr& rhs); -extern "C" SEXP _arrow_dataset___expr__less_equal(SEXP lhs_sexp, SEXP rhs_sexp){ -BEGIN_CPP11 - arrow::r::Input&>::type lhs(lhs_sexp); - arrow::r::Input&>::type rhs(rhs_sexp); - return cpp11::as_sexp(dataset___expr__less_equal(lhs, rhs)); -END_CPP11 -} -#else -extern "C" SEXP _arrow_dataset___expr__less_equal(SEXP lhs_sexp, SEXP rhs_sexp){ - Rf_error("Cannot call dataset___expr__less_equal(). Please use arrow::install_arrow() to install required runtime libraries. "); -} -#endif - -// expression.cpp -#if defined(ARROW_R_WITH_ARROW) -std::shared_ptr dataset___expr__in(const std::shared_ptr& lhs, const std::shared_ptr& rhs); -extern "C" SEXP _arrow_dataset___expr__in(SEXP lhs_sexp, SEXP rhs_sexp){ -BEGIN_CPP11 - arrow::r::Input&>::type lhs(lhs_sexp); - arrow::r::Input&>::type rhs(rhs_sexp); - return cpp11::as_sexp(dataset___expr__in(lhs, rhs)); -END_CPP11 -} -#else -extern "C" SEXP _arrow_dataset___expr__in(SEXP lhs_sexp, SEXP rhs_sexp){ - Rf_error("Cannot call dataset___expr__in(). Please use arrow::install_arrow() to install required runtime libraries. "); -} -#endif - -// expression.cpp -#if defined(ARROW_R_WITH_ARROW) -std::shared_ptr dataset___expr__and(const std::shared_ptr& lhs, const std::shared_ptr& rhs); -extern "C" SEXP _arrow_dataset___expr__and(SEXP lhs_sexp, SEXP rhs_sexp){ -BEGIN_CPP11 - arrow::r::Input&>::type lhs(lhs_sexp); - arrow::r::Input&>::type rhs(rhs_sexp); - return cpp11::as_sexp(dataset___expr__and(lhs, rhs)); -END_CPP11 -} -#else -extern "C" SEXP _arrow_dataset___expr__and(SEXP lhs_sexp, SEXP rhs_sexp){ - Rf_error("Cannot call dataset___expr__and(). Please use arrow::install_arrow() to install required runtime libraries. "); -} -#endif - -// expression.cpp -#if defined(ARROW_R_WITH_ARROW) -std::shared_ptr dataset___expr__or(const std::shared_ptr& lhs, const std::shared_ptr& rhs); -extern "C" SEXP _arrow_dataset___expr__or(SEXP lhs_sexp, SEXP rhs_sexp){ -BEGIN_CPP11 - arrow::r::Input&>::type lhs(lhs_sexp); - arrow::r::Input&>::type rhs(rhs_sexp); - return cpp11::as_sexp(dataset___expr__or(lhs, rhs)); -END_CPP11 -} -#else -extern "C" SEXP _arrow_dataset___expr__or(SEXP lhs_sexp, SEXP rhs_sexp){ - Rf_error("Cannot call dataset___expr__or(). Please use arrow::install_arrow() to install required runtime libraries. "); -} -#endif - -// expression.cpp -#if defined(ARROW_R_WITH_ARROW) -std::shared_ptr dataset___expr__not(const std::shared_ptr& lhs); -extern "C" SEXP _arrow_dataset___expr__not(SEXP lhs_sexp){ -BEGIN_CPP11 - arrow::r::Input&>::type lhs(lhs_sexp); - return cpp11::as_sexp(dataset___expr__not(lhs)); + arrow::r::Input::type func_name(func_name_sexp); + arrow::r::Input::type argument_list(argument_list_sexp); + arrow::r::Input::type options(options_sexp); + return cpp11::as_sexp(dataset___expr__call(func_name, argument_list, options)); END_CPP11 } #else -extern "C" SEXP _arrow_dataset___expr__not(SEXP lhs_sexp){ - Rf_error("Cannot call dataset___expr__not(). Please use arrow::install_arrow() to install required runtime libraries. "); +extern "C" SEXP _arrow_dataset___expr__call(SEXP func_name_sexp, SEXP argument_list_sexp, SEXP options_sexp){ + Rf_error("Cannot call dataset___expr__call(). Please use arrow::install_arrow() to install required runtime libraries. "); } #endif // expression.cpp #if defined(ARROW_R_WITH_ARROW) -std::shared_ptr dataset___expr__is_valid(const std::shared_ptr& lhs); -extern "C" SEXP _arrow_dataset___expr__is_valid(SEXP lhs_sexp){ +std::shared_ptr dataset___expr__field_ref(std::string name); +extern "C" SEXP _arrow_dataset___expr__field_ref(SEXP name_sexp){ BEGIN_CPP11 - arrow::r::Input&>::type lhs(lhs_sexp); - return cpp11::as_sexp(dataset___expr__is_valid(lhs)); + arrow::r::Input::type name(name_sexp); + return cpp11::as_sexp(dataset___expr__field_ref(name)); END_CPP11 } #else -extern "C" SEXP _arrow_dataset___expr__is_valid(SEXP lhs_sexp){ - Rf_error("Cannot call dataset___expr__is_valid(). Please use arrow::install_arrow() to install required runtime libraries. "); +extern "C" SEXP _arrow_dataset___expr__field_ref(SEXP name_sexp){ + Rf_error("Cannot call dataset___expr__field_ref(). Please use arrow::install_arrow() to install required runtime libraries. "); } #endif @@ -5685,22 +5528,6 @@ extern "C" SEXP _arrow_Scalar__ToString(SEXP s_sexp){ } #endif -// scalar.cpp -#if defined(ARROW_R_WITH_ARROW) -std::shared_ptr Scalar__CastTo(const std::shared_ptr& s, const std::shared_ptr& t); -extern "C" SEXP _arrow_Scalar__CastTo(SEXP s_sexp, SEXP t_sexp){ -BEGIN_CPP11 - arrow::r::Input&>::type s(s_sexp); - arrow::r::Input&>::type t(t_sexp); - return cpp11::as_sexp(Scalar__CastTo(s, t)); -END_CPP11 -} -#else -extern "C" SEXP _arrow_Scalar__CastTo(SEXP s_sexp, SEXP t_sexp){ - Rf_error("Cannot call Scalar__CastTo(). Please use arrow::install_arrow() to install required runtime libraries. "); -} -#endif - // scalar.cpp #if defined(ARROW_R_WITH_ARROW) std::shared_ptr StructScalar__field(const std::shared_ptr& s, int i); @@ -6590,18 +6417,8 @@ static const R_CallMethodDef CallEntries[] = { { "_arrow_FixedSizeListType__value_field", (DL_FUNC) &_arrow_FixedSizeListType__value_field, 1}, { "_arrow_FixedSizeListType__value_type", (DL_FUNC) &_arrow_FixedSizeListType__value_type, 1}, { "_arrow_FixedSizeListType__list_size", (DL_FUNC) &_arrow_FixedSizeListType__list_size, 1}, + { "_arrow_dataset___expr__call", (DL_FUNC) &_arrow_dataset___expr__call, 3}, { "_arrow_dataset___expr__field_ref", (DL_FUNC) &_arrow_dataset___expr__field_ref, 1}, - { "_arrow_dataset___expr__equal", (DL_FUNC) &_arrow_dataset___expr__equal, 2}, - { "_arrow_dataset___expr__not_equal", (DL_FUNC) &_arrow_dataset___expr__not_equal, 2}, - { "_arrow_dataset___expr__greater", (DL_FUNC) &_arrow_dataset___expr__greater, 2}, - { "_arrow_dataset___expr__greater_equal", (DL_FUNC) &_arrow_dataset___expr__greater_equal, 2}, - { "_arrow_dataset___expr__less", (DL_FUNC) &_arrow_dataset___expr__less, 2}, - { "_arrow_dataset___expr__less_equal", (DL_FUNC) &_arrow_dataset___expr__less_equal, 2}, - { "_arrow_dataset___expr__in", (DL_FUNC) &_arrow_dataset___expr__in, 2}, - { "_arrow_dataset___expr__and", (DL_FUNC) &_arrow_dataset___expr__and, 2}, - { "_arrow_dataset___expr__or", (DL_FUNC) &_arrow_dataset___expr__or, 2}, - { "_arrow_dataset___expr__not", (DL_FUNC) &_arrow_dataset___expr__not, 1}, - { "_arrow_dataset___expr__is_valid", (DL_FUNC) &_arrow_dataset___expr__is_valid, 1}, { "_arrow_dataset___expr__scalar", (DL_FUNC) &_arrow_dataset___expr__scalar, 1}, { "_arrow_dataset___expr__ToString", (DL_FUNC) &_arrow_dataset___expr__ToString, 1}, { "_arrow_ipc___WriteFeather__Table", (DL_FUNC) &_arrow_ipc___WriteFeather__Table, 6}, @@ -6768,7 +6585,6 @@ static const R_CallMethodDef CallEntries[] = { { "_arrow_ipc___RecordBatchStreamWriter__Open", (DL_FUNC) &_arrow_ipc___RecordBatchStreamWriter__Open, 4}, { "_arrow_Array__GetScalar", (DL_FUNC) &_arrow_Array__GetScalar, 2}, { "_arrow_Scalar__ToString", (DL_FUNC) &_arrow_Scalar__ToString, 1}, - { "_arrow_Scalar__CastTo", (DL_FUNC) &_arrow_Scalar__CastTo, 2}, { "_arrow_StructScalar__field", (DL_FUNC) &_arrow_StructScalar__field, 2}, { "_arrow_StructScalar__GetFieldByName", (DL_FUNC) &_arrow_StructScalar__GetFieldByName, 2}, { "_arrow_Scalar__as_vector", (DL_FUNC) &_arrow_Scalar__as_vector, 1}, diff --git a/r/src/compute.cpp b/r/src/compute.cpp index a456ec4711b..4497f5b59a3 100644 --- a/r/src/compute.cpp +++ b/r/src/compute.cpp @@ -126,7 +126,6 @@ arrow::Datum as_cpp(SEXP x) { // This assumes that R objects have already been converted to Arrow objects; // that seems right but should we do the wrapping here too/instead? cpp11::stop("to_datum: Not implemented for type %s", Rf_type2char(TYPEOF(x))); - return arrow::Datum(); } } // namespace cpp11 @@ -151,9 +150,7 @@ SEXP from_datum(arrow::Datum datum) { break; } - auto str = datum.ToString(); - cpp11::stop("from_datum: Not implemented for Datum %s", str.c_str()); - return R_NilValue; + cpp11::stop("from_datum: Not implemented for Datum %s", datum.ToString().c_str()); } std::shared_ptr make_compute_options( @@ -182,6 +179,39 @@ std::shared_ptr make_compute_options( return out; } + if (func_name == "is_in" || func_name == "index_in") { + using Options = arrow::compute::SetLookupOptions; + return std::make_shared(cpp11::as_cpp(options["value_set"]), + cpp11::as_cpp(options["skip_nulls"])); + } + + // hacky attempt to pass through to_type and other options + if (func_name == "cast") { + using Options = arrow::compute::CastOptions; + auto out = std::make_shared(true); + SEXP to_type = options["to_type"]; + if (!Rf_isNull(to_type) && cpp11::as_cpp>(to_type)) { + out->to_type = cpp11::as_cpp>(to_type); + } + + SEXP allow_float_truncate = options["allow_float_truncate"]; + if (!Rf_isNull(allow_float_truncate) && cpp11::as_cpp(allow_float_truncate)) { + out->allow_float_truncate = cpp11::as_cpp(allow_float_truncate); + } + + SEXP allow_time_truncate = options["allow_time_truncate"]; + if (!Rf_isNull(allow_time_truncate) && cpp11::as_cpp(allow_time_truncate)) { + out->allow_time_truncate = cpp11::as_cpp(allow_time_truncate); + } + + SEXP allow_int_overflow = options["allow_int_overflow"]; + if (!Rf_isNull(allow_int_overflow) && cpp11::as_cpp(allow_int_overflow)) { + out->allow_int_overflow = cpp11::as_cpp(allow_int_overflow); + } + + return out; + } + return nullptr; } @@ -191,7 +221,7 @@ SEXP compute__CallFunction(std::string func_name, cpp11::list args, cpp11::list auto datum_args = arrow::r::from_r_list(args); auto out = ValueOrStop( arrow::compute::CallFunction(func_name, datum_args, opts.get(), gc_context())); - return from_datum(out); + return from_datum(std::move(out)); } #endif diff --git a/r/src/dataset.cpp b/r/src/dataset.cpp index 2ad59677eb0..8d8aa9cff6d 100644 --- a/r/src/dataset.cpp +++ b/r/src/dataset.cpp @@ -312,10 +312,7 @@ void dataset___ScannerBuilder__Project(const std::shared_ptr // [[arrow::export]] void dataset___ScannerBuilder__Filter(const std::shared_ptr& sb, const std::shared_ptr& expr) { - // Expressions converted from R's expressions are typed with R's native type, - // i.e. double, int64_t and bool. - auto cast_filter = ValueOrStop(InsertImplicitCasts(*expr, *sb->schema())); - StopIfNotOk(sb->Filter(cast_filter)); + StopIfNotOk(sb->Filter(*expr)); } // [[arrow::export]] diff --git a/r/src/expression.cpp b/r/src/expression.cpp index 575dddc7da7..ddb1e72c309 100644 --- a/r/src/expression.cpp +++ b/r/src/expression.cpp @@ -19,93 +19,38 @@ #if defined(ARROW_R_WITH_ARROW) +#include #include namespace ds = ::arrow::dataset; -// [[arrow::export]] -std::shared_ptr dataset___expr__field_ref(std::string name) { - return ds::field_ref(std::move(name)); -} +std::shared_ptr make_compute_options( + std::string func_name, cpp11::list options); // [[arrow::export]] -std::shared_ptr dataset___expr__equal( - const std::shared_ptr& lhs, - const std::shared_ptr& rhs) { - return ds::equal(lhs, rhs); -} +std::shared_ptr dataset___expr__call(std::string func_name, + cpp11::list argument_list, + cpp11::list options) { + std::vector arguments; + for (SEXP argument : argument_list) { + auto argument_ptr = cpp11::as_cpp>(argument); + arguments.push_back(*argument_ptr); + } -// [[arrow::export]] -std::shared_ptr dataset___expr__not_equal( - const std::shared_ptr& lhs, - const std::shared_ptr& rhs) { - return ds::not_equal(lhs, rhs); -} + auto options_ptr = make_compute_options(func_name, options); -// [[arrow::export]] -std::shared_ptr dataset___expr__greater( - const std::shared_ptr& lhs, - const std::shared_ptr& rhs) { - return ds::greater(lhs, rhs); + return std::make_shared( + ds::call(std::move(func_name), std::move(arguments), std::move(options_ptr))); } // [[arrow::export]] -std::shared_ptr dataset___expr__greater_equal( - const std::shared_ptr& lhs, - const std::shared_ptr& rhs) { - return ds::greater_equal(lhs, rhs); -} - -// [[arrow::export]] -std::shared_ptr dataset___expr__less( - const std::shared_ptr& lhs, - const std::shared_ptr& rhs) { - return ds::less(lhs, rhs); -} - -// [[arrow::export]] -std::shared_ptr dataset___expr__less_equal( - const std::shared_ptr& lhs, - const std::shared_ptr& rhs) { - return ds::less_equal(lhs, rhs); -} - -// [[arrow::export]] -std::shared_ptr dataset___expr__in( - const std::shared_ptr& lhs, - const std::shared_ptr& rhs) { - return lhs->In(rhs).Copy(); -} - -// [[arrow::export]] -std::shared_ptr dataset___expr__and( - const std::shared_ptr& lhs, - const std::shared_ptr& rhs) { - return ds::and_(lhs, rhs); -} - -// [[arrow::export]] -std::shared_ptr dataset___expr__or( - const std::shared_ptr& lhs, - const std::shared_ptr& rhs) { - return ds::or_(lhs, rhs); -} - -// [[arrow::export]] -std::shared_ptr dataset___expr__not( - const std::shared_ptr& lhs) { - return ds::not_(lhs); -} - -// [[arrow::export]] -std::shared_ptr dataset___expr__is_valid( - const std::shared_ptr& lhs) { - return lhs->IsValid().Copy(); +std::shared_ptr dataset___expr__field_ref(std::string name) { + return std::make_shared(ds::field_ref(std::move(name))); } // [[arrow::export]] std::shared_ptr dataset___expr__scalar( const std::shared_ptr& x) { - return ds::scalar(x); + return std::make_shared(ds::literal(std::move(x))); } // [[arrow::export]] diff --git a/r/src/scalar.cpp b/r/src/scalar.cpp index 2c2d291b5bf..c0cc396b02d 100644 --- a/r/src/scalar.cpp +++ b/r/src/scalar.cpp @@ -47,12 +47,6 @@ std::string Scalar__ToString(const std::shared_ptr& s) { return s->ToString(); } -// [[arrow::export]] -std::shared_ptr Scalar__CastTo(const std::shared_ptr& s, - const std::shared_ptr& t) { - return ValueOrStop(s->CastTo(t)); -} - // [[arrow::export]] std::shared_ptr StructScalar__field( const std::shared_ptr& s, int i) { diff --git a/r/tests/testthat/test-compute-arith.R b/r/tests/testthat/test-compute-arith.R new file mode 100644 index 00000000000..ffde12c4d9b --- /dev/null +++ b/r/tests/testthat/test-compute-arith.R @@ -0,0 +1,78 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +test_that("Addition", { + a <- Array$create(c(1:4, NA_integer_)) + expect_type_equal(a, int32()) + expect_type_equal(a + 4, int32()) + expect_equal(a + 4, Array$create(c(5:8, NA_integer_))) + expect_identical(as.vector(a + 4), c(5:8, NA_integer_)) + expect_equal(a + 4L, Array$create(c(5:8, NA_integer_))) + expect_vector(a + 4L, c(5:8, NA_integer_)) + expect_equal(a + NA_integer_, Array$create(rep(NA_integer_, 5))) + + # overflow errors — this is slightly different from R's `NA` coercion when + # overflowing, but better than the alternative of silently restarting + casted <- a$cast(int8()) + expect_error(casted + 127) + expect_error(casted + 200) + + skip("autocasting should happen in compute kernels; R workaround fails on this ARROW-11078") + expect_type_equal(a + 4.1, float64()) + expect_equal(a + 4.1, Array$create(c(5.1, 6.1, 7.1, 8.1, NA_real_))) +}) + +test_that("Subtraction", { + a <- Array$create(c(1:4, NA_integer_)) + expect_equal(a - 3, Array$create(c(-2:1, NA_integer_))) +}) + +test_that("Multiplication", { + a <- Array$create(c(1:4, NA_integer_)) + expect_equal(a * 2, Array$create(c(1:4 * 2L, NA_integer_))) +}) + +test_that("Division", { + a <- Array$create(c(1:4, NA_integer_)) + expect_equal(a / 2, Array$create(c(1:4 / 2, NA_real_))) + expect_equal(a %/% 2, Array$create(c(0L, 1L, 1L, 2L, NA_integer_))) + expect_equal(a / 2 / 2, Array$create(c(1:4 / 2 / 2, NA_real_))) + expect_equal(a %/% 2 %/% 2, Array$create(c(0L, 0L, 0L, 1L, NA_integer_))) + + b <- a$cast(float64()) + expect_equal(b / 2, Array$create(c(1:4 / 2, NA_real_))) + expect_equal(b %/% 2, Array$create(c(0L, 1L, 1L, 2L, NA_integer_))) + + # the behavior of %/% matches R's (i.e. the integer of the quotient, not + # simply dividing two integers) + expect_equal(b / 2.2, Array$create(c(1:4 / 2.2, NA_real_))) + # c(1:4) %/% 2.2 != c(1:4) %/% as.integer(2.2) + # c(1:4) %/% 2.2 == c(0L, 0L, 1L, 1L) + # c(1:4) %/% as.integer(2.2) == c(0L, 1L, 1L, 2L) + expect_equal(b %/% 2.2, Array$create(c(0L, 0L, 1L, 1L, NA_integer_))) + + expect_equal(a %% 2, Array$create(c(1L, 0L, 1L, 0L, NA_integer_))) + + expect_equal(b %% 2, Array$create(c(1:4 %% 2, NA_real_))) +}) + +test_that("Dates casting", { + a <- Array$create(c(Sys.Date() + 1:4, NA_integer_)) + + skip("autocasting should happen in compute kernels; R workaround fails on this ARROW-11078") + expect_equal(a + 2, Array$create(c((Sys.Date() + 1:4 ) + 2), NA_integer_)) +}) diff --git a/r/tests/testthat/test-dataset.R b/r/tests/testthat/test-dataset.R index 73d654eb5a1..62b437e439c 100644 --- a/r/tests/testthat/test-dataset.R +++ b/r/tests/testthat/test-dataset.R @@ -494,12 +494,119 @@ test_that("filter() on date32 columns", { ) }) +test_that("filter() with expressions", { + ds <- open_dataset(dataset_dir, partitioning = schema(part = uint8())) + expect_is(ds$format, "ParquetFileFormat") + expect_is(ds$filesystem, "LocalFileSystem") + expect_is(ds, "Dataset") + expect_equivalent( + ds %>% + select(chr, dbl) %>% + filter(dbl * 2 > 14 & dbl - 50 < 3L) %>% + collect() %>% + arrange(dbl), + rbind( + df1[8:10, c("chr", "dbl")], + df2[1:2, c("chr", "dbl")] + ) + ) + + # check division's special casing. + expect_equivalent( + ds %>% + select(chr, dbl) %>% + filter(dbl / 2 > 3.5 & dbl < 53) %>% + collect() %>% + arrange(dbl), + rbind( + df1[8:10, c("chr", "dbl")], + df2[1:2, c("chr", "dbl")] + ) + ) + + expect_equivalent( + ds %>% + select(chr, dbl, int) %>% + filter(int %/% 2L > 3 & dbl < 53) %>% + collect() %>% + arrange(dbl), + rbind( + df1[8:10, c("chr", "dbl", "int")], + df2[1:2, c("chr", "dbl", "int")] + ) + ) + + expect_equivalent( + ds %>% + select(chr, dbl, int) %>% + filter(int %/% 2 > 3 & dbl < 53) %>% + collect() %>% + arrange(dbl), + rbind( + df1[8:10, c("chr", "dbl", "int")], + df2[1:2, c("chr", "dbl", "int")] + ) + ) + + expect_equivalent( + ds %>% + select(chr, dbl, int) %>% + filter(int %% 2L > 0 & dbl < 53) %>% + collect() %>% + arrange(dbl), + rbind( + df1[c(1, 3, 5, 7, 9), c("chr", "dbl", "int")], + df2[1, c("chr", "dbl", "int")] + ) + ) + + expect_equivalent( + ds %>% + select(chr, dbl, int) %>% + filter(int %% 2L > 0 & dbl < 53) %>% + collect() %>% + arrange(dbl), + rbind( + df1[c(1, 3, 5, 7, 9), c("chr", "dbl", "int")], + df2[1, c("chr", "dbl", "int")] + ) + ) + + skip("Implicit casts aren't being inserted everywhere they need to be (ARROW-11080)") + # Error: NotImplemented: Function multiply_checked has no kernel matching input types (scalar[double], array[int32]) + expect_equivalent( + ds %>% + select(chr, dbl, int) %>% + filter(int %% 2 > 0 & dbl < 53) %>% + collect() %>% + arrange(dbl), + rbind( + df1[c(1, 3, 5, 7, 9), c("chr", "dbl", "int")], + df2[1, c("chr", "dbl", "int")] + ) + ) + + skip("Implicit casts are only inserted for scalars (ARROW-11080)") + # Error: NotImplemented: Function add_checked has no kernel matching input types (array[double], array[int32]) + expect_equivalent( + ds %>% + select(chr, dbl, int) %>% + filter(dbl + int > 15 & dbl < 53L) %>% + collect() %>% + arrange(dbl), + rbind( + df1[8:10, c("chr", "dbl", "int")], + df2[1:2, c("chr", "dbl", "int")] + ) + ) +}) + test_that("filter scalar validation doesn't crash (ARROW-7772)", { expect_error( ds %>% filter(int == "fff", part == 1) %>% collect(), - "error parsing 'fff' as scalar of type int32" + "Failed to parse string: 'fff' as a scalar of type int32" ) }) @@ -654,7 +761,7 @@ test_that("Dataset and query print methods", { "lgl: bool", "integer: int32", "", - "* Filter: (int == 6:double)", + "* Filter: (int == 6)", "* Grouped by lgl", "See $.data for the source Arrow object", sep = "\n" diff --git a/r/tests/testthat/test-dplyr.R b/r/tests/testthat/test-dplyr.R index 2145adcc0ee..19c0665c807 100644 --- a/r/tests/testthat/test-dplyr.R +++ b/r/tests/testthat/test-dplyr.R @@ -27,6 +27,8 @@ expect_dplyr_equal <- function(expr, # A dplyr pipeline with `input` as its star expr <- rlang::enquo(expr) expected <- rlang::eval_tidy(expr, rlang::new_data_mask(rlang::env(input = tbl))) + skip_msg <- NULL + if (is.null(skip_record_batch)) { via_batch <- rlang::eval_tidy( expr, @@ -34,7 +36,7 @@ expect_dplyr_equal <- function(expr, # A dplyr pipeline with `input` as its star ) expect_equivalent(via_batch, expected, ...) } else { - skip(skip_record_batch) + skip_msg <- c(skip_msg, skip_record_batch) } if (is.null(skip_table)) { @@ -44,7 +46,11 @@ expect_dplyr_equal <- function(expr, # A dplyr pipeline with `input` as its star ) expect_equivalent(via_table, expected, ...) } else { - skip(skip_table) + skip_msg <- c(skip_msg, skip_table) + } + + if (!is.null(skip_msg)) { + skip(paste(skip_msg, collpase = "\n")) } } @@ -133,6 +139,74 @@ test_that("filtering with expression", { ) }) +test_that("filtering with arithmetic", { + expect_dplyr_equal( + input %>% + filter(dbl + 1 > 3) %>% + select(string = chr, int, dbl) %>% + collect(), + tbl + ) + + expect_dplyr_equal( + input %>% + filter(dbl / 2 > 3) %>% + select(string = chr, int, dbl) %>% + collect(), + tbl + ) + + expect_dplyr_equal( + input %>% + filter(dbl / 2L > 3) %>% + select(string = chr, int, dbl) %>% + collect(), + tbl + ) + + expect_dplyr_equal( + input %>% + filter(int / 2 > 3) %>% + select(string = chr, int, dbl) %>% + collect(), + tbl + ) + + expect_dplyr_equal( + input %>% + filter(int / 2L > 3) %>% + select(string = chr, int, dbl) %>% + collect(), + tbl + ) + + expect_dplyr_equal( + input %>% + filter(dbl %/% 2 > 3) %>% + select(string = chr, int, dbl) %>% + collect(), + tbl + ) +}) + +test_that("filtering with expression + autocasting", { + expect_dplyr_equal( + input %>% + filter(dbl + 1 > 3L) %>% # test autocasting with comparison to 3L + select(string = chr, int, dbl) %>% + collect(), + tbl + ) + + expect_dplyr_equal( + input %>% + filter(int + 1 > 3) %>% + select(string = chr, int, dbl) %>% + collect(), + tbl + ) +}) + test_that("More complex select/filter", { expect_dplyr_equal( input %>% @@ -167,7 +241,7 @@ test_that("Print method", { int: int32 chr: string -* Filter: and(and(greater(, 2), or(equal(, "d"), equal(, "f"))), less(, 5L)) +* Filter: and(and(greater(, 2), or(equal(, "d"), equal(, "f"))), less(, 5)) See $.data for the source Arrow object', fixed = TRUE ) @@ -240,6 +314,14 @@ test_that("summarize", { summarize(min_int = min(int)), tbl ) + + expect_dplyr_equal( + input %>% + select(int, chr) %>% + filter(int > 5) %>% + summarize(min_int = min(int) / 2), + tbl + ) }) test_that("mutate", { diff --git a/r/tests/testthat/test-expression.R b/r/tests/testthat/test-expression.R index 1bf08595758..3c100812ff1 100644 --- a/r/tests/testthat/test-expression.R +++ b/r/tests/testthat/test-expression.R @@ -29,7 +29,7 @@ test_that("array_expression print method", { expect_output( print(build_array_expression(">", Array$create(1:5), 4)), # Not ideal but it is informative - "greater(, 4L)", + "greater(, 4)", fixed = TRUE ) }) @@ -60,9 +60,27 @@ test_that("C++ expressions", { expect_is(!(f < 4), "Expression") expect_output( print(f > 4), - 'Expression\n(f > 4:double)', + 'Expression\n(f > 4)', fixed = TRUE ) # Interprets that as a list type expect_is(f == c(1L, 2L), "Expression") }) + +test_that("Can create an expression", { + a <- Array$create(as.numeric(1:5)) + expr <- array_expression("cast", a, options = list(to_type = int32())) + expect_is(expr, "array_expression") + expect_equal(eval_array_expression(expr), Array$create(1:5)) + + b <- Array$create(0.5:4.5) + bad_expr <- array_expression("cast", b, options = list(to_type = int32())) + expect_is(bad_expr, "array_expression") + expect_error( + eval_array_expression(bad_expr), + "Invalid: Float value .* was truncated converting" + ) + expr <- array_expression("cast", b, options = list(to_type = int32(), allow_float_truncate = TRUE)) + expect_is(expr, "array_expression") + expect_equal(eval_array_expression(expr), Array$create(0:4)) +})