From 882e86718324e3f82529a21f7b56ea20044b096d Mon Sep 17 00:00:00 2001
From: David Li
Date: Fri, 3 Sep 2021 15:11:06 -0400
Subject: [PATCH 01/11] ARROW-13390: [C++] Implement ToString for union scalars
---
cpp/src/arrow/scalar.cc | 14 ++++++++++----
1 file changed, 10 insertions(+), 4 deletions(-)
diff --git a/cpp/src/arrow/scalar.cc b/cpp/src/arrow/scalar.cc
index adfc50182cb..d054a313373 100644
--- a/cpp/src/arrow/scalar.cc
+++ b/cpp/src/arrow/scalar.cc
@@ -920,6 +920,16 @@ Status CastImpl(const StructScalar& from, StringScalar* to) {
return Status::OK();
}
+Status CastImpl(const UnionScalar& from, StringScalar* to) {
+ const auto& union_ty = checked_cast(*from.type);
+ std::stringstream ss;
+ ss << '(' << static_cast(from.type_code) << ": "
+ << union_ty.field(union_ty.child_ids()[from.type_code])->ToString() << " = "
+ << from.value->ToString() << ')';
+ to->value = Buffer::FromString(ss.str());
+ return Status::OK();
+}
+
struct CastImplVisitor {
Status NotImplemented() {
return Status::NotImplemented("cast to ", *to_type_, " from ", *from_.type);
@@ -953,8 +963,6 @@ struct FromTypeVisitor : CastImplVisitor {
}
Status Visit(const NullType&) { return NotImplemented(); }
- Status Visit(const SparseUnionType&) { return NotImplemented(); }
- Status Visit(const DenseUnionType&) { return NotImplemented(); }
Status Visit(const DictionaryType&) { return NotImplemented(); }
Status Visit(const ExtensionType&) { return NotImplemented(); }
};
@@ -983,8 +991,6 @@ struct ToTypeVisitor : CastImplVisitor {
return Int32Scalar(0).CastTo(dict_type.index_type()).Value(&out.index);
}
- Status Visit(const SparseUnionType&) { return NotImplemented(); }
- Status Visit(const DenseUnionType&) { return NotImplemented(); }
Status Visit(const ExtensionType&) { return NotImplemented(); }
};
From 46912ee22336609c109455b3f9bbfcd7949b6bb3 Mon Sep 17 00:00:00 2001
From: David Li
Date: Fri, 3 Sep 2021 15:11:28 -0400
Subject: [PATCH 02/11] ARROW-13390: [C++] Implement Coalesce for remaining
types
---
.../arrow/compute/kernels/scalar_if_else.cc | 243 +++++++++++--
.../compute/kernels/scalar_if_else_test.cc | 325 +++++++++++++++++-
2 files changed, 540 insertions(+), 28 deletions(-)
diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else.cc b/cpp/src/arrow/compute/kernels/scalar_if_else.cc
index 35bb6248f23..d7ec9c0b84b 100644
--- a/cpp/src/arrow/compute/kernels/scalar_if_else.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_if_else.cc
@@ -2018,6 +2018,65 @@ Status ExecBinaryCoalesce(KernelContext* ctx, Datum left, Datum right, int64_t l
return Status::OK();
}
+template
+static Status ExecVarWidthCoalesceImpl(KernelContext* ctx, const ExecBatch& batch,
+ Datum* out,
+ std::function reserve_data,
+ AppendScalar append_scalar) {
+ // Special case: grab any leading non-null scalar or array arguments
+ for (const auto& datum : batch.values) {
+ if (datum.is_scalar()) {
+ if (!datum.scalar()->is_valid) continue;
+ ARROW_ASSIGN_OR_RAISE(
+ *out, MakeArrayFromScalar(*datum.scalar(), batch.length, ctx->memory_pool()));
+ return Status::OK();
+ } else if (datum.is_array() && !datum.array()->MayHaveNulls()) {
+ *out = datum;
+ return Status::OK();
+ }
+ break;
+ }
+ ArrayData* output = out->mutable_array();
+ std::unique_ptr raw_builder;
+ RETURN_NOT_OK(MakeBuilder(ctx->memory_pool(), out->type(), &raw_builder));
+ RETURN_NOT_OK(raw_builder->Reserve(batch.length));
+ RETURN_NOT_OK(reserve_data(raw_builder.get()));
+
+ for (int64_t i = 0; i < batch.length; i++) {
+ bool set = false;
+ for (const auto& datum : batch.values) {
+ if (datum.is_scalar()) {
+ if (datum.scalar()->is_valid) {
+ RETURN_NOT_OK(append_scalar(raw_builder.get(), *datum.scalar()));
+ set = true;
+ break;
+ }
+ } else {
+ const ArrayData& source = *datum.array();
+ if (!source.MayHaveNulls() ||
+ BitUtil::GetBit(source.buffers[0]->data(), source.offset + i)) {
+ RETURN_NOT_OK(raw_builder->AppendArraySlice(source, i, /*length=*/1));
+ set = true;
+ break;
+ }
+ }
+ }
+ if (!set) RETURN_NOT_OK(raw_builder->AppendNull());
+ }
+ ARROW_ASSIGN_OR_RAISE(auto temp_output, raw_builder->Finish());
+ *output = *temp_output->data();
+ output->type = batch[0].type();
+ return Status::OK();
+}
+
+static Status ExecVarWidthCoalesce(KernelContext* ctx, const ExecBatch& batch, Datum* out,
+ std::function reserve_data) {
+ return ExecVarWidthCoalesceImpl(ctx, batch, out, std::move(reserve_data),
+ [](ArrayBuilder* builder, const Scalar& scalar) {
+ return builder->AppendScalar(scalar);
+ });
+}
+
template
struct CoalesceFunctor {
static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
@@ -2093,51 +2152,173 @@ struct CoalesceFunctor> {
}
static Status ExecArray(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
- // Special case: grab any leading non-null scalar or array arguments
+ return ExecVarWidthCoalesceImpl(
+ ctx, batch, out,
+ [&](ArrayBuilder* builder) {
+ int64_t reservation = 0;
+ for (const auto& datum : batch.values) {
+ if (datum.is_array()) {
+ const ArrayType array(datum.array());
+ reservation = std::max(reservation, array.total_values_length());
+ } else {
+ const auto& scalar = *datum.scalar();
+ if (scalar.is_valid) {
+ const int64_t size = UnboxScalar::Unbox(scalar).size();
+ reservation = std::max(reservation, batch.length * size);
+ }
+ }
+ }
+ return checked_cast(builder)->ReserveData(reservation);
+ },
+ [&](ArrayBuilder* builder, const Scalar& scalar) {
+ return checked_cast(builder)->Append(
+ UnboxScalar::Unbox(scalar));
+ });
+ }
+};
+
+template <>
+struct CoalesceFunctor {
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
for (const auto& datum : batch.values) {
- if (datum.is_scalar()) {
- if (!datum.scalar()->is_valid) continue;
- ARROW_ASSIGN_OR_RAISE(
- *out, MakeArrayFromScalar(*datum.scalar(), batch.length, ctx->memory_pool()));
- return Status::OK();
- } else if (datum.is_array() && !datum.array()->MayHaveNulls()) {
- *out = datum;
- return Status::OK();
+ if (datum.is_array()) {
+ return ExecArray(ctx, batch, out);
+ }
+ }
+ return ExecScalarCoalesce(ctx, batch, out);
+ }
+
+ static Status ExecArray(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ return ExecVarWidthCoalesce(ctx, batch, out,
+ [&](ArrayBuilder* builder) { return Status::OK(); });
+ }
+};
+
+template
+struct CoalesceFunctor> {
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ for (const auto& datum : batch.values) {
+ if (datum.is_array()) {
+ return ExecArray(ctx, batch, out);
}
- break;
}
+ return ExecScalarCoalesce(ctx, batch, out);
+ }
+
+ static Status ExecArray(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ return ExecVarWidthCoalesce(ctx, batch, out,
+ [&](ArrayBuilder* builder) { return Status::OK(); });
+ }
+};
+
+template <>
+struct CoalesceFunctor {
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ for (const auto& datum : batch.values) {
+ if (datum.is_array()) {
+ return ExecArray(ctx, batch, out);
+ }
+ }
+ return ExecScalarCoalesce(ctx, batch, out);
+ }
+
+ static Status ExecArray(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ return ExecVarWidthCoalesce(ctx, batch, out,
+ [&](ArrayBuilder* builder) { return Status::OK(); });
+ }
+};
+
+template <>
+struct CoalesceFunctor {
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ for (const auto& datum : batch.values) {
+ if (datum.is_array()) {
+ return ExecArray(ctx, batch, out);
+ }
+ }
+ return ExecScalarCoalesce(ctx, batch, out);
+ }
+
+ static Status ExecArray(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ return ExecVarWidthCoalesce(ctx, batch, out,
+ [&](ArrayBuilder* builder) { return Status::OK(); });
+ }
+};
+
+template
+struct CoalesceFunctor> {
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ // Unions don't have top-level nulls, so a specialized implementation is needed
+ for (const auto& datum : batch.values) {
+ if (datum.is_array()) {
+ return ExecArray(ctx, batch, out);
+ }
+ }
+ return ExecScalar(ctx, batch, out);
+ }
+
+ static Status ExecArray(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
ArrayData* output = out->mutable_array();
- BuilderType builder(batch[0].type(), ctx->memory_pool());
- RETURN_NOT_OK(builder.Reserve(batch.length));
+ std::unique_ptr raw_builder;
+ RETURN_NOT_OK(MakeBuilder(ctx->memory_pool(), out->type(), &raw_builder));
+ RETURN_NOT_OK(raw_builder->Reserve(batch.length));
+
+ // TODO: make sure differing union types are rejected
+ const UnionType& type = checked_cast(*out->type());
for (int64_t i = 0; i < batch.length; i++) {
bool set = false;
for (const auto& datum : batch.values) {
if (datum.is_scalar()) {
- if (datum.scalar()->is_valid) {
- RETURN_NOT_OK(builder.Append(UnboxScalar::Unbox(*datum.scalar())));
+ const auto& scalar = checked_cast(*datum.scalar());
+ if (scalar.is_valid && scalar.value->is_valid) {
+ RETURN_NOT_OK(raw_builder->AppendScalar(scalar));
set = true;
break;
}
} else {
const ArrayData& source = *datum.array();
- if (!source.MayHaveNulls() ||
- BitUtil::GetBit(source.buffers[0]->data(), source.offset + i)) {
- const uint8_t* data = source.buffers[2]->data();
- const offset_type* offsets = source.GetValues(1);
- const offset_type offset0 = offsets[i];
- const offset_type offset1 = offsets[i + 1];
- RETURN_NOT_OK(builder.Append(data + offset0, offset1 - offset0));
- set = true;
- break;
+ // Peek at the relevant child array's validity bitmap
+ if (std::is_same::value) {
+ const int8_t type_id = source.GetValues(1)[i];
+ const int child_id = type.child_ids()[type_id];
+ const ArrayData& child = *source.child_data[child_id];
+ if (!child.MayHaveNulls() ||
+ BitUtil::GetBit(child.buffers[0]->data(),
+ source.offset + child.offset + i)) {
+ RETURN_NOT_OK(raw_builder->AppendArraySlice(source, i, /*length=*/1));
+ set = true;
+ break;
+ }
+ } else {
+ const int8_t type_id = source.GetValues(1)[i];
+ const int32_t offset = source.GetValues(2)[i];
+ const int child_id = type.child_ids()[type_id];
+ const ArrayData& child = *source.child_data[child_id];
+ if (!child.MayHaveNulls() ||
+ BitUtil::GetBit(child.buffers[0]->data(), child.offset + offset)) {
+ RETURN_NOT_OK(raw_builder->AppendArraySlice(source, i, /*length=*/1));
+ set = true;
+ break;
+ }
}
}
}
- if (!set) RETURN_NOT_OK(builder.AppendNull());
+ if (!set) RETURN_NOT_OK(raw_builder->AppendNull());
}
- ARROW_ASSIGN_OR_RAISE(auto temp_output, builder.Finish());
+ ARROW_ASSIGN_OR_RAISE(auto temp_output, raw_builder->Finish());
*output = *temp_output->data();
- // Builder type != logical type due to GenerateTypeAgnosticVarBinaryBase
- output->type = batch[0].type();
+ return Status::OK();
+ }
+
+ static Status ExecScalar(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ for (const auto& datum : batch.values) {
+ const auto& scalar = checked_cast(*datum.scalar());
+ // Union scalars can have top-level validity
+ if (scalar.is_valid && scalar.value->is_valid) {
+ *out = datum;
+ break;
+ }
+ }
return Status::OK();
}
};
@@ -2511,6 +2692,14 @@ void RegisterScalarIfElse(FunctionRegistry* registry) {
for (const auto& ty : BaseBinaryTypes()) {
AddCoalesceKernel(func, ty, GenerateTypeAgnosticVarBinaryBase(ty));
}
+ AddCoalesceKernel(func, Type::FIXED_SIZE_LIST,
+ CoalesceFunctor::Exec);
+ AddCoalesceKernel(func, Type::LIST, CoalesceFunctor::Exec);
+ AddCoalesceKernel(func, Type::LARGE_LIST, CoalesceFunctor::Exec);
+ AddCoalesceKernel(func, Type::MAP, CoalesceFunctor::Exec);
+ AddCoalesceKernel(func, Type::STRUCT, CoalesceFunctor::Exec);
+ AddCoalesceKernel(func, Type::DENSE_UNION, CoalesceFunctor::Exec);
+ AddCoalesceKernel(func, Type::SPARSE_UNION, CoalesceFunctor::Exec);
DCHECK_OK(registry->AddFunction(std::move(func)));
}
{
diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc
index 8793cac7619..376c056ca87 100644
--- a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc
@@ -1688,11 +1688,14 @@ template
class TestCoalesceNumeric : public ::testing::Test {};
template
class TestCoalesceBinary : public ::testing::Test {};
+template
+class TestCoalesceList : public ::testing::Test {};
TYPED_TEST_SUITE(TestCoalesceNumeric, NumericBasedTypes);
TYPED_TEST_SUITE(TestCoalesceBinary, BinaryArrowTypes);
+TYPED_TEST_SUITE(TestCoalesceList, ListArrowTypes);
-TYPED_TEST(TestCoalesceNumeric, FixedSize) {
+TYPED_TEST(TestCoalesceNumeric, Basics) {
auto type = default_type_instance();
auto scalar_null = ScalarFromJSON(type, "null");
auto scalar1 = ScalarFromJSON(type, "20");
@@ -1718,6 +1721,34 @@ TYPED_TEST(TestCoalesceNumeric, FixedSize) {
CheckScalar("coalesce", {scalar1, values1}, ArrayFromJSON(type, "[20, 20, 20, 20]"));
}
+TYPED_TEST(TestCoalesceNumeric, ListOfType) {
+ auto type = list(default_type_instance());
+ auto scalar_null = ScalarFromJSON(type, "null");
+ auto scalar1 = ScalarFromJSON(type, "[20, 24]");
+ auto values_null = ArrayFromJSON(type, "[null, null, null, null]");
+ auto values1 = ArrayFromJSON(type, "[null, [10, null, 20], [], [null, null]]");
+ auto values2 = ArrayFromJSON(type, "[[23], [14, 24], [null, 15], [16]]");
+ auto values3 = ArrayFromJSON(type, "[[17, 18], [19], [], null]");
+ CheckScalar("coalesce", {values_null}, values_null);
+ CheckScalar("coalesce", {values_null, scalar1},
+ ArrayFromJSON(type, "[[20, 24], [20, 24], [20, 24], [20, 24]]"));
+ CheckScalar("coalesce", {values_null, values1}, values1);
+ CheckScalar("coalesce", {values_null, values2}, values2);
+ CheckScalar("coalesce", {values1, values_null}, values1);
+ CheckScalar("coalesce", {values2, values_null}, values2);
+ CheckScalar("coalesce", {scalar_null, values1}, values1);
+ CheckScalar("coalesce", {values1, scalar_null}, values1);
+ CheckScalar("coalesce", {values2, values1, values_null}, values2);
+ CheckScalar("coalesce", {values1, scalar1},
+ ArrayFromJSON(type, "[[20, 24], [10, null, 20], [], [null, null]]"));
+ CheckScalar("coalesce", {values1, values2},
+ ArrayFromJSON(type, "[[23], [10, null, 20], [], [null, null]]"));
+ CheckScalar("coalesce", {values1, values2, values3},
+ ArrayFromJSON(type, "[[23], [10, null, 20], [], [null, null]]"));
+ CheckScalar("coalesce", {scalar1, values1},
+ ArrayFromJSON(type, "[[20, 24], [20, 24], [20, 24], [20, 24]]"));
+}
+
TYPED_TEST(TestCoalesceBinary, Basics) {
auto type = default_type_instance();
auto scalar_null = ScalarFromJSON(type, "null");
@@ -1747,6 +1778,140 @@ TYPED_TEST(TestCoalesceBinary, Basics) {
ArrayFromJSON(type, R"(["a", "a", "a", "a"])"));
}
+TYPED_TEST(TestCoalesceList, ListOfString) {
+ auto type = std::make_shared(utf8());
+ auto scalar_null = ScalarFromJSON(type, "null");
+ auto scalar1 = ScalarFromJSON(type, R"([null, "a"])");
+ auto values_null = ArrayFromJSON(type, R"([null, null, null, null])");
+ auto values1 = ArrayFromJSON(type, R"([null, ["bc", null], ["def"], []])");
+ auto values2 = ArrayFromJSON(type, R"([["klmno"], ["p"], ["qr", null], ["stu"]])");
+ auto values3 = ArrayFromJSON(type, R"([["vwxy"], [], ["d"], null])");
+ CheckScalar("coalesce", {values_null}, values_null);
+ CheckScalar(
+ "coalesce", {values_null, scalar1},
+ ArrayFromJSON(type, R"([[null, "a"], [null, "a"], [null, "a"], [null, "a"]])"));
+ CheckScalar("coalesce", {values_null, values1}, values1);
+ CheckScalar("coalesce", {values_null, values2}, values2);
+ CheckScalar("coalesce", {values1, values_null}, values1);
+ CheckScalar("coalesce", {values2, values_null}, values2);
+ CheckScalar("coalesce", {scalar_null, values1}, values1);
+ CheckScalar("coalesce", {values1, scalar_null}, values1);
+ CheckScalar("coalesce", {values2, values1, values_null}, values2);
+ CheckScalar("coalesce", {values1, scalar1},
+ ArrayFromJSON(type, R"([[null, "a"], ["bc", null], ["def"], []])"));
+ CheckScalar("coalesce", {values1, values2},
+ ArrayFromJSON(type, R"([["klmno"], ["bc", null], ["def"], []])"));
+ CheckScalar("coalesce", {values1, values2, values3},
+ ArrayFromJSON(type, R"([["klmno"], ["bc", null], ["def"], []])"));
+ CheckScalar(
+ "coalesce", {scalar1, values1},
+ ArrayFromJSON(type, R"([[null, "a"], [null, "a"], [null, "a"], [null, "a"]])"));
+}
+
+// More minimal tests to check type coverage
+TYPED_TEST(TestCoalesceList, ListOfBool) {
+ auto type = std::make_shared(boolean());
+ auto scalar_null = ScalarFromJSON(type, "null");
+ auto scalar1 = ScalarFromJSON(type, "[true, false, null]");
+ auto values_null = ArrayFromJSON(type, "[null, null, null, null]");
+ auto values1 = ArrayFromJSON(type, "[null, [true, null, true], [], [null, null]]");
+ CheckScalar("coalesce", {values_null}, values_null);
+ CheckScalar("coalesce", {values_null, scalar1},
+ ArrayFromJSON(type,
+ "[[true, false, null], [true, false, null], [true, false, "
+ "null], [true, false, null]]"));
+ CheckScalar("coalesce", {values_null, values1}, values1);
+ CheckScalar("coalesce", {values1, values_null}, values1);
+ CheckScalar("coalesce", {scalar_null, values1}, values1);
+ CheckScalar("coalesce", {values1, scalar_null}, values1);
+}
+
+TYPED_TEST(TestCoalesceList, ListOfInt) {
+ auto type = std::make_shared(int64());
+ auto scalar_null = ScalarFromJSON(type, "null");
+ auto scalar1 = ScalarFromJSON(type, "[20, 24]");
+ auto values_null = ArrayFromJSON(type, "[null, null, null, null]");
+ auto values1 = ArrayFromJSON(type, "[null, [10, null, 20], [], [null, null]]");
+ CheckScalar("coalesce", {values_null}, values_null);
+ CheckScalar("coalesce", {values_null, scalar1},
+ ArrayFromJSON(type, "[[20, 24], [20, 24], [20, 24], [20, 24]]"));
+ CheckScalar("coalesce", {values_null, values1}, values1);
+ CheckScalar("coalesce", {values1, values_null}, values1);
+ CheckScalar("coalesce", {scalar_null, values1}, values1);
+ CheckScalar("coalesce", {values1, scalar_null}, values1);
+}
+
+TYPED_TEST(TestCoalesceList, ListOfDayTimeInterval) {
+ auto type = std::make_shared(day_time_interval());
+ auto scalar_null = ScalarFromJSON(type, "null");
+ auto scalar1 = ScalarFromJSON(type, "[[20, 24], null]");
+ auto values_null = ArrayFromJSON(type, "[null, null, null, null]");
+ auto values1 =
+ ArrayFromJSON(type, "[null, [[10, 12], null, [20, 22]], [], [null, null]]");
+ CheckScalar("coalesce", {values_null}, values_null);
+ CheckScalar(
+ "coalesce", {values_null, scalar1},
+ ArrayFromJSON(
+ type,
+ "[[[20, 24], null], [[20, 24], null], [[20, 24], null], [[20, 24], null]]"));
+ CheckScalar("coalesce", {values_null, values1}, values1);
+ CheckScalar("coalesce", {values1, values_null}, values1);
+ CheckScalar("coalesce", {scalar_null, values1}, values1);
+ CheckScalar("coalesce", {values1, scalar_null}, values1);
+}
+
+TYPED_TEST(TestCoalesceList, ListOfDecimal) {
+ for (auto ty : {decimal128(3, 2), decimal256(3, 2)}) {
+ auto type = std::make_shared(ty);
+ auto scalar_null = ScalarFromJSON(type, "null");
+ auto scalar1 = ScalarFromJSON(type, R"(["0.42", null])");
+ auto values_null = ArrayFromJSON(type, "[null, null, null, null]");
+ auto values1 = ArrayFromJSON(type, R"([null, ["1.23"], [], [null, null]])");
+ CheckScalar("coalesce", {values_null}, values_null);
+ CheckScalar(
+ "coalesce", {values_null, scalar1},
+ ArrayFromJSON(
+ type, R"([["0.42", null], ["0.42", null], ["0.42", null], ["0.42", null]])"));
+ CheckScalar("coalesce", {values_null, values1}, values1);
+ CheckScalar("coalesce", {values1, values_null}, values1);
+ CheckScalar("coalesce", {scalar_null, values1}, values1);
+ CheckScalar("coalesce", {values1, scalar_null}, values1);
+ }
+}
+
+TYPED_TEST(TestCoalesceList, ListOfFixedSizeBinary) {
+ auto type = std::make_shared(fixed_size_binary(3));
+ auto scalar_null = ScalarFromJSON(type, "null");
+ auto scalar1 = ScalarFromJSON(type, R"(["ab!", null])");
+ auto values_null = ArrayFromJSON(type, "[null, null, null, null]");
+ auto values1 = ArrayFromJSON(type, R"([null, ["def"], [], [null, null]])");
+ CheckScalar("coalesce", {values_null}, values_null);
+ CheckScalar(
+ "coalesce", {values_null, scalar1},
+ ArrayFromJSON(type,
+ R"([["ab!", null], ["ab!", null], ["ab!", null], ["ab!", null]])"));
+ CheckScalar("coalesce", {values_null, values1}, values1);
+ CheckScalar("coalesce", {values1, values_null}, values1);
+ CheckScalar("coalesce", {scalar_null, values1}, values1);
+ CheckScalar("coalesce", {values1, scalar_null}, values1);
+}
+
+TYPED_TEST(TestCoalesceList, ListOfListOfInt) {
+ auto type = std::make_shared(std::make_shared(int64()));
+ auto scalar_null = ScalarFromJSON(type, "null");
+ auto scalar1 = ScalarFromJSON(type, "[[20], null]");
+ auto values_null = ArrayFromJSON(type, "[null, null, null, null]");
+ auto values1 = ArrayFromJSON(type, "[null, [[10, 12], null, []], [], [null, null]]");
+ CheckScalar("coalesce", {values_null}, values_null);
+ CheckScalar(
+ "coalesce", {values_null, scalar1},
+ ArrayFromJSON(type, "[[[20], null], [[20], null], [[20], null], [[20], null]]"));
+ CheckScalar("coalesce", {values_null, values1}, values1);
+ CheckScalar("coalesce", {values1, values_null}, values1);
+ CheckScalar("coalesce", {scalar_null, values1}, values1);
+ CheckScalar("coalesce", {values1, scalar_null}, values1);
+}
+
TEST(TestCoalesce, Null) {
auto type = null();
auto scalar_null = ScalarFromJSON(type, "null");
@@ -1870,6 +2035,164 @@ TEST(TestCoalesce, FixedSizeBinary) {
ArrayFromJSON(type, R"(["abc", "abc", "abc", "abc"])"));
}
+TEST(TestCoalesce, FixedSizeListOfInt) {
+ auto type = fixed_size_list(uint8(), 2);
+ auto scalar_null = ScalarFromJSON(type, "null");
+ auto scalar1 = ScalarFromJSON(type, R"([42, null])");
+ auto values_null = ArrayFromJSON(type, R"([null, null, null, null])");
+ auto values1 = ArrayFromJSON(type, R"([null, [2, null], [4, 8], [null, null]])");
+ auto values2 = ArrayFromJSON(type, R"([[1, 5], [16, 32], [64, null], [null, 128]])");
+ auto values3 = ArrayFromJSON(type, R"([[null, null], [1, 3], [9, 27], null])");
+ CheckScalar("coalesce", {values_null}, values_null);
+ CheckScalar("coalesce", {values_null, scalar1},
+ ArrayFromJSON(type, R"([[42, null], [42, null], [42, null], [42, null]])"));
+ CheckScalar("coalesce", {values_null, values1}, values1);
+ CheckScalar("coalesce", {values_null, values2}, values2);
+ CheckScalar("coalesce", {values1, values_null}, values1);
+ CheckScalar("coalesce", {values2, values_null}, values2);
+ CheckScalar("coalesce", {scalar_null, values1}, values1);
+ CheckScalar("coalesce", {values1, scalar_null}, values1);
+ CheckScalar("coalesce", {values2, values1, values_null}, values2);
+ CheckScalar("coalesce", {values1, scalar1},
+ ArrayFromJSON(type, R"([[42, null], [2, null], [4, 8], [null, null]])"));
+ CheckScalar("coalesce", {values1, values2},
+ ArrayFromJSON(type, R"([[1, 5], [2, null], [4, 8], [null, null]])"));
+ CheckScalar("coalesce", {values1, values2, values3},
+ ArrayFromJSON(type, R"([[1, 5], [2, null], [4, 8], [null, null]])"));
+ CheckScalar("coalesce", {scalar1, values1},
+ ArrayFromJSON(type, R"([[42, null], [42, null], [42, null], [42, null]])"));
+}
+
+TEST(TestCoalesce, FixedSizeListOfString) {
+ auto type = fixed_size_list(utf8(), 2);
+ auto scalar_null = ScalarFromJSON(type, "null");
+ auto scalar1 = ScalarFromJSON(type, R"(["abc", null])");
+ auto values_null = ArrayFromJSON(type, R"([null, null, null, null])");
+ auto values1 =
+ ArrayFromJSON(type, R"([null, ["d", null], ["ghi", "jkl"], [null, null]])");
+ auto values2 = ArrayFromJSON(
+ type, R"([["mno", "pq"], ["pqr", "ab"], ["stu", null], [null, "vwx"]])");
+ auto values3 =
+ ArrayFromJSON(type, R"([[null, null], ["a", "bcd"], ["d", "efg"], null])");
+ CheckScalar("coalesce", {values_null}, values_null);
+ CheckScalar(
+ "coalesce", {values_null, scalar1},
+ ArrayFromJSON(type,
+ R"([["abc", null], ["abc", null], ["abc", null], ["abc", null]])"));
+ CheckScalar("coalesce", {values_null, values1}, values1);
+ CheckScalar("coalesce", {values_null, values2}, values2);
+ CheckScalar("coalesce", {values1, values_null}, values1);
+ CheckScalar("coalesce", {values2, values_null}, values2);
+ CheckScalar("coalesce", {scalar_null, values1}, values1);
+ CheckScalar("coalesce", {values1, scalar_null}, values1);
+ CheckScalar("coalesce", {values2, values1, values_null}, values2);
+ CheckScalar("coalesce", {values1, scalar1},
+ ArrayFromJSON(
+ type, R"([["abc", null], ["d", null], ["ghi", "jkl"], [null, null]])"));
+ CheckScalar("coalesce", {values1, values2},
+ ArrayFromJSON(
+ type, R"([["mno", "pq"], ["d", null], ["ghi", "jkl"], [null, null]])"));
+ CheckScalar("coalesce", {values1, values2, values3},
+ ArrayFromJSON(
+ type, R"([["mno", "pq"], ["d", null], ["ghi", "jkl"], [null, null]])"));
+ CheckScalar(
+ "coalesce", {scalar1, values1},
+ ArrayFromJSON(type,
+ R"([["abc", null], ["abc", null], ["abc", null], ["abc", null]])"));
+}
+
+TEST(TestCoalesce, Map) {
+ auto type = map(int64(), utf8());
+ auto scalar_null = ScalarFromJSON(type, "null");
+ auto scalar1 = ScalarFromJSON(type, R"([[1, "a"], [5, "bc"]])");
+ auto values_null = ArrayFromJSON(type, R"([null, null, null, null])");
+ auto values1 =
+ ArrayFromJSON(type, R"([null, [[2, "foo"], [4, null]], [[3, "test"]], []])");
+ auto values2 = ArrayFromJSON(
+ type, R"([[[1, "b"]], [[2, "c"]], [[5, "c"], [6, "d"]], [[7, "abc"]]])");
+ CheckScalar("coalesce", {values_null}, values_null);
+ CheckScalar("coalesce", {values_null, scalar1}, *MakeArrayFromScalar(*scalar1, 4));
+ CheckScalar("coalesce", {values_null, values1}, values1);
+ CheckScalar("coalesce", {values_null, values2}, values2);
+ CheckScalar("coalesce", {values1, values_null}, values1);
+ CheckScalar("coalesce", {values2, values_null}, values2);
+ CheckScalar("coalesce", {scalar_null, values1}, values1);
+ CheckScalar("coalesce", {values1, scalar_null}, values1);
+ CheckScalar("coalesce", {values2, values1, values_null}, values2);
+ CheckScalar(
+ "coalesce", {values1, scalar1},
+ ArrayFromJSON(
+ type,
+ R"([[[1, "a"], [5, "bc"]], [[2, "foo"], [4, null]], [[3, "test"]], []])"));
+ CheckScalar(
+ "coalesce", {values1, values2},
+ ArrayFromJSON(type, R"([[[1, "b"]], [[2, "foo"], [4, null]], [[3, "test"]], []])"));
+ CheckScalar("coalesce", {scalar1, values1}, *MakeArrayFromScalar(*scalar1, 4));
+}
+
+TEST(TestCoalesce, Struct) {
+ auto type = struct_(
+ {field("int", uint32()), field("str", utf8()), field("list", list(int8()))});
+ auto scalar_null = ScalarFromJSON(type, "null");
+ auto scalar1 = ScalarFromJSON(type, R"([42, "spam", [null, -1]])");
+ auto values_null = ArrayFromJSON(type, R"([null, null, null, null])");
+ auto values1 = ArrayFromJSON(
+ type, R"([null, [null, "eggs", []], [0, "", [null]], [32, "abc", [1, 2, 3]]])");
+ auto values2 = ArrayFromJSON(
+ type,
+ R"([[21, "foobar", [1, null, 2]], [5, "bar", []], [20, null, null], [1, "", [null]]])");
+ CheckScalar("coalesce", {values_null}, values_null);
+ CheckScalar("coalesce", {values_null, scalar1}, *MakeArrayFromScalar(*scalar1, 4));
+ CheckScalar("coalesce", {values_null, values1}, values1);
+ CheckScalar("coalesce", {values_null, values2}, values2);
+ CheckScalar("coalesce", {values1, values_null}, values1);
+ CheckScalar("coalesce", {values2, values_null}, values2);
+ CheckScalar("coalesce", {scalar_null, values1}, values1);
+ CheckScalar("coalesce", {values1, scalar_null}, values1);
+ CheckScalar("coalesce", {values2, values1, values_null}, values2);
+ CheckScalar(
+ "coalesce", {values1, scalar1},
+ ArrayFromJSON(
+ type,
+ R"([[42, "spam", [null, -1]], [null, "eggs", []], [0, "", [null]], [32, "abc", [1, 2, 3]]])"));
+ CheckScalar(
+ "coalesce", {values1, values2},
+ ArrayFromJSON(
+ type,
+ R"([[21, "foobar", [1, null, 2]], [null, "eggs", []], [0, "", [null]], [32, "abc", [1, 2, 3]]])"));
+ CheckScalar("coalesce", {scalar1, values1}, *MakeArrayFromScalar(*scalar1, 4));
+}
+
+TEST(TestCoalesce, UnionBoolString) {
+ for (const auto& type : {
+ sparse_union({field("a", boolean()), field("b", utf8())}, {2, 7}),
+ dense_union({field("a", boolean()), field("b", utf8())}, {2, 7}),
+ }) {
+ auto scalar_null = ScalarFromJSON(type, "null");
+ auto scalar1 = ScalarFromJSON(type, R"([7, "foo"])");
+ auto values_null = ArrayFromJSON(type, R"([null, null, null, null])");
+ auto values1 = ArrayFromJSON(type, R"([null, [2, false], [7, "bar"], [7, "baz"]])");
+ auto values2 =
+ ArrayFromJSON(type, R"([[2, true], [2, false], [7, "foo"], [7, "bar"]])");
+ CheckScalar("coalesce", {values_null}, values_null);
+ CheckScalar("coalesce", {values_null, scalar1}, *MakeArrayFromScalar(*scalar1, 4));
+ CheckScalar("coalesce", {values_null, values1}, values1);
+ CheckScalar("coalesce", {values_null, values2}, values2);
+ CheckScalar("coalesce", {values1, values_null}, values1);
+ CheckScalar("coalesce", {values2, values_null}, values2);
+ CheckScalar("coalesce", {scalar_null, values1}, values1);
+ CheckScalar("coalesce", {values1, scalar_null}, values1);
+ CheckScalar("coalesce", {values2, values1, values_null}, values2);
+ CheckScalar(
+ "coalesce", {values1, scalar1},
+ ArrayFromJSON(type, R"([[7, "foo"], [2, false], [7, "bar"], [7, "baz"]])"));
+ CheckScalar(
+ "coalesce", {values1, values2},
+ ArrayFromJSON(type, R"([[2, true], [2, false], [7, "bar"], [7, "baz"]])"));
+ CheckScalar("coalesce", {scalar1, values1}, *MakeArrayFromScalar(*scalar1, 4));
+ }
+}
+
template
class TestChooseNumeric : public ::testing::Test {};
template
From f744fcac91a51eeb8106c40c84e6ad65cb7197c2 Mon Sep 17 00:00:00 2001
From: David Li
Date: Tue, 7 Sep 2021 14:41:44 -0400
Subject: [PATCH 03/11] ARROW-13390: [Ruby] Update scalar to_s tests
---
c_glib/test/test-dense-union-scalar.rb | 2 +-
c_glib/test/test-sparse-union-scalar.rb | 2 +-
2 files changed, 2 insertions(+), 2 deletions(-)
diff --git a/c_glib/test/test-dense-union-scalar.rb b/c_glib/test/test-dense-union-scalar.rb
index ec2053b3fe9..b6173d98232 100644
--- a/c_glib/test/test-dense-union-scalar.rb
+++ b/c_glib/test/test-dense-union-scalar.rb
@@ -49,7 +49,7 @@ def test_equal
end
def test_to_s
- assert_equal("...", @scalar.to_s)
+ assert_equal("(2: number: int8 = -29)", @scalar.to_s)
end
def test_value
diff --git a/c_glib/test/test-sparse-union-scalar.rb b/c_glib/test/test-sparse-union-scalar.rb
index acb8531560b..1f26d8df869 100644
--- a/c_glib/test/test-sparse-union-scalar.rb
+++ b/c_glib/test/test-sparse-union-scalar.rb
@@ -49,7 +49,7 @@ def test_equal
end
def test_to_s
- assert_equal("...", @scalar.to_s)
+ assert_equal("(2: number: int8 = -29)", @scalar.to_s)
end
def test_value
From 754e0ce9b160fb7d7176933a966a5d45f5ae9bb0 Mon Sep 17 00:00:00 2001
From: David Li
Date: Tue, 7 Sep 2021 17:44:34 -0400
Subject: [PATCH 04/11] ARROW-13390: [C++] Try to satisfy MinGW
---
cpp/src/arrow/compute/kernels/scalar_if_else.cc | 16 ++++++++--------
1 file changed, 8 insertions(+), 8 deletions(-)
diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else.cc b/cpp/src/arrow/compute/kernels/scalar_if_else.cc
index d7ec9c0b84b..23ac666931e 100644
--- a/cpp/src/arrow/compute/kernels/scalar_if_else.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_if_else.cc
@@ -2189,8 +2189,8 @@ struct CoalesceFunctor {
}
static Status ExecArray(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
- return ExecVarWidthCoalesce(ctx, batch, out,
- [&](ArrayBuilder* builder) { return Status::OK(); });
+ std::function reserve_data = ReserveNoData;
+ return ExecVarWidthCoalesce(ctx, batch, out, reserve_data);
}
};
@@ -2206,8 +2206,8 @@ struct CoalesceFunctor> {
}
static Status ExecArray(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
- return ExecVarWidthCoalesce(ctx, batch, out,
- [&](ArrayBuilder* builder) { return Status::OK(); });
+ std::function reserve_data = ReserveNoData;
+ return ExecVarWidthCoalesce(ctx, batch, out, reserve_data);
}
};
@@ -2223,8 +2223,8 @@ struct CoalesceFunctor {
}
static Status ExecArray(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
- return ExecVarWidthCoalesce(ctx, batch, out,
- [&](ArrayBuilder* builder) { return Status::OK(); });
+ std::function reserve_data = ReserveNoData;
+ return ExecVarWidthCoalesce(ctx, batch, out, reserve_data);
}
};
@@ -2240,8 +2240,8 @@ struct CoalesceFunctor {
}
static Status ExecArray(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
- return ExecVarWidthCoalesce(ctx, batch, out,
- [&](ArrayBuilder* builder) { return Status::OK(); });
+ std::function reserve_data = ReserveNoData;
+ return ExecVarWidthCoalesce(ctx, batch, out, reserve_data);
}
};
From b51c2510c4465dbdef911ad2e9760c9fe578aff4 Mon Sep 17 00:00:00 2001
From: David Li
Date: Thu, 16 Sep 2021 16:46:55 -0400
Subject: [PATCH 05/11] ARROW-13390: [C++] Address feedback, add dispatch tests
---
c_glib/test/test-dense-union-scalar.rb | 2 +-
c_glib/test/test-sparse-union-scalar.rb | 2 +-
cpp/src/arrow/compute/kernels/CMakeLists.txt | 7 ++
.../arrow/compute/kernels/codegen_internal.cc | 78 ++++++++++++++--
.../arrow/compute/kernels/codegen_internal.h | 11 +++
.../compute/kernels/codegen_internal_test.cc | 86 ++++++++++++++++++
.../arrow/compute/kernels/scalar_if_else.cc | 33 ++++++-
.../compute/kernels/scalar_if_else_test.cc | 88 +++++++++++++++++++
cpp/src/arrow/scalar.cc | 5 +-
9 files changed, 299 insertions(+), 13 deletions(-)
create mode 100644 cpp/src/arrow/compute/kernels/codegen_internal_test.cc
diff --git a/c_glib/test/test-dense-union-scalar.rb b/c_glib/test/test-dense-union-scalar.rb
index b6173d98232..4a3e5c0dee7 100644
--- a/c_glib/test/test-dense-union-scalar.rb
+++ b/c_glib/test/test-dense-union-scalar.rb
@@ -49,7 +49,7 @@ def test_equal
end
def test_to_s
- assert_equal("(2: number: int8 = -29)", @scalar.to_s)
+ assert_equal("union{number: int8 = -29}", @scalar.to_s)
end
def test_value
diff --git a/c_glib/test/test-sparse-union-scalar.rb b/c_glib/test/test-sparse-union-scalar.rb
index 1f26d8df869..a7f1b06953e 100644
--- a/c_glib/test/test-sparse-union-scalar.rb
+++ b/c_glib/test/test-sparse-union-scalar.rb
@@ -49,7 +49,7 @@ def test_equal
end
def test_to_s
- assert_equal("(2: number: int8 = -29)", @scalar.to_s)
+ assert_equal("union{number: int8 = -29}", @scalar.to_s)
end
def test_value
diff --git a/cpp/src/arrow/compute/kernels/CMakeLists.txt b/cpp/src/arrow/compute/kernels/CMakeLists.txt
index ce7a85f1557..cd6bc19c869 100644
--- a/cpp/src/arrow/compute/kernels/CMakeLists.txt
+++ b/cpp/src/arrow/compute/kernels/CMakeLists.txt
@@ -71,3 +71,10 @@ add_arrow_compute_test(aggregate_test
hash_aggregate_test.cc
test_util.cc)
add_arrow_benchmark(aggregate_benchmark PREFIX "arrow-compute")
+
+# ----------------------------------------------------------------------
+# Utilities
+
+add_arrow_compute_test(kernel_utility_test
+ SOURCES
+ codegen_internal_test.cc)
diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.cc b/cpp/src/arrow/compute/kernels/codegen_internal.cc
index fe4b593b481..b8f3d165ae2 100644
--- a/cpp/src/arrow/compute/kernels/codegen_internal.cc
+++ b/cpp/src/arrow/compute/kernels/codegen_internal.cc
@@ -95,8 +95,14 @@ void ReplaceNullWithOtherType(std::vector* descrs) {
void ReplaceTypes(const std::shared_ptr& type,
std::vector* descrs) {
- for (auto& descr : *descrs) {
- descr.type = type;
+ ReplaceTypes(type, descrs->data(), descrs->size());
+}
+
+void ReplaceTypes(const std::shared_ptr& type, ValueDescr* begin,
+ size_t count) {
+ auto* end = begin + count;
+ for (auto* it = begin; it != end; it++) {
+ it->type = type;
}
}
@@ -160,6 +166,7 @@ std::shared_ptr CommonNumeric(const ValueDescr* begin, size_t count) {
std::shared_ptr CommonTimestamp(const std::vector& descrs) {
TimeUnit::type finest_unit = TimeUnit::SECOND;
+ const std::string* timezone = nullptr;
for (const auto& descr : descrs) {
auto id = descr.type->id();
@@ -168,16 +175,22 @@ std::shared_ptr CommonTimestamp(const std::vector& descrs)
case Type::DATE32:
case Type::DATE64:
continue;
- case Type::TIMESTAMP:
- finest_unit =
- std::max(finest_unit, checked_cast(*descr.type).unit());
+ case Type::TIMESTAMP: {
+ const auto& ty = checked_cast(*descr.type);
+ // Don't cast to common timezone by default (may not make
+ // sense for all kernels)
+ if (timezone && *timezone != ty.timezone()) return nullptr;
+ timezone = &ty.timezone();
+ finest_unit = std::max(finest_unit, ty.unit());
continue;
+ }
default:
return nullptr;
}
}
- return timestamp(finest_unit);
+ // Don't cast if we see only dates, no timestamps
+ return timezone ? timestamp(finest_unit, *timezone) : nullptr;
}
std::shared_ptr CommonBinary(const std::vector& descrs) {
@@ -290,6 +303,59 @@ Status CastBinaryDecimalArgs(DecimalPromotion promotion,
return Status::OK();
}
+Status CastDecimalArgs(ValueDescr* begin, size_t count) {
+ Type::type casted_type_id = Type::DECIMAL128;
+ auto* end = begin + count;
+
+ int32_t max_scale = 0;
+ for (auto* it = begin; it != end; ++it) {
+ const auto& ty = *it->type;
+ if (is_floating(ty.id())) {
+ // Decimal + float = float
+ ReplaceTypes(float64(), begin, count);
+ return Status::OK();
+ } else if (is_integer(ty.id())) {
+ // Nothing to do here
+ } else if (is_decimal(ty.id())) {
+ max_scale = std::max(max_scale, checked_cast(ty).scale());
+ if (ty.id() == Type::DECIMAL256) {
+ casted_type_id = Type::DECIMAL256;
+ }
+ } else {
+ // Non-numeric, can't cast
+ return Status::OK();
+ }
+ }
+
+ // All integer and decimal, rescale
+ int32_t common_precision = 0;
+ for (auto* it = begin; it != end; ++it) {
+ const auto& ty = *it->type;
+ if (is_integer(ty.id())) {
+ ARROW_ASSIGN_OR_RAISE(auto precision, MaxDecimalDigitsForInteger(ty.id()));
+ precision += max_scale;
+ common_precision = std::max(common_precision, precision);
+ } else if (is_decimal(ty.id())) {
+ const auto& decimal_ty = checked_cast(ty);
+ auto precision = decimal_ty.precision();
+ const auto scale = decimal_ty.scale();
+ precision += max_scale - scale;
+ common_precision = std::max(common_precision, precision);
+ }
+ }
+
+ if (common_precision > BasicDecimal128::kMaxPrecision) {
+ casted_type_id = Type::DECIMAL256;
+ }
+
+ for (auto* it = begin; it != end; ++it) {
+ ARROW_ASSIGN_OR_RAISE(it->type,
+ DecimalType::Make(casted_type_id, common_precision, max_scale));
+ }
+
+ return Status::OK();
+}
+
bool HasDecimal(const std::vector& descrs) {
for (const auto& descr : descrs) {
if (is_decimal(descr.type->id())) {
diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.h b/cpp/src/arrow/compute/kernels/codegen_internal.h
index f9ce34b06e0..0648ac61e92 100644
--- a/cpp/src/arrow/compute/kernels/codegen_internal.h
+++ b/cpp/src/arrow/compute/kernels/codegen_internal.h
@@ -1279,6 +1279,9 @@ void ReplaceNullWithOtherType(std::vector* descrs);
ARROW_EXPORT
void ReplaceTypes(const std::shared_ptr&, std::vector* descrs);
+ARROW_EXPORT
+void ReplaceTypes(const std::shared_ptr&, ValueDescr* descrs, size_t count);
+
ARROW_EXPORT
std::shared_ptr CommonNumeric(const std::vector& descrs);
@@ -1298,9 +1301,17 @@ enum class DecimalPromotion : uint8_t {
kDivide,
};
+/// Given two arguments, at least one of which is decimal, promote all
+/// to not necessarily identical types, but types which are compatible
+/// for the given operator (add/multiply/divide).
ARROW_EXPORT
Status CastBinaryDecimalArgs(DecimalPromotion promotion, std::vector* descrs);
+/// Given one or more arguments, at least one of which is decimal,
+/// promote all to an identical type.
+ARROW_EXPORT
+Status CastDecimalArgs(ValueDescr* begin, size_t count);
+
ARROW_EXPORT
bool HasDecimal(const std::vector& descrs);
diff --git a/cpp/src/arrow/compute/kernels/codegen_internal_test.cc b/cpp/src/arrow/compute/kernels/codegen_internal_test.cc
new file mode 100644
index 00000000000..9ba5377de08
--- /dev/null
+++ b/cpp/src/arrow/compute/kernels/codegen_internal_test.cc
@@ -0,0 +1,86 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include
+
+#include "arrow/compute/kernels/codegen_internal.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/type.h"
+#include "arrow/type_fwd.h"
+
+namespace arrow {
+namespace compute {
+namespace internal {
+
+TEST(TestDispatchBest, CastDecimalArgs) {
+ std::vector args;
+
+ // Any float -> all float
+ args = {decimal128(3, 2), float64()};
+ ASSERT_OK(CastDecimalArgs(args.data(), args.size()));
+ AssertTypeEqual(*args[0].type, *float64());
+ AssertTypeEqual(*args[1].type, *float64());
+
+ // Promote to common decimal width
+ args = {decimal128(3, 2), decimal256(3, 2)};
+ ASSERT_OK(CastDecimalArgs(args.data(), args.size()));
+ AssertTypeEqual(*args[0].type, *decimal256(3, 2));
+ AssertTypeEqual(*args[1].type, *decimal256(3, 2));
+
+ // Rescale so all have common scale/precision
+ args = {decimal128(3, 2), decimal128(3, 0)};
+ ASSERT_OK(CastDecimalArgs(args.data(), args.size()));
+ AssertTypeEqual(*args[0].type, *decimal128(5, 2));
+ AssertTypeEqual(*args[1].type, *decimal128(5, 2));
+
+ // Integer -> decimal with appropriate precision
+ args = {decimal128(3, 0), int64()};
+ ASSERT_OK(CastDecimalArgs(args.data(), args.size()));
+ AssertTypeEqual(*args[0].type, *decimal128(19, 0));
+ AssertTypeEqual(*args[1].type, *decimal128(19, 0));
+
+ args = {decimal128(3, 1), int64()};
+ ASSERT_OK(CastDecimalArgs(args.data(), args.size()));
+ AssertTypeEqual(*args[0].type, *decimal128(20, 1));
+ AssertTypeEqual(*args[1].type, *decimal128(20, 1));
+
+ // Overflow decimal128 max precision -> promote to decimal256
+ args = {decimal128(38, 0), decimal128(37, 2)};
+ ASSERT_OK(CastDecimalArgs(args.data(), args.size()));
+ AssertTypeEqual(*args[0].type, *decimal256(40, 2));
+ AssertTypeEqual(*args[1].type, *decimal256(40, 2));
+}
+
+TEST(TestDispatchBest, CommonTimestamp) {
+ AssertTypeEqual(
+ timestamp(TimeUnit::NANO),
+ CommonTimestamp({timestamp(TimeUnit::SECOND), timestamp(TimeUnit::NANO)}));
+ AssertTypeEqual(timestamp(TimeUnit::NANO, "UTC"),
+ CommonTimestamp({timestamp(TimeUnit::SECOND, "UTC"),
+ timestamp(TimeUnit::NANO, "UTC")}));
+ AssertTypeEqual(timestamp(TimeUnit::NANO),
+ CommonTimestamp({date32(), timestamp(TimeUnit::NANO)}));
+ ASSERT_EQ(nullptr, CommonTimestamp({date32(), date64()}));
+ ASSERT_EQ(nullptr, CommonTimestamp({timestamp(TimeUnit::SECOND),
+ timestamp(TimeUnit::SECOND, "UTC")}));
+ ASSERT_EQ(nullptr, CommonTimestamp({timestamp(TimeUnit::SECOND, "America/Phoenix"),
+ timestamp(TimeUnit::SECOND, "UTC")}));
+}
+
+} // namespace internal
+} // namespace compute
+} // namespace arrow
diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else.cc b/cpp/src/arrow/compute/kernels/scalar_if_else.cc
index 23ac666931e..a5ff191a7b4 100644
--- a/cpp/src/arrow/compute/kernels/scalar_if_else.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_if_else.cc
@@ -1737,11 +1737,20 @@ struct CoalesceFunction : ScalarFunction {
Result DispatchBest(std::vector* values) const override {
RETURN_NOT_OK(CheckArity(*values));
using arrow::compute::detail::DispatchExactImpl;
- if (auto kernel = DispatchExactImpl(this, *values)) return kernel;
+ // Do not DispatchExact here since we want to rescale decimals if necessary
EnsureDictionaryDecoded(values);
if (auto type = CommonNumeric(*values)) {
ReplaceTypes(type, values);
}
+ if (auto type = CommonBinary(*values)) {
+ ReplaceTypes(type, values);
+ }
+ if (auto type = CommonTimestamp(*values)) {
+ ReplaceTypes(type, values);
+ }
+ if (HasDecimal(*values)) {
+ RETURN_NOT_OK(CastDecimalArgs(values->data(), values->size()));
+ }
if (auto kernel = DispatchExactImpl(this, *values)) return kernel;
return arrow::compute::detail::NoMatchingKernel(this, *values);
}
@@ -2077,9 +2086,24 @@ static Status ExecVarWidthCoalesce(KernelContext* ctx, const ExecBatch& batch, D
});
}
+// Ensure parameterized types are identical.
+static Status CheckIdenticalTypes(const ExecBatch& batch) {
+ auto ty = batch[0].type();
+ for (auto it = batch.values.begin() + 1; it != batch.values.end(); ++it) {
+ if (!ty->Equals(*it->type())) {
+ return Status::TypeError("coalesce: all types must be identical, expected: ", *ty,
+ ", but got: ", *it->type());
+ }
+ }
+ return Status::OK();
+}
+
template
struct CoalesceFunctor {
static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ if (!TypeTraits::is_parameter_free) {
+ RETURN_NOT_OK(CheckIdenticalTypes(batch));
+ }
// Special case for two arguments (since "fill_null" is a common operation)
if (batch.num_values() == 2) {
return ExecBinaryCoalesce(ctx, batch[0], batch[1], batch.length, out);
@@ -2180,6 +2204,7 @@ struct CoalesceFunctor> {
template <>
struct CoalesceFunctor {
static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ RETURN_NOT_OK(CheckIdenticalTypes(batch));
for (const auto& datum : batch.values) {
if (datum.is_array()) {
return ExecArray(ctx, batch, out);
@@ -2197,6 +2222,7 @@ struct CoalesceFunctor {
template
struct CoalesceFunctor> {
static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ RETURN_NOT_OK(CheckIdenticalTypes(batch));
for (const auto& datum : batch.values) {
if (datum.is_array()) {
return ExecArray(ctx, batch, out);
@@ -2214,6 +2240,7 @@ struct CoalesceFunctor> {
template <>
struct CoalesceFunctor {
static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ RETURN_NOT_OK(CheckIdenticalTypes(batch));
for (const auto& datum : batch.values) {
if (datum.is_array()) {
return ExecArray(ctx, batch, out);
@@ -2231,6 +2258,7 @@ struct CoalesceFunctor {
template <>
struct CoalesceFunctor {
static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ RETURN_NOT_OK(CheckIdenticalTypes(batch));
for (const auto& datum : batch.values) {
if (datum.is_array()) {
return ExecArray(ctx, batch, out);
@@ -2249,6 +2277,8 @@ template
struct CoalesceFunctor> {
static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
// Unions don't have top-level nulls, so a specialized implementation is needed
+ RETURN_NOT_OK(CheckIdenticalTypes(batch));
+
for (const auto& datum : batch.values) {
if (datum.is_array()) {
return ExecArray(ctx, batch, out);
@@ -2263,7 +2293,6 @@ struct CoalesceFunctor> {
RETURN_NOT_OK(MakeBuilder(ctx->memory_pool(), out->type(), &raw_builder));
RETURN_NOT_OK(raw_builder->Reserve(batch.length));
- // TODO: make sure differing union types are rejected
const UnionType& type = checked_cast(*out->type());
for (int64_t i = 0; i < batch.length; i++) {
bool set = false;
diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc
index 376c056ca87..84c0404c049 100644
--- a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc
@@ -1912,6 +1912,17 @@ TYPED_TEST(TestCoalesceList, ListOfListOfInt) {
CheckScalar("coalesce", {values1, scalar_null}, values1);
}
+TYPED_TEST(TestCoalesceList, Errors) {
+ auto type1 = std::make_shared(int64());
+ auto type2 = std::make_shared(utf8());
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ TypeError, ::testing::HasSubstr("coalesce: all types must be identical"),
+ CallFunction("coalesce", {
+ ArrayFromJSON(type1, "[null]"),
+ ArrayFromJSON(type2, "[null]"),
+ }));
+}
+
TEST(TestCoalesce, Null) {
auto type = null();
auto scalar_null = ScalarFromJSON(type, "null");
@@ -2005,6 +2016,13 @@ TEST(TestCoalesce, Decimal) {
CheckScalar("coalesce", {scalar1, values1},
ArrayFromJSON(type, R"(["1.23", "1.23", "1.23", "1.23"])"));
}
+ // Ensure promotion
+ CheckScalar("coalesce",
+ {
+ ArrayFromJSON(decimal128(3, 2), R"(["1.23", null])"),
+ ArrayFromJSON(decimal128(4, 1), R"([null, "1.0"])"),
+ },
+ ArrayFromJSON(decimal128(5, 2), R"(["1.23", "1.00"])"));
}
TEST(TestCoalesce, FixedSizeBinary) {
@@ -2033,6 +2051,15 @@ TEST(TestCoalesce, FixedSizeBinary) {
ArrayFromJSON(type, R"(["mno", "def", "ghi", "jkl"])"));
CheckScalar("coalesce", {scalar1, values1},
ArrayFromJSON(type, R"(["abc", "abc", "abc", "abc"])"));
+
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ TypeError,
+ ::testing::HasSubstr("coalesce: all types must be identical, expected: "
+ "fixed_size_binary[3], but got: fixed_size_binary[2]"),
+ CallFunction("coalesce", {
+ ArrayFromJSON(type, "[null]"),
+ ArrayFromJSON(fixed_size_binary(2), "[null]"),
+ }));
}
TEST(TestCoalesce, FixedSizeListOfInt) {
@@ -2061,6 +2088,16 @@ TEST(TestCoalesce, FixedSizeListOfInt) {
ArrayFromJSON(type, R"([[1, 5], [2, null], [4, 8], [null, null]])"));
CheckScalar("coalesce", {scalar1, values1},
ArrayFromJSON(type, R"([[42, null], [42, null], [42, null], [42, null]])"));
+
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ TypeError,
+ ::testing::HasSubstr(
+ "coalesce: all types must be identical, expected: fixed_size_list[2], but got: fixed_size_list[3]"),
+ CallFunction("coalesce", {
+ ArrayFromJSON(type, "[null]"),
+ ArrayFromJSON(fixed_size_list(uint8(), 3), "[null]"),
+ }));
}
TEST(TestCoalesce, FixedSizeListOfString) {
@@ -2128,6 +2165,15 @@ TEST(TestCoalesce, Map) {
"coalesce", {values1, values2},
ArrayFromJSON(type, R"([[[1, "b"]], [[2, "foo"], [4, null]], [[3, "test"]], []])"));
CheckScalar("coalesce", {scalar1, values1}, *MakeArrayFromScalar(*scalar1, 4));
+
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ TypeError,
+ ::testing::HasSubstr("coalesce: all types must be identical, expected: map, but got: map"),
+ CallFunction("coalesce", {
+ ArrayFromJSON(type, "[null]"),
+ ArrayFromJSON(map(int64(), int32()), "[null]"),
+ }));
}
TEST(TestCoalesce, Struct) {
@@ -2161,6 +2207,16 @@ TEST(TestCoalesce, Struct) {
type,
R"([[21, "foobar", [1, null, 2]], [null, "eggs", []], [0, "", [null]], [32, "abc", [1, 2, 3]]])"));
CheckScalar("coalesce", {scalar1, values1}, *MakeArrayFromScalar(*scalar1, 4));
+
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ TypeError,
+ ::testing::HasSubstr("coalesce: all types must be identical, expected: struct, but got: struct"),
+ CallFunction("coalesce",
+ {
+ ArrayFromJSON(struct_({field("str", utf8())}), "[null]"),
+ ArrayFromJSON(struct_({field("int", uint16())}), "[null]"),
+ }));
}
TEST(TestCoalesce, UnionBoolString) {
@@ -2191,6 +2247,38 @@ TEST(TestCoalesce, UnionBoolString) {
ArrayFromJSON(type, R"([[2, true], [2, false], [7, "bar"], [7, "baz"]])"));
CheckScalar("coalesce", {scalar1, values1}, *MakeArrayFromScalar(*scalar1, 4));
}
+
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ TypeError,
+ ::testing::HasSubstr("coalesce: all types must be identical, expected: "
+ "sparse_union, but got: sparse_union"),
+ CallFunction(
+ "coalesce",
+ {
+ ArrayFromJSON(sparse_union({field("a", boolean())}), "[[0, true]]"),
+ ArrayFromJSON(sparse_union({field("a", int64())}), "[[0, 1]]"),
+ }));
+}
+
+TEST(TestCoalesce, DispatchBest) {
+ CheckDispatchBest("coalesce", {int8(), float64()}, {float64(), float64()});
+ CheckDispatchBest("coalesce", {int8(), uint32()}, {int64(), int64()});
+ CheckDispatchBest("coalesce", {binary(), utf8()}, {binary(), binary()});
+ CheckDispatchBest("coalesce", {binary(), large_binary()},
+ {large_binary(), large_binary()});
+ CheckDispatchBest("coalesce", {int32(), decimal128(3, 2)},
+ {decimal128(12, 2), decimal128(12, 2)});
+ CheckDispatchBest("coalesce", {float32(), decimal128(3, 2)}, {float64(), float64()});
+ CheckDispatchBest("coalesce", {decimal128(3, 2), decimal256(3, 2)},
+ {decimal256(3, 2), decimal256(3, 2)});
+ CheckDispatchBest("coalesce", {timestamp(TimeUnit::SECOND), date32()},
+ {timestamp(TimeUnit::SECOND), timestamp(TimeUnit::SECOND)});
+ CheckDispatchBest("coalesce", {timestamp(TimeUnit::SECOND), timestamp(TimeUnit::MILLI)},
+ {timestamp(TimeUnit::MILLI), timestamp(TimeUnit::MILLI)});
+ CheckDispatchFails("coalesce", {
+ sparse_union({field("a", boolean())}),
+ dense_union({field("a", boolean())}),
+ });
}
template
diff --git a/cpp/src/arrow/scalar.cc b/cpp/src/arrow/scalar.cc
index d054a313373..b441af55545 100644
--- a/cpp/src/arrow/scalar.cc
+++ b/cpp/src/arrow/scalar.cc
@@ -923,9 +923,8 @@ Status CastImpl(const StructScalar& from, StringScalar* to) {
Status CastImpl(const UnionScalar& from, StringScalar* to) {
const auto& union_ty = checked_cast(*from.type);
std::stringstream ss;
- ss << '(' << static_cast(from.type_code) << ": "
- << union_ty.field(union_ty.child_ids()[from.type_code])->ToString() << " = "
- << from.value->ToString() << ')';
+ ss << "union{" << union_ty.field(union_ty.child_ids()[from.type_code])->ToString()
+ << " = " << from.value->ToString() << '}';
to->value = Buffer::FromString(ss.str());
return Status::OK();
}
From c11c4bd907cb8609862362268e299bd054b97504 Mon Sep 17 00:00:00 2001
From: David Li
Date: Fri, 17 Sep 2021 09:48:43 -0400
Subject: [PATCH 06/11] ARROW-13390: [C++] Add/fix more dispatch tests
---
.../arrow/compute/kernels/codegen_internal.cc | 17 ++++-
.../compute/kernels/codegen_internal_test.cc | 49 ++++++++++----
.../compute/kernels/scalar_arithmetic_test.cc | 65 ++++++++++++-------
.../arrow/compute/kernels/scalar_compare.cc | 5 +-
.../compute/kernels/scalar_compare_test.cc | 48 +++++---------
cpp/src/arrow/compute/kernels/test_util.cc | 6 ++
6 files changed, 120 insertions(+), 70 deletions(-)
diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.cc b/cpp/src/arrow/compute/kernels/codegen_internal.cc
index b8f3d165ae2..f5850c80765 100644
--- a/cpp/src/arrow/compute/kernels/codegen_internal.cc
+++ b/cpp/src/arrow/compute/kernels/codegen_internal.cc
@@ -167,13 +167,20 @@ std::shared_ptr CommonNumeric(const ValueDescr* begin, size_t count) {
std::shared_ptr CommonTimestamp(const std::vector& descrs) {
TimeUnit::type finest_unit = TimeUnit::SECOND;
const std::string* timezone = nullptr;
+ bool saw_date32 = false;
+ bool saw_date64 = false;
for (const auto& descr : descrs) {
auto id = descr.type->id();
// a common timestamp is only possible if all types are timestamp like
switch (id) {
case Type::DATE32:
+ // Date32's unit is days, but the coarsest we have is seconds
+ saw_date32 = true;
+ continue;
case Type::DATE64:
+ finest_unit = std::max(finest_unit, TimeUnit::MILLI);
+ saw_date64 = true;
continue;
case Type::TIMESTAMP: {
const auto& ty = checked_cast(*descr.type);
@@ -189,8 +196,14 @@ std::shared_ptr CommonTimestamp(const std::vector& descrs)
}
}
- // Don't cast if we see only dates, no timestamps
- return timezone ? timestamp(finest_unit, *timezone) : nullptr;
+ if (timezone) {
+ // At least one timestamp seen
+ return timestamp(finest_unit, *timezone);
+ } else if (saw_date32 && saw_date64) {
+ // Saw mixed date types
+ return date64();
+ }
+ return nullptr;
}
std::shared_ptr CommonBinary(const std::vector& descrs) {
diff --git a/cpp/src/arrow/compute/kernels/codegen_internal_test.cc b/cpp/src/arrow/compute/kernels/codegen_internal_test.cc
index 9ba5377de08..6b19837f3ac 100644
--- a/cpp/src/arrow/compute/kernels/codegen_internal_test.cc
+++ b/cpp/src/arrow/compute/kernels/codegen_internal_test.cc
@@ -26,43 +26,63 @@ namespace arrow {
namespace compute {
namespace internal {
+TEST(TestDispatchBest, CastBinaryDecimalArgs) {
+ std::vector args;
+ std::vector modes = {
+ DecimalPromotion::kAdd, DecimalPromotion::kMultiply, DecimalPromotion::kDivide};
+
+ // Any float -> all float
+ for (auto mode : modes) {
+ args = {decimal128(3, 2), float64()};
+ ASSERT_OK(CastBinaryDecimalArgs(mode, &args));
+ AssertTypeEqual(args[0].type, float64());
+ AssertTypeEqual(args[1].type, float64());
+ }
+
+ // Integer -> decimal with common scale
+ args = {decimal128(1, 0), int64()};
+ ASSERT_OK(CastBinaryDecimalArgs(DecimalPromotion::kAdd, &args));
+ AssertTypeEqual(args[0].type, decimal128(1, 0));
+ AssertTypeEqual(args[1].type, decimal128(19, 0));
+}
+
TEST(TestDispatchBest, CastDecimalArgs) {
std::vector args;
// Any float -> all float
args = {decimal128(3, 2), float64()};
ASSERT_OK(CastDecimalArgs(args.data(), args.size()));
- AssertTypeEqual(*args[0].type, *float64());
- AssertTypeEqual(*args[1].type, *float64());
+ AssertTypeEqual(args[0].type, float64());
+ AssertTypeEqual(args[1].type, float64());
// Promote to common decimal width
args = {decimal128(3, 2), decimal256(3, 2)};
ASSERT_OK(CastDecimalArgs(args.data(), args.size()));
- AssertTypeEqual(*args[0].type, *decimal256(3, 2));
- AssertTypeEqual(*args[1].type, *decimal256(3, 2));
+ AssertTypeEqual(args[0].type, decimal256(3, 2));
+ AssertTypeEqual(args[1].type, decimal256(3, 2));
// Rescale so all have common scale/precision
args = {decimal128(3, 2), decimal128(3, 0)};
ASSERT_OK(CastDecimalArgs(args.data(), args.size()));
- AssertTypeEqual(*args[0].type, *decimal128(5, 2));
- AssertTypeEqual(*args[1].type, *decimal128(5, 2));
+ AssertTypeEqual(args[0].type, decimal128(5, 2));
+ AssertTypeEqual(args[1].type, decimal128(5, 2));
// Integer -> decimal with appropriate precision
args = {decimal128(3, 0), int64()};
ASSERT_OK(CastDecimalArgs(args.data(), args.size()));
- AssertTypeEqual(*args[0].type, *decimal128(19, 0));
- AssertTypeEqual(*args[1].type, *decimal128(19, 0));
+ AssertTypeEqual(args[0].type, decimal128(19, 0));
+ AssertTypeEqual(args[1].type, decimal128(19, 0));
args = {decimal128(3, 1), int64()};
ASSERT_OK(CastDecimalArgs(args.data(), args.size()));
- AssertTypeEqual(*args[0].type, *decimal128(20, 1));
- AssertTypeEqual(*args[1].type, *decimal128(20, 1));
+ AssertTypeEqual(args[0].type, decimal128(20, 1));
+ AssertTypeEqual(args[1].type, decimal128(20, 1));
// Overflow decimal128 max precision -> promote to decimal256
args = {decimal128(38, 0), decimal128(37, 2)};
ASSERT_OK(CastDecimalArgs(args.data(), args.size()));
- AssertTypeEqual(*args[0].type, *decimal256(40, 2));
- AssertTypeEqual(*args[1].type, *decimal256(40, 2));
+ AssertTypeEqual(args[0].type, decimal256(40, 2));
+ AssertTypeEqual(args[1].type, decimal256(40, 2));
}
TEST(TestDispatchBest, CommonTimestamp) {
@@ -74,7 +94,10 @@ TEST(TestDispatchBest, CommonTimestamp) {
timestamp(TimeUnit::NANO, "UTC")}));
AssertTypeEqual(timestamp(TimeUnit::NANO),
CommonTimestamp({date32(), timestamp(TimeUnit::NANO)}));
- ASSERT_EQ(nullptr, CommonTimestamp({date32(), date64()}));
+ AssertTypeEqual(timestamp(TimeUnit::MILLI),
+ CommonTimestamp({date64(), timestamp(TimeUnit::SECOND)}));
+ AssertTypeEqual(date64(), CommonTimestamp({date32(), date64()}));
+ ASSERT_EQ(nullptr, CommonTimestamp({date32(), date32()}));
ASSERT_EQ(nullptr, CommonTimestamp({timestamp(TimeUnit::SECOND),
timestamp(TimeUnit::SECOND, "UTC")}));
ASSERT_EQ(nullptr, CommonTimestamp({timestamp(TimeUnit::SECOND, "America/Phoenix"),
diff --git a/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc b/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc
index 35b734e29f6..3aa345bd9da 100644
--- a/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc
@@ -1622,31 +1622,29 @@ TEST(TestBinaryDecimalArithmetic, DispatchBest) {
}
}
- // decimal, integer
- for (std::string name : {"add", "subtract", "multiply", "divide"}) {
+ // decimal, decimal -> decimal and decimal, integer -> decimal
+ for (std::string name : {"add", "subtract"}) {
for (std::string suffix : {"", "_checked"}) {
name += suffix;
CheckDispatchBest(name, {int64(), decimal128(1, 0)},
- {decimal128(1, 0), decimal128(1, 0)});
+ {decimal128(19, 0), decimal128(1, 0)});
CheckDispatchBest(name, {decimal128(1, 0), int64()},
- {decimal128(1, 0), decimal128(1, 0)});
- }
- }
-
- // decimal, decimal
- for (std::string name : {"add", "subtract"}) {
- for (std::string suffix : {"", "_checked"}) {
- name += suffix;
+ {decimal128(1, 0), decimal128(19, 0)});
CheckDispatchBest(name, {decimal128(2, 1), decimal128(2, 1)},
- {decimal128(3, 1), decimal128(3, 1)});
+ {decimal128(2, 1), decimal128(2, 1)});
CheckDispatchBest(name, {decimal256(2, 1), decimal256(2, 1)},
- {decimal256(3, 1), decimal256(3, 1)});
+ {decimal256(2, 1), decimal256(2, 1)});
CheckDispatchBest(name, {decimal128(2, 1), decimal256(2, 1)},
- {decimal256(3, 1), decimal256(3, 1)});
+ {decimal256(2, 1), decimal256(2, 1)});
CheckDispatchBest(name, {decimal256(2, 1), decimal128(2, 1)},
- {decimal256(3, 1), decimal256(3, 1)});
+ {decimal256(2, 1), decimal256(2, 1)});
+
+ CheckDispatchBest(name, {decimal128(2, 0), decimal128(2, 1)},
+ {decimal128(3, 1), decimal128(2, 1)});
+ CheckDispatchBest(name, {decimal128(2, 1), decimal128(2, 0)},
+ {decimal128(2, 1), decimal128(3, 1)});
}
}
{
@@ -1654,29 +1652,50 @@ TEST(TestBinaryDecimalArithmetic, DispatchBest) {
for (std::string suffix : {"", "_checked"}) {
name += suffix;
+ CheckDispatchBest(name, {int64(), decimal128(1, 0)},
+ {decimal128(19, 0), decimal128(1, 0)});
+ CheckDispatchBest(name, {decimal128(1, 0), int64()},
+ {decimal128(1, 0), decimal128(19, 0)});
+
CheckDispatchBest(name, {decimal128(2, 1), decimal128(2, 1)},
- {decimal128(5, 2), decimal128(5, 2)});
+ {decimal128(2, 1), decimal128(2, 1)});
CheckDispatchBest(name, {decimal256(2, 1), decimal256(2, 1)},
- {decimal256(5, 2), decimal256(5, 2)});
+ {decimal256(2, 1), decimal256(2, 1)});
CheckDispatchBest(name, {decimal128(2, 1), decimal256(2, 1)},
- {decimal256(5, 2), decimal256(5, 2)});
+ {decimal256(2, 1), decimal256(2, 1)});
CheckDispatchBest(name, {decimal256(2, 1), decimal128(2, 1)},
- {decimal256(5, 2), decimal256(5, 2)});
+ {decimal256(2, 1), decimal256(2, 1)});
+
+ CheckDispatchBest(name, {decimal128(2, 0), decimal128(2, 1)},
+ {decimal128(2, 0), decimal128(2, 1)});
+ CheckDispatchBest(name, {decimal128(2, 1), decimal128(2, 0)},
+ {decimal128(2, 1), decimal128(2, 0)});
}
}
{
std::string name = "divide";
for (std::string suffix : {"", "_checked"}) {
name += suffix;
+ SCOPED_TRACE(name);
+
+ CheckDispatchBest(name, {int64(), decimal128(1, 0)},
+ {decimal128(23, 4), decimal128(1, 0)});
+ CheckDispatchBest(name, {decimal128(1, 0), int64()},
+ {decimal128(21, 20), decimal128(19, 0)});
CheckDispatchBest(name, {decimal128(2, 1), decimal128(2, 1)},
- {decimal128(6, 4), decimal128(6, 4)});
+ {decimal128(6, 5), decimal128(2, 1)});
CheckDispatchBest(name, {decimal256(2, 1), decimal256(2, 1)},
- {decimal256(6, 4), decimal256(6, 4)});
+ {decimal256(6, 5), decimal256(2, 1)});
CheckDispatchBest(name, {decimal128(2, 1), decimal256(2, 1)},
- {decimal256(6, 4), decimal256(6, 4)});
+ {decimal256(6, 5), decimal256(2, 1)});
CheckDispatchBest(name, {decimal256(2, 1), decimal128(2, 1)},
- {decimal256(6, 4), decimal256(6, 4)});
+ {decimal256(6, 5), decimal256(2, 1)});
+
+ CheckDispatchBest(name, {decimal128(2, 0), decimal128(2, 1)},
+ {decimal128(7, 5), decimal128(2, 1)});
+ CheckDispatchBest(name, {decimal128(2, 1), decimal128(2, 0)},
+ {decimal128(5, 4), decimal128(2, 0)});
}
}
}
diff --git a/cpp/src/arrow/compute/kernels/scalar_compare.cc b/cpp/src/arrow/compute/kernels/scalar_compare.cc
index 17eae5adbbd..9751127f833 100644
--- a/cpp/src/arrow/compute/kernels/scalar_compare.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_compare.cc
@@ -159,6 +159,9 @@ struct CompareFunction : ScalarFunction {
Result DispatchBest(std::vector* values) const override {
RETURN_NOT_OK(CheckArity(*values));
+ if (HasDecimal(*values)) {
+ RETURN_NOT_OK(CastBinaryDecimalArgs(DecimalPromotion::kAdd, values));
+ }
using arrow::compute::detail::DispatchExactImpl;
if (auto kernel = DispatchExactImpl(this, *values)) return kernel;
@@ -172,8 +175,6 @@ struct CompareFunction : ScalarFunction {
ReplaceTypes(type, values);
} else if (auto type = CommonBinary(*values)) {
ReplaceTypes(type, values);
- } else if (HasDecimal(*values)) {
- RETURN_NOT_OK(CastBinaryDecimalArgs(DecimalPromotion::kAdd, values));
}
if (auto kernel = DispatchExactImpl(this, *values)) return kernel;
diff --git a/cpp/src/arrow/compute/kernels/scalar_compare_test.cc b/cpp/src/arrow/compute/kernels/scalar_compare_test.cc
index a5bc89d87f3..3d9eb018e72 100644
--- a/cpp/src/arrow/compute/kernels/scalar_compare_test.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_compare_test.cc
@@ -29,6 +29,7 @@
#include "arrow/compute/kernels/test_util.h"
#include "arrow/testing/gtest_common.h"
#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/matchers.h"
#include "arrow/testing/random.h"
#include "arrow/type.h"
#include "arrow/type_traits.h"
@@ -601,9 +602,9 @@ TEST(TestCompareKernel, DispatchBest) {
CheckDispatchBest(name, {decimal128(3, 2), float64()}, {float64(), float64()});
CheckDispatchBest(name, {float64(), decimal128(3, 2)}, {float64(), float64()});
CheckDispatchBest(name, {decimal128(3, 2), int64()},
- {decimal128(3, 2), decimal128(3, 2)});
+ {decimal128(3, 2), decimal128(21, 2)});
CheckDispatchBest(name, {int64(), decimal128(3, 2)},
- {decimal128(3, 2), decimal128(3, 2)});
+ {decimal128(21, 2), decimal128(3, 2)});
}
}
@@ -1083,34 +1084,21 @@ TYPED_TEST(TestVarArgsCompareParametricTemporal, MaxElementWise) {
}
TEST(TestMaxElementWiseMinElementWise, CommonTimestamp) {
- {
- auto t1 = std::make_shared(TimeUnit::SECOND);
- auto t2 = std::make_shared(TimeUnit::MILLI);
- auto expected = MakeScalar(t2, 1000).ValueOrDie();
- ASSERT_OK_AND_ASSIGN(auto actual,
- MinElementWise({Datum(MakeScalar(t1, 1).ValueOrDie()),
- Datum(MakeScalar(t2, 12000).ValueOrDie())}));
- AssertScalarsEqual(*expected, *actual.scalar(), /*verbose=*/true);
- }
- {
- auto t1 = std::make_shared();
- auto t2 = std::make_shared(TimeUnit::SECOND);
- auto expected = MakeScalar(t2, 86401).ValueOrDie();
- ASSERT_OK_AND_ASSIGN(auto actual,
- MaxElementWise({Datum(MakeScalar(t1, 1).ValueOrDie()),
- Datum(MakeScalar(t2, 86401).ValueOrDie())}));
- AssertScalarsEqual(*expected, *actual.scalar(), /*verbose=*/true);
- }
- {
- auto t1 = std::make_shared();
- auto t2 = std::make_shared();
- auto t3 = std::make_shared(TimeUnit::SECOND);
- auto expected = MakeScalar(t3, 86400).ValueOrDie();
- ASSERT_OK_AND_ASSIGN(
- auto actual, MinElementWise({Datum(MakeScalar(t1, 1).ValueOrDie()),
- Datum(MakeScalar(t2, 2 * 86400000).ValueOrDie())}));
- AssertScalarsEqual(*expected, *actual.scalar(), /*verbose=*/true);
- }
+ EXPECT_THAT(MinElementWise({
+ ScalarFromJSON(timestamp(TimeUnit::SECOND), "1"),
+ ScalarFromJSON(timestamp(TimeUnit::MILLI), "12000"),
+ }),
+ ResultWith(ScalarFromJSON(timestamp(TimeUnit::MILLI), "1000")));
+ EXPECT_THAT(MaxElementWise({
+ ScalarFromJSON(date32(), "1"),
+ ScalarFromJSON(timestamp(TimeUnit::SECOND), "86401"),
+ }),
+ ResultWith(ScalarFromJSON(timestamp(TimeUnit::SECOND), "86401")));
+ EXPECT_THAT(MinElementWise({
+ ScalarFromJSON(date32(), "1"),
+ ScalarFromJSON(date64(), "172800000"),
+ }),
+ ResultWith(ScalarFromJSON(date64(), "86400000")));
}
} // namespace compute
diff --git a/cpp/src/arrow/compute/kernels/test_util.cc b/cpp/src/arrow/compute/kernels/test_util.cc
index cedc03698a1..e72c3dce294 100644
--- a/cpp/src/arrow/compute/kernels/test_util.cc
+++ b/cpp/src/arrow/compute/kernels/test_util.cc
@@ -344,6 +344,12 @@ void CheckDispatchBest(std::string func_name, std::vector original_v
<< actual_kernel->signature->ToString() << "\n"
<< " DispatchExact" << ValueDescr::ToString(expected_equivalent_values) << " => "
<< expected_kernel->signature->ToString();
+ EXPECT_EQ(values.size(), expected_equivalent_values.size());
+ for (size_t i = 0; i < values.size(); i++) {
+ EXPECT_EQ(values[i].shape, expected_equivalent_values[i].shape)
+ << "Argument " << i << " should have the same shape";
+ AssertTypeEqual(values[i].type, expected_equivalent_values[i].type);
+ }
}
void CheckDispatchFails(std::string func_name, std::vector values) {
From 56a975ceb9fefbbd9be04a898cc708bbf0ecc958 Mon Sep 17 00:00:00 2001
From: David Li
Date: Fri, 17 Sep 2021 10:41:00 -0400
Subject: [PATCH 07/11] ARROW-13390: [C++] Fix lint error
---
cpp/src/arrow/compute/kernels/CMakeLists.txt | 4 +---
1 file changed, 1 insertion(+), 3 deletions(-)
diff --git a/cpp/src/arrow/compute/kernels/CMakeLists.txt b/cpp/src/arrow/compute/kernels/CMakeLists.txt
index cd6bc19c869..28686a9cafa 100644
--- a/cpp/src/arrow/compute/kernels/CMakeLists.txt
+++ b/cpp/src/arrow/compute/kernels/CMakeLists.txt
@@ -75,6 +75,4 @@ add_arrow_benchmark(aggregate_benchmark PREFIX "arrow-compute")
# ----------------------------------------------------------------------
# Utilities
-add_arrow_compute_test(kernel_utility_test
- SOURCES
- codegen_internal_test.cc)
+add_arrow_compute_test(kernel_utility_test SOURCES codegen_internal_test.cc)
From 30585595c88bd25a07ff95a204cf2a5a3909a2ac Mon Sep 17 00:00:00 2001
From: David Li
Date: Tue, 21 Sep 2021 09:43:50 -0400
Subject: [PATCH 08/11] ARROW-13390: [C++] Address review suggestions
---
.../arrow/compute/kernels/codegen_internal.cc | 21 ++++--
.../arrow/compute/kernels/codegen_internal.h | 2 +-
.../compute/kernels/codegen_internal_test.cc | 73 +++++++++++++++----
.../arrow/compute/kernels/scalar_compare.cc | 4 +-
.../compute/kernels/scalar_compare_test.cc | 2 +-
.../arrow/compute/kernels/scalar_if_else.cc | 61 +---------------
.../compute/kernels/scalar_if_else_test.cc | 15 ++--
cpp/src/arrow/util/basic_decimal.cc | 8 ++
8 files changed, 98 insertions(+), 88 deletions(-)
diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.cc b/cpp/src/arrow/compute/kernels/codegen_internal.cc
index f5850c80765..8a8292c1c89 100644
--- a/cpp/src/arrow/compute/kernels/codegen_internal.cc
+++ b/cpp/src/arrow/compute/kernels/codegen_internal.cc
@@ -164,7 +164,7 @@ std::shared_ptr CommonNumeric(const ValueDescr* begin, size_t count) {
return int8();
}
-std::shared_ptr CommonTimestamp(const std::vector& descrs) {
+std::shared_ptr CommonTemporal(const std::vector& descrs) {
TimeUnit::type finest_unit = TimeUnit::SECOND;
const std::string* timezone = nullptr;
bool saw_date32 = false;
@@ -199,9 +199,10 @@ std::shared_ptr CommonTimestamp(const std::vector& descrs)
if (timezone) {
// At least one timestamp seen
return timestamp(finest_unit, *timezone);
- } else if (saw_date32 && saw_date64) {
- // Saw mixed date types
+ } else if (saw_date64) {
return date64();
+ } else if (saw_date32) {
+ return date32();
}
return nullptr;
}
@@ -321,12 +322,12 @@ Status CastDecimalArgs(ValueDescr* begin, size_t count) {
auto* end = begin + count;
int32_t max_scale = 0;
+ bool any_floating = false;
for (auto* it = begin; it != end; ++it) {
const auto& ty = *it->type;
if (is_floating(ty.id())) {
// Decimal + float = float
- ReplaceTypes(float64(), begin, count);
- return Status::OK();
+ any_floating = true;
} else if (is_integer(ty.id())) {
// Nothing to do here
} else if (is_decimal(ty.id())) {
@@ -339,6 +340,10 @@ Status CastDecimalArgs(ValueDescr* begin, size_t count) {
return Status::OK();
}
}
+ if (any_floating) {
+ ReplaceTypes(float64(), begin, count);
+ return Status::OK();
+ }
// All integer and decimal, rescale
int32_t common_precision = 0;
@@ -357,7 +362,11 @@ Status CastDecimalArgs(ValueDescr* begin, size_t count) {
}
}
- if (common_precision > BasicDecimal128::kMaxPrecision) {
+ if (common_precision > BasicDecimal256::kMaxPrecision) {
+ return Status::Invalid("Result precision (", common_precision,
+ ") exceeds max precision of Decimal256 (",
+ BasicDecimal256::kMaxPrecision, ")");
+ } else if (common_precision > BasicDecimal128::kMaxPrecision) {
casted_type_id = Type::DECIMAL256;
}
diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.h b/cpp/src/arrow/compute/kernels/codegen_internal.h
index 0648ac61e92..5bd65d8c53c 100644
--- a/cpp/src/arrow/compute/kernels/codegen_internal.h
+++ b/cpp/src/arrow/compute/kernels/codegen_internal.h
@@ -1289,7 +1289,7 @@ ARROW_EXPORT
std::shared_ptr CommonNumeric(const ValueDescr* begin, size_t count);
ARROW_EXPORT
-std::shared_ptr CommonTimestamp(const std::vector& descrs);
+std::shared_ptr CommonTemporal(const std::vector& descrs);
ARROW_EXPORT
std::shared_ptr CommonBinary(const std::vector& descrs);
diff --git a/cpp/src/arrow/compute/kernels/codegen_internal_test.cc b/cpp/src/arrow/compute/kernels/codegen_internal_test.cc
index 6b19837f3ac..89fc35b38d2 100644
--- a/cpp/src/arrow/compute/kernels/codegen_internal_test.cc
+++ b/cpp/src/arrow/compute/kernels/codegen_internal_test.cc
@@ -15,6 +15,7 @@
// specific language governing permissions and limitations
// under the License.
+#include
#include
#include "arrow/compute/kernels/codegen_internal.h"
@@ -44,6 +45,12 @@ TEST(TestDispatchBest, CastBinaryDecimalArgs) {
ASSERT_OK(CastBinaryDecimalArgs(DecimalPromotion::kAdd, &args));
AssertTypeEqual(args[0].type, decimal128(1, 0));
AssertTypeEqual(args[1].type, decimal128(19, 0));
+
+ // Add: rescale so all have common scale
+ args = {decimal128(3, 2), decimal128(3, -2)};
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ NotImplemented, ::testing::HasSubstr("Decimals with negative scales not supported"),
+ CastBinaryDecimalArgs(DecimalPromotion::kAdd, &args));
}
TEST(TestDispatchBest, CastDecimalArgs) {
@@ -55,6 +62,12 @@ TEST(TestDispatchBest, CastDecimalArgs) {
AssertTypeEqual(args[0].type, float64());
AssertTypeEqual(args[1].type, float64());
+ args = {float32(), float64(), decimal128(3, 2)};
+ ASSERT_OK(CastDecimalArgs(args.data(), args.size()));
+ AssertTypeEqual(args[0].type, float64());
+ AssertTypeEqual(args[1].type, float64());
+ AssertTypeEqual(args[2].type, float64());
+
// Promote to common decimal width
args = {decimal128(3, 2), decimal256(3, 2)};
ASSERT_OK(CastDecimalArgs(args.data(), args.size()));
@@ -67,6 +80,17 @@ TEST(TestDispatchBest, CastDecimalArgs) {
AssertTypeEqual(args[0].type, decimal128(5, 2));
AssertTypeEqual(args[1].type, decimal128(5, 2));
+ args = {decimal128(3, 2), decimal128(3, -2)};
+ ASSERT_OK(CastDecimalArgs(args.data(), args.size()));
+ AssertTypeEqual(args[0].type, decimal128(7, 2));
+ AssertTypeEqual(args[1].type, decimal128(7, 2));
+
+ args = {decimal128(3, 0), decimal128(3, 1), decimal128(3, 2)};
+ ASSERT_OK(CastDecimalArgs(args.data(), args.size()));
+ AssertTypeEqual(args[0].type, decimal128(5, 2));
+ AssertTypeEqual(args[1].type, decimal128(5, 2));
+ AssertTypeEqual(args[2].type, decimal128(5, 2));
+
// Integer -> decimal with appropriate precision
args = {decimal128(3, 0), int64()};
ASSERT_OK(CastDecimalArgs(args.data(), args.size()));
@@ -78,30 +102,51 @@ TEST(TestDispatchBest, CastDecimalArgs) {
AssertTypeEqual(args[0].type, decimal128(20, 1));
AssertTypeEqual(args[1].type, decimal128(20, 1));
+ args = {decimal128(3, -1), int64()};
+ ASSERT_OK(CastDecimalArgs(args.data(), args.size()));
+ AssertTypeEqual(args[0].type, decimal128(19, 0));
+ AssertTypeEqual(args[1].type, decimal128(19, 0));
+
// Overflow decimal128 max precision -> promote to decimal256
args = {decimal128(38, 0), decimal128(37, 2)};
ASSERT_OK(CastDecimalArgs(args.data(), args.size()));
AssertTypeEqual(args[0].type, decimal256(40, 2));
AssertTypeEqual(args[1].type, decimal256(40, 2));
+
+ // Overflow decimal256 max precision
+ args = {decimal256(76, 0), decimal256(75, 1)};
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid,
+ ::testing::HasSubstr(
+ "Result precision (77) exceeds max precision of Decimal256 (76)"),
+ CastDecimalArgs(args.data(), args.size()));
+
+ // Incompatible, no cast
+ args = {decimal256(3, 2), float64(), utf8()};
+ ASSERT_OK(CastDecimalArgs(args.data(), args.size()));
+ AssertTypeEqual(args[0].type, decimal256(3, 2));
+ AssertTypeEqual(args[1].type, float64());
+ AssertTypeEqual(args[2].type, utf8());
}
-TEST(TestDispatchBest, CommonTimestamp) {
- AssertTypeEqual(
- timestamp(TimeUnit::NANO),
- CommonTimestamp({timestamp(TimeUnit::SECOND), timestamp(TimeUnit::NANO)}));
+TEST(TestDispatchBest, CommonTemporal) {
+ AssertTypeEqual(timestamp(TimeUnit::NANO), CommonTemporal({timestamp(TimeUnit::SECOND),
+ timestamp(TimeUnit::NANO)}));
AssertTypeEqual(timestamp(TimeUnit::NANO, "UTC"),
- CommonTimestamp({timestamp(TimeUnit::SECOND, "UTC"),
- timestamp(TimeUnit::NANO, "UTC")}));
+ CommonTemporal({timestamp(TimeUnit::SECOND, "UTC"),
+ timestamp(TimeUnit::NANO, "UTC")}));
AssertTypeEqual(timestamp(TimeUnit::NANO),
- CommonTimestamp({date32(), timestamp(TimeUnit::NANO)}));
+ CommonTemporal({date32(), timestamp(TimeUnit::NANO)}));
AssertTypeEqual(timestamp(TimeUnit::MILLI),
- CommonTimestamp({date64(), timestamp(TimeUnit::SECOND)}));
- AssertTypeEqual(date64(), CommonTimestamp({date32(), date64()}));
- ASSERT_EQ(nullptr, CommonTimestamp({date32(), date32()}));
- ASSERT_EQ(nullptr, CommonTimestamp({timestamp(TimeUnit::SECOND),
- timestamp(TimeUnit::SECOND, "UTC")}));
- ASSERT_EQ(nullptr, CommonTimestamp({timestamp(TimeUnit::SECOND, "America/Phoenix"),
- timestamp(TimeUnit::SECOND, "UTC")}));
+ CommonTemporal({date64(), timestamp(TimeUnit::SECOND)}));
+ AssertTypeEqual(date32(), CommonTemporal({date32(), date32()}));
+ AssertTypeEqual(date64(), CommonTemporal({date64(), date64()}));
+ AssertTypeEqual(date64(), CommonTemporal({date32(), date64()}));
+ ASSERT_EQ(nullptr, CommonTemporal({}));
+ ASSERT_EQ(nullptr, CommonTemporal({timestamp(TimeUnit::SECOND),
+ timestamp(TimeUnit::SECOND, "UTC")}));
+ ASSERT_EQ(nullptr, CommonTemporal({timestamp(TimeUnit::SECOND, "America/Phoenix"),
+ timestamp(TimeUnit::SECOND, "UTC")}));
}
} // namespace internal
diff --git a/cpp/src/arrow/compute/kernels/scalar_compare.cc b/cpp/src/arrow/compute/kernels/scalar_compare.cc
index 9751127f833..c42b0ddac24 100644
--- a/cpp/src/arrow/compute/kernels/scalar_compare.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_compare.cc
@@ -171,7 +171,7 @@ struct CompareFunction : ScalarFunction {
if (auto type = CommonNumeric(*values)) {
ReplaceTypes(type, values);
- } else if (auto type = CommonTimestamp(*values)) {
+ } else if (auto type = CommonTemporal(*values)) {
ReplaceTypes(type, values);
} else if (auto type = CommonBinary(*values)) {
ReplaceTypes(type, values);
@@ -195,7 +195,7 @@ struct VarArgsCompareFunction : ScalarFunction {
if (auto type = CommonNumeric(*values)) {
ReplaceTypes(type, values);
- } else if (auto type = CommonTimestamp(*values)) {
+ } else if (auto type = CommonTemporal(*values)) {
ReplaceTypes(type, values);
}
diff --git a/cpp/src/arrow/compute/kernels/scalar_compare_test.cc b/cpp/src/arrow/compute/kernels/scalar_compare_test.cc
index 3d9eb018e72..dae89a3f518 100644
--- a/cpp/src/arrow/compute/kernels/scalar_compare_test.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_compare_test.cc
@@ -1083,7 +1083,7 @@ TYPED_TEST(TestVarArgsCompareParametricTemporal, MaxElementWise) {
{this->array("[1, null, 3, 4]"), this->array("[2, 2, null, 2]")});
}
-TEST(TestMaxElementWiseMinElementWise, CommonTimestamp) {
+TEST(TestMaxElementWiseMinElementWise, CommonTemporal) {
EXPECT_THAT(MinElementWise({
ScalarFromJSON(timestamp(TimeUnit::SECOND), "1"),
ScalarFromJSON(timestamp(TimeUnit::MILLI), "12000"),
diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else.cc b/cpp/src/arrow/compute/kernels/scalar_if_else.cc
index a5ff191a7b4..bab97c8dc2a 100644
--- a/cpp/src/arrow/compute/kernels/scalar_if_else.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_if_else.cc
@@ -1745,7 +1745,7 @@ struct CoalesceFunction : ScalarFunction {
if (auto type = CommonBinary(*values)) {
ReplaceTypes(type, values);
}
- if (auto type = CommonTimestamp(*values)) {
+ if (auto type = CommonTemporal(*values)) {
ReplaceTypes(type, values);
}
if (HasDecimal(*values)) {
@@ -2091,7 +2091,7 @@ static Status CheckIdenticalTypes(const ExecBatch& batch) {
auto ty = batch[0].type();
for (auto it = batch.values.begin() + 1; it != batch.values.end(); ++it) {
if (!ty->Equals(*it->type())) {
- return Status::TypeError("coalesce: all types must be identical, expected: ", *ty,
+ return Status::TypeError("coalesce: all types must be compatible, expected: ", *ty,
", but got: ", *it->type());
}
}
@@ -2201,62 +2201,9 @@ struct CoalesceFunctor> {
}
};
-template <>
-struct CoalesceFunctor {
- static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
- RETURN_NOT_OK(CheckIdenticalTypes(batch));
- for (const auto& datum : batch.values) {
- if (datum.is_array()) {
- return ExecArray(ctx, batch, out);
- }
- }
- return ExecScalarCoalesce(ctx, batch, out);
- }
-
- static Status ExecArray(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
- std::function reserve_data = ReserveNoData;
- return ExecVarWidthCoalesce(ctx, batch, out, reserve_data);
- }
-};
-
template
-struct CoalesceFunctor> {
- static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
- RETURN_NOT_OK(CheckIdenticalTypes(batch));
- for (const auto& datum : batch.values) {
- if (datum.is_array()) {
- return ExecArray(ctx, batch, out);
- }
- }
- return ExecScalarCoalesce(ctx, batch, out);
- }
-
- static Status ExecArray(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
- std::function reserve_data = ReserveNoData;
- return ExecVarWidthCoalesce(ctx, batch, out, reserve_data);
- }
-};
-
-template <>
-struct CoalesceFunctor {
- static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
- RETURN_NOT_OK(CheckIdenticalTypes(batch));
- for (const auto& datum : batch.values) {
- if (datum.is_array()) {
- return ExecArray(ctx, batch, out);
- }
- }
- return ExecScalarCoalesce(ctx, batch, out);
- }
-
- static Status ExecArray(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
- std::function reserve_data = ReserveNoData;
- return ExecVarWidthCoalesce(ctx, batch, out, reserve_data);
- }
-};
-
-template <>
-struct CoalesceFunctor {
+struct CoalesceFunctor<
+ Type, enable_if_t::value && !is_union_type::value>> {
static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
RETURN_NOT_OK(CheckIdenticalTypes(batch));
for (const auto& datum : batch.values) {
diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc
index 84c0404c049..52e145803dd 100644
--- a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc
@@ -1916,7 +1916,7 @@ TYPED_TEST(TestCoalesceList, Errors) {
auto type1 = std::make_shared(int64());
auto type2 = std::make_shared(utf8());
EXPECT_RAISES_WITH_MESSAGE_THAT(
- TypeError, ::testing::HasSubstr("coalesce: all types must be identical"),
+ TypeError, ::testing::HasSubstr("coalesce: all types must be compatible"),
CallFunction("coalesce", {
ArrayFromJSON(type1, "[null]"),
ArrayFromJSON(type2, "[null]"),
@@ -2054,7 +2054,7 @@ TEST(TestCoalesce, FixedSizeBinary) {
EXPECT_RAISES_WITH_MESSAGE_THAT(
TypeError,
- ::testing::HasSubstr("coalesce: all types must be identical, expected: "
+ ::testing::HasSubstr("coalesce: all types must be compatible, expected: "
"fixed_size_binary[3], but got: fixed_size_binary[2]"),
CallFunction("coalesce", {
ArrayFromJSON(type, "[null]"),
@@ -2092,7 +2092,7 @@ TEST(TestCoalesce, FixedSizeListOfInt) {
EXPECT_RAISES_WITH_MESSAGE_THAT(
TypeError,
::testing::HasSubstr(
- "coalesce: all types must be identical, expected: fixed_size_list[2], but got: fixed_size_list[3]"),
CallFunction("coalesce", {
ArrayFromJSON(type, "[null]"),
@@ -2168,7 +2168,7 @@ TEST(TestCoalesce, Map) {
EXPECT_RAISES_WITH_MESSAGE_THAT(
TypeError,
- ::testing::HasSubstr("coalesce: all types must be identical, expected: map, but got: map"),
CallFunction("coalesce", {
ArrayFromJSON(type, "[null]"),
@@ -2210,8 +2210,9 @@ TEST(TestCoalesce, Struct) {
EXPECT_RAISES_WITH_MESSAGE_THAT(
TypeError,
- ::testing::HasSubstr("coalesce: all types must be identical, expected: struct, but got: struct"),
+ ::testing::HasSubstr(
+ "coalesce: all types must be compatible, expected: struct, but got: struct"),
CallFunction("coalesce",
{
ArrayFromJSON(struct_({field("str", utf8())}), "[null]"),
@@ -2250,7 +2251,7 @@ TEST(TestCoalesce, UnionBoolString) {
EXPECT_RAISES_WITH_MESSAGE_THAT(
TypeError,
- ::testing::HasSubstr("coalesce: all types must be identical, expected: "
+ ::testing::HasSubstr("coalesce: all types must be compatible, expected: "
"sparse_union, but got: sparse_union"),
CallFunction(
"coalesce",
diff --git a/cpp/src/arrow/util/basic_decimal.cc b/cpp/src/arrow/util/basic_decimal.cc
index edc25e25db8..24c193dff9f 100644
--- a/cpp/src/arrow/util/basic_decimal.cc
+++ b/cpp/src/arrow/util/basic_decimal.cc
@@ -377,6 +377,10 @@ BasicDecimal128::BasicDecimal128(const uint8_t* bytes)
reinterpret_cast(bytes)[1]) {}
#endif
+constexpr int BasicDecimal128::kBitWidth;
+constexpr int BasicDecimal128::kMaxPrecision;
+constexpr int BasicDecimal128::kMaxScale;
+
std::array BasicDecimal128::ToBytes() const {
std::array out{{0}};
ToBytes(out.data());
@@ -1153,6 +1157,10 @@ BasicDecimal256::BasicDecimal256(const uint8_t* bytes)
reinterpret_cast(bytes)[2],
reinterpret_cast(bytes)[3]}) {}
+constexpr int BasicDecimal256::kBitWidth;
+constexpr int BasicDecimal256::kMaxPrecision;
+constexpr int BasicDecimal256::kMaxScale;
+
BasicDecimal256& BasicDecimal256::Negate() {
auto array_le = BitUtil::LittleEndianArray::Make(&array_);
uint64_t carry = 1;
From 90c6a102e22138f08a9890503881761069c32c81 Mon Sep 17 00:00:00 2001
From: David Li
Date: Wed, 22 Sep 2021 14:44:40 -0400
Subject: [PATCH 09/11] Update
cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc
Co-authored-by: Benjamin Kietzman
---
cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc b/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc
index 3aa345bd9da..bd3fecfd6d9 100644
--- a/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc
@@ -1622,7 +1622,8 @@ TEST(TestBinaryDecimalArithmetic, DispatchBest) {
}
}
- // decimal, decimal -> decimal and decimal, integer -> decimal
+ // decimal, decimal -> decimal
+ // decimal, integer -> decimal
for (std::string name : {"add", "subtract"}) {
for (std::string suffix : {"", "_checked"}) {
name += suffix;
From 3a519e4b0460e0181a09c20c68520dff8b61fb95 Mon Sep 17 00:00:00 2001
From: David Li
Date: Wed, 22 Sep 2021 14:44:58 -0400
Subject: [PATCH 10/11] ARROW-13390: [C++] Add one last test case
---
cpp/src/arrow/compute/kernels/codegen_internal_test.cc | 1 +
1 file changed, 1 insertion(+)
diff --git a/cpp/src/arrow/compute/kernels/codegen_internal_test.cc b/cpp/src/arrow/compute/kernels/codegen_internal_test.cc
index 89fc35b38d2..a830d0c7636 100644
--- a/cpp/src/arrow/compute/kernels/codegen_internal_test.cc
+++ b/cpp/src/arrow/compute/kernels/codegen_internal_test.cc
@@ -143,6 +143,7 @@ TEST(TestDispatchBest, CommonTemporal) {
AssertTypeEqual(date64(), CommonTemporal({date64(), date64()}));
AssertTypeEqual(date64(), CommonTemporal({date32(), date64()}));
ASSERT_EQ(nullptr, CommonTemporal({}));
+ ASSERT_EQ(nullptr, CommonTemporal({float64(), int32()}));
ASSERT_EQ(nullptr, CommonTemporal({timestamp(TimeUnit::SECOND),
timestamp(TimeUnit::SECOND, "UTC")}));
ASSERT_EQ(nullptr, CommonTemporal({timestamp(TimeUnit::SECOND, "America/Phoenix"),
From 63b2fa66549b5c6375c87db238372c32f4773a0e Mon Sep 17 00:00:00 2001
From: David Li
Date: Thu, 23 Sep 2021 14:00:07 -0400
Subject: [PATCH 11/11] ARROW-13390: [C++] Reconcile with ARROW-13358
---
.../arrow/compute/kernels/scalar_if_else.cc | 32 ++++++++++---------
.../compute/kernels/scalar_if_else_test.cc | 15 ++++-----
2 files changed, 24 insertions(+), 23 deletions(-)
diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else.cc b/cpp/src/arrow/compute/kernels/scalar_if_else.cc
index bab97c8dc2a..3f3933256f4 100644
--- a/cpp/src/arrow/compute/kernels/scalar_if_else.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_if_else.cc
@@ -63,6 +63,20 @@ inline Bitmap GetBitmap(const Datum& datum, int i) {
return Bitmap{a.buffers[i], a.offset, a.length};
}
+// Ensure parameterized types are identical.
+Status CheckIdenticalTypes(const Datum* begin, size_t count) {
+ const auto& ty = begin->type();
+ const auto* end = begin + count;
+ for (auto it = begin + 1; it != end; ++it) {
+ const DataType& other_ty = *it->type();
+ if (!ty->Equals(other_ty)) {
+ return Status::TypeError("All types must be compatible, expected: ", *ty,
+ ", but got: ", other_ty);
+ }
+ }
+ return Status::OK();
+}
+
// if the condition is null then output is null otherwise we take validity from the
// selected argument
// ie. cond.valid & (cond.data & left.valid | ~cond.data & right.valid)
@@ -2086,23 +2100,11 @@ static Status ExecVarWidthCoalesce(KernelContext* ctx, const ExecBatch& batch, D
});
}
-// Ensure parameterized types are identical.
-static Status CheckIdenticalTypes(const ExecBatch& batch) {
- auto ty = batch[0].type();
- for (auto it = batch.values.begin() + 1; it != batch.values.end(); ++it) {
- if (!ty->Equals(*it->type())) {
- return Status::TypeError("coalesce: all types must be compatible, expected: ", *ty,
- ", but got: ", *it->type());
- }
- }
- return Status::OK();
-}
-
template
struct CoalesceFunctor {
static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
if (!TypeTraits::is_parameter_free) {
- RETURN_NOT_OK(CheckIdenticalTypes(batch));
+ RETURN_NOT_OK(CheckIdenticalTypes(&batch.values[0], batch.values.size()));
}
// Special case for two arguments (since "fill_null" is a common operation)
if (batch.num_values() == 2) {
@@ -2205,7 +2207,7 @@ template
struct CoalesceFunctor<
Type, enable_if_t::value && !is_union_type::value>> {
static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
- RETURN_NOT_OK(CheckIdenticalTypes(batch));
+ RETURN_NOT_OK(CheckIdenticalTypes(&batch.values[0], batch.values.size()));
for (const auto& datum : batch.values) {
if (datum.is_array()) {
return ExecArray(ctx, batch, out);
@@ -2224,7 +2226,7 @@ template
struct CoalesceFunctor> {
static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
// Unions don't have top-level nulls, so a specialized implementation is needed
- RETURN_NOT_OK(CheckIdenticalTypes(batch));
+ RETURN_NOT_OK(CheckIdenticalTypes(&batch.values[0], batch.values.size()));
for (const auto& datum : batch.values) {
if (datum.is_array()) {
diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc
index 52e145803dd..907c7b0e638 100644
--- a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc
@@ -1916,7 +1916,7 @@ TYPED_TEST(TestCoalesceList, Errors) {
auto type1 = std::make_shared(int64());
auto type2 = std::make_shared(utf8());
EXPECT_RAISES_WITH_MESSAGE_THAT(
- TypeError, ::testing::HasSubstr("coalesce: all types must be compatible"),
+ TypeError, ::testing::HasSubstr("All types must be compatible"),
CallFunction("coalesce", {
ArrayFromJSON(type1, "[null]"),
ArrayFromJSON(type2, "[null]"),
@@ -2054,7 +2054,7 @@ TEST(TestCoalesce, FixedSizeBinary) {
EXPECT_RAISES_WITH_MESSAGE_THAT(
TypeError,
- ::testing::HasSubstr("coalesce: all types must be compatible, expected: "
+ ::testing::HasSubstr("All types must be compatible, expected: "
"fixed_size_binary[3], but got: fixed_size_binary[2]"),
CallFunction("coalesce", {
ArrayFromJSON(type, "[null]"),
@@ -2092,7 +2092,7 @@ TEST(TestCoalesce, FixedSizeListOfInt) {
EXPECT_RAISES_WITH_MESSAGE_THAT(
TypeError,
::testing::HasSubstr(
- "coalesce: all types must be compatible, expected: fixed_size_list[2], but got: fixed_size_list[3]"),
CallFunction("coalesce", {
ArrayFromJSON(type, "[null]"),
@@ -2168,7 +2168,7 @@ TEST(TestCoalesce, Map) {
EXPECT_RAISES_WITH_MESSAGE_THAT(
TypeError,
- ::testing::HasSubstr("coalesce: all types must be compatible, expected: map, but got: map"),
CallFunction("coalesce", {
ArrayFromJSON(type, "[null]"),
@@ -2210,9 +2210,8 @@ TEST(TestCoalesce, Struct) {
EXPECT_RAISES_WITH_MESSAGE_THAT(
TypeError,
- ::testing::HasSubstr(
- "coalesce: all types must be compatible, expected: struct, but got: struct"),
+ ::testing::HasSubstr("All types must be compatible, expected: struct, but got: struct"),
CallFunction("coalesce",
{
ArrayFromJSON(struct_({field("str", utf8())}), "[null]"),
@@ -2251,7 +2250,7 @@ TEST(TestCoalesce, UnionBoolString) {
EXPECT_RAISES_WITH_MESSAGE_THAT(
TypeError,
- ::testing::HasSubstr("coalesce: all types must be compatible, expected: "
+ ::testing::HasSubstr("All types must be compatible, expected: "
"sparse_union, but got: sparse_union"),
CallFunction(
"coalesce",