Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cpp/src/arrow/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ if(ARROW_COMPUTE)
compute/kernels/scalar_cast_temporal.cc
compute/kernels/scalar_compare.cc
compute/kernels/scalar_set_lookup.cc
compute/kernels/scalar_string_ascii.cc
compute/kernels/scalar_string.cc
compute/kernels/vector_filter.cc
compute/kernels/vector_hash.cc
compute/kernels/vector_sort.cc
Expand Down
13 changes: 13 additions & 0 deletions cpp/src/arrow/compute/api_scalar.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

#pragma once

#include <memory>
#include <string>
#include <utility>

#include "arrow/compute/exec.h" // IWYU pragma: keep
Expand Down Expand Up @@ -200,5 +202,16 @@ ARROW_EXPORT
Result<Datum> Match(const Datum& values, const Datum& value_set,
ExecContext* ctx = NULLPTR);

// ----------------------------------------------------------------------
// Temporal functions

struct ARROW_EXPORT StrptimeOptions : public FunctionOptions {
explicit StrptimeOptions(std::string format, TimeUnit::type unit)
: format(format), unit(unit) {}

std::string format;
TimeUnit::type unit;
};

} // namespace compute
} // namespace arrow
13 changes: 7 additions & 6 deletions cpp/src/arrow/compute/kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -254,14 +254,15 @@ OutputType::OutputType(ValueDescr descr) : OutputType(descr.type) {

Result<ValueDescr> OutputType::Resolve(KernelContext* ctx,
const std::vector<ValueDescr>& args) const {
ValueDescr::Shape broadcasted_shape = GetBroadcastShape(args);
if (kind_ == OutputType::FIXED) {
ValueDescr::Shape out_shape = shape_;
if (out_shape == ValueDescr::ANY) {
out_shape = GetBroadcastShape(args);
}
return ValueDescr(type_, out_shape);
return ValueDescr(type_, shape_ == ValueDescr::ANY ? broadcasted_shape : shape_);
} else {
return resolver_(ctx, args);
ARROW_ASSIGN_OR_RAISE(ValueDescr resolved_descr, resolver_(ctx, args));
if (resolved_descr.shape == ValueDescr::ANY) {
resolved_descr.shape = broadcasted_shape;
}
return resolved_descr;
}
}

Expand Down
11 changes: 11 additions & 0 deletions cpp/src/arrow/compute/kernel_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,17 @@ TEST(OutputType, Resolve) {
return ValueDescr(args[0]);
});
ASSERT_RAISES(Invalid, ty3.Resolve(nullptr, {}));

// Type resolver that returns ValueDescr::ANY and needs type promotion
OutputType ty4(
[](KernelContext* ctx, const std::vector<ValueDescr>& args) -> Result<ValueDescr> {
return int32();
});

ASSERT_OK_AND_ASSIGN(descr, ty4.Resolve(nullptr, {ValueDescr::Array(int8())}));
ASSERT_EQ(ValueDescr::Array(int32()), descr);
ASSERT_OK_AND_ASSIGN(descr, ty4.Resolve(nullptr, {ValueDescr::Scalar(int8())}));
ASSERT_EQ(ValueDescr::Scalar(int32()), descr);
}

TEST(OutputType, ResolveDescr) {
Expand Down
8 changes: 5 additions & 3 deletions cpp/src/arrow/compute/kernels/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ add_arrow_compute_test(scalar_test
scalar_cast_test.cc
scalar_compare_test.cc
scalar_set_lookup_test.cc
scalar_string_test.cc)
scalar_string_test.cc
test_util.cc)

add_arrow_benchmark(scalar_compare_benchmark PREFIX "arrow-compute")

Expand All @@ -37,7 +38,8 @@ add_arrow_compute_test(vector_test
vector_filter_test.cc
vector_hash_test.cc
vector_take_test.cc
vector_sort_test.cc)
vector_sort_test.cc
test_util.cc)

add_arrow_benchmark(vector_hash_benchmark PREFIX "arrow-compute")
add_arrow_benchmark(vector_sort_benchmark PREFIX "arrow-compute")
Expand All @@ -50,5 +52,5 @@ add_arrow_benchmark(vector_take_benchmark PREFIX "arrow-compute")

# Aggregates

add_arrow_compute_test(aggregate_test)
add_arrow_compute_test(aggregate_test SOURCES aggregate_test.cc test_util.cc)
add_arrow_benchmark(aggregate_benchmark PREFIX "arrow-compute")
2 changes: 1 addition & 1 deletion cpp/src/arrow/compute/kernels/aggregate_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
#include "arrow/array.h"
#include "arrow/compute/api_aggregate.h"
#include "arrow/compute/kernels/aggregate_internal.h"
#include "arrow/compute/test_util.h"
#include "arrow/compute/kernels/test_util.h"
#include "arrow/table.h"
#include "arrow/type.h"
#include "arrow/type_traits.h"
Expand Down
12 changes: 12 additions & 0 deletions cpp/src/arrow/compute/kernels/codegen_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,18 @@ namespace compute {

#endif // ARROW_EXTRA_ERROR_CONTEXT

template <typename OptionsType>
struct OptionsWrapper : public KernelState {
OptionsWrapper(const OptionsType& options) : options(options) {}
OptionsType options;
};

template <typename OptionsType>
std::unique_ptr<KernelState> InitWrapOptions(KernelContext*, const KernelInitArgs& args) {
return std::unique_ptr<KernelState>(
new OptionsWrapper<OptionsType>(*static_cast<const OptionsType*>(args.options)));
}

// ----------------------------------------------------------------------
// Iteration / value access utilities

Expand Down
2 changes: 1 addition & 1 deletion cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

#include "arrow/array.h"
#include "arrow/compute/api.h"
#include "arrow/compute/test_util.h"
#include "arrow/compute/kernels/test_util.h"
#include "arrow/type.h"
#include "arrow/type_traits.h"
#include "arrow/util/checked_cast.h"
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/arrow/compute/kernels/scalar_boolean_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
#include <gtest/gtest.h>

#include "arrow/compute/api_scalar.h"
#include "arrow/compute/test_util.h"
#include "arrow/compute/kernels/test_util.h"
#include "arrow/table.h"
#include "arrow/testing/gtest_common.h"
#include "arrow/testing/gtest_util.h"
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/arrow/compute/kernels/scalar_cast_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
#include "arrow/compute/api_vector.h"
#include "arrow/compute/cast.h"
#include "arrow/compute/kernel.h"
#include "arrow/compute/test_util.h"
#include "arrow/compute/kernels/test_util.h"

namespace arrow {
namespace compute {
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/arrow/compute/kernels/scalar_compare_benchmark.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

#include "arrow/compute/api_scalar.h"
#include "arrow/compute/benchmark_util.h"
#include "arrow/compute/test_util.h"
#include "arrow/compute/kernels/test_util.h"
#include "arrow/testing/gtest_util.h"
#include "arrow/testing/random.h"

Expand Down
2 changes: 1 addition & 1 deletion cpp/src/arrow/compute/kernels/scalar_compare_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

#include "arrow/array.h"
#include "arrow/compute/api.h"
#include "arrow/compute/test_util.h"
#include "arrow/compute/kernels/test_util.h"
#include "arrow/type.h"
#include "arrow/type_traits.h"
#include "arrow/util/checked_cast.h"
Expand Down
3 changes: 2 additions & 1 deletion cpp/src/arrow/compute/kernels/scalar_set_lookup_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@
#include "arrow/array/builder_binary.h"
#include "arrow/array/builder_primitive.h"
#include "arrow/compute/api.h"
#include "arrow/compute/test_util.h"
#include "arrow/compute/kernels/test_util.h"
#include "arrow/memory_pool.h"
#include "arrow/result.h"
#include "arrow/status.h"
#include "arrow/table.h"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@
#include <cctype>
#include <string>

#include "arrow/compute/api_scalar.h"
#include "arrow/compute/kernels/common.h"
#include "arrow/compute/kernels/scalar_string_internal.h"
#include "arrow/util/value_parsing.h"

namespace arrow {
namespace compute {
Expand Down Expand Up @@ -57,9 +59,57 @@ void AddAsciiLength(FunctionRegistry* registry) {
DCHECK_OK(registry->AddFunction(std::move(func)));
}

// ----------------------------------------------------------------------
// strptime string parsing

using StrptimeWrapper = OptionsWrapper<StrptimeOptions>;

struct ParseStrptime {
explicit ParseStrptime(const StrptimeOptions& options)
: parser(TimestampParser::MakeStrptime(options.format)), unit(options.unit) {}

template <typename... Ignored>
int64_t Call(KernelContext* ctx, util::string_view val) const {
int64_t result = 0;
if (!(*parser)(val.data(), val.size(), unit, &result)) {
ctx->SetStatus(Status::Invalid("Failed to parse string ", val));
}
return result;
}

std::shared_ptr<TimestampParser> parser;
TimeUnit::type unit;
};

template <typename InputType>
void StrptimeExec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
const auto& options = checked_cast<const StrptimeWrapper*>(ctx->state())->options;
codegen::ScalarUnaryNotNullStateful<TimestampType, InputType, ParseStrptime> kernel =
ParseStrptime(options);
return kernel.Exec(ctx, batch, out);
}

Result<ValueDescr> StrptimeResolve(KernelContext* ctx, const std::vector<ValueDescr>&) {
const auto& options = checked_cast<const StrptimeWrapper*>(ctx->state())->options;
return ::arrow::timestamp(options.unit);
}

void AddStrptime(FunctionRegistry* registry) {
auto func = std::make_shared<ScalarFunction>("strptime", Arity::Unary());
DCHECK_OK(func->AddKernel({utf8()}, OutputType(StrptimeResolve),
StrptimeExec<StringType>, InitWrapOptions<StrptimeOptions>));
DCHECK_OK(func->AddKernel({large_utf8()}, OutputType(StrptimeResolve),
StrptimeExec<LargeStringType>,
InitWrapOptions<StrptimeOptions>));
DCHECK_OK(registry->AddFunction(std::move(func)));
}

// ----------------------------------------------------------------------

void RegisterScalarStringAscii(FunctionRegistry* registry) {
MakeUnaryStringToString<AsciiUpper>("ascii_upper", registry);
AddAsciiLength(registry);
AddStrptime(registry);
}

} // namespace internal
Expand Down
25 changes: 12 additions & 13 deletions cpp/src/arrow/compute/kernels/scalar_string_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <gtest/gtest.h>

#include "arrow/compute/api_scalar.h"
#include "arrow/compute/kernels/test_util.h"
#include "arrow/testing/gtest_util.h"

namespace arrow {
Expand All @@ -33,19 +34,10 @@ class TestStringKernels : public ::testing::Test {
using OffsetType = typename TypeTraits<TestType>::OffsetType;

void CheckUnary(std::string func_name, std::string json_input,
std::shared_ptr<DataType> out_ty, std::string json_expected) {
auto input = ArrayFromJSON(string_type(), json_input);
auto expected = ArrayFromJSON(out_ty, json_expected);
ASSERT_OK_AND_ASSIGN(Datum out, CallFunction(func_name, {input}));
AssertArraysEqual(*expected, *out.make_array(), /*verbose=*/true);

// Check all the scalars
for (int64_t i = 0; i < input->length(); ++i) {
ASSERT_OK_AND_ASSIGN(auto val, input->GetScalar(i));
ASSERT_OK_AND_ASSIGN(auto ex_val, expected->GetScalar(i));
ASSERT_OK_AND_ASSIGN(Datum out, CallFunction(func_name, {val}));
AssertScalarsEqual(*ex_val, *out.scalar(), /*verbose=*/true);
}
std::shared_ptr<DataType> out_ty, std::string json_expected,
const FunctionOptions* options = nullptr) {
CheckScalarUnary(func_name, string_type(), json_input, out_ty, json_expected,
options);
}

std::shared_ptr<DataType> string_type() {
Expand All @@ -69,5 +61,12 @@ TYPED_TEST(TestStringKernels, AsciiUpper) {
"[\"AAA&\", null, \"\", \"B\"]");
}

TYPED_TEST(TestStringKernels, Strptime) {
std::string input1 = R"(["5/1/2020", null, "12/11/1900"])";
std::string output1 = R"(["2020-05-01", null, "1900-12-11"])";
StrptimeOptions options("%m/%d/%Y", TimeUnit::MICRO);
this->CheckUnary("strptime", input1, timestamp(TimeUnit::MICRO), output1, &options);
}

} // namespace compute
} // namespace arrow
51 changes: 51 additions & 0 deletions cpp/src/arrow/compute/kernels/test_util.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

#include "arrow/compute/kernels/test_util.h"

#include <cstdint>
#include <memory>
#include <string>

#include "arrow/array.h"
#include "arrow/compute/exec.h"
#include "arrow/datum.h"
#include "arrow/result.h"
#include "arrow/testing/gtest_util.h"

namespace arrow {
namespace compute {

void CheckScalarUnary(std::string func_name, std::shared_ptr<DataType> in_ty,
std::string json_input, std::shared_ptr<DataType> out_ty,
std::string json_expected, const FunctionOptions* options) {
auto input = ArrayFromJSON(in_ty, json_input);
auto expected = ArrayFromJSON(out_ty, json_expected);
ASSERT_OK_AND_ASSIGN(Datum out, CallFunction(func_name, {input}, options));
AssertArraysEqual(*expected, *out.make_array(), /*verbose=*/true);

// Check all the scalars
for (int64_t i = 0; i < input->length(); ++i) {
ASSERT_OK_AND_ASSIGN(auto val, input->GetScalar(i));
ASSERT_OK_AND_ASSIGN(auto ex_val, expected->GetScalar(i));
ASSERT_OK_AND_ASSIGN(Datum out, CallFunction(func_name, {val}, options));
AssertScalarsEqual(*ex_val, *out.scalar(), /*verbose=*/true);
}
}

} // namespace compute
} // namespace arrow
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
// IWYU pragma: begin_exports

#include <memory>
#include <string>
#include <vector>

#include <gmock/gmock.h>
Expand Down Expand Up @@ -86,6 +87,11 @@ struct DatumEqual<Type, enable_if_integer<Type>> {
}
};

void CheckScalarUnary(std::string func_name, std::shared_ptr<DataType> in_ty,
std::string json_input, std::shared_ptr<DataType> out_ty,
std::string json_expected,
const FunctionOptions* options = nullptr);

using TestingStringTypes =
::testing::Types<StringType, LargeStringType, BinaryType, LargeBinaryType>;

Expand Down
2 changes: 1 addition & 1 deletion cpp/src/arrow/compute/kernels/vector_filter_benchmark.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

#include "arrow/compute/api_vector.h"
#include "arrow/compute/benchmark_util.h"
#include "arrow/compute/test_util.h"
#include "arrow/compute/kernels/test_util.h"
#include "arrow/testing/gtest_util.h"
#include "arrow/testing/random.h"

Expand Down
2 changes: 1 addition & 1 deletion cpp/src/arrow/compute/kernels/vector_filter_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
#include <vector>

#include "arrow/compute/api.h"
#include "arrow/compute/test_util.h"
#include "arrow/compute/kernels/test_util.h"
#include "arrow/table.h"
#include "arrow/testing/gtest_common.h"
#include "arrow/testing/gtest_util.h"
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/arrow/compute/kernels/vector_hash_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
#include "arrow/util/decimal.h"

#include "arrow/compute/api.h"
#include "arrow/compute/test_util.h"
#include "arrow/compute/kernels/test_util.h"

#include "arrow/ipc/json_simple.h"

Expand Down
Loading