diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index 30c4c737081..031ae5d39c5 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -340,6 +340,7 @@ if(ARROW_COMPUTE) compute/kernels/scalar_cast_temporal.cc compute/kernels/scalar_compare.cc compute/kernels/scalar_set_lookup.cc + compute/kernels/scalar_string_ascii.cc compute/kernels/vector_filter.cc compute/kernels/vector_hash.cc compute/kernels/vector_sort.cc diff --git a/cpp/src/arrow/compare.cc b/cpp/src/arrow/compare.cc index 351a42f4918..7c19a6fdc23 100644 --- a/cpp/src/arrow/compare.cc +++ b/cpp/src/arrow/compare.cc @@ -850,7 +850,7 @@ class ScalarEqualsVisitor { template typename std::enable_if::value, Status>::type Visit(const T& left) { - const auto& right = checked_cast(right_); + const auto& right = checked_cast(right_); result_ = internal::SharedPtrEquals(left.value, right.value); return Status::OK(); } diff --git a/cpp/src/arrow/compute/kernels/CMakeLists.txt b/cpp/src/arrow/compute/kernels/CMakeLists.txt index 361e24b7523..74493a85e18 100644 --- a/cpp/src/arrow/compute/kernels/CMakeLists.txt +++ b/cpp/src/arrow/compute/kernels/CMakeLists.txt @@ -24,7 +24,8 @@ add_arrow_compute_test(scalar_test scalar_boolean_test.cc scalar_cast_test.cc scalar_compare_test.cc - scalar_set_lookup_test.cc) + scalar_set_lookup_test.cc + scalar_string_test.cc) add_arrow_benchmark(scalar_compare_benchmark PREFIX "arrow-compute") diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.cc b/cpp/src/arrow/compute/kernels/codegen_internal.cc index 2771b6a89f8..5db4c92471e 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal.cc +++ b/cpp/src/arrow/compute/kernels/codegen_internal.cc @@ -102,6 +102,11 @@ const std::vector>& BaseBinaryTypes() { return g_base_binary_types; } +const std::vector>& StringTypes() { + static DataTypeVector types = {utf8(), large_utf8()}; + return types; +} + const std::vector>& SignedIntTypes() { std::call_once(codegen_static_initialized, InitStaticData); return g_signed_int_types; diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.h b/cpp/src/arrow/compute/kernels/codegen_internal.h index bf504a320fc..512f2a063b3 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal.h +++ b/cpp/src/arrow/compute/kernels/codegen_internal.h @@ -122,20 +122,55 @@ struct UnboxScalar> { }; template -struct GetValueType; +struct GetViewType; template -struct GetValueType> { +struct GetViewType> { using T = typename Type::c_type; }; template -struct GetValueType< +struct GetViewType< Type, enable_if_t::value || is_decimal_type::value || is_fixed_size_binary_type::value>> { using T = util::string_view; }; +template +struct GetOutputType; + +template +struct GetOutputType> { + using T = typename Type::c_type; +}; + +template +struct GetOutputType< + Type, enable_if_t::value>> { + using T = std::string; +}; + +template +struct BoxScalar; + +template +struct BoxScalar> { + using T = typename GetOutputType::T; + using ScalarType = typename TypeTraits::ScalarType; + static std::shared_ptr Box(T val, const std::shared_ptr& type) { + return std::make_shared(val, type); + } +}; + +template +struct BoxScalar> { + using T = typename GetOutputType::T; + using ScalarType = typename TypeTraits::ScalarType; + static std::shared_ptr Box(T val, const std::shared_ptr&) { + return std::make_shared(val); + } +}; + // ---------------------------------------------------------------------- // Reusable type resolvers @@ -154,6 +189,7 @@ void BinaryExecFlipped(KernelContext* ctx, ArrayKernelExec exec, // functions const std::vector>& BaseBinaryTypes(); +const std::vector>& StringTypes(); const std::vector>& SignedIntTypes(); const std::vector>& UnsignedIntTypes(); const std::vector>& IntTypes(); @@ -327,10 +363,8 @@ struct OutputAdapter> { // }; template struct ScalarUnary { - using OutScalar = typename TypeTraits::ScalarType; - - using OUT = typename GetValueType::T; - using ARG0 = typename GetValueType::T; + using OUT = typename GetOutputType::T; + using ARG0 = typename GetViewType::T; static void Array(KernelContext* ctx, const ExecBatch& batch, Datum* out) { ArrayIterator arg0(*batch[0].array()); @@ -342,8 +376,9 @@ struct ScalarUnary { static void Scalar(KernelContext* ctx, const ExecBatch& batch, Datum* out) { if (batch[0].scalar()->is_valid) { ARG0 arg0 = UnboxScalar::Unbox(batch[0]); - out->value = std::make_shared(Op::template Call(ctx, arg0), - out->type()); + out->value = BoxScalar::Box( + Op::template Call(ctx, arg0), + out->type()); } else { out->value = MakeNullScalar(batch[0].type()); } @@ -363,9 +398,8 @@ struct ScalarUnary { template struct ScalarUnaryNotNullStateful { using ThisType = ScalarUnaryNotNullStateful; - using OutScalar = typename TypeTraits::ScalarType; - using OUT = typename GetValueType::T; - using ARG0 = typename GetValueType::T; + using OUT = typename GetOutputType::T; + using ARG0 = typename GetViewType::T; Op op; ScalarUnaryNotNullStateful(Op op) : op(std::move(op)) {} @@ -394,6 +428,30 @@ struct ScalarUnaryNotNullStateful { } }; + template + struct ArrayExec> { + static void Exec(const ThisType& functor, KernelContext* ctx, const ExecBatch& batch, + Datum* out) { + typename TypeTraits::BuilderType builder; + Status s = VisitArrayDataInline( + *batch[0].array(), [&](util::optional v) -> Status { + if (v.has_value()) { + return builder.Append(functor.op.Call(ctx, *v)); + } else { + return builder.AppendNull(); + } + }); + if (!s.ok()) { + ctx->SetStatus(s); + return; + } else { + std::shared_ptr result; + ctx->SetStatus(builder.FinishInternal(&result)); + out->value = std::move(result); + } + } + }; + template struct ArrayExec::value>> { static void Exec(const ThisType& functor, KernelContext* ctx, const ExecBatch& batch, @@ -416,7 +474,7 @@ struct ScalarUnaryNotNullStateful { void Scalar(KernelContext* ctx, const ExecBatch& batch, Datum* out) { if (batch[0].scalar()->is_valid) { ARG0 arg0 = UnboxScalar::Unbox(batch[0]); - out->value = std::make_shared( + out->value = BoxScalar::Box( this->op.template Call(ctx, arg0), out->type()); } else { @@ -438,6 +496,9 @@ struct ScalarUnaryNotNullStateful { // operator requires some initialization use ScalarUnaryNotNullStateful template struct ScalarUnaryNotNull { + using OUT = typename GetOutputType::T; + using ARG0 = typename GetViewType::T; + static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { // Seed kernel with dummy state ScalarUnaryNotNullStateful kernel({}); @@ -464,11 +525,9 @@ struct ScalarUnaryNotNull { template struct ScalarBinary { - using OutScalarType = typename TypeTraits::ScalarType; - - using OUT = typename GetValueType::T; - using ARG0 = typename GetValueType::T; - using ARG1 = typename GetValueType::T; + using OUT = typename GetOutputType::T; + using ARG0 = typename GetViewType::T; + using ARG1 = typename GetViewType::T; template static void ArrayArray(KernelContext* ctx, const ExecBatch& batch, Datum* out) { @@ -492,7 +551,8 @@ struct ScalarBinary { static void ScalarScalar(KernelContext* ctx, const ExecBatch& batch, Datum* out) { auto arg0 = UnboxScalar::Unbox(batch[0]); auto arg1 = UnboxScalar::Unbox(batch[1]); - out->value = std::make_shared(ChosenOp::template Call(ctx, arg0, arg1)); + out->value = BoxScalar::Box(ChosenOp::template Call(ctx, arg0, arg1), + out->type()); } static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { diff --git a/cpp/src/arrow/compute/kernels/scalar_set_lookup.cc b/cpp/src/arrow/compute/kernels/scalar_set_lookup.cc index 956d9e5c6b9..502fba2980e 100644 --- a/cpp/src/arrow/compute/kernels/scalar_set_lookup.cc +++ b/cpp/src/arrow/compute/kernels/scalar_set_lookup.cc @@ -42,7 +42,7 @@ struct SetLookupState : public KernelState { : lookup_table(pool, 0), lookup_null_count(0) {} Status Init(const SetLookupOptions& options) { - using T = typename GetValueType::T; + using T = typename GetViewType::T; auto insert_value = [&](util::optional v) { if (v.has_value()) { int32_t unused_memo_index; @@ -147,7 +147,7 @@ struct MatchVisitor { template enable_if_supports_set_lookup Visit(const Type&) { - using T = typename GetValueType::T; + using T = typename GetViewType::T; const auto& state = checked_cast&>(*ctx->state()); @@ -222,7 +222,7 @@ struct IsInVisitor { template enable_if_supports_set_lookup Visit(const Type&) { - using T = typename GetValueType::T; + using T = typename GetViewType::T; const auto& state = checked_cast&>(*ctx->state()); ArrayData* output = out->mutable_array(); diff --git a/cpp/src/arrow/compute/kernels/scalar_string_ascii.cc b/cpp/src/arrow/compute/kernels/scalar_string_ascii.cc new file mode 100644 index 00000000000..19eaf84016f --- /dev/null +++ b/cpp/src/arrow/compute/kernels/scalar_string_ascii.cc @@ -0,0 +1,67 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include +#include + +#include "arrow/compute/kernels/common.h" +#include "arrow/compute/kernels/scalar_string_internal.h" + +namespace arrow { +namespace compute { +namespace internal { + +// TODO: optional ascii validation + +struct AsciiLength { + template + static OUT Call(KernelContext*, ARG0 val) { + return static_cast(val.size()); + } +}; + +struct AsciiUpper { + // XXX: the Scalar codegen path passes template arguments that are unused + template + static std::string Call(KernelContext*, const util::string_view& val) { + std::string result = val.to_string(); + std::transform(result.begin(), result.end(), result.begin(), + [](unsigned char c) { return std::toupper(c); }); + return result; + } +}; + +void AddAsciiLength(FunctionRegistry* registry) { + auto func = std::make_shared("ascii_length", Arity::Unary()); + ArrayKernelExec exec_offset_32 = + codegen::ScalarUnaryNotNull::Exec; + ArrayKernelExec exec_offset_64 = + codegen::ScalarUnaryNotNull::Exec; + DCHECK_OK(func->AddKernel({utf8()}, int32(), exec_offset_32)); + DCHECK_OK(func->AddKernel({large_utf8()}, int64(), exec_offset_64)); + DCHECK_OK(registry->AddFunction(std::move(func))); +} + +void RegisterScalarStringAscii(FunctionRegistry* registry) { + MakeUnaryStringToString("ascii_upper", registry); + AddAsciiLength(registry); +} + +} // namespace internal +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/scalar_string_internal.h b/cpp/src/arrow/compute/kernels/scalar_string_internal.h new file mode 100644 index 00000000000..dc71a044273 --- /dev/null +++ b/cpp/src/arrow/compute/kernels/scalar_string_internal.h @@ -0,0 +1,44 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include +#include + +#include "arrow/array/builder_binary.h" +#include "arrow/compute/kernels/common.h" + +namespace arrow { +namespace compute { +namespace internal { + +// Apply a scalar function to each string and yield same output type +template +void MakeUnaryStringToString(std::string name, FunctionRegistry* registry) { + auto func = std::make_shared(name, Arity::Unary()); + ArrayKernelExec exec_offset_32 = + codegen::ScalarUnaryNotNull::Exec; + ArrayKernelExec exec_offset_64 = + codegen::ScalarUnaryNotNull::Exec; + DCHECK_OK(func->AddKernel({utf8()}, utf8(), exec_offset_32)); + DCHECK_OK(func->AddKernel({large_utf8()}, large_utf8(), exec_offset_64)); + DCHECK_OK(registry->AddFunction(std::move(func))); +} + +} // namespace internal +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/scalar_string_test.cc b/cpp/src/arrow/compute/kernels/scalar_string_test.cc new file mode 100644 index 00000000000..fba9a21e786 --- /dev/null +++ b/cpp/src/arrow/compute/kernels/scalar_string_test.cc @@ -0,0 +1,73 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include + +#include + +#include "arrow/compute/api_scalar.h" +#include "arrow/testing/gtest_util.h" + +namespace arrow { +namespace compute { + +typedef ::testing::Types StringTypes; + +template +class TestStringKernels : public ::testing::Test { + protected: + using OffsetType = typename TypeTraits::OffsetType; + + void CheckUnary(std::string func_name, std::string json_input, + std::shared_ptr out_ty, std::string json_expected) { + auto input = ArrayFromJSON(string_type(), json_input); + auto expected = ArrayFromJSON(out_ty, json_expected); + ASSERT_OK_AND_ASSIGN(Datum out, CallFunction(func_name, {input})); + AssertArraysEqual(*expected, *out.make_array(), /*verbose=*/true); + + // Check all the scalars + for (int64_t i = 0; i < input->length(); ++i) { + ASSERT_OK_AND_ASSIGN(auto val, input->GetScalar(i)); + ASSERT_OK_AND_ASSIGN(auto ex_val, expected->GetScalar(i)); + ASSERT_OK_AND_ASSIGN(Datum out, CallFunction(func_name, {val})); + AssertScalarsEqual(*ex_val, *out.scalar(), /*verbose=*/true); + } + } + + std::shared_ptr string_type() { + return TypeTraits::type_singleton(); + } + + std::shared_ptr offset_type() { + return TypeTraits::type_singleton(); + } +}; + +TYPED_TEST_SUITE(TestStringKernels, StringTypes); + +TYPED_TEST(TestStringKernels, AsciiLength) { + this->CheckUnary("ascii_length", "[\"aaa\", null, \"\", \"b\"]", this->offset_type(), + "[3, null, 0, 1]"); +} + +TYPED_TEST(TestStringKernels, AsciiUpper) { + this->CheckUnary("ascii_upper", "[\"aAa&\", null, \"\", \"b\"]", this->string_type(), + "[\"AAA&\", null, \"\", \"B\"]"); +} + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/registry.cc b/cpp/src/arrow/compute/registry.cc index c7b6099908a..0d3f3e60574 100644 --- a/cpp/src/arrow/compute/registry.cc +++ b/cpp/src/arrow/compute/registry.cc @@ -102,6 +102,7 @@ static std::unique_ptr CreateBuiltInRegistry() { RegisterScalarBoolean(registry.get()); RegisterScalarComparison(registry.get()); RegisterScalarSetLookup(registry.get()); + RegisterScalarStringAscii(registry.get()); // Aggregate functions RegisterScalarAggregateBasic(registry.get()); diff --git a/cpp/src/arrow/compute/registry_internal.h b/cpp/src/arrow/compute/registry_internal.h index 75e53c793fe..801a5f5cd99 100644 --- a/cpp/src/arrow/compute/registry_internal.h +++ b/cpp/src/arrow/compute/registry_internal.h @@ -29,6 +29,7 @@ void RegisterScalarArithmetic(FunctionRegistry* registry); void RegisterScalarBoolean(FunctionRegistry* registry); void RegisterScalarComparison(FunctionRegistry* registry); void RegisterScalarSetLookup(FunctionRegistry* registry); +void RegisterScalarStringAscii(FunctionRegistry* registry); // Vector functions void RegisterVectorFilter(FunctionRegistry* registry); diff --git a/cpp/src/arrow/scalar.cc b/cpp/src/arrow/scalar.cc index c3e281111ca..cc32d8aa549 100644 --- a/cpp/src/arrow/scalar.cc +++ b/cpp/src/arrow/scalar.cc @@ -127,6 +127,9 @@ size_t Scalar::Hash::hash(const Scalar& scalar) { return ScalarHashImpl(scalar). StringScalar::StringScalar(std::string s) : StringScalar(Buffer::FromString(std::move(s))) {} +LargeStringScalar::LargeStringScalar(std::string s) + : LargeStringScalar(Buffer::FromString(std::move(s))) {} + FixedSizeBinaryScalar::FixedSizeBinaryScalar(std::shared_ptr value, std::shared_ptr type) : BinaryScalar(std::move(value), std::move(type)) { @@ -212,6 +215,9 @@ std::shared_ptr MakeNullScalar(std::shared_ptr type) { } std::string Scalar::ToString() const { + if (!this->is_valid) { + return "null"; + } auto maybe_repr = CastTo(utf8()); if (maybe_repr.ok()) { return checked_cast(*maybe_repr.ValueOrDie()).value->ToString(); diff --git a/cpp/src/arrow/scalar.h b/cpp/src/arrow/scalar.h index ac4e3dba053..5caf04d86d8 100644 --- a/cpp/src/arrow/scalar.h +++ b/cpp/src/arrow/scalar.h @@ -224,6 +224,8 @@ struct ARROW_EXPORT LargeStringScalar : public LargeBinaryScalar { explicit LargeStringScalar(std::shared_ptr value) : LargeStringScalar(std::move(value), large_utf8()) {} + explicit LargeStringScalar(std::string s); + LargeStringScalar() : LargeStringScalar(large_utf8()) {} }; diff --git a/cpp/src/arrow/testing/gtest_util.cc b/cpp/src/arrow/testing/gtest_util.cc index 894e3fb60c6..a91fc834a64 100644 --- a/cpp/src/arrow/testing/gtest_util.cc +++ b/cpp/src/arrow/testing/gtest_util.cc @@ -82,6 +82,22 @@ void AssertArraysEqual(const Array& expected, const Array& actual, bool verbose) } } +void AssertScalarsEqual(const Scalar& expected, const Scalar& actual, bool verbose) { + std::stringstream diff; + // ARROW-8956, ScalarEquals returns false when both are null + if (!expected.is_valid && !actual.is_valid) { + // We consider both being null to be equal in this function + return; + } + if (!expected.Equals(actual)) { + if (verbose) { + diff << "Expected:\n" << expected.ToString(); + diff << "\nActual:\n" << actual.ToString(); + } + FAIL() << diff.str(); + } +} + void AssertBatchesEqual(const RecordBatch& expected, const RecordBatch& actual, bool check_metadata) { AssertTsEqual(expected, actual, check_metadata); diff --git a/cpp/src/arrow/testing/gtest_util.h b/cpp/src/arrow/testing/gtest_util.h index 846e30dda62..d84db73fdd4 100644 --- a/cpp/src/arrow/testing/gtest_util.h +++ b/cpp/src/arrow/testing/gtest_util.h @@ -167,6 +167,9 @@ struct Datum; // If verbose is true, then the arrays will be pretty printed ARROW_EXPORT void AssertArraysEqual(const Array& expected, const Array& actual, bool verbose = false); +// Returns true when values are both null +ARROW_EXPORT void AssertScalarsEqual(const Scalar& expected, const Scalar& actual, + bool verbose = false); ARROW_EXPORT void AssertBatchesEqual(const RecordBatch& expected, const RecordBatch& actual, bool check_metadata = false); diff --git a/cpp/src/arrow/type_traits.h b/cpp/src/arrow/type_traits.h index f5e32ba1343..f61b690ac9d 100644 --- a/cpp/src/arrow/type_traits.h +++ b/cpp/src/arrow/type_traits.h @@ -267,6 +267,7 @@ struct TypeTraits { using ArrayType = StringArray; using BuilderType = StringBuilder; using ScalarType = StringScalar; + using OffsetType = Int32Type; constexpr static bool is_parameter_free = true; static inline std::shared_ptr type_singleton() { return utf8(); } }; @@ -276,6 +277,7 @@ struct TypeTraits { using ArrayType = LargeStringArray; using BuilderType = LargeStringBuilder; using ScalarType = LargeStringScalar; + using OffsetType = Int64Type; constexpr static bool is_parameter_free = true; static inline std::shared_ptr type_singleton() { return large_utf8(); } }; diff --git a/python/pyarrow/compute.py b/python/pyarrow/compute.py index eee193875cc..a2a5e91500e 100644 --- a/python/pyarrow/compute.py +++ b/python/pyarrow/compute.py @@ -25,6 +25,16 @@ ) +def _simple_unary_function(name): + def func(arg): + return call_function(name, [arg]) + return func + + +ascii_length = _simple_unary_function('ascii_length') +ascii_upper = _simple_unary_function('ascii_upper') + + def sum(array): """ Sum the values in a numerical (chunked) array.