From d662edded6126ef6c47f738c96d1383024eb2ada Mon Sep 17 00:00:00 2001 From: Wes McKinney Date: Sun, 31 May 2020 22:39:00 -0500 Subject: [PATCH 1/3] Implement strptime scalar string to timestamp kernel Add a unit test for the OutputType::Resolve changes Create kernels/test_util.cc --- cpp/src/arrow/CMakeLists.txt | 2 +- cpp/src/arrow/compute/api_scalar.h | 13 +++++ cpp/src/arrow/compute/kernel.cc | 13 ++--- cpp/src/arrow/compute/kernel_test.cc | 11 ++++ cpp/src/arrow/compute/kernels/CMakeLists.txt | 8 +-- .../arrow/compute/kernels/aggregate_test.cc | 3 +- .../arrow/compute/kernels/codegen_internal.h | 12 +++++ .../compute/kernels/scalar_arithmetic_test.cc | 2 +- .../compute/kernels/scalar_boolean_test.cc | 3 +- .../arrow/compute/kernels/scalar_cast_test.cc | 2 +- .../kernels/scalar_compare_benchmark.cc | 2 +- .../compute/kernels/scalar_compare_test.cc | 2 +- .../compute/kernels/scalar_set_lookup_test.cc | 3 +- ...calar_string_ascii.cc => scalar_string.cc} | 50 +++++++++++++++++++ .../compute/kernels/scalar_string_test.cc | 25 +++++----- .../arrow/compute/{ => kernels}/test_util.h | 6 +++ .../kernels/vector_filter_benchmark.cc | 2 +- .../compute/kernels/vector_filter_test.cc | 3 +- .../arrow/compute/kernels/vector_hash_test.cc | 2 +- .../kernels/vector_partition_benchmark.cc | 2 +- .../compute/kernels/vector_sort_benchmark.cc | 2 +- .../arrow/compute/kernels/vector_sort_test.cc | 2 +- .../compute/kernels/vector_take_benchmark.cc | 2 +- .../arrow/compute/kernels/vector_take_test.cc | 3 +- 24 files changed, 133 insertions(+), 42 deletions(-) rename cpp/src/arrow/compute/kernels/{scalar_string_ascii.cc => scalar_string.cc} (55%) rename cpp/src/arrow/compute/{ => kernels}/test_util.h (91%) diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index 6436c0c9910..eef925010c8 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -346,7 +346,7 @@ if(ARROW_COMPUTE) compute/kernels/scalar_cast_temporal.cc compute/kernels/scalar_compare.cc compute/kernels/scalar_set_lookup.cc - compute/kernels/scalar_string_ascii.cc + compute/kernels/scalar_string.cc compute/kernels/vector_filter.cc compute/kernels/vector_hash.cc compute/kernels/vector_sort.cc diff --git a/cpp/src/arrow/compute/api_scalar.h b/cpp/src/arrow/compute/api_scalar.h index ae9e284b2c4..66c29f495a0 100644 --- a/cpp/src/arrow/compute/api_scalar.h +++ b/cpp/src/arrow/compute/api_scalar.h @@ -20,6 +20,8 @@ #pragma once +#include +#include #include #include "arrow/compute/exec.h" // IWYU pragma: keep @@ -200,5 +202,16 @@ ARROW_EXPORT Result Match(const Datum& values, const Datum& value_set, ExecContext* ctx = NULLPTR); +// ---------------------------------------------------------------------- +// Temporal functions + +struct ARROW_EXPORT StrptimeOptions : public FunctionOptions { + explicit StrptimeOptions(std::string format, TimeUnit::type unit) + : format(format), unit(unit) {} + + std::string format; + TimeUnit::type unit; +}; + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/kernel.cc b/cpp/src/arrow/compute/kernel.cc index 20100e42104..1a6f1bd1031 100644 --- a/cpp/src/arrow/compute/kernel.cc +++ b/cpp/src/arrow/compute/kernel.cc @@ -254,14 +254,15 @@ OutputType::OutputType(ValueDescr descr) : OutputType(descr.type) { Result OutputType::Resolve(KernelContext* ctx, const std::vector& args) const { + ValueDescr::Shape broadcasted_shape = GetBroadcastShape(args); if (kind_ == OutputType::FIXED) { - ValueDescr::Shape out_shape = shape_; - if (out_shape == ValueDescr::ANY) { - out_shape = GetBroadcastShape(args); - } - return ValueDescr(type_, out_shape); + return ValueDescr(type_, shape_ == ValueDescr::ANY ? broadcasted_shape : shape_); } else { - return resolver_(ctx, args); + ARROW_ASSIGN_OR_RAISE(ValueDescr resolved_descr, resolver_(ctx, args)); + if (resolved_descr.shape == ValueDescr::ANY) { + resolved_descr.shape = broadcasted_shape; + } + return resolved_descr; } } diff --git a/cpp/src/arrow/compute/kernel_test.cc b/cpp/src/arrow/compute/kernel_test.cc index bd5571b2fb5..fbdaf124d81 100644 --- a/cpp/src/arrow/compute/kernel_test.cc +++ b/cpp/src/arrow/compute/kernel_test.cc @@ -316,6 +316,17 @@ TEST(OutputType, Resolve) { return ValueDescr(args[0]); }); ASSERT_RAISES(Invalid, ty3.Resolve(nullptr, {})); + + // Type resolver that returns ValueDescr::ANY and needs type promotion + OutputType ty4( + [](KernelContext* ctx, const std::vector& args) -> Result { + return int32(); + }); + + ASSERT_OK_AND_ASSIGN(descr, ty4.Resolve(nullptr, {ValueDescr::Array(int8())})); + ASSERT_EQ(ValueDescr::Array(int32()), descr); + ASSERT_OK_AND_ASSIGN(descr, ty4.Resolve(nullptr, {ValueDescr::Scalar(int8())})); + ASSERT_EQ(ValueDescr::Scalar(int32()), descr); } TEST(OutputType, ResolveDescr) { diff --git a/cpp/src/arrow/compute/kernels/CMakeLists.txt b/cpp/src/arrow/compute/kernels/CMakeLists.txt index 74493a85e18..9eb23716a06 100644 --- a/cpp/src/arrow/compute/kernels/CMakeLists.txt +++ b/cpp/src/arrow/compute/kernels/CMakeLists.txt @@ -25,7 +25,8 @@ add_arrow_compute_test(scalar_test scalar_cast_test.cc scalar_compare_test.cc scalar_set_lookup_test.cc - scalar_string_test.cc) + scalar_string_test.cc + test_util.cc) add_arrow_benchmark(scalar_compare_benchmark PREFIX "arrow-compute") @@ -37,7 +38,8 @@ add_arrow_compute_test(vector_test vector_filter_test.cc vector_hash_test.cc vector_take_test.cc - vector_sort_test.cc) + vector_sort_test.cc + test_util.cc) add_arrow_benchmark(vector_hash_benchmark PREFIX "arrow-compute") add_arrow_benchmark(vector_sort_benchmark PREFIX "arrow-compute") @@ -50,5 +52,5 @@ add_arrow_benchmark(vector_take_benchmark PREFIX "arrow-compute") # Aggregates -add_arrow_compute_test(aggregate_test) +add_arrow_compute_test(aggregate_test SOURCES aggregate_test.cc test_util.cc) add_arrow_benchmark(aggregate_benchmark PREFIX "arrow-compute") diff --git a/cpp/src/arrow/compute/kernels/aggregate_test.cc b/cpp/src/arrow/compute/kernels/aggregate_test.cc index 5e5881d47d1..c908f4777aa 100644 --- a/cpp/src/arrow/compute/kernels/aggregate_test.cc +++ b/cpp/src/arrow/compute/kernels/aggregate_test.cc @@ -25,8 +25,7 @@ #include "arrow/array.h" #include "arrow/compute/api_aggregate.h" #include "arrow/compute/kernels/aggregate_internal.h" -#include "arrow/compute/test_util.h" -#include "arrow/table.h" +#include "arrow/compute/kernels/test_util.h" #include "arrow/type.h" #include "arrow/type_traits.h" #include "arrow/util/checked_cast.h" diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.h b/cpp/src/arrow/compute/kernels/codegen_internal.h index 6232c533e3f..77774caa787 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal.h +++ b/cpp/src/arrow/compute/kernels/codegen_internal.h @@ -73,6 +73,18 @@ namespace compute { #endif // ARROW_EXTRA_ERROR_CONTEXT +template +struct OptionsWrapper : public KernelState { + OptionsWrapper(const OptionsType& options) : options(options) {} + OptionsType options; +}; + +template +std::unique_ptr InitWrapOptions(KernelContext*, const KernelInitArgs& args) { + return std::unique_ptr( + new OptionsWrapper(*static_cast(args.options))); +} + // ---------------------------------------------------------------------- // Iteration / value access utilities diff --git a/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc b/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc index b94a9a94d8e..017c9f5f034 100644 --- a/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc @@ -25,7 +25,7 @@ #include "arrow/array.h" #include "arrow/compute/api.h" -#include "arrow/compute/test_util.h" +#include "arrow/compute/kernels/test_util.h" #include "arrow/type.h" #include "arrow/type_traits.h" #include "arrow/util/checked_cast.h" diff --git a/cpp/src/arrow/compute/kernels/scalar_boolean_test.cc b/cpp/src/arrow/compute/kernels/scalar_boolean_test.cc index a850026357c..e96b2ddcc26 100644 --- a/cpp/src/arrow/compute/kernels/scalar_boolean_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_boolean_test.cc @@ -23,8 +23,7 @@ #include #include "arrow/compute/api_scalar.h" -#include "arrow/compute/test_util.h" -#include "arrow/table.h" +#include "arrow/compute/kernels/test_util.h" #include "arrow/testing/gtest_common.h" #include "arrow/testing/gtest_util.h" diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_test.cc b/cpp/src/arrow/compute/kernels/scalar_cast_test.cc index 89ca3667684..87a59848426 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_test.cc @@ -43,7 +43,7 @@ #include "arrow/compute/api_vector.h" #include "arrow/compute/cast.h" #include "arrow/compute/kernel.h" -#include "arrow/compute/test_util.h" +#include "arrow/compute/kernels/test_util.h" namespace arrow { namespace compute { diff --git a/cpp/src/arrow/compute/kernels/scalar_compare_benchmark.cc b/cpp/src/arrow/compute/kernels/scalar_compare_benchmark.cc index 90b6c276df6..136223f5b6f 100644 --- a/cpp/src/arrow/compute/kernels/scalar_compare_benchmark.cc +++ b/cpp/src/arrow/compute/kernels/scalar_compare_benchmark.cc @@ -21,7 +21,7 @@ #include "arrow/compute/api_scalar.h" #include "arrow/compute/benchmark_util.h" -#include "arrow/compute/test_util.h" +#include "arrow/compute/kernels/test_util.h" #include "arrow/testing/gtest_util.h" #include "arrow/testing/random.h" diff --git a/cpp/src/arrow/compute/kernels/scalar_compare_test.cc b/cpp/src/arrow/compute/kernels/scalar_compare_test.cc index 72b51dbb581..df4306d94b0 100644 --- a/cpp/src/arrow/compute/kernels/scalar_compare_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_compare_test.cc @@ -26,7 +26,7 @@ #include "arrow/array.h" #include "arrow/compute/api.h" -#include "arrow/compute/test_util.h" +#include "arrow/compute/kernels/test_util.h" #include "arrow/type.h" #include "arrow/type_traits.h" #include "arrow/util/checked_cast.h" diff --git a/cpp/src/arrow/compute/kernels/scalar_set_lookup_test.cc b/cpp/src/arrow/compute/kernels/scalar_set_lookup_test.cc index dd8241dad46..137191b875c 100644 --- a/cpp/src/arrow/compute/kernels/scalar_set_lookup_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_set_lookup_test.cc @@ -32,8 +32,9 @@ #include "arrow/array/builder_binary.h" #include "arrow/array/builder_primitive.h" #include "arrow/compute/api.h" -#include "arrow/compute/test_util.h" #include "arrow/result.h" +#include "arrow/memory_pool.h" +#include "arrow/compute/kernels/test_util.h" #include "arrow/status.h" #include "arrow/table.h" #include "arrow/testing/gtest_compat.h" diff --git a/cpp/src/arrow/compute/kernels/scalar_string_ascii.cc b/cpp/src/arrow/compute/kernels/scalar_string.cc similarity index 55% rename from cpp/src/arrow/compute/kernels/scalar_string_ascii.cc rename to cpp/src/arrow/compute/kernels/scalar_string.cc index 19eaf84016f..efb54e79e98 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string_ascii.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string.cc @@ -19,8 +19,10 @@ #include #include +#include "arrow/compute/api_scalar.h" #include "arrow/compute/kernels/common.h" #include "arrow/compute/kernels/scalar_string_internal.h" +#include "arrow/util/value_parsing.h" namespace arrow { namespace compute { @@ -57,9 +59,57 @@ void AddAsciiLength(FunctionRegistry* registry) { DCHECK_OK(registry->AddFunction(std::move(func))); } +// ---------------------------------------------------------------------- +// strptime string parsing + +using StrptimeWrapper = OptionsWrapper; + +struct ParseStrptime { + explicit ParseStrptime(const StrptimeOptions& options) + : parser(TimestampParser::MakeStrptime(options.format)), unit(options.unit) {} + + template + int64_t Call(KernelContext* ctx, util::string_view val) const { + int64_t result = 0; + if (!(*parser)(val.data(), val.size(), unit, &result)) { + ctx->SetStatus(Status::Invalid("Failed to parse string ", val)); + } + return result; + } + + std::shared_ptr parser; + TimeUnit::type unit; +}; + +template +void StrptimeExec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + const auto& options = checked_cast(ctx->state())->options; + codegen::ScalarUnaryNotNullStateful kernel = + ParseStrptime(options); + return kernel.Exec(ctx, batch, out); +} + +Result StrptimeResolve(KernelContext* ctx, const std::vector&) { + const auto& options = checked_cast(ctx->state())->options; + return ::arrow::timestamp(options.unit); +} + +void AddStrptime(FunctionRegistry* registry) { + auto func = std::make_shared("strptime", Arity::Unary()); + DCHECK_OK(func->AddKernel({utf8()}, OutputType(StrptimeResolve), + StrptimeExec, InitWrapOptions)); + DCHECK_OK(func->AddKernel({large_utf8()}, OutputType(StrptimeResolve), + StrptimeExec, + InitWrapOptions)); + DCHECK_OK(registry->AddFunction(std::move(func))); +} + +// ---------------------------------------------------------------------- + void RegisterScalarStringAscii(FunctionRegistry* registry) { MakeUnaryStringToString("ascii_upper", registry); AddAsciiLength(registry); + AddStrptime(registry); } } // namespace internal diff --git a/cpp/src/arrow/compute/kernels/scalar_string_test.cc b/cpp/src/arrow/compute/kernels/scalar_string_test.cc index fba9a21e786..5a4c2c0e219 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string_test.cc @@ -20,6 +20,7 @@ #include #include "arrow/compute/api_scalar.h" +#include "arrow/compute/kernels/test_util.h" #include "arrow/testing/gtest_util.h" namespace arrow { @@ -33,19 +34,10 @@ class TestStringKernels : public ::testing::Test { using OffsetType = typename TypeTraits::OffsetType; void CheckUnary(std::string func_name, std::string json_input, - std::shared_ptr out_ty, std::string json_expected) { - auto input = ArrayFromJSON(string_type(), json_input); - auto expected = ArrayFromJSON(out_ty, json_expected); - ASSERT_OK_AND_ASSIGN(Datum out, CallFunction(func_name, {input})); - AssertArraysEqual(*expected, *out.make_array(), /*verbose=*/true); - - // Check all the scalars - for (int64_t i = 0; i < input->length(); ++i) { - ASSERT_OK_AND_ASSIGN(auto val, input->GetScalar(i)); - ASSERT_OK_AND_ASSIGN(auto ex_val, expected->GetScalar(i)); - ASSERT_OK_AND_ASSIGN(Datum out, CallFunction(func_name, {val})); - AssertScalarsEqual(*ex_val, *out.scalar(), /*verbose=*/true); - } + std::shared_ptr out_ty, std::string json_expected, + const FunctionOptions* options = nullptr) { + CheckScalarUnary(func_name, string_type(), json_input, out_ty, json_expected, + options); } std::shared_ptr string_type() { @@ -69,5 +61,12 @@ TYPED_TEST(TestStringKernels, AsciiUpper) { "[\"AAA&\", null, \"\", \"B\"]"); } +TYPED_TEST(TestStringKernels, Strptime) { + std::string input1 = R"(["5/1/2020", null, "12/11/1900"])"; + std::string output1 = R"(["2020-05-01", null, "1900-12-11"])"; + StrptimeOptions options("%m/%d/%Y", TimeUnit::MICRO); + this->CheckUnary("strptime", input1, timestamp(TimeUnit::MICRO), output1, &options); +} + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/test_util.h b/cpp/src/arrow/compute/kernels/test_util.h similarity index 91% rename from cpp/src/arrow/compute/test_util.h rename to cpp/src/arrow/compute/kernels/test_util.h index c7623c107d6..88c3c3f4485 100644 --- a/cpp/src/arrow/compute/test_util.h +++ b/cpp/src/arrow/compute/kernels/test_util.h @@ -20,6 +20,7 @@ // IWYU pragma: begin_exports #include +#include #include #include @@ -86,6 +87,11 @@ struct DatumEqual> { } }; +void CheckScalarUnary(std::string func_name, std::shared_ptr in_ty, + std::string json_input, std::shared_ptr out_ty, + std::string json_expected, + const FunctionOptions* options = nullptr); + using TestingStringTypes = ::testing::Types; diff --git a/cpp/src/arrow/compute/kernels/vector_filter_benchmark.cc b/cpp/src/arrow/compute/kernels/vector_filter_benchmark.cc index 78c6b4afa23..85bb58d6b96 100644 --- a/cpp/src/arrow/compute/kernels/vector_filter_benchmark.cc +++ b/cpp/src/arrow/compute/kernels/vector_filter_benchmark.cc @@ -19,7 +19,7 @@ #include "arrow/compute/api_vector.h" #include "arrow/compute/benchmark_util.h" -#include "arrow/compute/test_util.h" +#include "arrow/compute/kernels/test_util.h" #include "arrow/testing/gtest_util.h" #include "arrow/testing/random.h" diff --git a/cpp/src/arrow/compute/kernels/vector_filter_test.cc b/cpp/src/arrow/compute/kernels/vector_filter_test.cc index 6a7c237d02f..8f0d5bd99f1 100644 --- a/cpp/src/arrow/compute/kernels/vector_filter_test.cc +++ b/cpp/src/arrow/compute/kernels/vector_filter_test.cc @@ -21,8 +21,7 @@ #include #include "arrow/compute/api.h" -#include "arrow/compute/test_util.h" -#include "arrow/table.h" +#include "arrow/compute/kernels/test_util.h" #include "arrow/testing/gtest_common.h" #include "arrow/testing/gtest_util.h" #include "arrow/testing/random.h" diff --git a/cpp/src/arrow/compute/kernels/vector_hash_test.cc b/cpp/src/arrow/compute/kernels/vector_hash_test.cc index aab914056f9..c7b584fea19 100644 --- a/cpp/src/arrow/compute/kernels/vector_hash_test.cc +++ b/cpp/src/arrow/compute/kernels/vector_hash_test.cc @@ -41,7 +41,7 @@ #include "arrow/util/decimal.h" #include "arrow/compute/api.h" -#include "arrow/compute/test_util.h" +#include "arrow/compute/kernels/test_util.h" #include "arrow/ipc/json_simple.h" diff --git a/cpp/src/arrow/compute/kernels/vector_partition_benchmark.cc b/cpp/src/arrow/compute/kernels/vector_partition_benchmark.cc index 03533f4d1df..e76b27146f2 100644 --- a/cpp/src/arrow/compute/kernels/vector_partition_benchmark.cc +++ b/cpp/src/arrow/compute/kernels/vector_partition_benchmark.cc @@ -19,7 +19,7 @@ #include "arrow/compute/api_vector.h" #include "arrow/compute/benchmark_util.h" -#include "arrow/compute/test_util.h" +#include "arrow/compute/kernels/test_util.h" #include "arrow/testing/gtest_util.h" #include "arrow/testing/random.h" diff --git a/cpp/src/arrow/compute/kernels/vector_sort_benchmark.cc b/cpp/src/arrow/compute/kernels/vector_sort_benchmark.cc index ee8a3119c21..344de258ccd 100644 --- a/cpp/src/arrow/compute/kernels/vector_sort_benchmark.cc +++ b/cpp/src/arrow/compute/kernels/vector_sort_benchmark.cc @@ -19,7 +19,7 @@ #include "arrow/compute/api_vector.h" #include "arrow/compute/benchmark_util.h" -#include "arrow/compute/test_util.h" +#include "arrow/compute/kernels/test_util.h" #include "arrow/testing/gtest_util.h" #include "arrow/testing/random.h" diff --git a/cpp/src/arrow/compute/kernels/vector_sort_test.cc b/cpp/src/arrow/compute/kernels/vector_sort_test.cc index 95328c285ab..74963cae9fd 100644 --- a/cpp/src/arrow/compute/kernels/vector_sort_test.cc +++ b/cpp/src/arrow/compute/kernels/vector_sort_test.cc @@ -21,7 +21,7 @@ #include #include "arrow/compute/api_vector.h" -#include "arrow/compute/test_util.h" +#include "arrow/compute/kernels/test_util.h" #include "arrow/testing/gtest_common.h" #include "arrow/testing/gtest_util.h" #include "arrow/testing/random.h" diff --git a/cpp/src/arrow/compute/kernels/vector_take_benchmark.cc b/cpp/src/arrow/compute/kernels/vector_take_benchmark.cc index 00b0a7bbd3a..184eed31e8f 100644 --- a/cpp/src/arrow/compute/kernels/vector_take_benchmark.cc +++ b/cpp/src/arrow/compute/kernels/vector_take_benchmark.cc @@ -20,7 +20,7 @@ #include "arrow/compute/api.h" #include "arrow/compute/benchmark_util.h" -#include "arrow/compute/test_util.h" +#include "arrow/compute/kernels/test_util.h" #include "arrow/testing/gtest_util.h" #include "arrow/testing/random.h" diff --git a/cpp/src/arrow/compute/kernels/vector_take_test.cc b/cpp/src/arrow/compute/kernels/vector_take_test.cc index f207058a549..9024c940fb0 100644 --- a/cpp/src/arrow/compute/kernels/vector_take_test.cc +++ b/cpp/src/arrow/compute/kernels/vector_take_test.cc @@ -22,8 +22,7 @@ #include #include "arrow/compute/api.h" -#include "arrow/compute/test_util.h" -#include "arrow/table.h" +#include "arrow/compute/kernels/test_util.h" #include "arrow/testing/gtest_common.h" #include "arrow/testing/gtest_util.h" #include "arrow/testing/random.h" From 8d071d53cc7f3ab3d7810e590f1ff2cf5af5b64d Mon Sep 17 00:00:00 2001 From: Wes McKinney Date: Tue, 2 Jun 2020 10:05:06 -0500 Subject: [PATCH 2/3] Post-rebase fixes --- cpp/src/arrow/compute/kernels/aggregate_test.cc | 1 + cpp/src/arrow/compute/kernels/scalar_boolean_test.cc | 1 + cpp/src/arrow/compute/kernels/scalar_set_lookup_test.cc | 4 ++-- cpp/src/arrow/compute/kernels/vector_filter_test.cc | 1 + cpp/src/arrow/compute/kernels/vector_take_test.cc | 1 + 5 files changed, 6 insertions(+), 2 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/aggregate_test.cc b/cpp/src/arrow/compute/kernels/aggregate_test.cc index c908f4777aa..9f92da9dd69 100644 --- a/cpp/src/arrow/compute/kernels/aggregate_test.cc +++ b/cpp/src/arrow/compute/kernels/aggregate_test.cc @@ -26,6 +26,7 @@ #include "arrow/compute/api_aggregate.h" #include "arrow/compute/kernels/aggregate_internal.h" #include "arrow/compute/kernels/test_util.h" +#include "arrow/table.h" #include "arrow/type.h" #include "arrow/type_traits.h" #include "arrow/util/checked_cast.h" diff --git a/cpp/src/arrow/compute/kernels/scalar_boolean_test.cc b/cpp/src/arrow/compute/kernels/scalar_boolean_test.cc index e96b2ddcc26..cb3fdb06014 100644 --- a/cpp/src/arrow/compute/kernels/scalar_boolean_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_boolean_test.cc @@ -24,6 +24,7 @@ #include "arrow/compute/api_scalar.h" #include "arrow/compute/kernels/test_util.h" +#include "arrow/table.h" #include "arrow/testing/gtest_common.h" #include "arrow/testing/gtest_util.h" diff --git a/cpp/src/arrow/compute/kernels/scalar_set_lookup_test.cc b/cpp/src/arrow/compute/kernels/scalar_set_lookup_test.cc index 137191b875c..3350d29a9e7 100644 --- a/cpp/src/arrow/compute/kernels/scalar_set_lookup_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_set_lookup_test.cc @@ -32,9 +32,9 @@ #include "arrow/array/builder_binary.h" #include "arrow/array/builder_primitive.h" #include "arrow/compute/api.h" -#include "arrow/result.h" -#include "arrow/memory_pool.h" #include "arrow/compute/kernels/test_util.h" +#include "arrow/memory_pool.h" +#include "arrow/result.h" #include "arrow/status.h" #include "arrow/table.h" #include "arrow/testing/gtest_compat.h" diff --git a/cpp/src/arrow/compute/kernels/vector_filter_test.cc b/cpp/src/arrow/compute/kernels/vector_filter_test.cc index 8f0d5bd99f1..a835417dd0f 100644 --- a/cpp/src/arrow/compute/kernels/vector_filter_test.cc +++ b/cpp/src/arrow/compute/kernels/vector_filter_test.cc @@ -22,6 +22,7 @@ #include "arrow/compute/api.h" #include "arrow/compute/kernels/test_util.h" +#include "arrow/table.h" #include "arrow/testing/gtest_common.h" #include "arrow/testing/gtest_util.h" #include "arrow/testing/random.h" diff --git a/cpp/src/arrow/compute/kernels/vector_take_test.cc b/cpp/src/arrow/compute/kernels/vector_take_test.cc index 9024c940fb0..1c3a19851c9 100644 --- a/cpp/src/arrow/compute/kernels/vector_take_test.cc +++ b/cpp/src/arrow/compute/kernels/vector_take_test.cc @@ -23,6 +23,7 @@ #include "arrow/compute/api.h" #include "arrow/compute/kernels/test_util.h" +#include "arrow/table.h" #include "arrow/testing/gtest_common.h" #include "arrow/testing/gtest_util.h" #include "arrow/testing/random.h" From 0cfeefa994088ae484504133720b68973e123240 Mon Sep 17 00:00:00 2001 From: Wes McKinney Date: Tue, 2 Jun 2020 10:05:56 -0500 Subject: [PATCH 3/3] Check in missing file --- cpp/src/arrow/compute/kernels/test_util.cc | 51 ++++++++++++++++++++++ 1 file changed, 51 insertions(+) create mode 100644 cpp/src/arrow/compute/kernels/test_util.cc diff --git a/cpp/src/arrow/compute/kernels/test_util.cc b/cpp/src/arrow/compute/kernels/test_util.cc new file mode 100644 index 00000000000..49b8bcec7b2 --- /dev/null +++ b/cpp/src/arrow/compute/kernels/test_util.cc @@ -0,0 +1,51 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/compute/kernels/test_util.h" + +#include +#include +#include + +#include "arrow/array.h" +#include "arrow/compute/exec.h" +#include "arrow/datum.h" +#include "arrow/result.h" +#include "arrow/testing/gtest_util.h" + +namespace arrow { +namespace compute { + +void CheckScalarUnary(std::string func_name, std::shared_ptr in_ty, + std::string json_input, std::shared_ptr out_ty, + std::string json_expected, const FunctionOptions* options) { + auto input = ArrayFromJSON(in_ty, json_input); + auto expected = ArrayFromJSON(out_ty, json_expected); + ASSERT_OK_AND_ASSIGN(Datum out, CallFunction(func_name, {input}, options)); + AssertArraysEqual(*expected, *out.make_array(), /*verbose=*/true); + + // Check all the scalars + for (int64_t i = 0; i < input->length(); ++i) { + ASSERT_OK_AND_ASSIGN(auto val, input->GetScalar(i)); + ASSERT_OK_AND_ASSIGN(auto ex_val, expected->GetScalar(i)); + ASSERT_OK_AND_ASSIGN(Datum out, CallFunction(func_name, {val}, options)); + AssertScalarsEqual(*ex_val, *out.scalar(), /*verbose=*/true); + } +} + +} // namespace compute +} // namespace arrow