diff --git a/cpp/src/arrow/array/array_test.cc b/cpp/src/arrow/array/array_test.cc index 5a532e17519..480c5f1c649 100644 --- a/cpp/src/arrow/array/array_test.cc +++ b/cpp/src/arrow/array/array_test.cc @@ -377,6 +377,7 @@ TEST_F(TestArray, TestMakeArrayOfNullUnion) { const int64_t union_length = 10; auto s_union_ty = sparse_union({field("a", utf8()), field("b", int32())}, {0, 1}); ASSERT_OK_AND_ASSIGN(auto s_union_nulls, MakeArrayOfNull(s_union_ty, union_length)); + ASSERT_OK(s_union_nulls->ValidateFull()); ASSERT_EQ(s_union_nulls->null_count(), 0); { const auto& typed_union = checked_cast(*s_union_nulls); @@ -388,8 +389,23 @@ TEST_F(TestArray, TestMakeArrayOfNullUnion) { } } + s_union_ty = sparse_union({field("a", utf8()), field("b", int32())}, {2, 7}); + ASSERT_OK_AND_ASSIGN(s_union_nulls, MakeArrayOfNull(s_union_ty, union_length)); + ASSERT_OK(s_union_nulls->ValidateFull()); + ASSERT_EQ(s_union_nulls->null_count(), 0); + { + const auto& typed_union = checked_cast(*s_union_nulls); + ASSERT_EQ(typed_union.field(0)->null_count(), union_length); + + // Check type codes are all 2 + for (int i = 0; i < union_length; ++i) { + ASSERT_EQ(typed_union.raw_type_codes()[i], 2); + } + } + auto d_union_ty = dense_union({field("a", utf8()), field("b", int32())}, {0, 1}); ASSERT_OK_AND_ASSIGN(auto d_union_nulls, MakeArrayOfNull(d_union_ty, union_length)); + ASSERT_OK(d_union_nulls->ValidateFull()); ASSERT_EQ(d_union_nulls->null_count(), 0); { const auto& typed_union = checked_cast(*d_union_nulls); @@ -484,23 +500,30 @@ static ScalarVector GetScalars() { const auto dense_union_ty = ::arrow::dense_union(union_fields, union_type_codes); return { - std::make_shared(false), std::make_shared(3), - std::make_shared(3), std::make_shared(3), - std::make_shared(3), std::make_shared(3.0), - std::make_shared(10), std::make_shared(11), + std::make_shared(false), + std::make_shared(3), + std::make_shared(3), + std::make_shared(3), + std::make_shared(3), + std::make_shared(3.0), + std::make_shared(10), + std::make_shared(11), std::make_shared(1000, time32(TimeUnit::SECOND)), std::make_shared(1111, time64(TimeUnit::MICRO)), std::make_shared(1111, timestamp(TimeUnit::MILLI)), std::make_shared(1), std::make_shared(daytime), std::make_shared(60, duration(TimeUnit::SECOND)), - std::make_shared(hello), std::make_shared(hello), + std::make_shared(hello), + std::make_shared(hello), std::make_shared( hello, fixed_size_binary(static_cast(hello->size()))), std::make_shared(Decimal128(10), decimal(16, 4)), std::make_shared(Decimal256(10), decimal(76, 38)), - std::make_shared(hello), std::make_shared(hello), + std::make_shared(hello), + std::make_shared(hello), std::make_shared(ArrayFromJSON(int8(), "[1, 2, 3]")), + ScalarFromJSON(map(int8(), utf8()), R"([[1, "foo"], [2, "bar"]])"), std::make_shared(ArrayFromJSON(int8(), "[1, 1, 2, 2, 3, 3]")), std::make_shared(ArrayFromJSON(int8(), "[1, 2, 3, 4]")), std::make_shared( @@ -517,7 +540,12 @@ static ScalarVector GetScalars() { std::make_shared(std::make_shared(101), 6, dense_union_ty), std::make_shared(std::make_shared(101), 42, - dense_union_ty)}; + dense_union_ty), + DictionaryScalar::Make(ScalarFromJSON(int8(), "1"), + ArrayFromJSON(utf8(), R"(["foo", "bar"])")), + DictionaryScalar::Make(ScalarFromJSON(uint8(), "1"), + ArrayFromJSON(utf8(), R"(["foo", "bar"])")), + }; } TEST_F(TestArray, TestMakeArrayFromScalar) { @@ -544,6 +572,8 @@ TEST_F(TestArray, TestMakeArrayFromScalar) { } for (auto scalar : scalars) { + // TODO(ARROW-13197): appending dictionary scalars not implemented + if (is_dictionary(scalar->type->id())) continue; AssertAppendScalar(pool_, scalar); } } @@ -598,6 +628,77 @@ TEST_F(TestArray, TestMakeArrayFromMapScalar) { AssertAppendScalar(pool_, std::make_shared(scalar)); } +TEST_F(TestArray, TestAppendArraySlice) { + auto scalars = GetScalars(); + for (const auto& scalar : scalars) { + // TODO(ARROW-13573): appending dictionary arrays not implemented + if (is_dictionary(scalar->type->id())) continue; + + ARROW_SCOPED_TRACE(*scalar->type); + ASSERT_OK_AND_ASSIGN(auto array, MakeArrayFromScalar(*scalar, 16)); + ASSERT_OK_AND_ASSIGN(auto nulls, MakeArrayOfNull(scalar->type, 16)); + + std::unique_ptr builder; + ASSERT_OK(MakeBuilder(pool_, scalar->type, &builder)); + + ASSERT_OK(builder->AppendArraySlice(*array->data(), 0, 4)); + ASSERT_EQ(4, builder->length()); + ASSERT_OK(builder->AppendArraySlice(*array->data(), 0, 0)); + ASSERT_EQ(4, builder->length()); + ASSERT_OK(builder->AppendArraySlice(*array->data(), 1, 0)); + ASSERT_EQ(4, builder->length()); + ASSERT_OK(builder->AppendArraySlice(*array->data(), 1, 4)); + ASSERT_EQ(8, builder->length()); + + ASSERT_OK(builder->AppendArraySlice(*nulls->data(), 0, 4)); + ASSERT_EQ(12, builder->length()); + if (!is_union(scalar->type->id())) { + ASSERT_EQ(4, builder->null_count()); + } + ASSERT_OK(builder->AppendArraySlice(*nulls->data(), 0, 0)); + ASSERT_EQ(12, builder->length()); + if (!is_union(scalar->type->id())) { + ASSERT_EQ(4, builder->null_count()); + } + ASSERT_OK(builder->AppendArraySlice(*nulls->data(), 1, 0)); + ASSERT_EQ(12, builder->length()); + if (!is_union(scalar->type->id())) { + ASSERT_EQ(4, builder->null_count()); + } + ASSERT_OK(builder->AppendArraySlice(*nulls->data(), 1, 4)); + ASSERT_EQ(16, builder->length()); + if (!is_union(scalar->type->id())) { + ASSERT_EQ(8, builder->null_count()); + } + + std::shared_ptr result; + ASSERT_OK(builder->Finish(&result)); + ASSERT_OK(result->ValidateFull()); + ASSERT_EQ(16, result->length()); + if (!is_union(scalar->type->id())) { + ASSERT_EQ(8, result->null_count()); + } + } + + { + ASSERT_OK_AND_ASSIGN(auto array, MakeArrayOfNull(null(), 16)); + NullBuilder builder(pool_); + ASSERT_OK(builder.AppendArraySlice(*array->data(), 0, 4)); + ASSERT_EQ(4, builder.length()); + ASSERT_OK(builder.AppendArraySlice(*array->data(), 0, 0)); + ASSERT_EQ(4, builder.length()); + ASSERT_OK(builder.AppendArraySlice(*array->data(), 1, 0)); + ASSERT_EQ(4, builder.length()); + ASSERT_OK(builder.AppendArraySlice(*array->data(), 1, 4)); + ASSERT_EQ(8, builder.length()); + std::shared_ptr result; + ASSERT_OK(builder.Finish(&result)); + ASSERT_OK(result->ValidateFull()); + ASSERT_EQ(8, result->length()); + ASSERT_EQ(8, result->null_count()); + } +} + TEST_F(TestArray, ValidateBuffersPrimitive) { auto empty_buffer = std::make_shared(""); auto null_buffer = Buffer::FromString("\xff"); diff --git a/cpp/src/arrow/array/builder_base.h b/cpp/src/arrow/array/builder_base.h index c2aba4e959f..67203e79071 100644 --- a/cpp/src/arrow/array/builder_base.h +++ b/cpp/src/arrow/array/builder_base.h @@ -123,6 +123,14 @@ class ARROW_EXPORT ArrayBuilder { Status AppendScalar(const Scalar& scalar, int64_t n_repeats); Status AppendScalars(const ScalarVector& scalars); + /// \brief Append a range of values from an array. + /// + /// The given array must be the same type as the builder. + virtual Status AppendArraySlice(const ArrayData& array, int64_t offset, + int64_t length) { + return Status::NotImplemented("AppendArraySlice for builder for ", *type()); + } + /// For cases where raw data was memcpy'd into the internal buffers, allows us /// to advance the length of the builder. It is your responsibility to use /// this function responsibly. @@ -189,6 +197,17 @@ class ARROW_EXPORT ArrayBuilder { null_count_ = null_bitmap_builder_.false_count(); } + // Vector append. Copy from a given bitmap. If bitmap is null assume + // all of length bits are valid. + void UnsafeAppendToBitmap(const uint8_t* bitmap, int64_t offset, int64_t length) { + if (bitmap == NULLPTR) { + return UnsafeSetNotNull(length); + } + null_bitmap_builder_.UnsafeAppend(bitmap, offset, length); + length_ += length; + null_count_ = null_bitmap_builder_.false_count(); + } + // Append the same validity value a given number of times. void UnsafeAppendToBitmap(const int64_t num_bits, bool value) { if (value) { diff --git a/cpp/src/arrow/array/builder_binary.cc b/cpp/src/arrow/array/builder_binary.cc index 6822dc89903..fd1be179816 100644 --- a/cpp/src/arrow/array/builder_binary.cc +++ b/cpp/src/arrow/array/builder_binary.cc @@ -60,6 +60,14 @@ Status FixedSizeBinaryBuilder::AppendValues(const uint8_t* data, int64_t length, return byte_builder_.Append(data, length * byte_width_); } +Status FixedSizeBinaryBuilder::AppendValues(const uint8_t* data, int64_t length, + const uint8_t* validity, + int64_t bitmap_offset) { + RETURN_NOT_OK(Reserve(length)); + UnsafeAppendToBitmap(validity, bitmap_offset, length); + return byte_builder_.Append(data, length * byte_width_); +} + Status FixedSizeBinaryBuilder::AppendNull() { RETURN_NOT_OK(Reserve(1)); UnsafeAppendNull(); diff --git a/cpp/src/arrow/array/builder_binary.h b/cpp/src/arrow/array/builder_binary.h index 7653eeca5c4..6ca65113f1c 100644 --- a/cpp/src/arrow/array/builder_binary.h +++ b/cpp/src/arrow/array/builder_binary.h @@ -274,6 +274,23 @@ class BaseBinaryBuilder : public ArrayBuilder { return Status::OK(); } + Status AppendArraySlice(const ArrayData& array, int64_t offset, + int64_t length) override { + auto bitmap = array.GetValues(0, 0); + auto offsets = array.GetValues(1); + auto data = array.GetValues(2, 0); + for (int64_t i = 0; i < length; i++) { + if (!bitmap || BitUtil::GetBit(bitmap, array.offset + offset + i)) { + const offset_type start = offsets[offset + i]; + const offset_type end = offsets[offset + i + 1]; + ARROW_RETURN_NOT_OK(Append(data + start, end - start)); + } else { + ARROW_RETURN_NOT_OK(AppendNull()); + } + } + return Status::OK(); + } + void Reset() override { ArrayBuilder::Reset(); offsets_builder_.Reset(); @@ -486,12 +503,22 @@ class ARROW_EXPORT FixedSizeBinaryBuilder : public ArrayBuilder { Status AppendValues(const uint8_t* data, int64_t length, const uint8_t* valid_bytes = NULLPTR); + Status AppendValues(const uint8_t* data, int64_t length, const uint8_t* validity, + int64_t bitmap_offset); + Status AppendNull() final; Status AppendNulls(int64_t length) final; Status AppendEmptyValue() final; Status AppendEmptyValues(int64_t length) final; + Status AppendArraySlice(const ArrayData& array, int64_t offset, + int64_t length) override { + return AppendValues( + array.GetValues(1, 0) + ((array.offset + offset) * byte_width_), length, + array.GetValues(0, 0), array.offset + offset); + } + void UnsafeAppend(const uint8_t* value) { UnsafeAppendToBitmap(true); if (ARROW_PREDICT_TRUE(byte_width_ > 0)) { diff --git a/cpp/src/arrow/array/builder_nested.h b/cpp/src/arrow/array/builder_nested.h index 12b999b786e..e53b758efa3 100644 --- a/cpp/src/arrow/array/builder_nested.h +++ b/cpp/src/arrow/array/builder_nested.h @@ -122,6 +122,23 @@ class BaseListBuilder : public ArrayBuilder { return Status::OK(); } + Status AppendArraySlice(const ArrayData& array, int64_t offset, + int64_t length) override { + const offset_type* offsets = array.GetValues(1); + const uint8_t* validity = array.MayHaveNulls() ? array.buffers[0]->data() : NULLPTR; + for (int64_t row = offset; row < offset + length; row++) { + if (!validity || BitUtil::GetBit(validity, array.offset + row)) { + ARROW_RETURN_NOT_OK(Append()); + int64_t slot_length = offsets[row + 1] - offsets[row]; + ARROW_RETURN_NOT_OK(value_builder_->AppendArraySlice(*array.child_data[0], + offsets[row], slot_length)); + } else { + ARROW_RETURN_NOT_OK(AppendNull()); + } + } + return Status::OK(); + } + Status FinishInternal(std::shared_ptr* out) override { ARROW_RETURN_NOT_OK(AppendNextOffset()); @@ -275,6 +292,25 @@ class ARROW_EXPORT MapBuilder : public ArrayBuilder { Status AppendEmptyValues(int64_t length) final; + Status AppendArraySlice(const ArrayData& array, int64_t offset, + int64_t length) override { + const int32_t* offsets = array.GetValues(1); + const uint8_t* validity = array.MayHaveNulls() ? array.buffers[0]->data() : NULLPTR; + for (int64_t row = offset; row < offset + length; row++) { + if (!validity || BitUtil::GetBit(validity, array.offset + row)) { + ARROW_RETURN_NOT_OK(Append()); + const int64_t slot_length = offsets[row + 1] - offsets[row]; + ARROW_RETURN_NOT_OK(key_builder_->AppendArraySlice( + *array.child_data[0]->child_data[0], offsets[row], slot_length)); + ARROW_RETURN_NOT_OK(item_builder_->AppendArraySlice( + *array.child_data[0]->child_data[1], offsets[row], slot_length)); + } else { + ARROW_RETURN_NOT_OK(AppendNull()); + } + } + return Status::OK(); + } + /// \brief Get builder to append keys. /// /// Append a key with this builder should be followed by appending @@ -374,6 +410,20 @@ class ARROW_EXPORT FixedSizeListBuilder : public ArrayBuilder { Status AppendEmptyValues(int64_t length) final; + Status AppendArraySlice(const ArrayData& array, int64_t offset, int64_t length) final { + const uint8_t* validity = array.MayHaveNulls() ? array.buffers[0]->data() : NULLPTR; + for (int64_t row = offset; row < offset + length; row++) { + if (!validity || BitUtil::GetBit(validity, array.offset + row)) { + ARROW_RETURN_NOT_OK(value_builder_->AppendArraySlice( + *array.child_data[0], list_size_ * (array.offset + row), list_size_)); + ARROW_RETURN_NOT_OK(Append()); + } else { + ARROW_RETURN_NOT_OK(AppendNull()); + } + } + return Status::OK(); + } + ArrayBuilder* value_builder() const { return value_builder_.get(); } std::shared_ptr type() const override { @@ -467,6 +517,18 @@ class ARROW_EXPORT StructBuilder : public ArrayBuilder { return Status::OK(); } + Status AppendArraySlice(const ArrayData& array, int64_t offset, + int64_t length) override { + for (int i = 0; static_cast(i) < children_.size(); i++) { + ARROW_RETURN_NOT_OK(children_[i]->AppendArraySlice(*array.child_data[i], + array.offset + offset, length)); + } + const uint8_t* validity = array.MayHaveNulls() ? array.buffers[0]->data() : NULLPTR; + ARROW_RETURN_NOT_OK(Reserve(length)); + UnsafeAppendToBitmap(validity, array.offset + offset, length); + return Status::OK(); + } + void Reset() override; ArrayBuilder* field_builder(int i) const { return children_[i].get(); } diff --git a/cpp/src/arrow/array/builder_primitive.cc b/cpp/src/arrow/array/builder_primitive.cc index e403c42411d..769c2f7d07b 100644 --- a/cpp/src/arrow/array/builder_primitive.cc +++ b/cpp/src/arrow/array/builder_primitive.cc @@ -85,6 +85,14 @@ Status BooleanBuilder::AppendValues(const uint8_t* values, int64_t length, return Status::OK(); } +Status BooleanBuilder::AppendValues(const uint8_t* values, int64_t length, + const uint8_t* validity, int64_t offset) { + RETURN_NOT_OK(Reserve(length)); + data_builder_.UnsafeAppend(values, offset, length); + ArrayBuilder::UnsafeAppendToBitmap(validity, offset, length); + return Status::OK(); +} + Status BooleanBuilder::AppendValues(const uint8_t* values, int64_t length, const std::vector& is_valid) { RETURN_NOT_OK(Reserve(length)); diff --git a/cpp/src/arrow/array/builder_primitive.h b/cpp/src/arrow/array/builder_primitive.h index e0f39f97967..67d58fc9d13 100644 --- a/cpp/src/arrow/array/builder_primitive.h +++ b/cpp/src/arrow/array/builder_primitive.h @@ -53,6 +53,10 @@ class ARROW_EXPORT NullBuilder : public ArrayBuilder { Status Append(std::nullptr_t) { return AppendNull(); } + Status AppendArraySlice(const ArrayData&, int64_t, int64_t length) override { + return AppendNulls(length); + } + Status FinishInternal(std::shared_ptr* out) override; /// \cond FALSE @@ -153,6 +157,21 @@ class NumericBuilder : public ArrayBuilder { return Status::OK(); } + /// \brief Append a sequence of elements in one shot + /// \param[in] values a contiguous C array of values + /// \param[in] length the number of values to append + /// \param[in] bitmap a validity bitmap to copy (may be null) + /// \param[in] bitmap_offset an offset into the validity bitmap + /// \return Status + Status AppendValues(const value_type* values, int64_t length, const uint8_t* bitmap, + int64_t bitmap_offset) { + ARROW_RETURN_NOT_OK(Reserve(length)); + data_builder_.UnsafeAppend(values, length); + // length_ is update by these + ArrayBuilder::UnsafeAppendToBitmap(bitmap, bitmap_offset, length); + return Status::OK(); + } + /// \brief Append a sequence of elements in one shot /// \param[in] values a contiguous C array of values /// \param[in] length the number of values to append @@ -256,6 +275,12 @@ class NumericBuilder : public ArrayBuilder { return Status::OK(); } + Status AppendArraySlice(const ArrayData& array, int64_t offset, + int64_t length) override { + return AppendValues(array.GetValues(1) + offset, length, + array.GetValues(0, 0), array.offset + offset); + } + /// Append a single scalar under the assumption that the underlying Buffer is /// large enough. /// @@ -363,6 +388,15 @@ class ARROW_EXPORT BooleanBuilder : public ArrayBuilder { Status AppendValues(const uint8_t* values, int64_t length, const uint8_t* valid_bytes = NULLPTR); + /// \brief Append a sequence of elements in one shot + /// \param[in] values a bitmap of values + /// \param[in] length the number of values to append + /// \param[in] validity a validity bitmap to copy (may be null) + /// \param[in] offset an offset into the values and validity bitmaps + /// \return Status + Status AppendValues(const uint8_t* values, int64_t length, const uint8_t* validity, + int64_t offset); + /// \brief Append a sequence of elements in one shot /// \param[in] values a contiguous C array of values /// \param[in] length the number of values to append @@ -459,6 +493,12 @@ class ARROW_EXPORT BooleanBuilder : public ArrayBuilder { Status AppendValues(int64_t length, bool value); + Status AppendArraySlice(const ArrayData& array, int64_t offset, + int64_t length) override { + return AppendValues(array.GetValues(1, 0), length, + array.GetValues(0, 0), array.offset + offset); + } + Status FinishInternal(std::shared_ptr* out) override; /// \cond FALSE diff --git a/cpp/src/arrow/array/builder_union.cc b/cpp/src/arrow/array/builder_union.cc index 8617cb73fce..6096b76ff21 100644 --- a/cpp/src/arrow/array/builder_union.cc +++ b/cpp/src/arrow/array/builder_union.cc @@ -45,6 +45,21 @@ Status BasicUnionBuilder::FinishInternal(std::shared_ptr* out) { return Status::OK(); } +Status DenseUnionBuilder::AppendArraySlice(const ArrayData& array, const int64_t offset, + const int64_t length) { + const int8_t* type_codes = array.GetValues(1); + const int32_t* offsets = array.GetValues(2); + for (int64_t row = offset; row < offset + length; row++) { + const int8_t type_code = type_codes[row]; + const int child_id = type_id_to_child_id_[type_code]; + const int32_t union_offset = offsets[row]; + RETURN_NOT_OK(Append(type_code)); + RETURN_NOT_OK(type_id_to_children_[type_code]->AppendArraySlice( + *array.child_data[child_id], union_offset, /*length=*/1)); + } + return Status::OK(); +} + Status DenseUnionBuilder::FinishInternal(std::shared_ptr* out) { ARROW_RETURN_NOT_OK(BasicUnionBuilder::FinishInternal(out)); (*out)->buffers.resize(3); @@ -64,6 +79,7 @@ BasicUnionBuilder::BasicUnionBuilder( type_codes_ = union_type.type_codes(); children_ = children; + type_id_to_child_id_.resize(union_type.max_type_code() + 1, -1); type_id_to_children_.resize(union_type.max_type_code() + 1, nullptr); DCHECK_LE( type_id_to_children_.size() - 1, @@ -73,6 +89,7 @@ BasicUnionBuilder::BasicUnionBuilder( child_fields_[i] = union_type.field(static_cast(i)); auto type_id = union_type.type_codes()[i]; + type_id_to_child_id_[type_id] = static_cast(i); type_id_to_children_[type_id] = children[i].get(); } } @@ -82,6 +99,7 @@ int8_t BasicUnionBuilder::AppendChild(const std::shared_ptr& new_c children_.push_back(new_child); auto new_type_id = NextTypeId(); + type_id_to_child_id_[new_type_id] = static_cast(children_.size() - 1); type_id_to_children_[new_type_id] = new_child.get(); child_fields_.push_back(field(field_name, nullptr)); type_codes_.push_back(static_cast(new_type_id)); @@ -114,8 +132,20 @@ int8_t BasicUnionBuilder::NextTypeId() { static_cast(UnionType::kMaxTypeCode)); // type_id_to_children_ is already densely packed, so just append the new child + type_id_to_child_id_.resize(type_id_to_child_id_.size() + 1); type_id_to_children_.resize(type_id_to_children_.size() + 1); return dense_type_id_++; } +Status SparseUnionBuilder::AppendArraySlice(const ArrayData& array, const int64_t offset, + const int64_t length) { + for (size_t i = 0; i < type_codes_.size(); i++) { + RETURN_NOT_OK(type_id_to_children_[type_codes_[i]]->AppendArraySlice( + *array.child_data[i], array.offset + offset, length)); + } + const int8_t* type_codes = array.GetValues(1); + RETURN_NOT_OK(types_builder_.Append(type_codes + offset, length)); + return Status::OK(); +} + } // namespace arrow diff --git a/cpp/src/arrow/array/builder_union.h b/cpp/src/arrow/array/builder_union.h index 060be474fb8..c1a799e56bf 100644 --- a/cpp/src/arrow/array/builder_union.h +++ b/cpp/src/arrow/array/builder_union.h @@ -74,6 +74,7 @@ class ARROW_EXPORT BasicUnionBuilder : public ArrayBuilder { UnionMode::type mode_; std::vector type_id_to_children_; + std::vector type_id_to_child_id_; // for all type_id < dense_type_id_, type_id_to_children_[type_id] != nullptr int8_t dense_type_id_ = 0; TypedBufferBuilder types_builder_; @@ -155,6 +156,9 @@ class ARROW_EXPORT DenseUnionBuilder : public BasicUnionBuilder { return offsets_builder_.Append(offset); } + Status AppendArraySlice(const ArrayData& array, int64_t offset, + int64_t length) override; + Status FinishInternal(std::shared_ptr* out) override; private: @@ -230,6 +234,9 @@ class ARROW_EXPORT SparseUnionBuilder : public BasicUnionBuilder { /// The corresponding child builder must be appended to independently after this method /// is called, and all other child builders must have null or empty value appended. Status Append(int8_t next_type) { return types_builder_.Append(next_type); } + + Status AppendArraySlice(const ArrayData& array, int64_t offset, + int64_t length) override; }; } // namespace arrow diff --git a/cpp/src/arrow/array/util.cc b/cpp/src/arrow/array/util.cc index fae379e51f4..5e95dc93f56 100644 --- a/cpp/src/arrow/array/util.cc +++ b/cpp/src/arrow/array/util.cc @@ -442,9 +442,12 @@ class NullArrayFactory { // First buffer is always null out_->buffers[0] = nullptr; - // Type codes are all zero, so we can use buffer_ which has had it's memory - // zeroed out_->buffers[1] = buffer_; + // buffer_ is zeroed, but 0 may not be a valid type code + if (type.type_codes()[0] != 0) { + ARROW_ASSIGN_OR_RAISE(out_->buffers[1], AllocateBuffer(length_, pool_)); + std::memset(out_->buffers[1]->mutable_data(), type.type_codes()[0], length_); + } // For sparse unions, we now create children with the same length as the // parent diff --git a/cpp/src/arrow/buffer_builder.h b/cpp/src/arrow/buffer_builder.h index eb3f68affc0..7b02ad09a82 100644 --- a/cpp/src/arrow/buffer_builder.h +++ b/cpp/src/arrow/buffer_builder.h @@ -28,6 +28,7 @@ #include "arrow/status.h" #include "arrow/util/bit_util.h" #include "arrow/util/bitmap_generate.h" +#include "arrow/util/bitmap_ops.h" #include "arrow/util/macros.h" #include "arrow/util/ubsan.h" #include "arrow/util/visibility.h" @@ -339,6 +340,7 @@ class TypedBufferBuilder { ++bit_length_; } + /// \brief Append bits from an array of bytes (one value per byte) void UnsafeAppend(const uint8_t* bytes, int64_t num_elements) { if (num_elements == 0) return; int64_t i = 0; @@ -350,6 +352,14 @@ class TypedBufferBuilder { bit_length_ += num_elements; } + /// \brief Append bits from a packed bitmap + void UnsafeAppend(const uint8_t* bitmap, int64_t offset, int64_t num_elements) { + if (num_elements == 0) return; + internal::CopyBitmap(bitmap, offset, num_elements, mutable_data(), bit_length_); + false_count_ += num_elements - internal::CountSetBits(bitmap, offset, num_elements); + bit_length_ += num_elements; + } + void UnsafeAppend(const int64_t num_copies, bool value) { BitUtil::SetBitsTo(mutable_data(), bit_length_, num_copies, value); false_count_ += num_copies * !value; diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else.cc b/cpp/src/arrow/compute/kernels/scalar_if_else.cc index cb261ec59a7..affe9267942 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else.cc @@ -15,6 +15,10 @@ // specific language governing permissions and limitations // under the License. +#include +#include +#include +#include #include #include #include @@ -1413,6 +1417,283 @@ struct CaseWhenFunctor { } }; +Status ExecVarWidthScalarCaseWhen(KernelContext* ctx, const ExecBatch& batch, + Datum* out) { + const auto& conds = checked_cast(*batch.values[0].scalar()); + Datum result; + for (size_t i = 0; i < batch.values.size() - 1; i++) { + if (i < conds.value.size()) { + const Scalar& cond = *conds.value[i]; + if (cond.is_valid && internal::UnboxScalar::Unbox(cond)) { + result = batch[i + 1]; + break; + } + } else { + // ELSE clause + result = batch[i + 1]; + break; + } + } + if (out->is_scalar()) { + DCHECK(result.is_scalar() || result.kind() == Datum::NONE); + *out = result.is_scalar() ? result.scalar() : MakeNullScalar(out->type()); + return Status::OK(); + } + ArrayData* output = out->mutable_array(); + if (!result.is_value()) { + // All conditions false, no 'else' argument + ARROW_ASSIGN_OR_RAISE( + auto array, MakeArrayOfNull(output->type, batch.length, ctx->memory_pool())); + *output = *array->data(); + } else if (result.is_scalar()) { + ARROW_ASSIGN_OR_RAISE(auto array, MakeArrayFromScalar(*result.scalar(), batch.length, + ctx->memory_pool())); + *output = *array->data(); + } else { + *output = *result.array(); + } + return Status::OK(); +} + +// Use std::function for reserve_data to avoid instantiating template so much +template +static Status ExecVarWidthArrayCaseWhenImpl( + KernelContext* ctx, const ExecBatch& batch, Datum* out, + std::function reserve_data, AppendScalar append_scalar) { + const auto& conds_array = *batch.values[0].array(); + ArrayData* output = out->mutable_array(); + const bool have_else_arg = + static_cast(conds_array.type->num_fields()) < (batch.values.size() - 1); + std::unique_ptr raw_builder; + RETURN_NOT_OK(MakeBuilder(ctx->memory_pool(), out->type(), &raw_builder)); + RETURN_NOT_OK(raw_builder->Reserve(batch.length)); + RETURN_NOT_OK(reserve_data(raw_builder.get())); + + for (int64_t row = 0; row < batch.length; row++) { + int64_t selected = have_else_arg ? static_cast(batch.values.size() - 1) : -1; + for (int64_t arg = 0; static_cast(arg) < conds_array.child_data.size(); + arg++) { + const ArrayData& cond_array = *conds_array.child_data[arg]; + if ((!cond_array.buffers[0] || + BitUtil::GetBit(cond_array.buffers[0]->data(), + conds_array.offset + cond_array.offset + row)) && + BitUtil::GetBit(cond_array.buffers[1]->data(), + conds_array.offset + cond_array.offset + row)) { + selected = arg + 1; + break; + } + } + if (selected < 0) { + RETURN_NOT_OK(raw_builder->AppendNull()); + continue; + } + const Datum& source = batch.values[selected]; + if (source.is_scalar()) { + const auto& scalar = *source.scalar(); + if (!scalar.is_valid) { + RETURN_NOT_OK(raw_builder->AppendNull()); + } else { + RETURN_NOT_OK(append_scalar(raw_builder.get(), scalar)); + } + } else { + const auto& array = source.array(); + if (!array->buffers[0] || + BitUtil::GetBit(array->buffers[0]->data(), array->offset + row)) { + RETURN_NOT_OK(raw_builder->AppendArraySlice(*array, row, /*length=*/1)); + } else { + RETURN_NOT_OK(raw_builder->AppendNull()); + } + } + } + + ARROW_ASSIGN_OR_RAISE(auto temp_output, raw_builder->Finish()); + *output = *temp_output->data(); + return Status::OK(); +} + +// Single instantiation using ArrayBuilder::AppendScalar for append_scalar +static Status ExecVarWidthArrayCaseWhen( + KernelContext* ctx, const ExecBatch& batch, Datum* out, + std::function reserve_data) { + return ExecVarWidthArrayCaseWhenImpl( + ctx, batch, out, std::move(reserve_data), + [](ArrayBuilder* raw_builder, const Scalar& scalar) { + return raw_builder->AppendScalar(scalar); + }); +} + +template +struct CaseWhenFunctor> { + using offset_type = typename Type::offset_type; + using BuilderType = typename TypeTraits::BuilderType; + static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + if (batch[0].null_count() > 0) { + return Status::Invalid("cond struct must not have outer nulls"); + } + if (batch[0].is_scalar()) { + return ExecVarWidthScalarCaseWhen(ctx, batch, out); + } + return ExecArray(ctx, batch, out); + } + + static Status ExecArray(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + return ExecVarWidthArrayCaseWhenImpl( + ctx, batch, out, + // ReserveData + [&](ArrayBuilder* raw_builder) { + int64_t reservation = 0; + for (size_t arg = 1; arg < batch.values.size(); arg++) { + auto source = batch.values[arg]; + if (source.is_scalar()) { + const auto& scalar = + checked_cast(*source.scalar()); + if (!scalar.value) continue; + reservation = + std::max(reservation, batch.length * scalar.value->size()); + } else { + const auto& array = *source.array(); + const auto& offsets = array.GetValues(1); + reservation = + std::max(reservation, offsets[array.length] - offsets[0]); + } + } + // checked_cast works since (Large)StringBuilder <: (Large)BinaryBuilder + return checked_cast(raw_builder)->ReserveData(reservation); + }, + // AppendScalar + [](ArrayBuilder* raw_builder, const Scalar& raw_scalar) { + const auto& scalar = checked_cast(raw_scalar); + return checked_cast(raw_builder) + ->Append(scalar.value->data(), + static_cast(scalar.value->size())); + }); + } +}; + +template +struct CaseWhenFunctor> { + using offset_type = typename Type::offset_type; + using BuilderType = typename TypeTraits::BuilderType; + static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + if (batch[0].null_count() > 0) { + return Status::Invalid("cond struct must not have outer nulls"); + } + if (batch[0].is_scalar()) { + return ExecVarWidthScalarCaseWhen(ctx, batch, out); + } + return ExecArray(ctx, batch, out); + } + + static Status ExecArray(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + return ExecVarWidthArrayCaseWhen( + ctx, batch, out, + // ReserveData + [&](ArrayBuilder* raw_builder) { + auto builder = checked_cast(raw_builder); + auto child_builder = builder->value_builder(); + + int64_t reservation = 0; + for (size_t arg = 1; arg < batch.values.size(); arg++) { + auto source = batch.values[arg]; + if (!source.is_array()) { + const auto& scalar = checked_cast(*source.scalar()); + if (!scalar.value) continue; + reservation = + std::max(reservation, batch.length * scalar.value->length()); + } else { + const auto& array = *source.array(); + reservation = std::max(reservation, array.child_data[0]->length); + } + } + return child_builder->Reserve(reservation); + }); + } +}; + +// No-op reserve function, pulled out to avoid apparent miscompilation on MinGW +Status ReserveNoData(ArrayBuilder*) { return Status::OK(); } + +template <> +struct CaseWhenFunctor { + static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + if (batch[0].null_count() > 0) { + return Status::Invalid("cond struct must not have outer nulls"); + } + if (batch[0].is_scalar()) { + return ExecVarWidthScalarCaseWhen(ctx, batch, out); + } + return ExecArray(ctx, batch, out); + } + + static Status ExecArray(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + std::function reserve_data = ReserveNoData; + return ExecVarWidthArrayCaseWhen(ctx, batch, out, std::move(reserve_data)); + } +}; + +template <> +struct CaseWhenFunctor { + static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + if (batch[0].null_count() > 0) { + return Status::Invalid("cond struct must not have outer nulls"); + } + if (batch[0].is_scalar()) { + return ExecVarWidthScalarCaseWhen(ctx, batch, out); + } + return ExecArray(ctx, batch, out); + } + + static Status ExecArray(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + std::function reserve_data = ReserveNoData; + return ExecVarWidthArrayCaseWhen(ctx, batch, out, std::move(reserve_data)); + } +}; + +template <> +struct CaseWhenFunctor { + static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + if (batch[0].null_count() > 0) { + return Status::Invalid("cond struct must not have outer nulls"); + } + if (batch[0].is_scalar()) { + return ExecVarWidthScalarCaseWhen(ctx, batch, out); + } + return ExecArray(ctx, batch, out); + } + + static Status ExecArray(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + const auto& ty = checked_cast(*out->type()); + const int64_t width = ty.list_size(); + return ExecVarWidthArrayCaseWhen( + ctx, batch, out, + // ReserveData + [&](ArrayBuilder* raw_builder) { + int64_t children = batch.length * width; + return checked_cast(raw_builder) + ->value_builder() + ->Reserve(children); + }); + } +}; + +template +struct CaseWhenFunctor> { + static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + if (batch[0].null_count() > 0) { + return Status::Invalid("cond struct must not have outer nulls"); + } + if (batch[0].is_scalar()) { + return ExecVarWidthScalarCaseWhen(ctx, batch, out); + } + return ExecArray(ctx, batch, out); + } + + static Status ExecArray(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + std::function reserve_data = ReserveNoData; + return ExecVarWidthArrayCaseWhen(ctx, batch, out, std::move(reserve_data)); + } +}; + struct CoalesceFunction : ScalarFunction { using ScalarFunction::ScalarFunction; @@ -1841,9 +2122,15 @@ void AddCaseWhenKernel(const std::shared_ptr& scalar_function, OutputType(LastType), /*is_varargs=*/true), exec); - kernel.null_handling = NullHandling::COMPUTED_PREALLOCATE; - kernel.mem_allocation = MemAllocation::PREALLOCATE; - kernel.can_write_into_slices = is_fixed_width(get_id.id); + if (is_fixed_width(get_id.id)) { + kernel.null_handling = NullHandling::COMPUTED_PREALLOCATE; + kernel.mem_allocation = MemAllocation::PREALLOCATE; + kernel.can_write_into_slices = true; + } else { + kernel.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE; + kernel.mem_allocation = MemAllocation::NO_PREALLOCATE; + kernel.can_write_into_slices = false; + } DCHECK_OK(scalar_function->AddKernel(std::move(kernel))); } @@ -1855,6 +2142,14 @@ void AddPrimitiveCaseWhenKernels(const std::shared_ptr& scalar } } +void AddBinaryCaseWhenKernels(const std::shared_ptr& scalar_function, + const std::vector>& types) { + for (auto&& type : types) { + auto exec = GenerateTypeAgnosticVarBinaryBase(*type); + AddCaseWhenKernel(scalar_function, type, std::move(exec)); + } +} + void AddCoalesceKernel(const std::shared_ptr& scalar_function, detail::GetTypeId get_id, ArrayKernelExec exec) { ScalarKernel kernel(KernelSignature::Make({InputType(get_id.id)}, OutputType(FirstType), @@ -1957,6 +2252,15 @@ void RegisterScalarIfElse(FunctionRegistry* registry) { CaseWhenFunctor::Exec); AddCaseWhenKernel(func, Type::DECIMAL128, CaseWhenFunctor::Exec); AddCaseWhenKernel(func, Type::DECIMAL256, CaseWhenFunctor::Exec); + AddBinaryCaseWhenKernels(func, BaseBinaryTypes()); + AddCaseWhenKernel(func, Type::FIXED_SIZE_LIST, + CaseWhenFunctor::Exec); + AddCaseWhenKernel(func, Type::LIST, CaseWhenFunctor::Exec); + AddCaseWhenKernel(func, Type::LARGE_LIST, CaseWhenFunctor::Exec); + AddCaseWhenKernel(func, Type::MAP, CaseWhenFunctor::Exec); + AddCaseWhenKernel(func, Type::STRUCT, CaseWhenFunctor::Exec); + AddCaseWhenKernel(func, Type::DENSE_UNION, CaseWhenFunctor::Exec); + AddCaseWhenKernel(func, Type::SPARSE_UNION, CaseWhenFunctor::Exec); DCHECK_OK(registry->AddFunction(std::move(func))); } { diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else_benchmark.cc b/cpp/src/arrow/compute/kernels/scalar_if_else_benchmark.cc index 9b59d54c3da..a8041f9086e 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else_benchmark.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else_benchmark.cc @@ -27,32 +27,31 @@ namespace arrow { namespace compute { const int64_t kNumItems = 1024 * 1024; +const int64_t kFewItems = 64 * 1024; template -struct SetBytesProcessed {}; +struct GetBytesProcessed {}; + +template <> +struct GetBytesProcessed { + static int64_t Get(const std::shared_ptr& arr) { return arr->length() / 8; } +}; template -struct SetBytesProcessed> { - static void Set(const std::shared_ptr& cond, const std::shared_ptr& left, - const std::shared_ptr& right, benchmark::State* state) { +struct GetBytesProcessed> { + static int64_t Get(const std::shared_ptr& arr) { using CType = typename Type::c_type; - state->SetBytesProcessed(state->iterations() * - (cond->length() / 8 + 2 * cond->length() * sizeof(CType))); + return arr->length() * sizeof(CType); } }; template -struct SetBytesProcessed> { - static void Set(const std::shared_ptr& cond, const std::shared_ptr& left, - const std::shared_ptr& right, benchmark::State* state) { +struct GetBytesProcessed> { + static int64_t Get(const std::shared_ptr& arr) { using ArrayType = typename TypeTraits::ArrayType; using OffsetType = typename TypeTraits::OffsetType::c_type; - - state->SetBytesProcessed( - state->iterations() * - (cond->length() / 8 + 2 * cond->length() * sizeof(OffsetType) + - std::static_pointer_cast(left)->total_values_length() + - std::static_pointer_cast(right)->total_values_length())); + return arr->length() * sizeof(OffsetType) + + std::static_pointer_cast(arr)->total_values_length(); } }; @@ -80,7 +79,10 @@ static void IfElseBench(benchmark::State& state) { ABORT_NOT_OK(IfElse(cond, left, right)); } - SetBytesProcessed::Set(cond, left, right, &state); + state.SetBytesProcessed(state.iterations() * + (GetBytesProcessed::Get(cond) + + GetBytesProcessed::Get(left) + + GetBytesProcessed::Get(right))); } template @@ -109,7 +111,10 @@ static void IfElseBenchContiguous(benchmark::State& state) { ABORT_NOT_OK(IfElse(cond, left, right)); } - SetBytesProcessed::Set(cond, left, right, &state); + state.SetBytesProcessed(state.iterations() * + (GetBytesProcessed::Get(cond) + + GetBytesProcessed::Get(left) + + GetBytesProcessed::Get(right))); } static void IfElseBench64(benchmark::State& state) { @@ -146,7 +151,6 @@ static void IfElseBenchString32Contiguous(benchmark::State& state) { template static void CaseWhenBench(benchmark::State& state) { - using CType = typename Type::c_type; auto type = TypeTraits::type_singleton(); using ArrayType = typename TypeTraits::ArrayType; @@ -155,12 +159,6 @@ static void CaseWhenBench(benchmark::State& state) { random::RandomArrayGenerator rand(/*seed=*/0); - auto cond1 = std::static_pointer_cast( - rand.ArrayOf(boolean(), len, /*null_probability=*/0.01)); - auto cond2 = std::static_pointer_cast( - rand.ArrayOf(boolean(), len, /*null_probability=*/0.01)); - auto cond3 = std::static_pointer_cast( - rand.ArrayOf(boolean(), len, /*null_probability=*/0.01)); auto cond_field = field("cond", boolean(), key_value_metadata({{"null_probability", "0.01"}})); auto cond = rand.ArrayOf(*field("", struct_({cond_field, cond_field, cond_field}), @@ -180,12 +178,44 @@ static void CaseWhenBench(benchmark::State& state) { val3->Slice(offset), val4->Slice(offset)})); } - state.SetBytesProcessed(state.iterations() * (len - offset) * sizeof(CType)); + // Set bytes processed to ~length of output + state.SetBytesProcessed(state.iterations() * GetBytesProcessed::Get(val1)); + state.SetItemsProcessed(state.iterations() * (len - offset)); +} + +static void CaseWhenBenchList(benchmark::State& state) { + auto type = list(int64()); + auto fld = field("", type); + + int64_t len = state.range(0); + int64_t offset = state.range(1); + + random::RandomArrayGenerator rand(/*seed=*/0); + + auto cond_field = + field("cond", boolean(), key_value_metadata({{"null_probability", "0.01"}})); + auto cond = rand.ArrayOf(*field("", struct_({cond_field, cond_field, cond_field}), + key_value_metadata({{"null_probability", "0.0"}})), + len); + auto val1 = rand.ArrayOf(*fld, len); + auto val2 = rand.ArrayOf(*fld, len); + auto val3 = rand.ArrayOf(*fld, len); + auto val4 = rand.ArrayOf(*fld, len); + for (auto _ : state) { + ABORT_NOT_OK( + CaseWhen(cond->Slice(offset), {val1->Slice(offset), val2->Slice(offset), + val3->Slice(offset), val4->Slice(offset)})); + } + + // Set bytes processed to ~length of output + state.SetBytesProcessed(state.iterations() * + GetBytesProcessed::Get( + std::static_pointer_cast(val1)->values())); + state.SetItemsProcessed(state.iterations() * (len - offset)); } template static void CaseWhenBenchContiguous(benchmark::State& state) { - using CType = typename Type::c_type; auto type = TypeTraits::type_singleton(); using ArrayType = typename TypeTraits::ArrayType; @@ -216,7 +246,9 @@ static void CaseWhenBenchContiguous(benchmark::State& state) { val3->Slice(offset)})); } - state.SetBytesProcessed(state.iterations() * (len - offset) * sizeof(CType)); + // Set bytes processed to ~length of output + state.SetBytesProcessed(state.iterations() * GetBytesProcessed::Get(val1)); + state.SetItemsProcessed(state.iterations() * (len - offset)); } static void CaseWhenBench64(benchmark::State& state) { @@ -227,6 +259,14 @@ static void CaseWhenBench64Contiguous(benchmark::State& state) { return CaseWhenBenchContiguous(state); } +static void CaseWhenBenchString(benchmark::State& state) { + return CaseWhenBench(state); +} + +static void CaseWhenBenchStringContiguous(benchmark::State& state) { + return CaseWhenBenchContiguous(state); +} + template static void CoalesceBench(benchmark::State& state) { using CType = typename Type::c_type; @@ -337,6 +377,15 @@ BENCHMARK(CaseWhenBench64)->Args({kNumItems, 99}); BENCHMARK(CaseWhenBench64Contiguous)->Args({kNumItems, 0}); BENCHMARK(CaseWhenBench64Contiguous)->Args({kNumItems, 99}); +BENCHMARK(CaseWhenBenchList)->Args({kFewItems, 0}); +BENCHMARK(CaseWhenBenchList)->Args({kFewItems, 99}); + +BENCHMARK(CaseWhenBenchString)->Args({kFewItems, 0}); +BENCHMARK(CaseWhenBenchString)->Args({kFewItems, 99}); + +BENCHMARK(CaseWhenBenchStringContiguous)->Args({kFewItems, 0}); +BENCHMARK(CaseWhenBenchStringContiguous)->Args({kFewItems, 99}); + BENCHMARK(CoalesceBench64)->Args({kNumItems, 0}); BENCHMARK(CoalesceBench64)->Args({kNumItems, 99}); diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc index 8a6ccd69865..b3b0f26cead 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc @@ -531,6 +531,10 @@ TYPED_TEST(TestCaseWhenNumeric, FixedSize) { CheckScalar("case_when", {MakeStruct({}), values1}, values1); CheckScalar("case_when", {MakeStruct({}), values_null}, values_null); + CheckScalar("case_when", {MakeStruct({cond_true}), scalar1, values1}, + *MakeArrayFromScalar(*scalar1, 4)); + CheckScalar("case_when", {MakeStruct({cond_false}), scalar1, values1}, values1); + CheckScalar("case_when", {MakeStruct({cond_true}), values1}, values1); CheckScalar("case_when", {MakeStruct({cond_false}), values1}, values_null); CheckScalar("case_when", {MakeStruct({cond_null}), values1}, values_null); @@ -603,6 +607,23 @@ TYPED_TEST(TestCaseWhenNumeric, FixedSize) { {Datum(*MakeArrayOfNull(struct_({field("", boolean())}), 4)), Datum(values1)})); } +TYPED_TEST(TestCaseWhenNumeric, ListOfType) { + // More minimal test to check type coverage + auto type = list(default_type_instance()); + auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]"); + auto cond2 = ArrayFromJSON(boolean(), "[true, false, true, null]"); + auto values_null = ArrayFromJSON(type, "[null, null, null, null]"); + auto values1 = ArrayFromJSON(type, R"([[1, 2], null, [3, 4, 5], [6, null]])"); + auto values2 = ArrayFromJSON(type, R"([[8, 9, 10], [11], null, [12]])"); + + CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2}, + ArrayFromJSON(type, R"([[1, 2], null, null, null])")); + CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2, values1}, + ArrayFromJSON(type, R"([[1, 2], null, null, [6, null]])")); + CheckScalar("case_when", {MakeStruct({cond1, cond2}), values_null, values2, values1}, + ArrayFromJSON(type, R"([null, null, null, [6, null]])")); +} + TEST(TestCaseWhen, Null) { auto cond_true = ScalarFromJSON(boolean(), "true"); auto cond_false = ScalarFromJSON(boolean(), "false"); @@ -632,6 +653,10 @@ TEST(TestCaseWhen, Boolean) { CheckScalar("case_when", {MakeStruct({}), values1}, values1); CheckScalar("case_when", {MakeStruct({}), values_null}, values_null); + CheckScalar("case_when", {MakeStruct({cond_true}), scalar1, values1}, + *MakeArrayFromScalar(*scalar1, 4)); + CheckScalar("case_when", {MakeStruct({cond_false}), scalar1, values1}, values1); + CheckScalar("case_when", {MakeStruct({cond_true}), values1}, values1); CheckScalar("case_when", {MakeStruct({cond_false}), values1}, values_null); CheckScalar("case_when", {MakeStruct({cond_null}), values1}, values_null); @@ -685,6 +710,10 @@ TEST(TestCaseWhen, DayTimeInterval) { CheckScalar("case_when", {MakeStruct({}), values1}, values1); CheckScalar("case_when", {MakeStruct({}), values_null}, values_null); + CheckScalar("case_when", {MakeStruct({cond_true}), scalar1, values1}, + *MakeArrayFromScalar(*scalar1, 4)); + CheckScalar("case_when", {MakeStruct({cond_false}), scalar1, values1}, values1); + CheckScalar("case_when", {MakeStruct({cond_true}), values1}, values1); CheckScalar("case_when", {MakeStruct({cond_false}), values1}, values_null); CheckScalar("case_when", {MakeStruct({cond_null}), values1}, values_null); @@ -739,6 +768,10 @@ TEST(TestCaseWhen, Decimal) { CheckScalar("case_when", {MakeStruct({}), values1}, values1); CheckScalar("case_when", {MakeStruct({}), values_null}, values_null); + CheckScalar("case_when", {MakeStruct({cond_true}), scalar1, values1}, + *MakeArrayFromScalar(*scalar1, 4)); + CheckScalar("case_when", {MakeStruct({cond_false}), scalar1, values1}, values1); + CheckScalar("case_when", {MakeStruct({cond_true}), values1}, values1); CheckScalar("case_when", {MakeStruct({cond_false}), values1}, values_null); CheckScalar("case_when", {MakeStruct({cond_null}), values1}, values_null); @@ -794,6 +827,10 @@ TEST(TestCaseWhen, FixedSizeBinary) { CheckScalar("case_when", {MakeStruct({}), values1}, values1); CheckScalar("case_when", {MakeStruct({}), values_null}, values_null); + CheckScalar("case_when", {MakeStruct({cond_true}), scalar1, values1}, + *MakeArrayFromScalar(*scalar1, 4)); + CheckScalar("case_when", {MakeStruct({cond_false}), scalar1, values1}, values1); + CheckScalar("case_when", {MakeStruct({cond_true}), values1}, values1); CheckScalar("case_when", {MakeStruct({cond_false}), values1}, values_null); CheckScalar("case_when", {MakeStruct({cond_null}), values1}, values_null); @@ -830,6 +867,606 @@ TEST(TestCaseWhen, FixedSizeBinary) { ArrayFromJSON(type, R"([null, null, null, "efg"])")); } +template +class TestCaseWhenBinary : public ::testing::Test {}; + +TYPED_TEST_SUITE(TestCaseWhenBinary, BinaryArrowTypes); + +TYPED_TEST(TestCaseWhenBinary, Basics) { + auto type = default_type_instance(); + auto cond_true = ScalarFromJSON(boolean(), "true"); + auto cond_false = ScalarFromJSON(boolean(), "false"); + auto cond_null = ScalarFromJSON(boolean(), "null"); + auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]"); + auto cond2 = ArrayFromJSON(boolean(), "[true, false, true, null]"); + auto scalar_null = ScalarFromJSON(type, "null"); + auto scalar1 = ScalarFromJSON(type, R"("aBxYz")"); + auto scalar2 = ScalarFromJSON(type, R"("b")"); + auto values_null = ArrayFromJSON(type, "[null, null, null, null]"); + auto values1 = ArrayFromJSON(type, R"(["cDE", null, "degfhi", "efg"])"); + auto values2 = ArrayFromJSON(type, R"(["fghijk", "ghi", null, "hi"])"); + + CheckScalar("case_when", {MakeStruct({}), values1}, values1); + CheckScalar("case_when", {MakeStruct({}), values_null}, values_null); + + CheckScalar("case_when", {MakeStruct({cond_true}), scalar1, values1}, + *MakeArrayFromScalar(*scalar1, 4)); + CheckScalar("case_when", {MakeStruct({cond_false}), scalar1, values1}, values1); + + CheckScalar("case_when", {MakeStruct({cond_true}), values1}, values1); + CheckScalar("case_when", {MakeStruct({cond_false}), values1}, values_null); + CheckScalar("case_when", {MakeStruct({cond_null}), values1}, values_null); + CheckScalar("case_when", {MakeStruct({cond_true}), values1, values2}, values1); + CheckScalar("case_when", {MakeStruct({cond_false}), values1, values2}, values2); + CheckScalar("case_when", {MakeStruct({cond_null}), values1, values2}, values2); + + CheckScalar("case_when", {MakeStruct({cond_true, cond_true}), values1, values2}, + values1); + CheckScalar("case_when", {MakeStruct({cond_false, cond_false}), values1, values2}, + values_null); + CheckScalar("case_when", {MakeStruct({cond_true, cond_false}), values1, values2}, + values1); + CheckScalar("case_when", {MakeStruct({cond_false, cond_true}), values1, values2}, + values2); + CheckScalar("case_when", {MakeStruct({cond_null, cond_true}), values1, values2}, + values2); + CheckScalar("case_when", + {MakeStruct({cond_false, cond_false}), values1, values2, values2}, values2); + + CheckScalar("case_when", {MakeStruct({cond1, cond2}), scalar1, scalar2}, + ArrayFromJSON(type, R"(["aBxYz", "aBxYz", "b", null])")); + CheckScalar("case_when", {MakeStruct({cond1}), scalar_null}, values_null); + CheckScalar("case_when", {MakeStruct({cond1}), scalar_null, scalar1}, + ArrayFromJSON(type, R"([null, null, "aBxYz", "aBxYz"])")); + CheckScalar("case_when", {MakeStruct({cond1, cond2}), scalar1, scalar2, scalar1}, + ArrayFromJSON(type, R"(["aBxYz", "aBxYz", "b", "aBxYz"])")); + + CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2}, + ArrayFromJSON(type, R"(["cDE", null, null, null])")); + CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2, values1}, + ArrayFromJSON(type, R"(["cDE", null, null, "efg"])")); + CheckScalar("case_when", {MakeStruct({cond1, cond2}), values_null, values2, values1}, + ArrayFromJSON(type, R"([null, null, null, "efg"])")); +} + +template +class TestCaseWhenList : public ::testing::Test {}; + +TYPED_TEST_SUITE(TestCaseWhenList, ListArrowTypes); + +TYPED_TEST(TestCaseWhenList, ListOfString) { + auto type = std::make_shared(utf8()); + auto cond_true = ScalarFromJSON(boolean(), "true"); + auto cond_false = ScalarFromJSON(boolean(), "false"); + auto cond_null = ScalarFromJSON(boolean(), "null"); + auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]"); + auto cond2 = ArrayFromJSON(boolean(), "[true, false, true, null]"); + auto scalar_null = ScalarFromJSON(type, "null"); + auto scalar1 = ScalarFromJSON(type, R"(["aB", "xYz"])"); + auto scalar2 = ScalarFromJSON(type, R"(["b", null])"); + auto values_null = ArrayFromJSON(type, "[null, null, null, null]"); + auto values1 = + ArrayFromJSON(type, R"([["cD", "E"], null, ["de", "gf", "hi"], ["ef", "g"]])"); + auto values2 = ArrayFromJSON(type, R"([["f", "ghi", "jk"], ["ghi"], null, ["hi"]])"); + + CheckScalar("case_when", {MakeStruct({}), values1}, values1); + CheckScalar("case_when", {MakeStruct({}), values_null}, values_null); + + CheckScalar("case_when", {MakeStruct({cond_true}), scalar1, values1}, + *MakeArrayFromScalar(*scalar1, 4)); + CheckScalar("case_when", {MakeStruct({cond_false}), scalar1, values1}, values1); + + CheckScalar("case_when", {MakeStruct({cond_true}), values1}, values1); + CheckScalar("case_when", {MakeStruct({cond_false}), values1}, values_null); + CheckScalar("case_when", {MakeStruct({cond_null}), values1}, values_null); + CheckScalar("case_when", {MakeStruct({cond_true}), values1, values2}, values1); + CheckScalar("case_when", {MakeStruct({cond_false}), values1, values2}, values2); + CheckScalar("case_when", {MakeStruct({cond_null}), values1, values2}, values2); + + CheckScalar("case_when", {MakeStruct({cond_true, cond_true}), values1, values2}, + values1); + CheckScalar("case_when", {MakeStruct({cond_false, cond_false}), values1, values2}, + values_null); + CheckScalar("case_when", {MakeStruct({cond_true, cond_false}), values1, values2}, + values1); + CheckScalar("case_when", {MakeStruct({cond_false, cond_true}), values1, values2}, + values2); + CheckScalar("case_when", {MakeStruct({cond_null, cond_true}), values1, values2}, + values2); + CheckScalar("case_when", + {MakeStruct({cond_false, cond_false}), values1, values2, values2}, values2); + + CheckScalar( + "case_when", {MakeStruct({cond1, cond2}), scalar1, scalar2}, + ArrayFromJSON(type, R"([["aB", "xYz"], ["aB", "xYz"], ["b", null], null])")); + CheckScalar("case_when", {MakeStruct({cond1}), scalar_null}, values_null); + CheckScalar("case_when", {MakeStruct({cond1}), scalar_null, scalar1}, + ArrayFromJSON(type, R"([null, null, ["aB", "xYz"], ["aB", "xYz"]])")); + CheckScalar("case_when", {MakeStruct({cond1, cond2}), scalar1, scalar2, scalar1}, + ArrayFromJSON( + type, R"([["aB", "xYz"], ["aB", "xYz"], ["b", null], ["aB", "xYz"]])")); + + CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2}, + ArrayFromJSON(type, R"([["cD", "E"], null, null, null])")); + CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2, values1}, + ArrayFromJSON(type, R"([["cD", "E"], null, null, ["ef", "g"]])")); + CheckScalar("case_when", {MakeStruct({cond1, cond2}), values_null, values2, values1}, + ArrayFromJSON(type, R"([null, null, null, ["ef", "g"]])")); +} + +// More minimal tests to check type coverage +TYPED_TEST(TestCaseWhenList, ListOfBool) { + auto type = std::make_shared(boolean()); + auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]"); + auto cond2 = ArrayFromJSON(boolean(), "[true, false, true, null]"); + auto values_null = ArrayFromJSON(type, "[null, null, null, null]"); + auto values1 = ArrayFromJSON(type, R"([[true], null, [false], [false, null]])"); + auto values2 = ArrayFromJSON(type, R"([[false], [false], null, [true]])"); + + CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2}, + ArrayFromJSON(type, R"([[true], null, null, null])")); + CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2, values1}, + ArrayFromJSON(type, R"([[true], null, null, [false, null]])")); + CheckScalar("case_when", {MakeStruct({cond1, cond2}), values_null, values2, values1}, + ArrayFromJSON(type, R"([null, null, null, [false, null]])")); +} + +TYPED_TEST(TestCaseWhenList, ListOfInt) { + auto type = std::make_shared(int64()); + auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]"); + auto cond2 = ArrayFromJSON(boolean(), "[true, false, true, null]"); + auto values_null = ArrayFromJSON(type, "[null, null, null, null]"); + auto values1 = ArrayFromJSON(type, R"([[1, 2], null, [3, 4, 5], [6, null]])"); + auto values2 = ArrayFromJSON(type, R"([[8, 9, 10], [11], null, [12]])"); + + CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2}, + ArrayFromJSON(type, R"([[1, 2], null, null, null])")); + CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2, values1}, + ArrayFromJSON(type, R"([[1, 2], null, null, [6, null]])")); + CheckScalar("case_when", {MakeStruct({cond1, cond2}), values_null, values2, values1}, + ArrayFromJSON(type, R"([null, null, null, [6, null]])")); +} + +TYPED_TEST(TestCaseWhenList, ListOfDayTimeInterval) { + auto type = std::make_shared(day_time_interval()); + auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]"); + auto cond2 = ArrayFromJSON(boolean(), "[true, false, true, null]"); + auto values_null = ArrayFromJSON(type, "[null, null, null, null]"); + auto values1 = + ArrayFromJSON(type, R"([[[1, 2]], null, [[3, 4], [5, 0]], [[6, 7], null]])"); + auto values2 = ArrayFromJSON(type, R"([[[8, 9], null], [[11, 12]], null, [[12, 1]]])"); + + CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2}, + ArrayFromJSON(type, R"([[[1, 2]], null, null, null])")); + CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2, values1}, + ArrayFromJSON(type, R"([[[1, 2]], null, null, [[6, 7], null]])")); + CheckScalar("case_when", {MakeStruct({cond1, cond2}), values_null, values2, values1}, + ArrayFromJSON(type, R"([null, null, null, [[6, 7], null]])")); +} + +TYPED_TEST(TestCaseWhenList, ListOfDecimal) { + for (const auto& decimal_ty : + std::vector>{decimal128(3, 2), decimal256(3, 2)}) { + auto type = std::make_shared(decimal_ty); + auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]"); + auto cond2 = ArrayFromJSON(boolean(), "[true, false, true, null]"); + auto values_null = ArrayFromJSON(type, "[null, null, null, null]"); + auto values1 = ArrayFromJSON( + type, R"([["1.23", "2.34"], null, ["3.45", "4.56", "5.67"], ["6.78", null]])"); + auto values2 = + ArrayFromJSON(type, R"([["8.90", "9.01", "1.02"], ["1.12"], null, ["1.23"]])"); + + CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2}, + ArrayFromJSON(type, R"([["1.23", "2.34"], null, null, null])")); + CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2, values1}, + ArrayFromJSON(type, R"([["1.23", "2.34"], null, null, ["6.78", null]])")); + CheckScalar("case_when", {MakeStruct({cond1, cond2}), values_null, values2, values1}, + ArrayFromJSON(type, R"([null, null, null, ["6.78", null]])")); + } +} + +TYPED_TEST(TestCaseWhenList, ListOfFixedSizeBinary) { + auto type = std::make_shared(fixed_size_binary(4)); + auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]"); + auto cond2 = ArrayFromJSON(boolean(), "[true, false, true, null]"); + auto values_null = ArrayFromJSON(type, "[null, null, null, null]"); + auto values1 = ArrayFromJSON( + type, R"([["1.23", "2.34"], null, ["3.45", "4.56", "5.67"], ["6.78", null]])"); + auto values2 = + ArrayFromJSON(type, R"([["8.90", "9.01", "1.02"], ["1.12"], null, ["1.23"]])"); + + CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2}, + ArrayFromJSON(type, R"([["1.23", "2.34"], null, null, null])")); + CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2, values1}, + ArrayFromJSON(type, R"([["1.23", "2.34"], null, null, ["6.78", null]])")); + CheckScalar("case_when", {MakeStruct({cond1, cond2}), values_null, values2, values1}, + ArrayFromJSON(type, R"([null, null, null, ["6.78", null]])")); +} + +TYPED_TEST(TestCaseWhenList, ListOfListOfInt) { + auto type = std::make_shared(list(int64())); + auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]"); + auto cond2 = ArrayFromJSON(boolean(), "[true, false, true, null]"); + auto values_null = ArrayFromJSON(type, "[null, null, null, null]"); + auto values1 = + ArrayFromJSON(type, R"([[[1, 2], []], null, [[3, 4, 5]], [[6, null], null]])"); + auto values2 = ArrayFromJSON(type, R"([[[8, 9, 10]], [[11]], null, [[12]]])"); + + CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2}, + ArrayFromJSON(type, R"([[[1, 2], []], null, null, null])")); + CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2, values1}, + ArrayFromJSON(type, R"([[[1, 2], []], null, null, [[6, null], null]])")); + CheckScalar("case_when", {MakeStruct({cond1, cond2}), values_null, values2, values1}, + ArrayFromJSON(type, R"([null, null, null, [[6, null], null]])")); +} + +TEST(TestCaseWhen, Map) { + auto type = map(int64(), utf8()); + auto cond_true = ScalarFromJSON(boolean(), "true"); + auto cond_false = ScalarFromJSON(boolean(), "false"); + auto cond_null = ScalarFromJSON(boolean(), "null"); + auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]"); + auto cond2 = ArrayFromJSON(boolean(), "[true, false, true, null]"); + auto scalar_null = ScalarFromJSON(type, "null"); + auto scalar1 = ScalarFromJSON(type, R"([[1, "abc"], [2, "de"]])"); + auto scalar2 = ScalarFromJSON(type, R"([[3, "fghi"]])"); + auto values_null = ArrayFromJSON(type, "[null, null, null, null]"); + auto values1 = + ArrayFromJSON(type, R"([[[4, "kl"]], null, [[5, "mn"]], [[6, "o"], [7, "pq"]]])"); + auto values2 = ArrayFromJSON(type, R"([[[8, "r"], [9, "st"]], [[10, "u"]], null, []])"); + + CheckScalar("case_when", {MakeStruct({}), values1}, values1); + CheckScalar("case_when", {MakeStruct({}), values_null}, values_null); + + CheckScalar("case_when", {MakeStruct({cond_true}), scalar1, values1}, + *MakeArrayFromScalar(*scalar1, 4)); + CheckScalar("case_when", {MakeStruct({cond_false}), scalar1, values1}, values1); + + CheckScalar("case_when", {MakeStruct({cond_true}), values1}, values1); + CheckScalar("case_when", {MakeStruct({cond_false}), values1}, values_null); + CheckScalar("case_when", {MakeStruct({cond_null}), values1}, values_null); + CheckScalar("case_when", {MakeStruct({cond_true}), values1, values2}, values1); + CheckScalar("case_when", {MakeStruct({cond_false}), values1, values2}, values2); + CheckScalar("case_when", {MakeStruct({cond_null}), values1, values2}, values2); + + CheckScalar("case_when", {MakeStruct({cond_true, cond_true}), values1, values2}, + values1); + CheckScalar("case_when", {MakeStruct({cond_false, cond_false}), values1, values2}, + values_null); + CheckScalar("case_when", {MakeStruct({cond_true, cond_false}), values1, values2}, + values1); + CheckScalar("case_when", {MakeStruct({cond_false, cond_true}), values1, values2}, + values2); + CheckScalar("case_when", {MakeStruct({cond_null, cond_true}), values1, values2}, + values2); + CheckScalar("case_when", + {MakeStruct({cond_false, cond_false}), values1, values2, values2}, values2); + + CheckScalar( + "case_when", {MakeStruct({cond1, cond2}), scalar1, scalar2}, + ArrayFromJSON( + type, + R"([[[1, "abc"], [2, "de"]], [[1, "abc"], [2, "de"]], [[3, "fghi"]], null])")); + CheckScalar("case_when", {MakeStruct({cond1}), scalar_null}, values_null); + CheckScalar( + "case_when", {MakeStruct({cond1}), scalar_null, scalar1}, + ArrayFromJSON(type, + R"([null, null, [[1, "abc"], [2, "de"]], [[1, "abc"], [2, "de"]]])")); + CheckScalar( + "case_when", {MakeStruct({cond1, cond2}), scalar1, scalar2, scalar1}, + ArrayFromJSON( + type, + R"([[[1, "abc"], [2, "de"]], [[1, "abc"], [2, "de"]], [[3, "fghi"]], [[1, "abc"], [2, "de"]]])")); + + CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2}, + ArrayFromJSON(type, R"([[[4, "kl"]], null, null, null])")); + CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2, values1}, + ArrayFromJSON(type, R"([[[4, "kl"]], null, null, [[6, "o"], [7, "pq"]]])")); + CheckScalar("case_when", {MakeStruct({cond1, cond2}), values_null, values2, values1}, + ArrayFromJSON(type, R"([null, null, null, [[6, "o"], [7, "pq"]]])")); +} + +TEST(TestCaseWhen, FixedSizeListOfInt) { + auto type = fixed_size_list(int64(), 2); + auto cond_true = ScalarFromJSON(boolean(), "true"); + auto cond_false = ScalarFromJSON(boolean(), "false"); + auto cond_null = ScalarFromJSON(boolean(), "null"); + auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]"); + auto cond2 = ArrayFromJSON(boolean(), "[true, false, true, null]"); + auto scalar_null = ScalarFromJSON(type, "null"); + auto scalar1 = ScalarFromJSON(type, R"([1, 2])"); + auto scalar2 = ScalarFromJSON(type, R"([3, null])"); + auto values_null = ArrayFromJSON(type, "[null, null, null, null]"); + auto values1 = ArrayFromJSON(type, R"([[4, 5], null, [6, 7], [8, 9]])"); + auto values2 = ArrayFromJSON(type, R"([[10, 11], [12, null], null, [null, 13]])"); + + CheckScalar("case_when", {MakeStruct({}), values1}, values1); + CheckScalar("case_when", {MakeStruct({}), values_null}, values_null); + + CheckScalar("case_when", {MakeStruct({cond_true}), scalar1, values1}, + *MakeArrayFromScalar(*scalar1, 4)); + CheckScalar("case_when", {MakeStruct({cond_false}), scalar1, values1}, values1); + + CheckScalar("case_when", {MakeStruct({cond_true}), values1}, values1); + CheckScalar("case_when", {MakeStruct({cond_false}), values1}, values_null); + CheckScalar("case_when", {MakeStruct({cond_null}), values1}, values_null); + CheckScalar("case_when", {MakeStruct({cond_true}), values1, values2}, values1); + CheckScalar("case_when", {MakeStruct({cond_false}), values1, values2}, values2); + CheckScalar("case_when", {MakeStruct({cond_null}), values1, values2}, values2); + + CheckScalar("case_when", {MakeStruct({cond_true, cond_true}), values1, values2}, + values1); + CheckScalar("case_when", {MakeStruct({cond_false, cond_false}), values1, values2}, + values_null); + CheckScalar("case_when", {MakeStruct({cond_true, cond_false}), values1, values2}, + values1); + CheckScalar("case_when", {MakeStruct({cond_false, cond_true}), values1, values2}, + values2); + CheckScalar("case_when", {MakeStruct({cond_null, cond_true}), values1, values2}, + values2); + CheckScalar("case_when", + {MakeStruct({cond_false, cond_false}), values1, values2, values2}, values2); + + CheckScalar("case_when", {MakeStruct({cond1, cond2}), scalar1, scalar2}, + ArrayFromJSON(type, R"([[1, 2], [1, 2], [3, null], null])")); + CheckScalar("case_when", {MakeStruct({cond1}), scalar_null}, values_null); + CheckScalar("case_when", {MakeStruct({cond1}), scalar_null, scalar1}, + ArrayFromJSON(type, R"([null, null, [1, 2], [1, 2]])")); + CheckScalar("case_when", {MakeStruct({cond1, cond2}), scalar1, scalar2, scalar1}, + ArrayFromJSON(type, R"([[1, 2], [1, 2], [3, null], [1, 2]])")); + + CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2}, + ArrayFromJSON(type, R"([[4, 5], null, null, null])")); + CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2, values1}, + ArrayFromJSON(type, R"([[4, 5], null, null, [8, 9]])")); + CheckScalar("case_when", {MakeStruct({cond1, cond2}), values_null, values2, values1}, + ArrayFromJSON(type, R"([null, null, null, [8, 9]])")); +} + +TEST(TestCaseWhen, FixedSizeListOfString) { + auto type = fixed_size_list(utf8(), 2); + auto cond_true = ScalarFromJSON(boolean(), "true"); + auto cond_false = ScalarFromJSON(boolean(), "false"); + auto cond_null = ScalarFromJSON(boolean(), "null"); + auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]"); + auto cond2 = ArrayFromJSON(boolean(), "[true, false, true, null]"); + auto scalar_null = ScalarFromJSON(type, "null"); + auto scalar1 = ScalarFromJSON(type, R"(["aB", "xYz"])"); + auto scalar2 = ScalarFromJSON(type, R"(["b", null])"); + auto values_null = ArrayFromJSON(type, "[null, null, null, null]"); + auto values1 = + ArrayFromJSON(type, R"([["cD", "E"], null, ["de", "gfhi"], ["ef", "g"]])"); + auto values2 = + ArrayFromJSON(type, R"([["fghi", "jk"], ["ghi", null], null, [null, "hi"]])"); + + CheckScalar("case_when", {MakeStruct({}), values1}, values1); + CheckScalar("case_when", {MakeStruct({}), values_null}, values_null); + + CheckScalar("case_when", {MakeStruct({cond_true}), scalar1, values1}, + *MakeArrayFromScalar(*scalar1, 4)); + CheckScalar("case_when", {MakeStruct({cond_false}), scalar1, values1}, values1); + + CheckScalar("case_when", {MakeStruct({cond_true}), values1}, values1); + CheckScalar("case_when", {MakeStruct({cond_false}), values1}, values_null); + CheckScalar("case_when", {MakeStruct({cond_null}), values1}, values_null); + CheckScalar("case_when", {MakeStruct({cond_true}), values1, values2}, values1); + CheckScalar("case_when", {MakeStruct({cond_false}), values1, values2}, values2); + CheckScalar("case_when", {MakeStruct({cond_null}), values1, values2}, values2); + + CheckScalar("case_when", {MakeStruct({cond_true, cond_true}), values1, values2}, + values1); + CheckScalar("case_when", {MakeStruct({cond_false, cond_false}), values1, values2}, + values_null); + CheckScalar("case_when", {MakeStruct({cond_true, cond_false}), values1, values2}, + values1); + CheckScalar("case_when", {MakeStruct({cond_false, cond_true}), values1, values2}, + values2); + CheckScalar("case_when", {MakeStruct({cond_null, cond_true}), values1, values2}, + values2); + CheckScalar("case_when", + {MakeStruct({cond_false, cond_false}), values1, values2, values2}, values2); + + CheckScalar( + "case_when", {MakeStruct({cond1, cond2}), scalar1, scalar2}, + ArrayFromJSON(type, R"([["aB", "xYz"], ["aB", "xYz"], ["b", null], null])")); + CheckScalar("case_when", {MakeStruct({cond1}), scalar_null}, values_null); + CheckScalar("case_when", {MakeStruct({cond1}), scalar_null, scalar1}, + ArrayFromJSON(type, R"([null, null, ["aB", "xYz"], ["aB", "xYz"]])")); + CheckScalar("case_when", {MakeStruct({cond1, cond2}), scalar1, scalar2, scalar1}, + ArrayFromJSON( + type, R"([["aB", "xYz"], ["aB", "xYz"], ["b", null], ["aB", "xYz"]])")); + + CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2}, + ArrayFromJSON(type, R"([["cD", "E"], null, null, null])")); + CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2, values1}, + ArrayFromJSON(type, R"([["cD", "E"], null, null, ["ef", "g"]])")); + CheckScalar("case_when", {MakeStruct({cond1, cond2}), values_null, values2, values1}, + ArrayFromJSON(type, R"([null, null, null, ["ef", "g"]])")); +} + +TEST(TestCaseWhen, StructOfInt) { + auto type = struct_({field("a", uint32()), field("b", int64())}); + auto cond_true = ScalarFromJSON(boolean(), "true"); + auto cond_false = ScalarFromJSON(boolean(), "false"); + auto cond_null = ScalarFromJSON(boolean(), "null"); + auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]"); + auto cond2 = ArrayFromJSON(boolean(), "[true, false, true, null]"); + auto scalar_null = ScalarFromJSON(type, "null"); + auto scalar1 = ScalarFromJSON(type, R"([1, -2])"); + auto scalar2 = ScalarFromJSON(type, R"([null, 3])"); + auto values_null = ArrayFromJSON(type, "[null, null, null, null]"); + auto values1 = ArrayFromJSON(type, R"([[4, null], null, [5, -6], [7, -8]])"); + auto values2 = ArrayFromJSON(type, R"([[9, 10], [11, -12], null, [null, null]])"); + + CheckScalar("case_when", {MakeStruct({}), values1}, values1); + CheckScalar("case_when", {MakeStruct({}), values_null}, values_null); + + CheckScalar("case_when", {MakeStruct({cond_true}), scalar1, values1}, + *MakeArrayFromScalar(*scalar1, 4)); + CheckScalar("case_when", {MakeStruct({cond_false}), scalar1, values1}, values1); + + CheckScalar("case_when", {MakeStruct({cond_true}), values1}, values1); + CheckScalar("case_when", {MakeStruct({cond_false}), values1}, values_null); + CheckScalar("case_when", {MakeStruct({cond_null}), values1}, values_null); + CheckScalar("case_when", {MakeStruct({cond_true}), values1, values2}, values1); + CheckScalar("case_when", {MakeStruct({cond_false}), values1, values2}, values2); + CheckScalar("case_when", {MakeStruct({cond_null}), values1, values2}, values2); + + CheckScalar("case_when", {MakeStruct({cond_true, cond_true}), values1, values2}, + values1); + CheckScalar("case_when", {MakeStruct({cond_false, cond_false}), values1, values2}, + values_null); + CheckScalar("case_when", {MakeStruct({cond_true, cond_false}), values1, values2}, + values1); + CheckScalar("case_when", {MakeStruct({cond_false, cond_true}), values1, values2}, + values2); + CheckScalar("case_when", {MakeStruct({cond_null, cond_true}), values1, values2}, + values2); + CheckScalar("case_when", + {MakeStruct({cond_false, cond_false}), values1, values2, values2}, values2); + + CheckScalar("case_when", {MakeStruct({cond1, cond2}), scalar1, scalar2}, + ArrayFromJSON(type, R"([[1, -2], [1, -2], [null, 3], null])")); + CheckScalar("case_when", {MakeStruct({cond1}), scalar_null}, values_null); + CheckScalar("case_when", {MakeStruct({cond1}), scalar_null, scalar1}, + ArrayFromJSON(type, R"([null, null, [1, -2], [1, -2]])")); + CheckScalar("case_when", {MakeStruct({cond1, cond2}), scalar1, scalar2, scalar1}, + ArrayFromJSON(type, R"([[1, -2], [1, -2], [null, 3], [1, -2]])")); + + CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2}, + ArrayFromJSON(type, R"([[4, null], null, null, null])")); + CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2, values1}, + ArrayFromJSON(type, R"([[4, null], null, null, [7, -8]])")); + CheckScalar("case_when", {MakeStruct({cond1, cond2}), values_null, values2, values1}, + ArrayFromJSON(type, R"([null, null, null, [7, -8]])")); +} + +TEST(TestCaseWhen, StructOfString) { + // More minimal test to check type coverage + auto type = struct_({field("a", utf8()), field("b", large_utf8())}); + auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]"); + auto cond2 = ArrayFromJSON(boolean(), "[true, false, true, null]"); + auto scalar_null = ScalarFromJSON(type, "null"); + auto scalar1 = ScalarFromJSON(type, R"(["a", "bc"])"); + auto scalar2 = ScalarFromJSON(type, R"([null, "d"])"); + auto values_null = ArrayFromJSON(type, "[null, null, null, null]"); + auto values1 = + ArrayFromJSON(type, R"([["efg", null], null, [null, null], [null, "hi"]])"); + auto values2 = + ArrayFromJSON(type, R"([["j", "k"], [null, "lmnop"], null, ["qr", "stu"]])"); + CheckScalar("case_when", {MakeStruct({cond1, cond2}), scalar1, scalar2}, + ArrayFromJSON(type, R"([["a", "bc"], ["a", "bc"], [null, "d"], null])")); + CheckScalar("case_when", {MakeStruct({cond1}), scalar_null}, values_null); + CheckScalar("case_when", {MakeStruct({cond1}), scalar_null, scalar1}, + ArrayFromJSON(type, R"([null, null, ["a", "bc"], ["a", "bc"]])")); + CheckScalar( + "case_when", {MakeStruct({cond1, cond2}), scalar1, scalar2, scalar1}, + ArrayFromJSON(type, R"([["a", "bc"], ["a", "bc"], [null, "d"], ["a", "bc"]])")); + + CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2}, + ArrayFromJSON(type, R"([["efg", null], null, null, null])")); + CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2, values1}, + ArrayFromJSON(type, R"([["efg", null], null, null, [null, "hi"]])")); + CheckScalar("case_when", {MakeStruct({cond1, cond2}), values_null, values2, values1}, + ArrayFromJSON(type, R"([null, null, null, [null, "hi"]])")); +} + +TEST(TestCaseWhen, StructOfListOfInt) { + // More minimal test to check type coverage + auto type = struct_({field("a", utf8()), field("b", list(int64()))}); + auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]"); + auto cond2 = ArrayFromJSON(boolean(), "[true, false, true, null]"); + auto scalar_null = ScalarFromJSON(type, "null"); + auto scalar1 = ScalarFromJSON(type, R"([null, [1, null]])"); + auto scalar2 = ScalarFromJSON(type, R"(["b", null])"); + auto values_null = ArrayFromJSON(type, "[null, null, null, null]"); + auto values1 = + ArrayFromJSON(type, R"([["efg", null], null, [null, null], [null, [null, 1]]])"); + auto values2 = + ArrayFromJSON(type, R"([["j", [2, 3]], [null, [4, 5, 6]], null, ["qr", [7]]])"); + CheckScalar("case_when", {MakeStruct({cond1, cond2}), scalar1, scalar2}, + ArrayFromJSON( + type, R"([[null, [1, null]], [null, [1, null]], ["b", null], null])")); + CheckScalar("case_when", {MakeStruct({cond1}), scalar_null}, values_null); + CheckScalar( + "case_when", {MakeStruct({cond1}), scalar_null, scalar1}, + ArrayFromJSON(type, R"([null, null, [null, [1, null]], [null, [1, null]]])")); + CheckScalar( + "case_when", {MakeStruct({cond1, cond2}), scalar1, scalar2, scalar1}, + ArrayFromJSON( + type, + R"([[null, [1, null]], [null, [1, null]], ["b", null], [null, [1, null]]])")); + + CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2}, + ArrayFromJSON(type, R"([["efg", null], null, null, null])")); + CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2, values1}, + ArrayFromJSON(type, R"([["efg", null], null, null, [null, [null, 1]]])")); + CheckScalar("case_when", {MakeStruct({cond1, cond2}), values_null, values2, values1}, + ArrayFromJSON(type, R"([null, null, null, [null, [null, 1]]])")); +} + +TEST(TestCaseWhen, UnionBoolString) { + for (const auto& type : std::vector>{ + sparse_union({field("a", boolean()), field("b", utf8())}, {2, 7}), + dense_union({field("a", boolean()), field("b", utf8())}, {2, 7})}) { + ARROW_SCOPED_TRACE(type->ToString()); + auto cond_true = ScalarFromJSON(boolean(), "true"); + auto cond_false = ScalarFromJSON(boolean(), "false"); + auto cond_null = ScalarFromJSON(boolean(), "null"); + auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]"); + auto cond2 = ArrayFromJSON(boolean(), "[true, false, true, null]"); + auto scalar_null = ScalarFromJSON(type, "null"); + auto scalar1 = ScalarFromJSON(type, R"([2, null])"); + auto scalar2 = ScalarFromJSON(type, R"([7, "foo"])"); + auto values_null = ArrayFromJSON(type, "[null, null, null, null]"); + auto values1 = ArrayFromJSON(type, R"([[2, true], null, [7, "bar"], [7, "baz"]])"); + auto values2 = ArrayFromJSON(type, R"([[7, "spam"], [2, null], null, [7, null]])"); + + CheckScalar("case_when", {MakeStruct({}), values1}, values1); + CheckScalar("case_when", {MakeStruct({}), values_null}, values_null); + + CheckScalar("case_when", {MakeStruct({cond_true}), scalar1, values1}, + *MakeArrayFromScalar(*scalar1, 4)); + CheckScalar("case_when", {MakeStruct({cond_false}), scalar1, values1}, values1); + + CheckScalar("case_when", {MakeStruct({cond_true}), values1}, values1); + CheckScalar("case_when", {MakeStruct({cond_false}), values1}, values_null); + CheckScalar("case_when", {MakeStruct({cond_null}), values1}, values_null); + CheckScalar("case_when", {MakeStruct({cond_true}), values1, values2}, values1); + CheckScalar("case_when", {MakeStruct({cond_false}), values1, values2}, values2); + CheckScalar("case_when", {MakeStruct({cond_null}), values1, values2}, values2); + + CheckScalar("case_when", {MakeStruct({cond_true, cond_true}), values1, values2}, + values1); + CheckScalar("case_when", {MakeStruct({cond_false, cond_false}), values1, values2}, + values_null); + CheckScalar("case_when", {MakeStruct({cond_true, cond_false}), values1, values2}, + values1); + CheckScalar("case_when", {MakeStruct({cond_false, cond_true}), values1, values2}, + values2); + CheckScalar("case_when", {MakeStruct({cond_null, cond_true}), values1, values2}, + values2); + CheckScalar("case_when", + {MakeStruct({cond_false, cond_false}), values1, values2, values2}, + values2); + + CheckScalar("case_when", {MakeStruct({cond1, cond2}), scalar1, scalar2}, + ArrayFromJSON(type, R"([[2, null], [2, null], [7, "foo"], null])")); + CheckScalar("case_when", {MakeStruct({cond1}), scalar_null}, values_null); + CheckScalar("case_when", {MakeStruct({cond1}), scalar_null, scalar1}, + ArrayFromJSON(type, R"([null, null, [2, null], [2, null]])")); + CheckScalar("case_when", {MakeStruct({cond1, cond2}), scalar1, scalar2, scalar1}, + ArrayFromJSON(type, R"([[2, null], [2, null], [7, "foo"], [2, null]])")); + + CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2}, + ArrayFromJSON(type, R"([[2, true], null, null, null])")); + CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2, values1}, + ArrayFromJSON(type, R"([[2, true], null, null, [7, "baz"]])")); + CheckScalar("case_when", {MakeStruct({cond1, cond2}), values_null, values2, values1}, + ArrayFromJSON(type, R"([null, null, null, [7, "baz"]])")); + } +} + TEST(TestCaseWhen, DispatchBest) { CheckDispatchBest("case_when", {struct_({field("", boolean())}), int64(), int32()}, {struct_({field("", boolean())}), int64(), int64()}); diff --git a/cpp/src/arrow/testing/gtest_util.h b/cpp/src/arrow/testing/gtest_util.h index 3f9408ecdcb..a58064e4261 100644 --- a/cpp/src/arrow/testing/gtest_util.h +++ b/cpp/src/arrow/testing/gtest_util.h @@ -170,6 +170,8 @@ using BinaryArrowTypes = using StringArrowTypes = ::testing::Types; +using ListArrowTypes = ::testing::Types; + using UnionArrowTypes = ::testing::Types; class Array; diff --git a/cpp/src/arrow/type_traits.h b/cpp/src/arrow/type_traits.h index 1e04821c019..f94ba513e6a 100644 --- a/cpp/src/arrow/type_traits.h +++ b/cpp/src/arrow/type_traits.h @@ -1005,6 +1005,17 @@ static inline bool is_nested(Type::type type_id) { return false; } +static inline bool is_union(Type::type type_id) { + switch (type_id) { + case Type::SPARSE_UNION: + case Type::DENSE_UNION: + return true; + default: + break; + } + return false; +} + static inline int offset_bit_width(Type::type type_id) { switch (type_id) { case Type::STRING: diff --git a/docs/source/cpp/compute.rst b/docs/source/cpp/compute.rst index ade2cdaa7d5..6de5a6ed9a7 100644 --- a/docs/source/cpp/compute.rst +++ b/docs/source/cpp/compute.rst @@ -902,7 +902,7 @@ Structural transforms +--------------------------+------------+---------------------------------------------------+---------------------+---------+ | Function name | Arity | Input types | Output type | Notes | +==========================+============+===================================================+=====================+=========+ -| case_when | Varargs | Struct of Boolean (Arg 0), Any fixed-width (rest) | Input type | \(1) | +| case_when | Varargs | Struct of Boolean (Arg 0), Any (rest) | Input type | \(1) | +--------------------------+------------+---------------------------------------------------+---------------------+---------+ | choose | Varargs | Integral (Arg 0); Fixed-width/Binary-like (rest) | Input type | \(2) | +--------------------------+------------+---------------------------------------------------+---------------------+---------+ @@ -936,6 +936,9 @@ Structural transforms the first value datum for which the corresponding Boolean is true, or the corresponding value from the 'default' input, or null otherwise. + Note that currently, while all types are supported, dictionaries will be + unpacked. + * \(2) The first input must be an integral type. The rest of the arguments can be any type, but must all be the same type or promotable to a common type. Each value of the first input (the 'index') is used as a zero-based index into the