diff --git a/cpp/src/arrow/array/array_test.cc b/cpp/src/arrow/array/array_test.cc index 602a468fafb..555e40b7b30 100644 --- a/cpp/src/arrow/array/array_test.cc +++ b/cpp/src/arrow/array/array_test.cc @@ -59,6 +59,7 @@ #include "arrow/util/bitmap_builders.h" #include "arrow/util/checked_cast.h" #include "arrow/util/decimal.h" +#include "arrow/util/key_value_metadata.h" #include "arrow/util/macros.h" #include "arrow/util/range.h" #include "arrow/visit_data_inline.h" @@ -366,13 +367,12 @@ TEST_F(TestArray, BuildLargeInMemoryArray) { ASSERT_EQ(length, result->length()); } -TEST_F(TestArray, TestMakeArrayOfNull) { +static std::vector> TestArrayUtilitiesAgainstTheseTypes() { FieldVector union_fields1({field("a", utf8()), field("b", int32())}); FieldVector union_fields2({field("a", null()), field("b", list(large_utf8()))}); std::vector union_type_codes{7, 42}; - std::shared_ptr types[] = { - // clang-format off + return { null(), boolean(), int8(), @@ -387,7 +387,7 @@ TEST_F(TestArray, TestMakeArrayOfNull) { utf8(), large_utf8(), list(utf8()), - list(int64()), // ARROW-9071 + list(int64()), // NOTE: Regression case for ARROW-9071/MakeArrayOfNull large_list(large_utf8()), fixed_size_list(utf8(), 3), fixed_size_list(int64(), 4), @@ -397,13 +397,15 @@ TEST_F(TestArray, TestMakeArrayOfNull) { sparse_union(union_fields2, union_type_codes), dense_union(union_fields1, union_type_codes), dense_union(union_fields2, union_type_codes), - smallint(), // extension type - list_extension_type(), // nested extension type - // clang-format on + smallint(), // extension type + list_extension_type(), // nested extension type + run_end_encoded(int16(), utf8()), }; +} +TEST_F(TestArray, TestMakeArrayOfNull) { for (int64_t length : {0, 1, 16, 133}) { - for (auto type : types) { + for (auto type : TestArrayUtilitiesAgainstTheseTypes()) { ARROW_SCOPED_TRACE("type = ", type->ToString()); ASSERT_OK_AND_ASSIGN(auto array, MakeArrayOfNull(type, length)); ASSERT_EQ(array->type(), type); @@ -716,36 +718,7 @@ void CheckSpanRoundTrip(const Array& array) { } TEST_F(TestArray, TestMakeEmptyArray) { - FieldVector union_fields1({field("a", utf8()), field("b", int32())}); - FieldVector union_fields2({field("a", null()), field("b", list(large_utf8()))}); - std::vector union_type_codes{7, 42}; - - std::shared_ptr types[] = {null(), - boolean(), - int8(), - uint16(), - int32(), - uint64(), - float64(), - binary(), - large_binary(), - fixed_size_binary(3), - decimal(16, 4), - utf8(), - large_utf8(), - list(utf8()), - list(int64()), - large_list(large_utf8()), - fixed_size_list(utf8(), 3), - fixed_size_list(int64(), 4), - dictionary(int32(), utf8()), - struct_({field("a", utf8()), field("b", int32())}), - sparse_union(union_fields1, union_type_codes), - sparse_union(union_fields2, union_type_codes), - dense_union(union_fields1, union_type_codes), - dense_union(union_fields2, union_type_codes)}; - - for (auto type : types) { + for (auto type : TestArrayUtilitiesAgainstTheseTypes()) { ARROW_SCOPED_TRACE("type = ", type->ToString()); ASSERT_OK_AND_ASSIGN(auto array, MakeEmptyArray(type)); ASSERT_OK(array->ValidateFull()); @@ -754,6 +727,29 @@ TEST_F(TestArray, TestMakeEmptyArray) { } } +TEST_F(TestArray, TestFillFromScalar) { + for (auto type : TestArrayUtilitiesAgainstTheseTypes()) { + ARROW_SCOPED_TRACE("type = ", type->ToString()); + for (auto seed : {0u, 0xdeadbeef, 42u}) { + ARROW_SCOPED_TRACE("seed = ", seed); + + Field field("", type, /*nullable=*/true, + key_value_metadata({{"extension_allow_random_storage", "true"}})); + auto array = random::GenerateArray(field, 1, seed); + + ASSERT_OK_AND_ASSIGN(auto scalar, array->GetScalar(0)); + + ArraySpan span(*scalar); + auto roundtripped_array = span.ToArray(); + AssertArraysEqual(*array, *roundtripped_array); + + ASSERT_OK(roundtripped_array->ValidateFull()); + ASSERT_OK_AND_ASSIGN(auto roundtripped_scalar, roundtripped_array->GetScalar(0)); + AssertScalarsEqual(*scalar, *roundtripped_scalar); + } + } +} + TEST_F(TestArray, ExtensionSpanRoundTrip) { // Other types are checked in MakeEmptyArray but MakeEmptyArray doesn't // work for extension types so we check that here diff --git a/cpp/src/arrow/array/data.cc b/cpp/src/arrow/array/data.cc index 8764e9c354c..79595ab7c7c 100644 --- a/cpp/src/arrow/array/data.cc +++ b/cpp/src/arrow/array/data.cc @@ -130,7 +130,8 @@ std::shared_ptr ArrayData::Make(std::shared_ptr type, int64 } std::shared_ptr ArrayData::Slice(int64_t off, int64_t len) const { - ARROW_CHECK_LE(off, length) << "Slice offset greater than array length"; + ARROW_CHECK_LE(off, length) << "Slice offset (" << off + << ") greater than array length (" << length << ")"; len = std::min(length - off, len); off += offset; @@ -228,12 +229,11 @@ void ArraySpan::SetMembers(const ArrayData& data) { namespace { template -void SetOffsetsForScalar(ArraySpan* span, offset_type* buffer, int64_t value_size, - int buffer_index = 1) { - buffer[0] = 0; - buffer[1] = static_cast(value_size); - span->buffers[buffer_index].data = reinterpret_cast(buffer); - span->buffers[buffer_index].size = 2 * sizeof(offset_type); +BufferSpan OffsetsForScalar(uint8_t* scratch_space, offset_type value_size) { + auto* offsets = reinterpret_cast(scratch_space); + offsets[0] = 0; + offsets[1] = static_cast(value_size); + return {scratch_space, sizeof(offset_type) * 2}; } int GetNumBuffers(const DataType& type) { @@ -241,9 +241,8 @@ int GetNumBuffers(const DataType& type) { case Type::NA: case Type::STRUCT: case Type::FIXED_SIZE_LIST: - return 1; case Type::RUN_END_ENCODED: - return 0; + return 1; case Type::BINARY: case Type::LARGE_BINARY: case Type::STRING: @@ -265,16 +264,19 @@ int GetNumBuffers(const DataType& type) { namespace internal { void FillZeroLengthArray(const DataType* type, ArraySpan* span) { - memset(span->scratch_space, 0x00, sizeof(span->scratch_space)); - span->type = type; span->length = 0; int num_buffers = GetNumBuffers(*type); for (int i = 0; i < num_buffers; ++i) { - span->buffers[i].data = reinterpret_cast(span->scratch_space); + alignas(int64_t) static std::array kZeros{0}; + span->buffers[i].data = kZeros.data(); span->buffers[i].size = 0; } + if (!HasValidityBitmap(type->id())) { + span->buffers[0] = {}; + } + for (int i = num_buffers; i < 3; ++i) { span->buffers[i] = {}; } @@ -304,9 +306,13 @@ void ArraySpan::FillFromScalar(const Scalar& value) { Type::type type_id = value.type->id(); - // Populate null count and validity bitmap (only for non-union/null types) - this->null_count = value.is_valid ? 0 : 1; - if (!is_union(type_id) && type_id != Type::NA) { + if (type_id == Type::NA) { + this->null_count = 1; + } else if (!internal::HasValidityBitmap(type_id)) { + this->null_count = 0; + } else { + // Populate null count and validity bitmap + this->null_count = value.is_valid ? 0 : 1; this->buffers[0].data = value.is_valid ? &kTrueBit : &kFalseBit; this->buffers[0].size = 1; } @@ -329,7 +335,7 @@ void ArraySpan::FillFromScalar(const Scalar& value) { } } else if (is_base_binary_like(type_id)) { const auto& scalar = checked_cast(value); - this->buffers[1].data = reinterpret_cast(this->scratch_space); + const uint8_t* data_buffer = nullptr; int64_t data_size = 0; if (scalar.is_valid) { @@ -337,12 +343,11 @@ void ArraySpan::FillFromScalar(const Scalar& value) { data_size = scalar.value->size(); } if (is_binary_like(type_id)) { - SetOffsetsForScalar(this, reinterpret_cast(this->scratch_space), - data_size); + this->buffers[1] = + OffsetsForScalar(scalar.scratch_space_, static_cast(data_size)); } else { // is_large_binary_like - SetOffsetsForScalar(this, reinterpret_cast(this->scratch_space), - data_size); + this->buffers[1] = OffsetsForScalar(scalar.scratch_space_, data_size); } this->buffers[2].data = const_cast(data_buffer); this->buffers[2].size = data_size; @@ -367,11 +372,10 @@ void ArraySpan::FillFromScalar(const Scalar& value) { } if (type_id == Type::LIST || type_id == Type::MAP) { - SetOffsetsForScalar(this, reinterpret_cast(this->scratch_space), - value_length); + this->buffers[1] = + OffsetsForScalar(scalar.scratch_space_, static_cast(value_length)); } else if (type_id == Type::LARGE_LIST) { - SetOffsetsForScalar(this, reinterpret_cast(this->scratch_space), - value_length); + this->buffers[1] = OffsetsForScalar(scalar.scratch_space_, value_length); } else { // FIXED_SIZE_LIST: does not have a second buffer this->buffers[1] = {}; @@ -384,26 +388,31 @@ void ArraySpan::FillFromScalar(const Scalar& value) { this->child_data[i].FillFromScalar(*scalar.value[i]); } } else if (is_union(type_id)) { + // Dense union needs scratch space to store both offsets and a type code + struct UnionScratchSpace { + alignas(int64_t) int8_t type_code; + alignas(int64_t) uint8_t offsets[sizeof(int32_t) * 2]; + }; + static_assert(sizeof(UnionScratchSpace) <= sizeof(UnionScalar::scratch_space_)); + auto* union_scratch_space = reinterpret_cast( + &checked_cast(value).scratch_space_); + // First buffer is kept null since unions have no validity vector this->buffers[0] = {}; - this->buffers[1].data = reinterpret_cast(this->scratch_space); + union_scratch_space->type_code = checked_cast(value).type_code; + this->buffers[1].data = reinterpret_cast(&union_scratch_space->type_code); this->buffers[1].size = 1; - int8_t* type_codes = reinterpret_cast(this->scratch_space); - type_codes[0] = checked_cast(value).type_code; this->child_data.resize(this->type->num_fields()); if (type_id == Type::DENSE_UNION) { const auto& scalar = checked_cast(value); - // Has offset; start 4 bytes in so it's aligned to a 32-bit boundaries - SetOffsetsForScalar(this, - reinterpret_cast(this->scratch_space) + 1, 1, - /*buffer_index=*/2); + this->buffers[2] = + OffsetsForScalar(union_scratch_space->offsets, static_cast(1)); // We can't "see" the other arrays in the union, but we put the "active" // union array in the right place and fill zero-length arrays for the // others - const std::vector& child_ids = - checked_cast(this->type)->child_ids(); + const auto& child_ids = checked_cast(this->type)->child_ids(); DCHECK_GE(scalar.type_code, 0); DCHECK_LT(scalar.type_code, static_cast(child_ids.size())); for (int i = 0; i < static_cast(this->child_data.size()); ++i) { @@ -429,6 +438,32 @@ void ArraySpan::FillFromScalar(const Scalar& value) { // Restore the extension type this->type = value.type.get(); + } else if (type_id == Type::RUN_END_ENCODED) { + const auto& scalar = checked_cast(value); + this->child_data.resize(2); + + auto set_run_end = [&](auto run_end) { + auto& e = this->child_data[0]; + e.type = scalar.run_end_type().get(); + e.length = 1; + e.null_count = 0; + e.buffers[1].data = scalar.scratch_space_; + e.buffers[1].size = sizeof(run_end); + reinterpret_cast(scalar.scratch_space_)[0] = run_end; + }; + + switch (scalar.run_end_type()->id()) { + case Type::INT16: + set_run_end(static_cast(1)); + break; + case Type::INT32: + set_run_end(static_cast(1)); + break; + default: + DCHECK_EQ(scalar.run_end_type()->id(), Type::INT64); + set_run_end(static_cast(1)); + } + this->child_data[1].FillFromScalar(*scalar.value); } else { DCHECK_EQ(Type::NA, type_id) << "should be unreachable: " << *value.type; } diff --git a/cpp/src/arrow/array/data.h b/cpp/src/arrow/array/data.h index 82a6e733727..715c233fcf2 100644 --- a/cpp/src/arrow/array/data.h +++ b/cpp/src/arrow/array/data.h @@ -372,11 +372,6 @@ struct ARROW_EXPORT ArraySpan { int64_t offset = 0; BufferSpan buffers[3]; - // 16 bytes of scratch space to enable this ArraySpan to be a view onto - // scalar values including binary scalars (where we need to create a buffer - // that looks like two 32-bit or 64-bit offsets) - uint64_t scratch_space[2]; - ArraySpan() = default; explicit ArraySpan(const DataType* type, int64_t length) : type(type), length(length) {} diff --git a/cpp/src/arrow/array/util.cc b/cpp/src/arrow/array/util.cc index d7a8783d442..e84ab404ad6 100644 --- a/cpp/src/arrow/array/util.cc +++ b/cpp/src/arrow/array/util.cc @@ -554,13 +554,18 @@ class NullArrayFactory { } Status Visit(const RunEndEncodedType& type) { - ARROW_ASSIGN_OR_RAISE(auto values, MakeArrayOfNull(type.value_type(), 1, pool_)); - ARROW_ASSIGN_OR_RAISE(auto run_end_scalar, - MakeScalarForRunEndValue(*type.run_end_type(), length_)); - ARROW_ASSIGN_OR_RAISE(auto run_ends, MakeArrayFromScalar(*run_end_scalar, 1, pool_)); - ARROW_ASSIGN_OR_RAISE(auto ree_array, - RunEndEncodedArray::Make(length_, run_ends, values)); - out_ = ree_array->data(); + std::shared_ptr run_ends, values; + if (length_ == 0) { + ARROW_ASSIGN_OR_RAISE(run_ends, MakeEmptyArray(type.run_end_type(), pool_)); + ARROW_ASSIGN_OR_RAISE(values, MakeEmptyArray(type.value_type(), pool_)); + } else { + ARROW_ASSIGN_OR_RAISE(auto run_end_scalar, + MakeScalarForRunEndValue(*type.run_end_type(), length_)); + ARROW_ASSIGN_OR_RAISE(run_ends, MakeArrayFromScalar(*run_end_scalar, 1, pool_)); + ARROW_ASSIGN_OR_RAISE(values, MakeArrayOfNull(type.value_type(), 1, pool_)); + } + out_->child_data[0] = run_ends->data(); + out_->child_data[1] = values->data(); return Status::OK(); } @@ -582,7 +587,7 @@ class NullArrayFactory { } MemoryPool* pool_; - std::shared_ptr type_; + const std::shared_ptr& type_; int64_t length_; std::shared_ptr out_; std::shared_ptr buffer_; @@ -859,6 +864,13 @@ Result> MakeArrayFromScalar(const Scalar& scalar, int64_t Result> MakeEmptyArray(std::shared_ptr type, MemoryPool* memory_pool) { + if (type->id() == Type::EXTENSION) { + const auto& ext_type = checked_cast(*type); + ARROW_ASSIGN_OR_RAISE(auto storage, + MakeEmptyArray(ext_type.storage_type(), memory_pool)); + storage->data()->type = std::move(type); + return ext_type.MakeArray(storage->data()); + } std::unique_ptr builder; RETURN_NOT_OK(MakeBuilder(memory_pool, type, &builder)); RETURN_NOT_OK(builder->Resize(0)); diff --git a/cpp/src/arrow/compute/expression_test.cc b/cpp/src/arrow/compute/expression_test.cc index d57604583e8..f90e01a2f81 100644 --- a/cpp/src/arrow/compute/expression_test.cc +++ b/cpp/src/arrow/compute/expression_test.cc @@ -31,7 +31,9 @@ #include "arrow/compute/function_internal.h" #include "arrow/compute/registry.h" #include "arrow/testing/gtest_util.h" +#include "arrow/testing/matchers.h" +using testing::Eq; using testing::HasSubstr; using testing::UnorderedElementsAreArray; @@ -78,27 +80,11 @@ Expression add(Expression l, Expression r) { return call("add", {std::move(l), std::move(r)}); } -template -void ExpectResultsEqual(Actual&& actual, Expected&& expected) { - using MaybeActual = typename EnsureResult::type>::type; - using MaybeExpected = typename EnsureResult::type>::type; - - MaybeActual maybe_actual(std::forward(actual)); - MaybeExpected maybe_expected(std::forward(expected)); - - if (maybe_expected.ok()) { - EXPECT_EQ(maybe_actual, maybe_expected); - } else { - EXPECT_RAISES_WITH_CODE_AND_MESSAGE_THAT( - expected.status().code(), HasSubstr(expected.status().message()), maybe_actual); - } -} - const auto no_change = std::nullopt; TEST(ExpressionUtils, Comparison) { - auto Expect = [](Result expected, Datum l, Datum r) { - ExpectResultsEqual(Comparison::Execute(l, r).Map(Comparison::GetName), expected); + auto cmp_name = [](Datum l, Datum r) { + return Comparison::Execute(l, r).Map(Comparison::GetName); }; Datum zero(0), one(1), two(2), null(std::make_shared()); @@ -106,27 +92,28 @@ TEST(ExpressionUtils, Comparison) { Datum dict_str(DictionaryScalar::Make(std::make_shared(0), ArrayFromJSON(utf8(), R"(["a", "b", "c"])"))); - Status not_impl = Status::NotImplemented("no kernel matching input types"); + auto RaisesNotImpl = + Raises(StatusCode::NotImplemented, HasSubstr("no kernel matching input types")); - Expect("equal", one, one); - Expect("less", one, two); - Expect("greater", one, zero); + EXPECT_THAT(cmp_name(one, one), ResultWith(Eq("equal"))); + EXPECT_THAT(cmp_name(one, two), ResultWith(Eq("less"))); + EXPECT_THAT(cmp_name(one, zero), ResultWith(Eq("greater"))); - Expect("na", one, null); - Expect("na", null, one); + EXPECT_THAT(cmp_name(one, null), ResultWith(Eq("na"))); + EXPECT_THAT(cmp_name(null, one), ResultWith(Eq("na"))); // strings and ints are not comparable without explicit casts - Expect(not_impl, str, one); - Expect(not_impl, one, str); - Expect(not_impl, str, null); // not even null ints + EXPECT_THAT(cmp_name(str, one), RaisesNotImpl); + EXPECT_THAT(cmp_name(one, str), RaisesNotImpl); + EXPECT_THAT(cmp_name(str, null), RaisesNotImpl); // not even null ints // string -> binary implicit cast allowed - Expect("equal", str, bin); - Expect("equal", bin, str); + EXPECT_THAT(cmp_name(str, bin), ResultWith(Eq("equal"))); + EXPECT_THAT(cmp_name(bin, str), ResultWith(Eq("equal"))); // dict_str -> string, implicit casts allowed - Expect("less", dict_str, str); - Expect("less", dict_str, bin); + EXPECT_THAT(cmp_name(dict_str, str), ResultWith(Eq("less"))); + EXPECT_THAT(cmp_name(dict_str, bin), ResultWith(Eq("less"))); } TEST(ExpressionUtils, StripOrderPreservingCasts) { diff --git a/cpp/src/arrow/compute/kernels/vector_nested_test.cc b/cpp/src/arrow/compute/kernels/vector_nested_test.cc index 277f8169bd2..eef1b6835ff 100644 --- a/cpp/src/arrow/compute/kernels/vector_nested_test.cc +++ b/cpp/src/arrow/compute/kernels/vector_nested_test.cc @@ -51,10 +51,12 @@ TEST(TestVectorNested, ListFlattenNulls) { TEST(TestVectorNested, ListFlattenChunkedArray) { for (auto ty : {list(int16()), large_list(int16())}) { + ARROW_SCOPED_TRACE(ty->ToString()); auto input = ChunkedArrayFromJSON(ty, {"[[0, null, 1], null]", "[[2, 3], []]"}); auto expected = ChunkedArrayFromJSON(int16(), {"[0, null, 1]", "[2, 3]"}); CheckVectorUnary("list_flatten", input, expected); + ARROW_SCOPED_TRACE("empty"); input = ChunkedArrayFromJSON(ty, {}); expected = ChunkedArrayFromJSON(int16(), {}); CheckVectorUnary("list_flatten", input, expected); diff --git a/cpp/src/arrow/scalar.h b/cpp/src/arrow/scalar.h index 0797306a674..1d1ce4aa729 100644 --- a/cpp/src/arrow/scalar.h +++ b/cpp/src/arrow/scalar.h @@ -131,6 +131,13 @@ struct ARROW_EXPORT NullScalar : public Scalar { namespace internal { +struct ARROW_EXPORT ArraySpanFillFromScalarScratchSpace { + // 16 bytes of scratch space to enable ArraySpan to be a view onto any + // Scalar- including binary scalars where we need to create a buffer + // that looks like two 32-bit or 64-bit offsets. + alignas(int64_t) mutable uint8_t scratch_space_[sizeof(int64_t) * 2]; +}; + struct ARROW_EXPORT PrimitiveScalarBase : public Scalar { explicit PrimitiveScalarBase(std::shared_ptr type) : Scalar(std::move(type), false) {} @@ -238,7 +245,9 @@ struct ARROW_EXPORT DoubleScalar : public NumericScalar { using NumericScalar::NumericScalar; }; -struct ARROW_EXPORT BaseBinaryScalar : public internal::PrimitiveScalarBase { +struct ARROW_EXPORT BaseBinaryScalar + : public internal::PrimitiveScalarBase, + private internal::ArraySpanFillFromScalarScratchSpace { using internal::PrimitiveScalarBase::PrimitiveScalarBase; using ValueType = std::shared_ptr; @@ -257,6 +266,8 @@ struct ARROW_EXPORT BaseBinaryScalar : public internal::PrimitiveScalarBase { protected: BaseBinaryScalar(std::shared_ptr value, std::shared_ptr type) : internal::PrimitiveScalarBase{std::move(type), true}, value(std::move(value)) {} + + friend ArraySpan; }; struct ARROW_EXPORT BinaryScalar : public BaseBinaryScalar { @@ -464,7 +475,9 @@ struct ARROW_EXPORT Decimal256Scalar : public DecimalScalar; @@ -472,6 +485,9 @@ struct ARROW_EXPORT BaseListScalar : public Scalar { bool is_valid = true); std::shared_ptr value; + + private: + friend struct ArraySpan; }; struct ARROW_EXPORT ListScalar : public BaseListScalar { @@ -519,7 +535,8 @@ struct ARROW_EXPORT StructScalar : public Scalar { std::vector field_names); }; -struct ARROW_EXPORT UnionScalar : public Scalar { +struct ARROW_EXPORT UnionScalar : public Scalar, + private internal::ArraySpanFillFromScalarScratchSpace { int8_t type_code; virtual const std::shared_ptr& child_value() const = 0; @@ -527,6 +544,8 @@ struct ARROW_EXPORT UnionScalar : public Scalar { protected: UnionScalar(std::shared_ptr type, int8_t type_code, bool is_valid) : Scalar(std::move(type), is_valid), type_code(type_code) {} + + friend struct ArraySpan; }; struct ARROW_EXPORT SparseUnionScalar : public UnionScalar { @@ -568,7 +587,9 @@ struct ARROW_EXPORT DenseUnionScalar : public UnionScalar { value(std::move(value)) {} }; -struct ARROW_EXPORT RunEndEncodedScalar : public Scalar { +struct ARROW_EXPORT RunEndEncodedScalar + : public Scalar, + private internal::ArraySpanFillFromScalarScratchSpace { using TypeClass = RunEndEncodedType; using ValueType = std::shared_ptr; @@ -589,6 +610,8 @@ struct ARROW_EXPORT RunEndEncodedScalar : public Scalar { private: const TypeClass& ree_type() const { return internal::checked_cast(*type); } + + friend ArraySpan; }; /// \brief A Scalar value for DictionaryType diff --git a/cpp/src/arrow/testing/random.cc b/cpp/src/arrow/testing/random.cc index b8ea247a437..b74c41f75e4 100644 --- a/cpp/src/arrow/testing/random.cc +++ b/cpp/src/arrow/testing/random.cc @@ -33,6 +33,7 @@ #include "arrow/array/builder_decimal.h" #include "arrow/array/builder_primitive.h" #include "arrow/buffer.h" +#include "arrow/extension_type.h" #include "arrow/record_batch.h" #include "arrow/testing/gtest_util.h" #include "arrow/type.h" @@ -935,14 +936,27 @@ std::shared_ptr RandomArrayGenerator::ArrayOf(const Field& field, int64_t case Type::type::SPARSE_UNION: case Type::type::DENSE_UNION: { ArrayVector child_arrays(field.type()->num_fields()); - for (int i = 0; i < field.type()->num_fields(); i++) { + for (int i = 0; i < field.type()->num_fields(); ++i) { const auto& child_field = field.type()->field(i); child_arrays[i] = ArrayOf(*child_field, length, alignment, memory_pool); } auto array = field.type()->id() == Type::type::SPARSE_UNION ? SparseUnion(child_arrays, length, alignment, memory_pool) : DenseUnion(child_arrays, length, alignment, memory_pool); - return *array->View(field.type()); + + const auto& type_codes = checked_cast(*field.type()).type_codes(); + const auto& default_type_codes = + checked_cast(*array->type()).type_codes(); + + if (type_codes != default_type_codes) { + // map to the type ids specified by the UnionType + auto* type_ids = + reinterpret_cast(array->data()->buffers[1]->mutable_data()); + for (int64_t i = 0; i != array->length(); ++i) { + type_ids[i] = type_codes[type_ids[i]]; + } + } + return *array->View(field.type()); // view gets the field names right for us } case Type::type::DICTIONARY: { @@ -982,8 +996,15 @@ std::shared_ptr RandomArrayGenerator::ArrayOf(const Field& field, int64_t } case Type::type::EXTENSION: - // Could be supported by generating the storage type (though any extension - // invariants wouldn't be preserved) + if (GetMetadata(field.metadata().get(), "extension_allow_random_storage", + false)) { + const auto& ext_type = checked_cast(*field.type()); + auto storage = ArrayOf(*field.WithType(ext_type.storage_type()), length, + alignment, memory_pool); + return ExtensionType::WrapArray(field.type(), storage); + } + // We don't have explicit permission to generate random storage; bail rather than + // silently risk breaking extension invariants break; case Type::type::FIXED_SIZE_LIST: { diff --git a/cpp/src/arrow/testing/random.h b/cpp/src/arrow/testing/random.h index 1bd189c39c2..de9ea6d0564 100644 --- a/cpp/src/arrow/testing/random.h +++ b/cpp/src/arrow/testing/random.h @@ -563,6 +563,13 @@ class ARROW_TESTING_EXPORT RandomArrayGenerator { /// For MapType: /// - values (int32_t): the number of key-value pairs to generate, which will be /// partitioned among the array values. + /// + /// For extension types: + /// - extension_allow_random_storage (bool): in general an extension array may have + /// invariants on its storage beyond those already imposed by the arrow format, + /// which may result in an invalid array if we just wrap randomly generated + /// storage. Set this flag to explicitly allow wrapping of randomly generated + /// storage. std::shared_ptr BatchOf( const FieldVector& fields, int64_t size, int64_t alignment = kDefaultBufferAlignment, @@ -575,7 +582,7 @@ class ARROW_TESTING_EXPORT RandomArrayGenerator { std::default_random_engine seed_rng_; }; -/// Generate an array with random data. See RandomArrayGenerator::BatchOf. +/// Generate a batch with random data. See RandomArrayGenerator::BatchOf. ARROW_TESTING_EXPORT std::shared_ptr GenerateBatch( const FieldVector& fields, int64_t size, SeedType seed, diff --git a/cpp/src/arrow/type_fwd.h b/cpp/src/arrow/type_fwd.h index 657abbaecc4..e10a3f33da3 100644 --- a/cpp/src/arrow/type_fwd.h +++ b/cpp/src/arrow/type_fwd.h @@ -68,6 +68,7 @@ using FieldVector = std::vector>; class Array; struct ArrayData; +struct ArraySpan; class ArrayBuilder; struct Scalar;