From 882e86718324e3f82529a21f7b56ea20044b096d Mon Sep 17 00:00:00 2001 From: David Li Date: Fri, 3 Sep 2021 15:11:06 -0400 Subject: [PATCH 01/11] ARROW-13390: [C++] Implement ToString for union scalars --- cpp/src/arrow/scalar.cc | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/cpp/src/arrow/scalar.cc b/cpp/src/arrow/scalar.cc index adfc50182cb..d054a313373 100644 --- a/cpp/src/arrow/scalar.cc +++ b/cpp/src/arrow/scalar.cc @@ -920,6 +920,16 @@ Status CastImpl(const StructScalar& from, StringScalar* to) { return Status::OK(); } +Status CastImpl(const UnionScalar& from, StringScalar* to) { + const auto& union_ty = checked_cast(*from.type); + std::stringstream ss; + ss << '(' << static_cast(from.type_code) << ": " + << union_ty.field(union_ty.child_ids()[from.type_code])->ToString() << " = " + << from.value->ToString() << ')'; + to->value = Buffer::FromString(ss.str()); + return Status::OK(); +} + struct CastImplVisitor { Status NotImplemented() { return Status::NotImplemented("cast to ", *to_type_, " from ", *from_.type); @@ -953,8 +963,6 @@ struct FromTypeVisitor : CastImplVisitor { } Status Visit(const NullType&) { return NotImplemented(); } - Status Visit(const SparseUnionType&) { return NotImplemented(); } - Status Visit(const DenseUnionType&) { return NotImplemented(); } Status Visit(const DictionaryType&) { return NotImplemented(); } Status Visit(const ExtensionType&) { return NotImplemented(); } }; @@ -983,8 +991,6 @@ struct ToTypeVisitor : CastImplVisitor { return Int32Scalar(0).CastTo(dict_type.index_type()).Value(&out.index); } - Status Visit(const SparseUnionType&) { return NotImplemented(); } - Status Visit(const DenseUnionType&) { return NotImplemented(); } Status Visit(const ExtensionType&) { return NotImplemented(); } }; From 46912ee22336609c109455b3f9bbfcd7949b6bb3 Mon Sep 17 00:00:00 2001 From: David Li Date: Fri, 3 Sep 2021 15:11:28 -0400 Subject: [PATCH 02/11] ARROW-13390: [C++] Implement Coalesce for remaining types --- .../arrow/compute/kernels/scalar_if_else.cc | 243 +++++++++++-- .../compute/kernels/scalar_if_else_test.cc | 325 +++++++++++++++++- 2 files changed, 540 insertions(+), 28 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else.cc b/cpp/src/arrow/compute/kernels/scalar_if_else.cc index 35bb6248f23..d7ec9c0b84b 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else.cc @@ -2018,6 +2018,65 @@ Status ExecBinaryCoalesce(KernelContext* ctx, Datum left, Datum right, int64_t l return Status::OK(); } +template +static Status ExecVarWidthCoalesceImpl(KernelContext* ctx, const ExecBatch& batch, + Datum* out, + std::function reserve_data, + AppendScalar append_scalar) { + // Special case: grab any leading non-null scalar or array arguments + for (const auto& datum : batch.values) { + if (datum.is_scalar()) { + if (!datum.scalar()->is_valid) continue; + ARROW_ASSIGN_OR_RAISE( + *out, MakeArrayFromScalar(*datum.scalar(), batch.length, ctx->memory_pool())); + return Status::OK(); + } else if (datum.is_array() && !datum.array()->MayHaveNulls()) { + *out = datum; + return Status::OK(); + } + break; + } + ArrayData* output = out->mutable_array(); + std::unique_ptr raw_builder; + RETURN_NOT_OK(MakeBuilder(ctx->memory_pool(), out->type(), &raw_builder)); + RETURN_NOT_OK(raw_builder->Reserve(batch.length)); + RETURN_NOT_OK(reserve_data(raw_builder.get())); + + for (int64_t i = 0; i < batch.length; i++) { + bool set = false; + for (const auto& datum : batch.values) { + if (datum.is_scalar()) { + if (datum.scalar()->is_valid) { + RETURN_NOT_OK(append_scalar(raw_builder.get(), *datum.scalar())); + set = true; + break; + } + } else { + const ArrayData& source = *datum.array(); + if (!source.MayHaveNulls() || + BitUtil::GetBit(source.buffers[0]->data(), source.offset + i)) { + RETURN_NOT_OK(raw_builder->AppendArraySlice(source, i, /*length=*/1)); + set = true; + break; + } + } + } + if (!set) RETURN_NOT_OK(raw_builder->AppendNull()); + } + ARROW_ASSIGN_OR_RAISE(auto temp_output, raw_builder->Finish()); + *output = *temp_output->data(); + output->type = batch[0].type(); + return Status::OK(); +} + +static Status ExecVarWidthCoalesce(KernelContext* ctx, const ExecBatch& batch, Datum* out, + std::function reserve_data) { + return ExecVarWidthCoalesceImpl(ctx, batch, out, std::move(reserve_data), + [](ArrayBuilder* builder, const Scalar& scalar) { + return builder->AppendScalar(scalar); + }); +} + template struct CoalesceFunctor { static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { @@ -2093,51 +2152,173 @@ struct CoalesceFunctor> { } static Status ExecArray(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - // Special case: grab any leading non-null scalar or array arguments + return ExecVarWidthCoalesceImpl( + ctx, batch, out, + [&](ArrayBuilder* builder) { + int64_t reservation = 0; + for (const auto& datum : batch.values) { + if (datum.is_array()) { + const ArrayType array(datum.array()); + reservation = std::max(reservation, array.total_values_length()); + } else { + const auto& scalar = *datum.scalar(); + if (scalar.is_valid) { + const int64_t size = UnboxScalar::Unbox(scalar).size(); + reservation = std::max(reservation, batch.length * size); + } + } + } + return checked_cast(builder)->ReserveData(reservation); + }, + [&](ArrayBuilder* builder, const Scalar& scalar) { + return checked_cast(builder)->Append( + UnboxScalar::Unbox(scalar)); + }); + } +}; + +template <> +struct CoalesceFunctor { + static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { for (const auto& datum : batch.values) { - if (datum.is_scalar()) { - if (!datum.scalar()->is_valid) continue; - ARROW_ASSIGN_OR_RAISE( - *out, MakeArrayFromScalar(*datum.scalar(), batch.length, ctx->memory_pool())); - return Status::OK(); - } else if (datum.is_array() && !datum.array()->MayHaveNulls()) { - *out = datum; - return Status::OK(); + if (datum.is_array()) { + return ExecArray(ctx, batch, out); + } + } + return ExecScalarCoalesce(ctx, batch, out); + } + + static Status ExecArray(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + return ExecVarWidthCoalesce(ctx, batch, out, + [&](ArrayBuilder* builder) { return Status::OK(); }); + } +}; + +template +struct CoalesceFunctor> { + static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + for (const auto& datum : batch.values) { + if (datum.is_array()) { + return ExecArray(ctx, batch, out); } - break; } + return ExecScalarCoalesce(ctx, batch, out); + } + + static Status ExecArray(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + return ExecVarWidthCoalesce(ctx, batch, out, + [&](ArrayBuilder* builder) { return Status::OK(); }); + } +}; + +template <> +struct CoalesceFunctor { + static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + for (const auto& datum : batch.values) { + if (datum.is_array()) { + return ExecArray(ctx, batch, out); + } + } + return ExecScalarCoalesce(ctx, batch, out); + } + + static Status ExecArray(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + return ExecVarWidthCoalesce(ctx, batch, out, + [&](ArrayBuilder* builder) { return Status::OK(); }); + } +}; + +template <> +struct CoalesceFunctor { + static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + for (const auto& datum : batch.values) { + if (datum.is_array()) { + return ExecArray(ctx, batch, out); + } + } + return ExecScalarCoalesce(ctx, batch, out); + } + + static Status ExecArray(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + return ExecVarWidthCoalesce(ctx, batch, out, + [&](ArrayBuilder* builder) { return Status::OK(); }); + } +}; + +template +struct CoalesceFunctor> { + static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + // Unions don't have top-level nulls, so a specialized implementation is needed + for (const auto& datum : batch.values) { + if (datum.is_array()) { + return ExecArray(ctx, batch, out); + } + } + return ExecScalar(ctx, batch, out); + } + + static Status ExecArray(KernelContext* ctx, const ExecBatch& batch, Datum* out) { ArrayData* output = out->mutable_array(); - BuilderType builder(batch[0].type(), ctx->memory_pool()); - RETURN_NOT_OK(builder.Reserve(batch.length)); + std::unique_ptr raw_builder; + RETURN_NOT_OK(MakeBuilder(ctx->memory_pool(), out->type(), &raw_builder)); + RETURN_NOT_OK(raw_builder->Reserve(batch.length)); + + // TODO: make sure differing union types are rejected + const UnionType& type = checked_cast(*out->type()); for (int64_t i = 0; i < batch.length; i++) { bool set = false; for (const auto& datum : batch.values) { if (datum.is_scalar()) { - if (datum.scalar()->is_valid) { - RETURN_NOT_OK(builder.Append(UnboxScalar::Unbox(*datum.scalar()))); + const auto& scalar = checked_cast(*datum.scalar()); + if (scalar.is_valid && scalar.value->is_valid) { + RETURN_NOT_OK(raw_builder->AppendScalar(scalar)); set = true; break; } } else { const ArrayData& source = *datum.array(); - if (!source.MayHaveNulls() || - BitUtil::GetBit(source.buffers[0]->data(), source.offset + i)) { - const uint8_t* data = source.buffers[2]->data(); - const offset_type* offsets = source.GetValues(1); - const offset_type offset0 = offsets[i]; - const offset_type offset1 = offsets[i + 1]; - RETURN_NOT_OK(builder.Append(data + offset0, offset1 - offset0)); - set = true; - break; + // Peek at the relevant child array's validity bitmap + if (std::is_same::value) { + const int8_t type_id = source.GetValues(1)[i]; + const int child_id = type.child_ids()[type_id]; + const ArrayData& child = *source.child_data[child_id]; + if (!child.MayHaveNulls() || + BitUtil::GetBit(child.buffers[0]->data(), + source.offset + child.offset + i)) { + RETURN_NOT_OK(raw_builder->AppendArraySlice(source, i, /*length=*/1)); + set = true; + break; + } + } else { + const int8_t type_id = source.GetValues(1)[i]; + const int32_t offset = source.GetValues(2)[i]; + const int child_id = type.child_ids()[type_id]; + const ArrayData& child = *source.child_data[child_id]; + if (!child.MayHaveNulls() || + BitUtil::GetBit(child.buffers[0]->data(), child.offset + offset)) { + RETURN_NOT_OK(raw_builder->AppendArraySlice(source, i, /*length=*/1)); + set = true; + break; + } } } } - if (!set) RETURN_NOT_OK(builder.AppendNull()); + if (!set) RETURN_NOT_OK(raw_builder->AppendNull()); } - ARROW_ASSIGN_OR_RAISE(auto temp_output, builder.Finish()); + ARROW_ASSIGN_OR_RAISE(auto temp_output, raw_builder->Finish()); *output = *temp_output->data(); - // Builder type != logical type due to GenerateTypeAgnosticVarBinaryBase - output->type = batch[0].type(); + return Status::OK(); + } + + static Status ExecScalar(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + for (const auto& datum : batch.values) { + const auto& scalar = checked_cast(*datum.scalar()); + // Union scalars can have top-level validity + if (scalar.is_valid && scalar.value->is_valid) { + *out = datum; + break; + } + } return Status::OK(); } }; @@ -2511,6 +2692,14 @@ void RegisterScalarIfElse(FunctionRegistry* registry) { for (const auto& ty : BaseBinaryTypes()) { AddCoalesceKernel(func, ty, GenerateTypeAgnosticVarBinaryBase(ty)); } + AddCoalesceKernel(func, Type::FIXED_SIZE_LIST, + CoalesceFunctor::Exec); + AddCoalesceKernel(func, Type::LIST, CoalesceFunctor::Exec); + AddCoalesceKernel(func, Type::LARGE_LIST, CoalesceFunctor::Exec); + AddCoalesceKernel(func, Type::MAP, CoalesceFunctor::Exec); + AddCoalesceKernel(func, Type::STRUCT, CoalesceFunctor::Exec); + AddCoalesceKernel(func, Type::DENSE_UNION, CoalesceFunctor::Exec); + AddCoalesceKernel(func, Type::SPARSE_UNION, CoalesceFunctor::Exec); DCHECK_OK(registry->AddFunction(std::move(func))); } { diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc index 8793cac7619..376c056ca87 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc @@ -1688,11 +1688,14 @@ template class TestCoalesceNumeric : public ::testing::Test {}; template class TestCoalesceBinary : public ::testing::Test {}; +template +class TestCoalesceList : public ::testing::Test {}; TYPED_TEST_SUITE(TestCoalesceNumeric, NumericBasedTypes); TYPED_TEST_SUITE(TestCoalesceBinary, BinaryArrowTypes); +TYPED_TEST_SUITE(TestCoalesceList, ListArrowTypes); -TYPED_TEST(TestCoalesceNumeric, FixedSize) { +TYPED_TEST(TestCoalesceNumeric, Basics) { auto type = default_type_instance(); auto scalar_null = ScalarFromJSON(type, "null"); auto scalar1 = ScalarFromJSON(type, "20"); @@ -1718,6 +1721,34 @@ TYPED_TEST(TestCoalesceNumeric, FixedSize) { CheckScalar("coalesce", {scalar1, values1}, ArrayFromJSON(type, "[20, 20, 20, 20]")); } +TYPED_TEST(TestCoalesceNumeric, ListOfType) { + auto type = list(default_type_instance()); + auto scalar_null = ScalarFromJSON(type, "null"); + auto scalar1 = ScalarFromJSON(type, "[20, 24]"); + auto values_null = ArrayFromJSON(type, "[null, null, null, null]"); + auto values1 = ArrayFromJSON(type, "[null, [10, null, 20], [], [null, null]]"); + auto values2 = ArrayFromJSON(type, "[[23], [14, 24], [null, 15], [16]]"); + auto values3 = ArrayFromJSON(type, "[[17, 18], [19], [], null]"); + CheckScalar("coalesce", {values_null}, values_null); + CheckScalar("coalesce", {values_null, scalar1}, + ArrayFromJSON(type, "[[20, 24], [20, 24], [20, 24], [20, 24]]")); + CheckScalar("coalesce", {values_null, values1}, values1); + CheckScalar("coalesce", {values_null, values2}, values2); + CheckScalar("coalesce", {values1, values_null}, values1); + CheckScalar("coalesce", {values2, values_null}, values2); + CheckScalar("coalesce", {scalar_null, values1}, values1); + CheckScalar("coalesce", {values1, scalar_null}, values1); + CheckScalar("coalesce", {values2, values1, values_null}, values2); + CheckScalar("coalesce", {values1, scalar1}, + ArrayFromJSON(type, "[[20, 24], [10, null, 20], [], [null, null]]")); + CheckScalar("coalesce", {values1, values2}, + ArrayFromJSON(type, "[[23], [10, null, 20], [], [null, null]]")); + CheckScalar("coalesce", {values1, values2, values3}, + ArrayFromJSON(type, "[[23], [10, null, 20], [], [null, null]]")); + CheckScalar("coalesce", {scalar1, values1}, + ArrayFromJSON(type, "[[20, 24], [20, 24], [20, 24], [20, 24]]")); +} + TYPED_TEST(TestCoalesceBinary, Basics) { auto type = default_type_instance(); auto scalar_null = ScalarFromJSON(type, "null"); @@ -1747,6 +1778,140 @@ TYPED_TEST(TestCoalesceBinary, Basics) { ArrayFromJSON(type, R"(["a", "a", "a", "a"])")); } +TYPED_TEST(TestCoalesceList, ListOfString) { + auto type = std::make_shared(utf8()); + auto scalar_null = ScalarFromJSON(type, "null"); + auto scalar1 = ScalarFromJSON(type, R"([null, "a"])"); + auto values_null = ArrayFromJSON(type, R"([null, null, null, null])"); + auto values1 = ArrayFromJSON(type, R"([null, ["bc", null], ["def"], []])"); + auto values2 = ArrayFromJSON(type, R"([["klmno"], ["p"], ["qr", null], ["stu"]])"); + auto values3 = ArrayFromJSON(type, R"([["vwxy"], [], ["d"], null])"); + CheckScalar("coalesce", {values_null}, values_null); + CheckScalar( + "coalesce", {values_null, scalar1}, + ArrayFromJSON(type, R"([[null, "a"], [null, "a"], [null, "a"], [null, "a"]])")); + CheckScalar("coalesce", {values_null, values1}, values1); + CheckScalar("coalesce", {values_null, values2}, values2); + CheckScalar("coalesce", {values1, values_null}, values1); + CheckScalar("coalesce", {values2, values_null}, values2); + CheckScalar("coalesce", {scalar_null, values1}, values1); + CheckScalar("coalesce", {values1, scalar_null}, values1); + CheckScalar("coalesce", {values2, values1, values_null}, values2); + CheckScalar("coalesce", {values1, scalar1}, + ArrayFromJSON(type, R"([[null, "a"], ["bc", null], ["def"], []])")); + CheckScalar("coalesce", {values1, values2}, + ArrayFromJSON(type, R"([["klmno"], ["bc", null], ["def"], []])")); + CheckScalar("coalesce", {values1, values2, values3}, + ArrayFromJSON(type, R"([["klmno"], ["bc", null], ["def"], []])")); + CheckScalar( + "coalesce", {scalar1, values1}, + ArrayFromJSON(type, R"([[null, "a"], [null, "a"], [null, "a"], [null, "a"]])")); +} + +// More minimal tests to check type coverage +TYPED_TEST(TestCoalesceList, ListOfBool) { + auto type = std::make_shared(boolean()); + auto scalar_null = ScalarFromJSON(type, "null"); + auto scalar1 = ScalarFromJSON(type, "[true, false, null]"); + auto values_null = ArrayFromJSON(type, "[null, null, null, null]"); + auto values1 = ArrayFromJSON(type, "[null, [true, null, true], [], [null, null]]"); + CheckScalar("coalesce", {values_null}, values_null); + CheckScalar("coalesce", {values_null, scalar1}, + ArrayFromJSON(type, + "[[true, false, null], [true, false, null], [true, false, " + "null], [true, false, null]]")); + CheckScalar("coalesce", {values_null, values1}, values1); + CheckScalar("coalesce", {values1, values_null}, values1); + CheckScalar("coalesce", {scalar_null, values1}, values1); + CheckScalar("coalesce", {values1, scalar_null}, values1); +} + +TYPED_TEST(TestCoalesceList, ListOfInt) { + auto type = std::make_shared(int64()); + auto scalar_null = ScalarFromJSON(type, "null"); + auto scalar1 = ScalarFromJSON(type, "[20, 24]"); + auto values_null = ArrayFromJSON(type, "[null, null, null, null]"); + auto values1 = ArrayFromJSON(type, "[null, [10, null, 20], [], [null, null]]"); + CheckScalar("coalesce", {values_null}, values_null); + CheckScalar("coalesce", {values_null, scalar1}, + ArrayFromJSON(type, "[[20, 24], [20, 24], [20, 24], [20, 24]]")); + CheckScalar("coalesce", {values_null, values1}, values1); + CheckScalar("coalesce", {values1, values_null}, values1); + CheckScalar("coalesce", {scalar_null, values1}, values1); + CheckScalar("coalesce", {values1, scalar_null}, values1); +} + +TYPED_TEST(TestCoalesceList, ListOfDayTimeInterval) { + auto type = std::make_shared(day_time_interval()); + auto scalar_null = ScalarFromJSON(type, "null"); + auto scalar1 = ScalarFromJSON(type, "[[20, 24], null]"); + auto values_null = ArrayFromJSON(type, "[null, null, null, null]"); + auto values1 = + ArrayFromJSON(type, "[null, [[10, 12], null, [20, 22]], [], [null, null]]"); + CheckScalar("coalesce", {values_null}, values_null); + CheckScalar( + "coalesce", {values_null, scalar1}, + ArrayFromJSON( + type, + "[[[20, 24], null], [[20, 24], null], [[20, 24], null], [[20, 24], null]]")); + CheckScalar("coalesce", {values_null, values1}, values1); + CheckScalar("coalesce", {values1, values_null}, values1); + CheckScalar("coalesce", {scalar_null, values1}, values1); + CheckScalar("coalesce", {values1, scalar_null}, values1); +} + +TYPED_TEST(TestCoalesceList, ListOfDecimal) { + for (auto ty : {decimal128(3, 2), decimal256(3, 2)}) { + auto type = std::make_shared(ty); + auto scalar_null = ScalarFromJSON(type, "null"); + auto scalar1 = ScalarFromJSON(type, R"(["0.42", null])"); + auto values_null = ArrayFromJSON(type, "[null, null, null, null]"); + auto values1 = ArrayFromJSON(type, R"([null, ["1.23"], [], [null, null]])"); + CheckScalar("coalesce", {values_null}, values_null); + CheckScalar( + "coalesce", {values_null, scalar1}, + ArrayFromJSON( + type, R"([["0.42", null], ["0.42", null], ["0.42", null], ["0.42", null]])")); + CheckScalar("coalesce", {values_null, values1}, values1); + CheckScalar("coalesce", {values1, values_null}, values1); + CheckScalar("coalesce", {scalar_null, values1}, values1); + CheckScalar("coalesce", {values1, scalar_null}, values1); + } +} + +TYPED_TEST(TestCoalesceList, ListOfFixedSizeBinary) { + auto type = std::make_shared(fixed_size_binary(3)); + auto scalar_null = ScalarFromJSON(type, "null"); + auto scalar1 = ScalarFromJSON(type, R"(["ab!", null])"); + auto values_null = ArrayFromJSON(type, "[null, null, null, null]"); + auto values1 = ArrayFromJSON(type, R"([null, ["def"], [], [null, null]])"); + CheckScalar("coalesce", {values_null}, values_null); + CheckScalar( + "coalesce", {values_null, scalar1}, + ArrayFromJSON(type, + R"([["ab!", null], ["ab!", null], ["ab!", null], ["ab!", null]])")); + CheckScalar("coalesce", {values_null, values1}, values1); + CheckScalar("coalesce", {values1, values_null}, values1); + CheckScalar("coalesce", {scalar_null, values1}, values1); + CheckScalar("coalesce", {values1, scalar_null}, values1); +} + +TYPED_TEST(TestCoalesceList, ListOfListOfInt) { + auto type = std::make_shared(std::make_shared(int64())); + auto scalar_null = ScalarFromJSON(type, "null"); + auto scalar1 = ScalarFromJSON(type, "[[20], null]"); + auto values_null = ArrayFromJSON(type, "[null, null, null, null]"); + auto values1 = ArrayFromJSON(type, "[null, [[10, 12], null, []], [], [null, null]]"); + CheckScalar("coalesce", {values_null}, values_null); + CheckScalar( + "coalesce", {values_null, scalar1}, + ArrayFromJSON(type, "[[[20], null], [[20], null], [[20], null], [[20], null]]")); + CheckScalar("coalesce", {values_null, values1}, values1); + CheckScalar("coalesce", {values1, values_null}, values1); + CheckScalar("coalesce", {scalar_null, values1}, values1); + CheckScalar("coalesce", {values1, scalar_null}, values1); +} + TEST(TestCoalesce, Null) { auto type = null(); auto scalar_null = ScalarFromJSON(type, "null"); @@ -1870,6 +2035,164 @@ TEST(TestCoalesce, FixedSizeBinary) { ArrayFromJSON(type, R"(["abc", "abc", "abc", "abc"])")); } +TEST(TestCoalesce, FixedSizeListOfInt) { + auto type = fixed_size_list(uint8(), 2); + auto scalar_null = ScalarFromJSON(type, "null"); + auto scalar1 = ScalarFromJSON(type, R"([42, null])"); + auto values_null = ArrayFromJSON(type, R"([null, null, null, null])"); + auto values1 = ArrayFromJSON(type, R"([null, [2, null], [4, 8], [null, null]])"); + auto values2 = ArrayFromJSON(type, R"([[1, 5], [16, 32], [64, null], [null, 128]])"); + auto values3 = ArrayFromJSON(type, R"([[null, null], [1, 3], [9, 27], null])"); + CheckScalar("coalesce", {values_null}, values_null); + CheckScalar("coalesce", {values_null, scalar1}, + ArrayFromJSON(type, R"([[42, null], [42, null], [42, null], [42, null]])")); + CheckScalar("coalesce", {values_null, values1}, values1); + CheckScalar("coalesce", {values_null, values2}, values2); + CheckScalar("coalesce", {values1, values_null}, values1); + CheckScalar("coalesce", {values2, values_null}, values2); + CheckScalar("coalesce", {scalar_null, values1}, values1); + CheckScalar("coalesce", {values1, scalar_null}, values1); + CheckScalar("coalesce", {values2, values1, values_null}, values2); + CheckScalar("coalesce", {values1, scalar1}, + ArrayFromJSON(type, R"([[42, null], [2, null], [4, 8], [null, null]])")); + CheckScalar("coalesce", {values1, values2}, + ArrayFromJSON(type, R"([[1, 5], [2, null], [4, 8], [null, null]])")); + CheckScalar("coalesce", {values1, values2, values3}, + ArrayFromJSON(type, R"([[1, 5], [2, null], [4, 8], [null, null]])")); + CheckScalar("coalesce", {scalar1, values1}, + ArrayFromJSON(type, R"([[42, null], [42, null], [42, null], [42, null]])")); +} + +TEST(TestCoalesce, FixedSizeListOfString) { + auto type = fixed_size_list(utf8(), 2); + auto scalar_null = ScalarFromJSON(type, "null"); + auto scalar1 = ScalarFromJSON(type, R"(["abc", null])"); + auto values_null = ArrayFromJSON(type, R"([null, null, null, null])"); + auto values1 = + ArrayFromJSON(type, R"([null, ["d", null], ["ghi", "jkl"], [null, null]])"); + auto values2 = ArrayFromJSON( + type, R"([["mno", "pq"], ["pqr", "ab"], ["stu", null], [null, "vwx"]])"); + auto values3 = + ArrayFromJSON(type, R"([[null, null], ["a", "bcd"], ["d", "efg"], null])"); + CheckScalar("coalesce", {values_null}, values_null); + CheckScalar( + "coalesce", {values_null, scalar1}, + ArrayFromJSON(type, + R"([["abc", null], ["abc", null], ["abc", null], ["abc", null]])")); + CheckScalar("coalesce", {values_null, values1}, values1); + CheckScalar("coalesce", {values_null, values2}, values2); + CheckScalar("coalesce", {values1, values_null}, values1); + CheckScalar("coalesce", {values2, values_null}, values2); + CheckScalar("coalesce", {scalar_null, values1}, values1); + CheckScalar("coalesce", {values1, scalar_null}, values1); + CheckScalar("coalesce", {values2, values1, values_null}, values2); + CheckScalar("coalesce", {values1, scalar1}, + ArrayFromJSON( + type, R"([["abc", null], ["d", null], ["ghi", "jkl"], [null, null]])")); + CheckScalar("coalesce", {values1, values2}, + ArrayFromJSON( + type, R"([["mno", "pq"], ["d", null], ["ghi", "jkl"], [null, null]])")); + CheckScalar("coalesce", {values1, values2, values3}, + ArrayFromJSON( + type, R"([["mno", "pq"], ["d", null], ["ghi", "jkl"], [null, null]])")); + CheckScalar( + "coalesce", {scalar1, values1}, + ArrayFromJSON(type, + R"([["abc", null], ["abc", null], ["abc", null], ["abc", null]])")); +} + +TEST(TestCoalesce, Map) { + auto type = map(int64(), utf8()); + auto scalar_null = ScalarFromJSON(type, "null"); + auto scalar1 = ScalarFromJSON(type, R"([[1, "a"], [5, "bc"]])"); + auto values_null = ArrayFromJSON(type, R"([null, null, null, null])"); + auto values1 = + ArrayFromJSON(type, R"([null, [[2, "foo"], [4, null]], [[3, "test"]], []])"); + auto values2 = ArrayFromJSON( + type, R"([[[1, "b"]], [[2, "c"]], [[5, "c"], [6, "d"]], [[7, "abc"]]])"); + CheckScalar("coalesce", {values_null}, values_null); + CheckScalar("coalesce", {values_null, scalar1}, *MakeArrayFromScalar(*scalar1, 4)); + CheckScalar("coalesce", {values_null, values1}, values1); + CheckScalar("coalesce", {values_null, values2}, values2); + CheckScalar("coalesce", {values1, values_null}, values1); + CheckScalar("coalesce", {values2, values_null}, values2); + CheckScalar("coalesce", {scalar_null, values1}, values1); + CheckScalar("coalesce", {values1, scalar_null}, values1); + CheckScalar("coalesce", {values2, values1, values_null}, values2); + CheckScalar( + "coalesce", {values1, scalar1}, + ArrayFromJSON( + type, + R"([[[1, "a"], [5, "bc"]], [[2, "foo"], [4, null]], [[3, "test"]], []])")); + CheckScalar( + "coalesce", {values1, values2}, + ArrayFromJSON(type, R"([[[1, "b"]], [[2, "foo"], [4, null]], [[3, "test"]], []])")); + CheckScalar("coalesce", {scalar1, values1}, *MakeArrayFromScalar(*scalar1, 4)); +} + +TEST(TestCoalesce, Struct) { + auto type = struct_( + {field("int", uint32()), field("str", utf8()), field("list", list(int8()))}); + auto scalar_null = ScalarFromJSON(type, "null"); + auto scalar1 = ScalarFromJSON(type, R"([42, "spam", [null, -1]])"); + auto values_null = ArrayFromJSON(type, R"([null, null, null, null])"); + auto values1 = ArrayFromJSON( + type, R"([null, [null, "eggs", []], [0, "", [null]], [32, "abc", [1, 2, 3]]])"); + auto values2 = ArrayFromJSON( + type, + R"([[21, "foobar", [1, null, 2]], [5, "bar", []], [20, null, null], [1, "", [null]]])"); + CheckScalar("coalesce", {values_null}, values_null); + CheckScalar("coalesce", {values_null, scalar1}, *MakeArrayFromScalar(*scalar1, 4)); + CheckScalar("coalesce", {values_null, values1}, values1); + CheckScalar("coalesce", {values_null, values2}, values2); + CheckScalar("coalesce", {values1, values_null}, values1); + CheckScalar("coalesce", {values2, values_null}, values2); + CheckScalar("coalesce", {scalar_null, values1}, values1); + CheckScalar("coalesce", {values1, scalar_null}, values1); + CheckScalar("coalesce", {values2, values1, values_null}, values2); + CheckScalar( + "coalesce", {values1, scalar1}, + ArrayFromJSON( + type, + R"([[42, "spam", [null, -1]], [null, "eggs", []], [0, "", [null]], [32, "abc", [1, 2, 3]]])")); + CheckScalar( + "coalesce", {values1, values2}, + ArrayFromJSON( + type, + R"([[21, "foobar", [1, null, 2]], [null, "eggs", []], [0, "", [null]], [32, "abc", [1, 2, 3]]])")); + CheckScalar("coalesce", {scalar1, values1}, *MakeArrayFromScalar(*scalar1, 4)); +} + +TEST(TestCoalesce, UnionBoolString) { + for (const auto& type : { + sparse_union({field("a", boolean()), field("b", utf8())}, {2, 7}), + dense_union({field("a", boolean()), field("b", utf8())}, {2, 7}), + }) { + auto scalar_null = ScalarFromJSON(type, "null"); + auto scalar1 = ScalarFromJSON(type, R"([7, "foo"])"); + auto values_null = ArrayFromJSON(type, R"([null, null, null, null])"); + auto values1 = ArrayFromJSON(type, R"([null, [2, false], [7, "bar"], [7, "baz"]])"); + auto values2 = + ArrayFromJSON(type, R"([[2, true], [2, false], [7, "foo"], [7, "bar"]])"); + CheckScalar("coalesce", {values_null}, values_null); + CheckScalar("coalesce", {values_null, scalar1}, *MakeArrayFromScalar(*scalar1, 4)); + CheckScalar("coalesce", {values_null, values1}, values1); + CheckScalar("coalesce", {values_null, values2}, values2); + CheckScalar("coalesce", {values1, values_null}, values1); + CheckScalar("coalesce", {values2, values_null}, values2); + CheckScalar("coalesce", {scalar_null, values1}, values1); + CheckScalar("coalesce", {values1, scalar_null}, values1); + CheckScalar("coalesce", {values2, values1, values_null}, values2); + CheckScalar( + "coalesce", {values1, scalar1}, + ArrayFromJSON(type, R"([[7, "foo"], [2, false], [7, "bar"], [7, "baz"]])")); + CheckScalar( + "coalesce", {values1, values2}, + ArrayFromJSON(type, R"([[2, true], [2, false], [7, "bar"], [7, "baz"]])")); + CheckScalar("coalesce", {scalar1, values1}, *MakeArrayFromScalar(*scalar1, 4)); + } +} + template class TestChooseNumeric : public ::testing::Test {}; template From f744fcac91a51eeb8106c40c84e6ad65cb7197c2 Mon Sep 17 00:00:00 2001 From: David Li Date: Tue, 7 Sep 2021 14:41:44 -0400 Subject: [PATCH 03/11] ARROW-13390: [Ruby] Update scalar to_s tests --- c_glib/test/test-dense-union-scalar.rb | 2 +- c_glib/test/test-sparse-union-scalar.rb | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/c_glib/test/test-dense-union-scalar.rb b/c_glib/test/test-dense-union-scalar.rb index ec2053b3fe9..b6173d98232 100644 --- a/c_glib/test/test-dense-union-scalar.rb +++ b/c_glib/test/test-dense-union-scalar.rb @@ -49,7 +49,7 @@ def test_equal end def test_to_s - assert_equal("...", @scalar.to_s) + assert_equal("(2: number: int8 = -29)", @scalar.to_s) end def test_value diff --git a/c_glib/test/test-sparse-union-scalar.rb b/c_glib/test/test-sparse-union-scalar.rb index acb8531560b..1f26d8df869 100644 --- a/c_glib/test/test-sparse-union-scalar.rb +++ b/c_glib/test/test-sparse-union-scalar.rb @@ -49,7 +49,7 @@ def test_equal end def test_to_s - assert_equal("...", @scalar.to_s) + assert_equal("(2: number: int8 = -29)", @scalar.to_s) end def test_value From 754e0ce9b160fb7d7176933a966a5d45f5ae9bb0 Mon Sep 17 00:00:00 2001 From: David Li Date: Tue, 7 Sep 2021 17:44:34 -0400 Subject: [PATCH 04/11] ARROW-13390: [C++] Try to satisfy MinGW --- cpp/src/arrow/compute/kernels/scalar_if_else.cc | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else.cc b/cpp/src/arrow/compute/kernels/scalar_if_else.cc index d7ec9c0b84b..23ac666931e 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else.cc @@ -2189,8 +2189,8 @@ struct CoalesceFunctor { } static Status ExecArray(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - return ExecVarWidthCoalesce(ctx, batch, out, - [&](ArrayBuilder* builder) { return Status::OK(); }); + std::function reserve_data = ReserveNoData; + return ExecVarWidthCoalesce(ctx, batch, out, reserve_data); } }; @@ -2206,8 +2206,8 @@ struct CoalesceFunctor> { } static Status ExecArray(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - return ExecVarWidthCoalesce(ctx, batch, out, - [&](ArrayBuilder* builder) { return Status::OK(); }); + std::function reserve_data = ReserveNoData; + return ExecVarWidthCoalesce(ctx, batch, out, reserve_data); } }; @@ -2223,8 +2223,8 @@ struct CoalesceFunctor { } static Status ExecArray(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - return ExecVarWidthCoalesce(ctx, batch, out, - [&](ArrayBuilder* builder) { return Status::OK(); }); + std::function reserve_data = ReserveNoData; + return ExecVarWidthCoalesce(ctx, batch, out, reserve_data); } }; @@ -2240,8 +2240,8 @@ struct CoalesceFunctor { } static Status ExecArray(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - return ExecVarWidthCoalesce(ctx, batch, out, - [&](ArrayBuilder* builder) { return Status::OK(); }); + std::function reserve_data = ReserveNoData; + return ExecVarWidthCoalesce(ctx, batch, out, reserve_data); } }; From b51c2510c4465dbdef911ad2e9760c9fe578aff4 Mon Sep 17 00:00:00 2001 From: David Li Date: Thu, 16 Sep 2021 16:46:55 -0400 Subject: [PATCH 05/11] ARROW-13390: [C++] Address feedback, add dispatch tests --- c_glib/test/test-dense-union-scalar.rb | 2 +- c_glib/test/test-sparse-union-scalar.rb | 2 +- cpp/src/arrow/compute/kernels/CMakeLists.txt | 7 ++ .../arrow/compute/kernels/codegen_internal.cc | 78 ++++++++++++++-- .../arrow/compute/kernels/codegen_internal.h | 11 +++ .../compute/kernels/codegen_internal_test.cc | 86 ++++++++++++++++++ .../arrow/compute/kernels/scalar_if_else.cc | 33 ++++++- .../compute/kernels/scalar_if_else_test.cc | 88 +++++++++++++++++++ cpp/src/arrow/scalar.cc | 5 +- 9 files changed, 299 insertions(+), 13 deletions(-) create mode 100644 cpp/src/arrow/compute/kernels/codegen_internal_test.cc diff --git a/c_glib/test/test-dense-union-scalar.rb b/c_glib/test/test-dense-union-scalar.rb index b6173d98232..4a3e5c0dee7 100644 --- a/c_glib/test/test-dense-union-scalar.rb +++ b/c_glib/test/test-dense-union-scalar.rb @@ -49,7 +49,7 @@ def test_equal end def test_to_s - assert_equal("(2: number: int8 = -29)", @scalar.to_s) + assert_equal("union{number: int8 = -29}", @scalar.to_s) end def test_value diff --git a/c_glib/test/test-sparse-union-scalar.rb b/c_glib/test/test-sparse-union-scalar.rb index 1f26d8df869..a7f1b06953e 100644 --- a/c_glib/test/test-sparse-union-scalar.rb +++ b/c_glib/test/test-sparse-union-scalar.rb @@ -49,7 +49,7 @@ def test_equal end def test_to_s - assert_equal("(2: number: int8 = -29)", @scalar.to_s) + assert_equal("union{number: int8 = -29}", @scalar.to_s) end def test_value diff --git a/cpp/src/arrow/compute/kernels/CMakeLists.txt b/cpp/src/arrow/compute/kernels/CMakeLists.txt index ce7a85f1557..cd6bc19c869 100644 --- a/cpp/src/arrow/compute/kernels/CMakeLists.txt +++ b/cpp/src/arrow/compute/kernels/CMakeLists.txt @@ -71,3 +71,10 @@ add_arrow_compute_test(aggregate_test hash_aggregate_test.cc test_util.cc) add_arrow_benchmark(aggregate_benchmark PREFIX "arrow-compute") + +# ---------------------------------------------------------------------- +# Utilities + +add_arrow_compute_test(kernel_utility_test + SOURCES + codegen_internal_test.cc) diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.cc b/cpp/src/arrow/compute/kernels/codegen_internal.cc index fe4b593b481..b8f3d165ae2 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal.cc +++ b/cpp/src/arrow/compute/kernels/codegen_internal.cc @@ -95,8 +95,14 @@ void ReplaceNullWithOtherType(std::vector* descrs) { void ReplaceTypes(const std::shared_ptr& type, std::vector* descrs) { - for (auto& descr : *descrs) { - descr.type = type; + ReplaceTypes(type, descrs->data(), descrs->size()); +} + +void ReplaceTypes(const std::shared_ptr& type, ValueDescr* begin, + size_t count) { + auto* end = begin + count; + for (auto* it = begin; it != end; it++) { + it->type = type; } } @@ -160,6 +166,7 @@ std::shared_ptr CommonNumeric(const ValueDescr* begin, size_t count) { std::shared_ptr CommonTimestamp(const std::vector& descrs) { TimeUnit::type finest_unit = TimeUnit::SECOND; + const std::string* timezone = nullptr; for (const auto& descr : descrs) { auto id = descr.type->id(); @@ -168,16 +175,22 @@ std::shared_ptr CommonTimestamp(const std::vector& descrs) case Type::DATE32: case Type::DATE64: continue; - case Type::TIMESTAMP: - finest_unit = - std::max(finest_unit, checked_cast(*descr.type).unit()); + case Type::TIMESTAMP: { + const auto& ty = checked_cast(*descr.type); + // Don't cast to common timezone by default (may not make + // sense for all kernels) + if (timezone && *timezone != ty.timezone()) return nullptr; + timezone = &ty.timezone(); + finest_unit = std::max(finest_unit, ty.unit()); continue; + } default: return nullptr; } } - return timestamp(finest_unit); + // Don't cast if we see only dates, no timestamps + return timezone ? timestamp(finest_unit, *timezone) : nullptr; } std::shared_ptr CommonBinary(const std::vector& descrs) { @@ -290,6 +303,59 @@ Status CastBinaryDecimalArgs(DecimalPromotion promotion, return Status::OK(); } +Status CastDecimalArgs(ValueDescr* begin, size_t count) { + Type::type casted_type_id = Type::DECIMAL128; + auto* end = begin + count; + + int32_t max_scale = 0; + for (auto* it = begin; it != end; ++it) { + const auto& ty = *it->type; + if (is_floating(ty.id())) { + // Decimal + float = float + ReplaceTypes(float64(), begin, count); + return Status::OK(); + } else if (is_integer(ty.id())) { + // Nothing to do here + } else if (is_decimal(ty.id())) { + max_scale = std::max(max_scale, checked_cast(ty).scale()); + if (ty.id() == Type::DECIMAL256) { + casted_type_id = Type::DECIMAL256; + } + } else { + // Non-numeric, can't cast + return Status::OK(); + } + } + + // All integer and decimal, rescale + int32_t common_precision = 0; + for (auto* it = begin; it != end; ++it) { + const auto& ty = *it->type; + if (is_integer(ty.id())) { + ARROW_ASSIGN_OR_RAISE(auto precision, MaxDecimalDigitsForInteger(ty.id())); + precision += max_scale; + common_precision = std::max(common_precision, precision); + } else if (is_decimal(ty.id())) { + const auto& decimal_ty = checked_cast(ty); + auto precision = decimal_ty.precision(); + const auto scale = decimal_ty.scale(); + precision += max_scale - scale; + common_precision = std::max(common_precision, precision); + } + } + + if (common_precision > BasicDecimal128::kMaxPrecision) { + casted_type_id = Type::DECIMAL256; + } + + for (auto* it = begin; it != end; ++it) { + ARROW_ASSIGN_OR_RAISE(it->type, + DecimalType::Make(casted_type_id, common_precision, max_scale)); + } + + return Status::OK(); +} + bool HasDecimal(const std::vector& descrs) { for (const auto& descr : descrs) { if (is_decimal(descr.type->id())) { diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.h b/cpp/src/arrow/compute/kernels/codegen_internal.h index f9ce34b06e0..0648ac61e92 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal.h +++ b/cpp/src/arrow/compute/kernels/codegen_internal.h @@ -1279,6 +1279,9 @@ void ReplaceNullWithOtherType(std::vector* descrs); ARROW_EXPORT void ReplaceTypes(const std::shared_ptr&, std::vector* descrs); +ARROW_EXPORT +void ReplaceTypes(const std::shared_ptr&, ValueDescr* descrs, size_t count); + ARROW_EXPORT std::shared_ptr CommonNumeric(const std::vector& descrs); @@ -1298,9 +1301,17 @@ enum class DecimalPromotion : uint8_t { kDivide, }; +/// Given two arguments, at least one of which is decimal, promote all +/// to not necessarily identical types, but types which are compatible +/// for the given operator (add/multiply/divide). ARROW_EXPORT Status CastBinaryDecimalArgs(DecimalPromotion promotion, std::vector* descrs); +/// Given one or more arguments, at least one of which is decimal, +/// promote all to an identical type. +ARROW_EXPORT +Status CastDecimalArgs(ValueDescr* begin, size_t count); + ARROW_EXPORT bool HasDecimal(const std::vector& descrs); diff --git a/cpp/src/arrow/compute/kernels/codegen_internal_test.cc b/cpp/src/arrow/compute/kernels/codegen_internal_test.cc new file mode 100644 index 00000000000..9ba5377de08 --- /dev/null +++ b/cpp/src/arrow/compute/kernels/codegen_internal_test.cc @@ -0,0 +1,86 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include + +#include "arrow/compute/kernels/codegen_internal.h" +#include "arrow/testing/gtest_util.h" +#include "arrow/type.h" +#include "arrow/type_fwd.h" + +namespace arrow { +namespace compute { +namespace internal { + +TEST(TestDispatchBest, CastDecimalArgs) { + std::vector args; + + // Any float -> all float + args = {decimal128(3, 2), float64()}; + ASSERT_OK(CastDecimalArgs(args.data(), args.size())); + AssertTypeEqual(*args[0].type, *float64()); + AssertTypeEqual(*args[1].type, *float64()); + + // Promote to common decimal width + args = {decimal128(3, 2), decimal256(3, 2)}; + ASSERT_OK(CastDecimalArgs(args.data(), args.size())); + AssertTypeEqual(*args[0].type, *decimal256(3, 2)); + AssertTypeEqual(*args[1].type, *decimal256(3, 2)); + + // Rescale so all have common scale/precision + args = {decimal128(3, 2), decimal128(3, 0)}; + ASSERT_OK(CastDecimalArgs(args.data(), args.size())); + AssertTypeEqual(*args[0].type, *decimal128(5, 2)); + AssertTypeEqual(*args[1].type, *decimal128(5, 2)); + + // Integer -> decimal with appropriate precision + args = {decimal128(3, 0), int64()}; + ASSERT_OK(CastDecimalArgs(args.data(), args.size())); + AssertTypeEqual(*args[0].type, *decimal128(19, 0)); + AssertTypeEqual(*args[1].type, *decimal128(19, 0)); + + args = {decimal128(3, 1), int64()}; + ASSERT_OK(CastDecimalArgs(args.data(), args.size())); + AssertTypeEqual(*args[0].type, *decimal128(20, 1)); + AssertTypeEqual(*args[1].type, *decimal128(20, 1)); + + // Overflow decimal128 max precision -> promote to decimal256 + args = {decimal128(38, 0), decimal128(37, 2)}; + ASSERT_OK(CastDecimalArgs(args.data(), args.size())); + AssertTypeEqual(*args[0].type, *decimal256(40, 2)); + AssertTypeEqual(*args[1].type, *decimal256(40, 2)); +} + +TEST(TestDispatchBest, CommonTimestamp) { + AssertTypeEqual( + timestamp(TimeUnit::NANO), + CommonTimestamp({timestamp(TimeUnit::SECOND), timestamp(TimeUnit::NANO)})); + AssertTypeEqual(timestamp(TimeUnit::NANO, "UTC"), + CommonTimestamp({timestamp(TimeUnit::SECOND, "UTC"), + timestamp(TimeUnit::NANO, "UTC")})); + AssertTypeEqual(timestamp(TimeUnit::NANO), + CommonTimestamp({date32(), timestamp(TimeUnit::NANO)})); + ASSERT_EQ(nullptr, CommonTimestamp({date32(), date64()})); + ASSERT_EQ(nullptr, CommonTimestamp({timestamp(TimeUnit::SECOND), + timestamp(TimeUnit::SECOND, "UTC")})); + ASSERT_EQ(nullptr, CommonTimestamp({timestamp(TimeUnit::SECOND, "America/Phoenix"), + timestamp(TimeUnit::SECOND, "UTC")})); +} + +} // namespace internal +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else.cc b/cpp/src/arrow/compute/kernels/scalar_if_else.cc index 23ac666931e..a5ff191a7b4 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else.cc @@ -1737,11 +1737,20 @@ struct CoalesceFunction : ScalarFunction { Result DispatchBest(std::vector* values) const override { RETURN_NOT_OK(CheckArity(*values)); using arrow::compute::detail::DispatchExactImpl; - if (auto kernel = DispatchExactImpl(this, *values)) return kernel; + // Do not DispatchExact here since we want to rescale decimals if necessary EnsureDictionaryDecoded(values); if (auto type = CommonNumeric(*values)) { ReplaceTypes(type, values); } + if (auto type = CommonBinary(*values)) { + ReplaceTypes(type, values); + } + if (auto type = CommonTimestamp(*values)) { + ReplaceTypes(type, values); + } + if (HasDecimal(*values)) { + RETURN_NOT_OK(CastDecimalArgs(values->data(), values->size())); + } if (auto kernel = DispatchExactImpl(this, *values)) return kernel; return arrow::compute::detail::NoMatchingKernel(this, *values); } @@ -2077,9 +2086,24 @@ static Status ExecVarWidthCoalesce(KernelContext* ctx, const ExecBatch& batch, D }); } +// Ensure parameterized types are identical. +static Status CheckIdenticalTypes(const ExecBatch& batch) { + auto ty = batch[0].type(); + for (auto it = batch.values.begin() + 1; it != batch.values.end(); ++it) { + if (!ty->Equals(*it->type())) { + return Status::TypeError("coalesce: all types must be identical, expected: ", *ty, + ", but got: ", *it->type()); + } + } + return Status::OK(); +} + template struct CoalesceFunctor { static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + if (!TypeTraits::is_parameter_free) { + RETURN_NOT_OK(CheckIdenticalTypes(batch)); + } // Special case for two arguments (since "fill_null" is a common operation) if (batch.num_values() == 2) { return ExecBinaryCoalesce(ctx, batch[0], batch[1], batch.length, out); @@ -2180,6 +2204,7 @@ struct CoalesceFunctor> { template <> struct CoalesceFunctor { static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + RETURN_NOT_OK(CheckIdenticalTypes(batch)); for (const auto& datum : batch.values) { if (datum.is_array()) { return ExecArray(ctx, batch, out); @@ -2197,6 +2222,7 @@ struct CoalesceFunctor { template struct CoalesceFunctor> { static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + RETURN_NOT_OK(CheckIdenticalTypes(batch)); for (const auto& datum : batch.values) { if (datum.is_array()) { return ExecArray(ctx, batch, out); @@ -2214,6 +2240,7 @@ struct CoalesceFunctor> { template <> struct CoalesceFunctor { static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + RETURN_NOT_OK(CheckIdenticalTypes(batch)); for (const auto& datum : batch.values) { if (datum.is_array()) { return ExecArray(ctx, batch, out); @@ -2231,6 +2258,7 @@ struct CoalesceFunctor { template <> struct CoalesceFunctor { static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + RETURN_NOT_OK(CheckIdenticalTypes(batch)); for (const auto& datum : batch.values) { if (datum.is_array()) { return ExecArray(ctx, batch, out); @@ -2249,6 +2277,8 @@ template struct CoalesceFunctor> { static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { // Unions don't have top-level nulls, so a specialized implementation is needed + RETURN_NOT_OK(CheckIdenticalTypes(batch)); + for (const auto& datum : batch.values) { if (datum.is_array()) { return ExecArray(ctx, batch, out); @@ -2263,7 +2293,6 @@ struct CoalesceFunctor> { RETURN_NOT_OK(MakeBuilder(ctx->memory_pool(), out->type(), &raw_builder)); RETURN_NOT_OK(raw_builder->Reserve(batch.length)); - // TODO: make sure differing union types are rejected const UnionType& type = checked_cast(*out->type()); for (int64_t i = 0; i < batch.length; i++) { bool set = false; diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc index 376c056ca87..84c0404c049 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc @@ -1912,6 +1912,17 @@ TYPED_TEST(TestCoalesceList, ListOfListOfInt) { CheckScalar("coalesce", {values1, scalar_null}, values1); } +TYPED_TEST(TestCoalesceList, Errors) { + auto type1 = std::make_shared(int64()); + auto type2 = std::make_shared(utf8()); + EXPECT_RAISES_WITH_MESSAGE_THAT( + TypeError, ::testing::HasSubstr("coalesce: all types must be identical"), + CallFunction("coalesce", { + ArrayFromJSON(type1, "[null]"), + ArrayFromJSON(type2, "[null]"), + })); +} + TEST(TestCoalesce, Null) { auto type = null(); auto scalar_null = ScalarFromJSON(type, "null"); @@ -2005,6 +2016,13 @@ TEST(TestCoalesce, Decimal) { CheckScalar("coalesce", {scalar1, values1}, ArrayFromJSON(type, R"(["1.23", "1.23", "1.23", "1.23"])")); } + // Ensure promotion + CheckScalar("coalesce", + { + ArrayFromJSON(decimal128(3, 2), R"(["1.23", null])"), + ArrayFromJSON(decimal128(4, 1), R"([null, "1.0"])"), + }, + ArrayFromJSON(decimal128(5, 2), R"(["1.23", "1.00"])")); } TEST(TestCoalesce, FixedSizeBinary) { @@ -2033,6 +2051,15 @@ TEST(TestCoalesce, FixedSizeBinary) { ArrayFromJSON(type, R"(["mno", "def", "ghi", "jkl"])")); CheckScalar("coalesce", {scalar1, values1}, ArrayFromJSON(type, R"(["abc", "abc", "abc", "abc"])")); + + EXPECT_RAISES_WITH_MESSAGE_THAT( + TypeError, + ::testing::HasSubstr("coalesce: all types must be identical, expected: " + "fixed_size_binary[3], but got: fixed_size_binary[2]"), + CallFunction("coalesce", { + ArrayFromJSON(type, "[null]"), + ArrayFromJSON(fixed_size_binary(2), "[null]"), + })); } TEST(TestCoalesce, FixedSizeListOfInt) { @@ -2061,6 +2088,16 @@ TEST(TestCoalesce, FixedSizeListOfInt) { ArrayFromJSON(type, R"([[1, 5], [2, null], [4, 8], [null, null]])")); CheckScalar("coalesce", {scalar1, values1}, ArrayFromJSON(type, R"([[42, null], [42, null], [42, null], [42, null]])")); + + EXPECT_RAISES_WITH_MESSAGE_THAT( + TypeError, + ::testing::HasSubstr( + "coalesce: all types must be identical, expected: fixed_size_list[2], but got: fixed_size_list[3]"), + CallFunction("coalesce", { + ArrayFromJSON(type, "[null]"), + ArrayFromJSON(fixed_size_list(uint8(), 3), "[null]"), + })); } TEST(TestCoalesce, FixedSizeListOfString) { @@ -2128,6 +2165,15 @@ TEST(TestCoalesce, Map) { "coalesce", {values1, values2}, ArrayFromJSON(type, R"([[[1, "b"]], [[2, "foo"], [4, null]], [[3, "test"]], []])")); CheckScalar("coalesce", {scalar1, values1}, *MakeArrayFromScalar(*scalar1, 4)); + + EXPECT_RAISES_WITH_MESSAGE_THAT( + TypeError, + ::testing::HasSubstr("coalesce: all types must be identical, expected: map, but got: map"), + CallFunction("coalesce", { + ArrayFromJSON(type, "[null]"), + ArrayFromJSON(map(int64(), int32()), "[null]"), + })); } TEST(TestCoalesce, Struct) { @@ -2161,6 +2207,16 @@ TEST(TestCoalesce, Struct) { type, R"([[21, "foobar", [1, null, 2]], [null, "eggs", []], [0, "", [null]], [32, "abc", [1, 2, 3]]])")); CheckScalar("coalesce", {scalar1, values1}, *MakeArrayFromScalar(*scalar1, 4)); + + EXPECT_RAISES_WITH_MESSAGE_THAT( + TypeError, + ::testing::HasSubstr("coalesce: all types must be identical, expected: struct, but got: struct"), + CallFunction("coalesce", + { + ArrayFromJSON(struct_({field("str", utf8())}), "[null]"), + ArrayFromJSON(struct_({field("int", uint16())}), "[null]"), + })); } TEST(TestCoalesce, UnionBoolString) { @@ -2191,6 +2247,38 @@ TEST(TestCoalesce, UnionBoolString) { ArrayFromJSON(type, R"([[2, true], [2, false], [7, "bar"], [7, "baz"]])")); CheckScalar("coalesce", {scalar1, values1}, *MakeArrayFromScalar(*scalar1, 4)); } + + EXPECT_RAISES_WITH_MESSAGE_THAT( + TypeError, + ::testing::HasSubstr("coalesce: all types must be identical, expected: " + "sparse_union, but got: sparse_union"), + CallFunction( + "coalesce", + { + ArrayFromJSON(sparse_union({field("a", boolean())}), "[[0, true]]"), + ArrayFromJSON(sparse_union({field("a", int64())}), "[[0, 1]]"), + })); +} + +TEST(TestCoalesce, DispatchBest) { + CheckDispatchBest("coalesce", {int8(), float64()}, {float64(), float64()}); + CheckDispatchBest("coalesce", {int8(), uint32()}, {int64(), int64()}); + CheckDispatchBest("coalesce", {binary(), utf8()}, {binary(), binary()}); + CheckDispatchBest("coalesce", {binary(), large_binary()}, + {large_binary(), large_binary()}); + CheckDispatchBest("coalesce", {int32(), decimal128(3, 2)}, + {decimal128(12, 2), decimal128(12, 2)}); + CheckDispatchBest("coalesce", {float32(), decimal128(3, 2)}, {float64(), float64()}); + CheckDispatchBest("coalesce", {decimal128(3, 2), decimal256(3, 2)}, + {decimal256(3, 2), decimal256(3, 2)}); + CheckDispatchBest("coalesce", {timestamp(TimeUnit::SECOND), date32()}, + {timestamp(TimeUnit::SECOND), timestamp(TimeUnit::SECOND)}); + CheckDispatchBest("coalesce", {timestamp(TimeUnit::SECOND), timestamp(TimeUnit::MILLI)}, + {timestamp(TimeUnit::MILLI), timestamp(TimeUnit::MILLI)}); + CheckDispatchFails("coalesce", { + sparse_union({field("a", boolean())}), + dense_union({field("a", boolean())}), + }); } template diff --git a/cpp/src/arrow/scalar.cc b/cpp/src/arrow/scalar.cc index d054a313373..b441af55545 100644 --- a/cpp/src/arrow/scalar.cc +++ b/cpp/src/arrow/scalar.cc @@ -923,9 +923,8 @@ Status CastImpl(const StructScalar& from, StringScalar* to) { Status CastImpl(const UnionScalar& from, StringScalar* to) { const auto& union_ty = checked_cast(*from.type); std::stringstream ss; - ss << '(' << static_cast(from.type_code) << ": " - << union_ty.field(union_ty.child_ids()[from.type_code])->ToString() << " = " - << from.value->ToString() << ')'; + ss << "union{" << union_ty.field(union_ty.child_ids()[from.type_code])->ToString() + << " = " << from.value->ToString() << '}'; to->value = Buffer::FromString(ss.str()); return Status::OK(); } From c11c4bd907cb8609862362268e299bd054b97504 Mon Sep 17 00:00:00 2001 From: David Li Date: Fri, 17 Sep 2021 09:48:43 -0400 Subject: [PATCH 06/11] ARROW-13390: [C++] Add/fix more dispatch tests --- .../arrow/compute/kernels/codegen_internal.cc | 17 ++++- .../compute/kernels/codegen_internal_test.cc | 49 ++++++++++---- .../compute/kernels/scalar_arithmetic_test.cc | 65 ++++++++++++------- .../arrow/compute/kernels/scalar_compare.cc | 5 +- .../compute/kernels/scalar_compare_test.cc | 48 +++++--------- cpp/src/arrow/compute/kernels/test_util.cc | 6 ++ 6 files changed, 120 insertions(+), 70 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.cc b/cpp/src/arrow/compute/kernels/codegen_internal.cc index b8f3d165ae2..f5850c80765 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal.cc +++ b/cpp/src/arrow/compute/kernels/codegen_internal.cc @@ -167,13 +167,20 @@ std::shared_ptr CommonNumeric(const ValueDescr* begin, size_t count) { std::shared_ptr CommonTimestamp(const std::vector& descrs) { TimeUnit::type finest_unit = TimeUnit::SECOND; const std::string* timezone = nullptr; + bool saw_date32 = false; + bool saw_date64 = false; for (const auto& descr : descrs) { auto id = descr.type->id(); // a common timestamp is only possible if all types are timestamp like switch (id) { case Type::DATE32: + // Date32's unit is days, but the coarsest we have is seconds + saw_date32 = true; + continue; case Type::DATE64: + finest_unit = std::max(finest_unit, TimeUnit::MILLI); + saw_date64 = true; continue; case Type::TIMESTAMP: { const auto& ty = checked_cast(*descr.type); @@ -189,8 +196,14 @@ std::shared_ptr CommonTimestamp(const std::vector& descrs) } } - // Don't cast if we see only dates, no timestamps - return timezone ? timestamp(finest_unit, *timezone) : nullptr; + if (timezone) { + // At least one timestamp seen + return timestamp(finest_unit, *timezone); + } else if (saw_date32 && saw_date64) { + // Saw mixed date types + return date64(); + } + return nullptr; } std::shared_ptr CommonBinary(const std::vector& descrs) { diff --git a/cpp/src/arrow/compute/kernels/codegen_internal_test.cc b/cpp/src/arrow/compute/kernels/codegen_internal_test.cc index 9ba5377de08..6b19837f3ac 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal_test.cc +++ b/cpp/src/arrow/compute/kernels/codegen_internal_test.cc @@ -26,43 +26,63 @@ namespace arrow { namespace compute { namespace internal { +TEST(TestDispatchBest, CastBinaryDecimalArgs) { + std::vector args; + std::vector modes = { + DecimalPromotion::kAdd, DecimalPromotion::kMultiply, DecimalPromotion::kDivide}; + + // Any float -> all float + for (auto mode : modes) { + args = {decimal128(3, 2), float64()}; + ASSERT_OK(CastBinaryDecimalArgs(mode, &args)); + AssertTypeEqual(args[0].type, float64()); + AssertTypeEqual(args[1].type, float64()); + } + + // Integer -> decimal with common scale + args = {decimal128(1, 0), int64()}; + ASSERT_OK(CastBinaryDecimalArgs(DecimalPromotion::kAdd, &args)); + AssertTypeEqual(args[0].type, decimal128(1, 0)); + AssertTypeEqual(args[1].type, decimal128(19, 0)); +} + TEST(TestDispatchBest, CastDecimalArgs) { std::vector args; // Any float -> all float args = {decimal128(3, 2), float64()}; ASSERT_OK(CastDecimalArgs(args.data(), args.size())); - AssertTypeEqual(*args[0].type, *float64()); - AssertTypeEqual(*args[1].type, *float64()); + AssertTypeEqual(args[0].type, float64()); + AssertTypeEqual(args[1].type, float64()); // Promote to common decimal width args = {decimal128(3, 2), decimal256(3, 2)}; ASSERT_OK(CastDecimalArgs(args.data(), args.size())); - AssertTypeEqual(*args[0].type, *decimal256(3, 2)); - AssertTypeEqual(*args[1].type, *decimal256(3, 2)); + AssertTypeEqual(args[0].type, decimal256(3, 2)); + AssertTypeEqual(args[1].type, decimal256(3, 2)); // Rescale so all have common scale/precision args = {decimal128(3, 2), decimal128(3, 0)}; ASSERT_OK(CastDecimalArgs(args.data(), args.size())); - AssertTypeEqual(*args[0].type, *decimal128(5, 2)); - AssertTypeEqual(*args[1].type, *decimal128(5, 2)); + AssertTypeEqual(args[0].type, decimal128(5, 2)); + AssertTypeEqual(args[1].type, decimal128(5, 2)); // Integer -> decimal with appropriate precision args = {decimal128(3, 0), int64()}; ASSERT_OK(CastDecimalArgs(args.data(), args.size())); - AssertTypeEqual(*args[0].type, *decimal128(19, 0)); - AssertTypeEqual(*args[1].type, *decimal128(19, 0)); + AssertTypeEqual(args[0].type, decimal128(19, 0)); + AssertTypeEqual(args[1].type, decimal128(19, 0)); args = {decimal128(3, 1), int64()}; ASSERT_OK(CastDecimalArgs(args.data(), args.size())); - AssertTypeEqual(*args[0].type, *decimal128(20, 1)); - AssertTypeEqual(*args[1].type, *decimal128(20, 1)); + AssertTypeEqual(args[0].type, decimal128(20, 1)); + AssertTypeEqual(args[1].type, decimal128(20, 1)); // Overflow decimal128 max precision -> promote to decimal256 args = {decimal128(38, 0), decimal128(37, 2)}; ASSERT_OK(CastDecimalArgs(args.data(), args.size())); - AssertTypeEqual(*args[0].type, *decimal256(40, 2)); - AssertTypeEqual(*args[1].type, *decimal256(40, 2)); + AssertTypeEqual(args[0].type, decimal256(40, 2)); + AssertTypeEqual(args[1].type, decimal256(40, 2)); } TEST(TestDispatchBest, CommonTimestamp) { @@ -74,7 +94,10 @@ TEST(TestDispatchBest, CommonTimestamp) { timestamp(TimeUnit::NANO, "UTC")})); AssertTypeEqual(timestamp(TimeUnit::NANO), CommonTimestamp({date32(), timestamp(TimeUnit::NANO)})); - ASSERT_EQ(nullptr, CommonTimestamp({date32(), date64()})); + AssertTypeEqual(timestamp(TimeUnit::MILLI), + CommonTimestamp({date64(), timestamp(TimeUnit::SECOND)})); + AssertTypeEqual(date64(), CommonTimestamp({date32(), date64()})); + ASSERT_EQ(nullptr, CommonTimestamp({date32(), date32()})); ASSERT_EQ(nullptr, CommonTimestamp({timestamp(TimeUnit::SECOND), timestamp(TimeUnit::SECOND, "UTC")})); ASSERT_EQ(nullptr, CommonTimestamp({timestamp(TimeUnit::SECOND, "America/Phoenix"), diff --git a/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc b/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc index 35b734e29f6..3aa345bd9da 100644 --- a/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc @@ -1622,31 +1622,29 @@ TEST(TestBinaryDecimalArithmetic, DispatchBest) { } } - // decimal, integer - for (std::string name : {"add", "subtract", "multiply", "divide"}) { + // decimal, decimal -> decimal and decimal, integer -> decimal + for (std::string name : {"add", "subtract"}) { for (std::string suffix : {"", "_checked"}) { name += suffix; CheckDispatchBest(name, {int64(), decimal128(1, 0)}, - {decimal128(1, 0), decimal128(1, 0)}); + {decimal128(19, 0), decimal128(1, 0)}); CheckDispatchBest(name, {decimal128(1, 0), int64()}, - {decimal128(1, 0), decimal128(1, 0)}); - } - } - - // decimal, decimal - for (std::string name : {"add", "subtract"}) { - for (std::string suffix : {"", "_checked"}) { - name += suffix; + {decimal128(1, 0), decimal128(19, 0)}); CheckDispatchBest(name, {decimal128(2, 1), decimal128(2, 1)}, - {decimal128(3, 1), decimal128(3, 1)}); + {decimal128(2, 1), decimal128(2, 1)}); CheckDispatchBest(name, {decimal256(2, 1), decimal256(2, 1)}, - {decimal256(3, 1), decimal256(3, 1)}); + {decimal256(2, 1), decimal256(2, 1)}); CheckDispatchBest(name, {decimal128(2, 1), decimal256(2, 1)}, - {decimal256(3, 1), decimal256(3, 1)}); + {decimal256(2, 1), decimal256(2, 1)}); CheckDispatchBest(name, {decimal256(2, 1), decimal128(2, 1)}, - {decimal256(3, 1), decimal256(3, 1)}); + {decimal256(2, 1), decimal256(2, 1)}); + + CheckDispatchBest(name, {decimal128(2, 0), decimal128(2, 1)}, + {decimal128(3, 1), decimal128(2, 1)}); + CheckDispatchBest(name, {decimal128(2, 1), decimal128(2, 0)}, + {decimal128(2, 1), decimal128(3, 1)}); } } { @@ -1654,29 +1652,50 @@ TEST(TestBinaryDecimalArithmetic, DispatchBest) { for (std::string suffix : {"", "_checked"}) { name += suffix; + CheckDispatchBest(name, {int64(), decimal128(1, 0)}, + {decimal128(19, 0), decimal128(1, 0)}); + CheckDispatchBest(name, {decimal128(1, 0), int64()}, + {decimal128(1, 0), decimal128(19, 0)}); + CheckDispatchBest(name, {decimal128(2, 1), decimal128(2, 1)}, - {decimal128(5, 2), decimal128(5, 2)}); + {decimal128(2, 1), decimal128(2, 1)}); CheckDispatchBest(name, {decimal256(2, 1), decimal256(2, 1)}, - {decimal256(5, 2), decimal256(5, 2)}); + {decimal256(2, 1), decimal256(2, 1)}); CheckDispatchBest(name, {decimal128(2, 1), decimal256(2, 1)}, - {decimal256(5, 2), decimal256(5, 2)}); + {decimal256(2, 1), decimal256(2, 1)}); CheckDispatchBest(name, {decimal256(2, 1), decimal128(2, 1)}, - {decimal256(5, 2), decimal256(5, 2)}); + {decimal256(2, 1), decimal256(2, 1)}); + + CheckDispatchBest(name, {decimal128(2, 0), decimal128(2, 1)}, + {decimal128(2, 0), decimal128(2, 1)}); + CheckDispatchBest(name, {decimal128(2, 1), decimal128(2, 0)}, + {decimal128(2, 1), decimal128(2, 0)}); } } { std::string name = "divide"; for (std::string suffix : {"", "_checked"}) { name += suffix; + SCOPED_TRACE(name); + + CheckDispatchBest(name, {int64(), decimal128(1, 0)}, + {decimal128(23, 4), decimal128(1, 0)}); + CheckDispatchBest(name, {decimal128(1, 0), int64()}, + {decimal128(21, 20), decimal128(19, 0)}); CheckDispatchBest(name, {decimal128(2, 1), decimal128(2, 1)}, - {decimal128(6, 4), decimal128(6, 4)}); + {decimal128(6, 5), decimal128(2, 1)}); CheckDispatchBest(name, {decimal256(2, 1), decimal256(2, 1)}, - {decimal256(6, 4), decimal256(6, 4)}); + {decimal256(6, 5), decimal256(2, 1)}); CheckDispatchBest(name, {decimal128(2, 1), decimal256(2, 1)}, - {decimal256(6, 4), decimal256(6, 4)}); + {decimal256(6, 5), decimal256(2, 1)}); CheckDispatchBest(name, {decimal256(2, 1), decimal128(2, 1)}, - {decimal256(6, 4), decimal256(6, 4)}); + {decimal256(6, 5), decimal256(2, 1)}); + + CheckDispatchBest(name, {decimal128(2, 0), decimal128(2, 1)}, + {decimal128(7, 5), decimal128(2, 1)}); + CheckDispatchBest(name, {decimal128(2, 1), decimal128(2, 0)}, + {decimal128(5, 4), decimal128(2, 0)}); } } } diff --git a/cpp/src/arrow/compute/kernels/scalar_compare.cc b/cpp/src/arrow/compute/kernels/scalar_compare.cc index 17eae5adbbd..9751127f833 100644 --- a/cpp/src/arrow/compute/kernels/scalar_compare.cc +++ b/cpp/src/arrow/compute/kernels/scalar_compare.cc @@ -159,6 +159,9 @@ struct CompareFunction : ScalarFunction { Result DispatchBest(std::vector* values) const override { RETURN_NOT_OK(CheckArity(*values)); + if (HasDecimal(*values)) { + RETURN_NOT_OK(CastBinaryDecimalArgs(DecimalPromotion::kAdd, values)); + } using arrow::compute::detail::DispatchExactImpl; if (auto kernel = DispatchExactImpl(this, *values)) return kernel; @@ -172,8 +175,6 @@ struct CompareFunction : ScalarFunction { ReplaceTypes(type, values); } else if (auto type = CommonBinary(*values)) { ReplaceTypes(type, values); - } else if (HasDecimal(*values)) { - RETURN_NOT_OK(CastBinaryDecimalArgs(DecimalPromotion::kAdd, values)); } if (auto kernel = DispatchExactImpl(this, *values)) return kernel; diff --git a/cpp/src/arrow/compute/kernels/scalar_compare_test.cc b/cpp/src/arrow/compute/kernels/scalar_compare_test.cc index a5bc89d87f3..3d9eb018e72 100644 --- a/cpp/src/arrow/compute/kernels/scalar_compare_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_compare_test.cc @@ -29,6 +29,7 @@ #include "arrow/compute/kernels/test_util.h" #include "arrow/testing/gtest_common.h" #include "arrow/testing/gtest_util.h" +#include "arrow/testing/matchers.h" #include "arrow/testing/random.h" #include "arrow/type.h" #include "arrow/type_traits.h" @@ -601,9 +602,9 @@ TEST(TestCompareKernel, DispatchBest) { CheckDispatchBest(name, {decimal128(3, 2), float64()}, {float64(), float64()}); CheckDispatchBest(name, {float64(), decimal128(3, 2)}, {float64(), float64()}); CheckDispatchBest(name, {decimal128(3, 2), int64()}, - {decimal128(3, 2), decimal128(3, 2)}); + {decimal128(3, 2), decimal128(21, 2)}); CheckDispatchBest(name, {int64(), decimal128(3, 2)}, - {decimal128(3, 2), decimal128(3, 2)}); + {decimal128(21, 2), decimal128(3, 2)}); } } @@ -1083,34 +1084,21 @@ TYPED_TEST(TestVarArgsCompareParametricTemporal, MaxElementWise) { } TEST(TestMaxElementWiseMinElementWise, CommonTimestamp) { - { - auto t1 = std::make_shared(TimeUnit::SECOND); - auto t2 = std::make_shared(TimeUnit::MILLI); - auto expected = MakeScalar(t2, 1000).ValueOrDie(); - ASSERT_OK_AND_ASSIGN(auto actual, - MinElementWise({Datum(MakeScalar(t1, 1).ValueOrDie()), - Datum(MakeScalar(t2, 12000).ValueOrDie())})); - AssertScalarsEqual(*expected, *actual.scalar(), /*verbose=*/true); - } - { - auto t1 = std::make_shared(); - auto t2 = std::make_shared(TimeUnit::SECOND); - auto expected = MakeScalar(t2, 86401).ValueOrDie(); - ASSERT_OK_AND_ASSIGN(auto actual, - MaxElementWise({Datum(MakeScalar(t1, 1).ValueOrDie()), - Datum(MakeScalar(t2, 86401).ValueOrDie())})); - AssertScalarsEqual(*expected, *actual.scalar(), /*verbose=*/true); - } - { - auto t1 = std::make_shared(); - auto t2 = std::make_shared(); - auto t3 = std::make_shared(TimeUnit::SECOND); - auto expected = MakeScalar(t3, 86400).ValueOrDie(); - ASSERT_OK_AND_ASSIGN( - auto actual, MinElementWise({Datum(MakeScalar(t1, 1).ValueOrDie()), - Datum(MakeScalar(t2, 2 * 86400000).ValueOrDie())})); - AssertScalarsEqual(*expected, *actual.scalar(), /*verbose=*/true); - } + EXPECT_THAT(MinElementWise({ + ScalarFromJSON(timestamp(TimeUnit::SECOND), "1"), + ScalarFromJSON(timestamp(TimeUnit::MILLI), "12000"), + }), + ResultWith(ScalarFromJSON(timestamp(TimeUnit::MILLI), "1000"))); + EXPECT_THAT(MaxElementWise({ + ScalarFromJSON(date32(), "1"), + ScalarFromJSON(timestamp(TimeUnit::SECOND), "86401"), + }), + ResultWith(ScalarFromJSON(timestamp(TimeUnit::SECOND), "86401"))); + EXPECT_THAT(MinElementWise({ + ScalarFromJSON(date32(), "1"), + ScalarFromJSON(date64(), "172800000"), + }), + ResultWith(ScalarFromJSON(date64(), "86400000"))); } } // namespace compute diff --git a/cpp/src/arrow/compute/kernels/test_util.cc b/cpp/src/arrow/compute/kernels/test_util.cc index cedc03698a1..e72c3dce294 100644 --- a/cpp/src/arrow/compute/kernels/test_util.cc +++ b/cpp/src/arrow/compute/kernels/test_util.cc @@ -344,6 +344,12 @@ void CheckDispatchBest(std::string func_name, std::vector original_v << actual_kernel->signature->ToString() << "\n" << " DispatchExact" << ValueDescr::ToString(expected_equivalent_values) << " => " << expected_kernel->signature->ToString(); + EXPECT_EQ(values.size(), expected_equivalent_values.size()); + for (size_t i = 0; i < values.size(); i++) { + EXPECT_EQ(values[i].shape, expected_equivalent_values[i].shape) + << "Argument " << i << " should have the same shape"; + AssertTypeEqual(values[i].type, expected_equivalent_values[i].type); + } } void CheckDispatchFails(std::string func_name, std::vector values) { From 56a975ceb9fefbbd9be04a898cc708bbf0ecc958 Mon Sep 17 00:00:00 2001 From: David Li Date: Fri, 17 Sep 2021 10:41:00 -0400 Subject: [PATCH 07/11] ARROW-13390: [C++] Fix lint error --- cpp/src/arrow/compute/kernels/CMakeLists.txt | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/CMakeLists.txt b/cpp/src/arrow/compute/kernels/CMakeLists.txt index cd6bc19c869..28686a9cafa 100644 --- a/cpp/src/arrow/compute/kernels/CMakeLists.txt +++ b/cpp/src/arrow/compute/kernels/CMakeLists.txt @@ -75,6 +75,4 @@ add_arrow_benchmark(aggregate_benchmark PREFIX "arrow-compute") # ---------------------------------------------------------------------- # Utilities -add_arrow_compute_test(kernel_utility_test - SOURCES - codegen_internal_test.cc) +add_arrow_compute_test(kernel_utility_test SOURCES codegen_internal_test.cc) From 30585595c88bd25a07ff95a204cf2a5a3909a2ac Mon Sep 17 00:00:00 2001 From: David Li Date: Tue, 21 Sep 2021 09:43:50 -0400 Subject: [PATCH 08/11] ARROW-13390: [C++] Address review suggestions --- .../arrow/compute/kernels/codegen_internal.cc | 21 ++++-- .../arrow/compute/kernels/codegen_internal.h | 2 +- .../compute/kernels/codegen_internal_test.cc | 73 +++++++++++++++---- .../arrow/compute/kernels/scalar_compare.cc | 4 +- .../compute/kernels/scalar_compare_test.cc | 2 +- .../arrow/compute/kernels/scalar_if_else.cc | 61 +--------------- .../compute/kernels/scalar_if_else_test.cc | 15 ++-- cpp/src/arrow/util/basic_decimal.cc | 8 ++ 8 files changed, 98 insertions(+), 88 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.cc b/cpp/src/arrow/compute/kernels/codegen_internal.cc index f5850c80765..8a8292c1c89 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal.cc +++ b/cpp/src/arrow/compute/kernels/codegen_internal.cc @@ -164,7 +164,7 @@ std::shared_ptr CommonNumeric(const ValueDescr* begin, size_t count) { return int8(); } -std::shared_ptr CommonTimestamp(const std::vector& descrs) { +std::shared_ptr CommonTemporal(const std::vector& descrs) { TimeUnit::type finest_unit = TimeUnit::SECOND; const std::string* timezone = nullptr; bool saw_date32 = false; @@ -199,9 +199,10 @@ std::shared_ptr CommonTimestamp(const std::vector& descrs) if (timezone) { // At least one timestamp seen return timestamp(finest_unit, *timezone); - } else if (saw_date32 && saw_date64) { - // Saw mixed date types + } else if (saw_date64) { return date64(); + } else if (saw_date32) { + return date32(); } return nullptr; } @@ -321,12 +322,12 @@ Status CastDecimalArgs(ValueDescr* begin, size_t count) { auto* end = begin + count; int32_t max_scale = 0; + bool any_floating = false; for (auto* it = begin; it != end; ++it) { const auto& ty = *it->type; if (is_floating(ty.id())) { // Decimal + float = float - ReplaceTypes(float64(), begin, count); - return Status::OK(); + any_floating = true; } else if (is_integer(ty.id())) { // Nothing to do here } else if (is_decimal(ty.id())) { @@ -339,6 +340,10 @@ Status CastDecimalArgs(ValueDescr* begin, size_t count) { return Status::OK(); } } + if (any_floating) { + ReplaceTypes(float64(), begin, count); + return Status::OK(); + } // All integer and decimal, rescale int32_t common_precision = 0; @@ -357,7 +362,11 @@ Status CastDecimalArgs(ValueDescr* begin, size_t count) { } } - if (common_precision > BasicDecimal128::kMaxPrecision) { + if (common_precision > BasicDecimal256::kMaxPrecision) { + return Status::Invalid("Result precision (", common_precision, + ") exceeds max precision of Decimal256 (", + BasicDecimal256::kMaxPrecision, ")"); + } else if (common_precision > BasicDecimal128::kMaxPrecision) { casted_type_id = Type::DECIMAL256; } diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.h b/cpp/src/arrow/compute/kernels/codegen_internal.h index 0648ac61e92..5bd65d8c53c 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal.h +++ b/cpp/src/arrow/compute/kernels/codegen_internal.h @@ -1289,7 +1289,7 @@ ARROW_EXPORT std::shared_ptr CommonNumeric(const ValueDescr* begin, size_t count); ARROW_EXPORT -std::shared_ptr CommonTimestamp(const std::vector& descrs); +std::shared_ptr CommonTemporal(const std::vector& descrs); ARROW_EXPORT std::shared_ptr CommonBinary(const std::vector& descrs); diff --git a/cpp/src/arrow/compute/kernels/codegen_internal_test.cc b/cpp/src/arrow/compute/kernels/codegen_internal_test.cc index 6b19837f3ac..89fc35b38d2 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal_test.cc +++ b/cpp/src/arrow/compute/kernels/codegen_internal_test.cc @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +#include #include #include "arrow/compute/kernels/codegen_internal.h" @@ -44,6 +45,12 @@ TEST(TestDispatchBest, CastBinaryDecimalArgs) { ASSERT_OK(CastBinaryDecimalArgs(DecimalPromotion::kAdd, &args)); AssertTypeEqual(args[0].type, decimal128(1, 0)); AssertTypeEqual(args[1].type, decimal128(19, 0)); + + // Add: rescale so all have common scale + args = {decimal128(3, 2), decimal128(3, -2)}; + EXPECT_RAISES_WITH_MESSAGE_THAT( + NotImplemented, ::testing::HasSubstr("Decimals with negative scales not supported"), + CastBinaryDecimalArgs(DecimalPromotion::kAdd, &args)); } TEST(TestDispatchBest, CastDecimalArgs) { @@ -55,6 +62,12 @@ TEST(TestDispatchBest, CastDecimalArgs) { AssertTypeEqual(args[0].type, float64()); AssertTypeEqual(args[1].type, float64()); + args = {float32(), float64(), decimal128(3, 2)}; + ASSERT_OK(CastDecimalArgs(args.data(), args.size())); + AssertTypeEqual(args[0].type, float64()); + AssertTypeEqual(args[1].type, float64()); + AssertTypeEqual(args[2].type, float64()); + // Promote to common decimal width args = {decimal128(3, 2), decimal256(3, 2)}; ASSERT_OK(CastDecimalArgs(args.data(), args.size())); @@ -67,6 +80,17 @@ TEST(TestDispatchBest, CastDecimalArgs) { AssertTypeEqual(args[0].type, decimal128(5, 2)); AssertTypeEqual(args[1].type, decimal128(5, 2)); + args = {decimal128(3, 2), decimal128(3, -2)}; + ASSERT_OK(CastDecimalArgs(args.data(), args.size())); + AssertTypeEqual(args[0].type, decimal128(7, 2)); + AssertTypeEqual(args[1].type, decimal128(7, 2)); + + args = {decimal128(3, 0), decimal128(3, 1), decimal128(3, 2)}; + ASSERT_OK(CastDecimalArgs(args.data(), args.size())); + AssertTypeEqual(args[0].type, decimal128(5, 2)); + AssertTypeEqual(args[1].type, decimal128(5, 2)); + AssertTypeEqual(args[2].type, decimal128(5, 2)); + // Integer -> decimal with appropriate precision args = {decimal128(3, 0), int64()}; ASSERT_OK(CastDecimalArgs(args.data(), args.size())); @@ -78,30 +102,51 @@ TEST(TestDispatchBest, CastDecimalArgs) { AssertTypeEqual(args[0].type, decimal128(20, 1)); AssertTypeEqual(args[1].type, decimal128(20, 1)); + args = {decimal128(3, -1), int64()}; + ASSERT_OK(CastDecimalArgs(args.data(), args.size())); + AssertTypeEqual(args[0].type, decimal128(19, 0)); + AssertTypeEqual(args[1].type, decimal128(19, 0)); + // Overflow decimal128 max precision -> promote to decimal256 args = {decimal128(38, 0), decimal128(37, 2)}; ASSERT_OK(CastDecimalArgs(args.data(), args.size())); AssertTypeEqual(args[0].type, decimal256(40, 2)); AssertTypeEqual(args[1].type, decimal256(40, 2)); + + // Overflow decimal256 max precision + args = {decimal256(76, 0), decimal256(75, 1)}; + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, + ::testing::HasSubstr( + "Result precision (77) exceeds max precision of Decimal256 (76)"), + CastDecimalArgs(args.data(), args.size())); + + // Incompatible, no cast + args = {decimal256(3, 2), float64(), utf8()}; + ASSERT_OK(CastDecimalArgs(args.data(), args.size())); + AssertTypeEqual(args[0].type, decimal256(3, 2)); + AssertTypeEqual(args[1].type, float64()); + AssertTypeEqual(args[2].type, utf8()); } -TEST(TestDispatchBest, CommonTimestamp) { - AssertTypeEqual( - timestamp(TimeUnit::NANO), - CommonTimestamp({timestamp(TimeUnit::SECOND), timestamp(TimeUnit::NANO)})); +TEST(TestDispatchBest, CommonTemporal) { + AssertTypeEqual(timestamp(TimeUnit::NANO), CommonTemporal({timestamp(TimeUnit::SECOND), + timestamp(TimeUnit::NANO)})); AssertTypeEqual(timestamp(TimeUnit::NANO, "UTC"), - CommonTimestamp({timestamp(TimeUnit::SECOND, "UTC"), - timestamp(TimeUnit::NANO, "UTC")})); + CommonTemporal({timestamp(TimeUnit::SECOND, "UTC"), + timestamp(TimeUnit::NANO, "UTC")})); AssertTypeEqual(timestamp(TimeUnit::NANO), - CommonTimestamp({date32(), timestamp(TimeUnit::NANO)})); + CommonTemporal({date32(), timestamp(TimeUnit::NANO)})); AssertTypeEqual(timestamp(TimeUnit::MILLI), - CommonTimestamp({date64(), timestamp(TimeUnit::SECOND)})); - AssertTypeEqual(date64(), CommonTimestamp({date32(), date64()})); - ASSERT_EQ(nullptr, CommonTimestamp({date32(), date32()})); - ASSERT_EQ(nullptr, CommonTimestamp({timestamp(TimeUnit::SECOND), - timestamp(TimeUnit::SECOND, "UTC")})); - ASSERT_EQ(nullptr, CommonTimestamp({timestamp(TimeUnit::SECOND, "America/Phoenix"), - timestamp(TimeUnit::SECOND, "UTC")})); + CommonTemporal({date64(), timestamp(TimeUnit::SECOND)})); + AssertTypeEqual(date32(), CommonTemporal({date32(), date32()})); + AssertTypeEqual(date64(), CommonTemporal({date64(), date64()})); + AssertTypeEqual(date64(), CommonTemporal({date32(), date64()})); + ASSERT_EQ(nullptr, CommonTemporal({})); + ASSERT_EQ(nullptr, CommonTemporal({timestamp(TimeUnit::SECOND), + timestamp(TimeUnit::SECOND, "UTC")})); + ASSERT_EQ(nullptr, CommonTemporal({timestamp(TimeUnit::SECOND, "America/Phoenix"), + timestamp(TimeUnit::SECOND, "UTC")})); } } // namespace internal diff --git a/cpp/src/arrow/compute/kernels/scalar_compare.cc b/cpp/src/arrow/compute/kernels/scalar_compare.cc index 9751127f833..c42b0ddac24 100644 --- a/cpp/src/arrow/compute/kernels/scalar_compare.cc +++ b/cpp/src/arrow/compute/kernels/scalar_compare.cc @@ -171,7 +171,7 @@ struct CompareFunction : ScalarFunction { if (auto type = CommonNumeric(*values)) { ReplaceTypes(type, values); - } else if (auto type = CommonTimestamp(*values)) { + } else if (auto type = CommonTemporal(*values)) { ReplaceTypes(type, values); } else if (auto type = CommonBinary(*values)) { ReplaceTypes(type, values); @@ -195,7 +195,7 @@ struct VarArgsCompareFunction : ScalarFunction { if (auto type = CommonNumeric(*values)) { ReplaceTypes(type, values); - } else if (auto type = CommonTimestamp(*values)) { + } else if (auto type = CommonTemporal(*values)) { ReplaceTypes(type, values); } diff --git a/cpp/src/arrow/compute/kernels/scalar_compare_test.cc b/cpp/src/arrow/compute/kernels/scalar_compare_test.cc index 3d9eb018e72..dae89a3f518 100644 --- a/cpp/src/arrow/compute/kernels/scalar_compare_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_compare_test.cc @@ -1083,7 +1083,7 @@ TYPED_TEST(TestVarArgsCompareParametricTemporal, MaxElementWise) { {this->array("[1, null, 3, 4]"), this->array("[2, 2, null, 2]")}); } -TEST(TestMaxElementWiseMinElementWise, CommonTimestamp) { +TEST(TestMaxElementWiseMinElementWise, CommonTemporal) { EXPECT_THAT(MinElementWise({ ScalarFromJSON(timestamp(TimeUnit::SECOND), "1"), ScalarFromJSON(timestamp(TimeUnit::MILLI), "12000"), diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else.cc b/cpp/src/arrow/compute/kernels/scalar_if_else.cc index a5ff191a7b4..bab97c8dc2a 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else.cc @@ -1745,7 +1745,7 @@ struct CoalesceFunction : ScalarFunction { if (auto type = CommonBinary(*values)) { ReplaceTypes(type, values); } - if (auto type = CommonTimestamp(*values)) { + if (auto type = CommonTemporal(*values)) { ReplaceTypes(type, values); } if (HasDecimal(*values)) { @@ -2091,7 +2091,7 @@ static Status CheckIdenticalTypes(const ExecBatch& batch) { auto ty = batch[0].type(); for (auto it = batch.values.begin() + 1; it != batch.values.end(); ++it) { if (!ty->Equals(*it->type())) { - return Status::TypeError("coalesce: all types must be identical, expected: ", *ty, + return Status::TypeError("coalesce: all types must be compatible, expected: ", *ty, ", but got: ", *it->type()); } } @@ -2201,62 +2201,9 @@ struct CoalesceFunctor> { } }; -template <> -struct CoalesceFunctor { - static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - RETURN_NOT_OK(CheckIdenticalTypes(batch)); - for (const auto& datum : batch.values) { - if (datum.is_array()) { - return ExecArray(ctx, batch, out); - } - } - return ExecScalarCoalesce(ctx, batch, out); - } - - static Status ExecArray(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - std::function reserve_data = ReserveNoData; - return ExecVarWidthCoalesce(ctx, batch, out, reserve_data); - } -}; - template -struct CoalesceFunctor> { - static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - RETURN_NOT_OK(CheckIdenticalTypes(batch)); - for (const auto& datum : batch.values) { - if (datum.is_array()) { - return ExecArray(ctx, batch, out); - } - } - return ExecScalarCoalesce(ctx, batch, out); - } - - static Status ExecArray(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - std::function reserve_data = ReserveNoData; - return ExecVarWidthCoalesce(ctx, batch, out, reserve_data); - } -}; - -template <> -struct CoalesceFunctor { - static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - RETURN_NOT_OK(CheckIdenticalTypes(batch)); - for (const auto& datum : batch.values) { - if (datum.is_array()) { - return ExecArray(ctx, batch, out); - } - } - return ExecScalarCoalesce(ctx, batch, out); - } - - static Status ExecArray(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - std::function reserve_data = ReserveNoData; - return ExecVarWidthCoalesce(ctx, batch, out, reserve_data); - } -}; - -template <> -struct CoalesceFunctor { +struct CoalesceFunctor< + Type, enable_if_t::value && !is_union_type::value>> { static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { RETURN_NOT_OK(CheckIdenticalTypes(batch)); for (const auto& datum : batch.values) { diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc index 84c0404c049..52e145803dd 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc @@ -1916,7 +1916,7 @@ TYPED_TEST(TestCoalesceList, Errors) { auto type1 = std::make_shared(int64()); auto type2 = std::make_shared(utf8()); EXPECT_RAISES_WITH_MESSAGE_THAT( - TypeError, ::testing::HasSubstr("coalesce: all types must be identical"), + TypeError, ::testing::HasSubstr("coalesce: all types must be compatible"), CallFunction("coalesce", { ArrayFromJSON(type1, "[null]"), ArrayFromJSON(type2, "[null]"), @@ -2054,7 +2054,7 @@ TEST(TestCoalesce, FixedSizeBinary) { EXPECT_RAISES_WITH_MESSAGE_THAT( TypeError, - ::testing::HasSubstr("coalesce: all types must be identical, expected: " + ::testing::HasSubstr("coalesce: all types must be compatible, expected: " "fixed_size_binary[3], but got: fixed_size_binary[2]"), CallFunction("coalesce", { ArrayFromJSON(type, "[null]"), @@ -2092,7 +2092,7 @@ TEST(TestCoalesce, FixedSizeListOfInt) { EXPECT_RAISES_WITH_MESSAGE_THAT( TypeError, ::testing::HasSubstr( - "coalesce: all types must be identical, expected: fixed_size_list[2], but got: fixed_size_list[3]"), CallFunction("coalesce", { ArrayFromJSON(type, "[null]"), @@ -2168,7 +2168,7 @@ TEST(TestCoalesce, Map) { EXPECT_RAISES_WITH_MESSAGE_THAT( TypeError, - ::testing::HasSubstr("coalesce: all types must be identical, expected: map, but got: map"), CallFunction("coalesce", { ArrayFromJSON(type, "[null]"), @@ -2210,8 +2210,9 @@ TEST(TestCoalesce, Struct) { EXPECT_RAISES_WITH_MESSAGE_THAT( TypeError, - ::testing::HasSubstr("coalesce: all types must be identical, expected: struct, but got: struct"), + ::testing::HasSubstr( + "coalesce: all types must be compatible, expected: struct, but got: struct"), CallFunction("coalesce", { ArrayFromJSON(struct_({field("str", utf8())}), "[null]"), @@ -2250,7 +2251,7 @@ TEST(TestCoalesce, UnionBoolString) { EXPECT_RAISES_WITH_MESSAGE_THAT( TypeError, - ::testing::HasSubstr("coalesce: all types must be identical, expected: " + ::testing::HasSubstr("coalesce: all types must be compatible, expected: " "sparse_union, but got: sparse_union"), CallFunction( "coalesce", diff --git a/cpp/src/arrow/util/basic_decimal.cc b/cpp/src/arrow/util/basic_decimal.cc index edc25e25db8..24c193dff9f 100644 --- a/cpp/src/arrow/util/basic_decimal.cc +++ b/cpp/src/arrow/util/basic_decimal.cc @@ -377,6 +377,10 @@ BasicDecimal128::BasicDecimal128(const uint8_t* bytes) reinterpret_cast(bytes)[1]) {} #endif +constexpr int BasicDecimal128::kBitWidth; +constexpr int BasicDecimal128::kMaxPrecision; +constexpr int BasicDecimal128::kMaxScale; + std::array BasicDecimal128::ToBytes() const { std::array out{{0}}; ToBytes(out.data()); @@ -1153,6 +1157,10 @@ BasicDecimal256::BasicDecimal256(const uint8_t* bytes) reinterpret_cast(bytes)[2], reinterpret_cast(bytes)[3]}) {} +constexpr int BasicDecimal256::kBitWidth; +constexpr int BasicDecimal256::kMaxPrecision; +constexpr int BasicDecimal256::kMaxScale; + BasicDecimal256& BasicDecimal256::Negate() { auto array_le = BitUtil::LittleEndianArray::Make(&array_); uint64_t carry = 1; From 90c6a102e22138f08a9890503881761069c32c81 Mon Sep 17 00:00:00 2001 From: David Li Date: Wed, 22 Sep 2021 14:44:40 -0400 Subject: [PATCH 09/11] Update cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc Co-authored-by: Benjamin Kietzman --- cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc b/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc index 3aa345bd9da..bd3fecfd6d9 100644 --- a/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc @@ -1622,7 +1622,8 @@ TEST(TestBinaryDecimalArithmetic, DispatchBest) { } } - // decimal, decimal -> decimal and decimal, integer -> decimal + // decimal, decimal -> decimal + // decimal, integer -> decimal for (std::string name : {"add", "subtract"}) { for (std::string suffix : {"", "_checked"}) { name += suffix; From 3a519e4b0460e0181a09c20c68520dff8b61fb95 Mon Sep 17 00:00:00 2001 From: David Li Date: Wed, 22 Sep 2021 14:44:58 -0400 Subject: [PATCH 10/11] ARROW-13390: [C++] Add one last test case --- cpp/src/arrow/compute/kernels/codegen_internal_test.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/cpp/src/arrow/compute/kernels/codegen_internal_test.cc b/cpp/src/arrow/compute/kernels/codegen_internal_test.cc index 89fc35b38d2..a830d0c7636 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal_test.cc +++ b/cpp/src/arrow/compute/kernels/codegen_internal_test.cc @@ -143,6 +143,7 @@ TEST(TestDispatchBest, CommonTemporal) { AssertTypeEqual(date64(), CommonTemporal({date64(), date64()})); AssertTypeEqual(date64(), CommonTemporal({date32(), date64()})); ASSERT_EQ(nullptr, CommonTemporal({})); + ASSERT_EQ(nullptr, CommonTemporal({float64(), int32()})); ASSERT_EQ(nullptr, CommonTemporal({timestamp(TimeUnit::SECOND), timestamp(TimeUnit::SECOND, "UTC")})); ASSERT_EQ(nullptr, CommonTemporal({timestamp(TimeUnit::SECOND, "America/Phoenix"), From 63b2fa66549b5c6375c87db238372c32f4773a0e Mon Sep 17 00:00:00 2001 From: David Li Date: Thu, 23 Sep 2021 14:00:07 -0400 Subject: [PATCH 11/11] ARROW-13390: [C++] Reconcile with ARROW-13358 --- .../arrow/compute/kernels/scalar_if_else.cc | 32 ++++++++++--------- .../compute/kernels/scalar_if_else_test.cc | 15 ++++----- 2 files changed, 24 insertions(+), 23 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else.cc b/cpp/src/arrow/compute/kernels/scalar_if_else.cc index bab97c8dc2a..3f3933256f4 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else.cc @@ -63,6 +63,20 @@ inline Bitmap GetBitmap(const Datum& datum, int i) { return Bitmap{a.buffers[i], a.offset, a.length}; } +// Ensure parameterized types are identical. +Status CheckIdenticalTypes(const Datum* begin, size_t count) { + const auto& ty = begin->type(); + const auto* end = begin + count; + for (auto it = begin + 1; it != end; ++it) { + const DataType& other_ty = *it->type(); + if (!ty->Equals(other_ty)) { + return Status::TypeError("All types must be compatible, expected: ", *ty, + ", but got: ", other_ty); + } + } + return Status::OK(); +} + // if the condition is null then output is null otherwise we take validity from the // selected argument // ie. cond.valid & (cond.data & left.valid | ~cond.data & right.valid) @@ -2086,23 +2100,11 @@ static Status ExecVarWidthCoalesce(KernelContext* ctx, const ExecBatch& batch, D }); } -// Ensure parameterized types are identical. -static Status CheckIdenticalTypes(const ExecBatch& batch) { - auto ty = batch[0].type(); - for (auto it = batch.values.begin() + 1; it != batch.values.end(); ++it) { - if (!ty->Equals(*it->type())) { - return Status::TypeError("coalesce: all types must be compatible, expected: ", *ty, - ", but got: ", *it->type()); - } - } - return Status::OK(); -} - template struct CoalesceFunctor { static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { if (!TypeTraits::is_parameter_free) { - RETURN_NOT_OK(CheckIdenticalTypes(batch)); + RETURN_NOT_OK(CheckIdenticalTypes(&batch.values[0], batch.values.size())); } // Special case for two arguments (since "fill_null" is a common operation) if (batch.num_values() == 2) { @@ -2205,7 +2207,7 @@ template struct CoalesceFunctor< Type, enable_if_t::value && !is_union_type::value>> { static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - RETURN_NOT_OK(CheckIdenticalTypes(batch)); + RETURN_NOT_OK(CheckIdenticalTypes(&batch.values[0], batch.values.size())); for (const auto& datum : batch.values) { if (datum.is_array()) { return ExecArray(ctx, batch, out); @@ -2224,7 +2226,7 @@ template struct CoalesceFunctor> { static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { // Unions don't have top-level nulls, so a specialized implementation is needed - RETURN_NOT_OK(CheckIdenticalTypes(batch)); + RETURN_NOT_OK(CheckIdenticalTypes(&batch.values[0], batch.values.size())); for (const auto& datum : batch.values) { if (datum.is_array()) { diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc index 52e145803dd..907c7b0e638 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc @@ -1916,7 +1916,7 @@ TYPED_TEST(TestCoalesceList, Errors) { auto type1 = std::make_shared(int64()); auto type2 = std::make_shared(utf8()); EXPECT_RAISES_WITH_MESSAGE_THAT( - TypeError, ::testing::HasSubstr("coalesce: all types must be compatible"), + TypeError, ::testing::HasSubstr("All types must be compatible"), CallFunction("coalesce", { ArrayFromJSON(type1, "[null]"), ArrayFromJSON(type2, "[null]"), @@ -2054,7 +2054,7 @@ TEST(TestCoalesce, FixedSizeBinary) { EXPECT_RAISES_WITH_MESSAGE_THAT( TypeError, - ::testing::HasSubstr("coalesce: all types must be compatible, expected: " + ::testing::HasSubstr("All types must be compatible, expected: " "fixed_size_binary[3], but got: fixed_size_binary[2]"), CallFunction("coalesce", { ArrayFromJSON(type, "[null]"), @@ -2092,7 +2092,7 @@ TEST(TestCoalesce, FixedSizeListOfInt) { EXPECT_RAISES_WITH_MESSAGE_THAT( TypeError, ::testing::HasSubstr( - "coalesce: all types must be compatible, expected: fixed_size_list[2], but got: fixed_size_list[3]"), CallFunction("coalesce", { ArrayFromJSON(type, "[null]"), @@ -2168,7 +2168,7 @@ TEST(TestCoalesce, Map) { EXPECT_RAISES_WITH_MESSAGE_THAT( TypeError, - ::testing::HasSubstr("coalesce: all types must be compatible, expected: map, but got: map"), CallFunction("coalesce", { ArrayFromJSON(type, "[null]"), @@ -2210,9 +2210,8 @@ TEST(TestCoalesce, Struct) { EXPECT_RAISES_WITH_MESSAGE_THAT( TypeError, - ::testing::HasSubstr( - "coalesce: all types must be compatible, expected: struct, but got: struct"), + ::testing::HasSubstr("All types must be compatible, expected: struct, but got: struct"), CallFunction("coalesce", { ArrayFromJSON(struct_({field("str", utf8())}), "[null]"), @@ -2251,7 +2250,7 @@ TEST(TestCoalesce, UnionBoolString) { EXPECT_RAISES_WITH_MESSAGE_THAT( TypeError, - ::testing::HasSubstr("coalesce: all types must be compatible, expected: " + ::testing::HasSubstr("All types must be compatible, expected: " "sparse_union, but got: sparse_union"), CallFunction( "coalesce",