diff --git a/.github/workflows/cpp.yml b/.github/workflows/cpp.yml index 086f45d6fee..0f19f7351c3 100644 --- a/.github/workflows/cpp.yml +++ b/.github/workflows/cpp.yml @@ -238,7 +238,7 @@ jobs: name: AMD64 Windows MinGW ${{ matrix.mingw-n-bits }} C++ runs-on: windows-latest if: ${{ !contains(github.event.pull_request.title, 'WIP') }} - timeout-minutes: 45 + timeout-minutes: 60 strategy: fail-fast: false matrix: diff --git a/.github/workflows/r.yml b/.github/workflows/r.yml index e160ba8128a..3886eafee94 100644 --- a/.github/workflows/r.yml +++ b/.github/workflows/r.yml @@ -53,7 +53,7 @@ jobs: name: AMD64 Ubuntu ${{ matrix.ubuntu }} R ${{ matrix.r }} runs-on: ubuntu-latest if: ${{ !contains(github.event.pull_request.title, 'WIP') }} - timeout-minutes: 60 + timeout-minutes: 75 strategy: fail-fast: false matrix: diff --git a/ci/scripts/PKGBUILD b/ci/scripts/PKGBUILD index 56d70d83daf..246b679129a 100644 --- a/ci/scripts/PKGBUILD +++ b/ci/scripts/PKGBUILD @@ -80,9 +80,13 @@ build() { export LIBS="-L${MINGW_PREFIX}/libs" export ARROW_S3=OFF export ARROW_WITH_RE2=OFF + # Without this, some dataset functionality segfaults + export CMAKE_UNITY_BUILD=ON else export ARROW_S3=ON export ARROW_WITH_RE2=ON + # Without this, some compute functionality segfaults in tests + export CMAKE_UNITY_BUILD=OFF fi MSYS2_ARG_CONV_EXCL="-DCMAKE_INSTALL_PREFIX=" \ @@ -115,7 +119,7 @@ build() { -DARROW_CXXFLAGS="${CPPFLAGS}" \ -DCMAKE_BUILD_TYPE="release" \ -DCMAKE_INSTALL_PREFIX=${MINGW_PREFIX} \ - -DCMAKE_UNITY_BUILD=ON \ + -DCMAKE_UNITY_BUILD=${CMAKE_UNITY_BUILD} \ -DCMAKE_VERBOSE_MAKEFILE=ON make -j3 diff --git a/cpp/src/arrow/array/array_test.cc b/cpp/src/arrow/array/array_test.cc index d9617c4e603..2e3d4057094 100644 --- a/cpp/src/arrow/array/array_test.cc +++ b/cpp/src/arrow/array/array_test.cc @@ -456,7 +456,7 @@ TEST_F(TestArray, TestValidateNullCount) { void AssertAppendScalar(MemoryPool* pool, const std::shared_ptr& scalar) { std::unique_ptr builder; auto null_scalar = MakeNullScalar(scalar->type); - ASSERT_OK(MakeBuilder(pool, scalar->type, &builder)); + ASSERT_OK(MakeBuilderExactIndex(pool, scalar->type, &builder)); ASSERT_OK(builder->AppendScalar(*scalar)); ASSERT_OK(builder->AppendScalar(*scalar)); ASSERT_OK(builder->AppendScalar(*null_scalar)); @@ -471,15 +471,18 @@ void AssertAppendScalar(MemoryPool* pool, const std::shared_ptr& scalar) ASSERT_EQ(out->length(), 9); const bool can_check_nulls = internal::HasValidityBitmap(out->type()->id()); + // For a dictionary builder, the output dictionary won't necessarily be the same + const bool can_check_values = !is_dictionary(out->type()->id()); if (can_check_nulls) { ASSERT_EQ(out->null_count(), 4); } + for (const auto index : {0, 1, 3, 5, 6}) { ASSERT_FALSE(out->IsNull(index)); ASSERT_OK_AND_ASSIGN(auto scalar_i, out->GetScalar(index)); ASSERT_OK(scalar_i->ValidateFull()); - AssertScalarsEqual(*scalar, *scalar_i, /*verbose=*/true); + if (can_check_values) AssertScalarsEqual(*scalar, *scalar_i, /*verbose=*/true); } for (const auto index : {2, 4, 7, 8}) { ASSERT_EQ(out->IsNull(index), can_check_nulls); @@ -575,8 +578,6 @@ TEST_F(TestArray, TestMakeArrayFromScalar) { } for (auto scalar : scalars) { - // TODO(ARROW-13197): appending dictionary scalars not implemented - if (is_dictionary(scalar->type->id())) continue; AssertAppendScalar(pool_, scalar); } } @@ -634,9 +635,6 @@ TEST_F(TestArray, TestMakeArrayFromMapScalar) { TEST_F(TestArray, TestAppendArraySlice) { auto scalars = GetScalars(); for (const auto& scalar : scalars) { - // TODO(ARROW-13573): appending dictionary arrays not implemented - if (is_dictionary(scalar->type->id())) continue; - ARROW_SCOPED_TRACE(*scalar->type); ASSERT_OK_AND_ASSIGN(auto array, MakeArrayFromScalar(*scalar, 16)); ASSERT_OK_AND_ASSIGN(auto nulls, MakeArrayOfNull(scalar->type, 16)); diff --git a/cpp/src/arrow/array/builder_base.cc b/cpp/src/arrow/array/builder_base.cc index 2f4e63b546d..117b9d37632 100644 --- a/cpp/src/arrow/array/builder_base.cc +++ b/cpp/src/arrow/array/builder_base.cc @@ -22,6 +22,7 @@ #include #include "arrow/array/array_base.h" +#include "arrow/array/builder_dict.h" #include "arrow/array/data.h" #include "arrow/array/util.h" #include "arrow/buffer.h" @@ -268,15 +269,6 @@ struct AppendScalarImpl { } // namespace -Status ArrayBuilder::AppendScalar(const Scalar& scalar) { - if (!scalar.type->Equals(type())) { - return Status::Invalid("Cannot append scalar of type ", scalar.type->ToString(), - " to builder for type ", type()->ToString()); - } - std::shared_ptr shared{const_cast(&scalar), [](Scalar*) {}}; - return AppendScalarImpl{&shared, &shared + 1, /*n_repeats=*/1, this}.Convert(); -} - Status ArrayBuilder::AppendScalar(const Scalar& scalar, int64_t n_repeats) { if (!scalar.type->Equals(type())) { return Status::Invalid("Cannot append scalar of type ", scalar.type->ToString(), diff --git a/cpp/src/arrow/array/builder_base.h b/cpp/src/arrow/array/builder_base.h index 67203e79071..87e39c3fe9f 100644 --- a/cpp/src/arrow/array/builder_base.h +++ b/cpp/src/arrow/array/builder_base.h @@ -119,9 +119,9 @@ class ARROW_EXPORT ArrayBuilder { virtual Status AppendEmptyValues(int64_t length) = 0; /// \brief Append a value from a scalar - Status AppendScalar(const Scalar& scalar); - Status AppendScalar(const Scalar& scalar, int64_t n_repeats); - Status AppendScalars(const ScalarVector& scalars); + Status AppendScalar(const Scalar& scalar) { return AppendScalar(scalar, 1); } + virtual Status AppendScalar(const Scalar& scalar, int64_t n_repeats); + virtual Status AppendScalars(const ScalarVector& scalars); /// \brief Append a range of values from an array. /// @@ -282,6 +282,13 @@ ARROW_EXPORT Status MakeBuilder(MemoryPool* pool, const std::shared_ptr& type, std::unique_ptr* out); +/// \brief Construct an empty ArrayBuilder corresponding to the data +/// type, where any top-level or nested dictionary builders return the +/// exact index type specified by the type. +ARROW_EXPORT +Status MakeBuilderExactIndex(MemoryPool* pool, const std::shared_ptr& type, + std::unique_ptr* out); + /// \brief Construct an empty DictionaryBuilder initialized optionally /// with a pre-existing dictionary /// \param[in] pool the MemoryPool to use for allocations diff --git a/cpp/src/arrow/array/builder_dict.cc b/cpp/src/arrow/array/builder_dict.cc index b13f6a2db34..d247316999d 100644 --- a/cpp/src/arrow/array/builder_dict.cc +++ b/cpp/src/arrow/array/builder_dict.cc @@ -159,23 +159,32 @@ DictionaryMemoTable::DictionaryMemoTable(MemoryPool* pool, DictionaryMemoTable::~DictionaryMemoTable() = default; -#define GET_OR_INSERT(C_TYPE) \ - Status DictionaryMemoTable::GetOrInsert( \ - const typename CTypeTraits::ArrowType*, C_TYPE value, int32_t* out) { \ - return impl_->GetOrInsert::ArrowType>(value, out); \ +#define GET_OR_INSERT(ARROW_TYPE) \ + Status DictionaryMemoTable::GetOrInsert( \ + const ARROW_TYPE*, typename ARROW_TYPE::c_type value, int32_t* out) { \ + return impl_->GetOrInsert(value, out); \ } -GET_OR_INSERT(bool) -GET_OR_INSERT(int8_t) -GET_OR_INSERT(int16_t) -GET_OR_INSERT(int32_t) -GET_OR_INSERT(int64_t) -GET_OR_INSERT(uint8_t) -GET_OR_INSERT(uint16_t) -GET_OR_INSERT(uint32_t) -GET_OR_INSERT(uint64_t) -GET_OR_INSERT(float) -GET_OR_INSERT(double) +GET_OR_INSERT(BooleanType) +GET_OR_INSERT(Int8Type) +GET_OR_INSERT(Int16Type) +GET_OR_INSERT(Int32Type) +GET_OR_INSERT(Int64Type) +GET_OR_INSERT(UInt8Type) +GET_OR_INSERT(UInt16Type) +GET_OR_INSERT(UInt32Type) +GET_OR_INSERT(UInt64Type) +GET_OR_INSERT(FloatType) +GET_OR_INSERT(DoubleType) +GET_OR_INSERT(DurationType); +GET_OR_INSERT(TimestampType); +GET_OR_INSERT(Date32Type); +GET_OR_INSERT(Date64Type); +GET_OR_INSERT(Time32Type); +GET_OR_INSERT(Time64Type); +GET_OR_INSERT(MonthDayNanoIntervalType); +GET_OR_INSERT(DayTimeIntervalType); +GET_OR_INSERT(MonthIntervalType); #undef GET_OR_INSERT diff --git a/cpp/src/arrow/array/builder_dict.h b/cpp/src/arrow/array/builder_dict.h index 455cb3df7b1..0637c9722a8 100644 --- a/cpp/src/arrow/array/builder_dict.h +++ b/cpp/src/arrow/array/builder_dict.h @@ -37,6 +37,7 @@ #include "arrow/util/decimal.h" #include "arrow/util/macros.h" #include "arrow/util/visibility.h" +#include "arrow/visitor_inline.h" namespace arrow { @@ -97,6 +98,17 @@ class ARROW_EXPORT DictionaryMemoTable { Status GetOrInsert(const UInt16Type*, uint16_t value, int32_t* out); Status GetOrInsert(const UInt32Type*, uint32_t value, int32_t* out); Status GetOrInsert(const UInt64Type*, uint64_t value, int32_t* out); + Status GetOrInsert(const DurationType*, int64_t value, int32_t* out); + Status GetOrInsert(const TimestampType*, int64_t value, int32_t* out); + Status GetOrInsert(const Date32Type*, int32_t value, int32_t* out); + Status GetOrInsert(const Date64Type*, int64_t value, int32_t* out); + Status GetOrInsert(const Time32Type*, int32_t value, int32_t* out); + Status GetOrInsert(const Time64Type*, int64_t value, int32_t* out); + Status GetOrInsert(const MonthDayNanoIntervalType*, + MonthDayNanoIntervalType::MonthDayNanos value, int32_t* out); + Status GetOrInsert(const DayTimeIntervalType*, + DayTimeIntervalType::DayMilliseconds value, int32_t* out); + Status GetOrInsert(const MonthIntervalType*, int32_t value, int32_t* out); Status GetOrInsert(const FloatType*, float value, int32_t* out); Status GetOrInsert(const DoubleType*, double value, int32_t* out); @@ -282,6 +294,73 @@ class DictionaryBuilderBase : public ArrayBuilder { return indices_builder_.AppendEmptyValues(length); } + Status AppendScalar(const Scalar& scalar, int64_t n_repeats) override { + if (!scalar.is_valid) return AppendNulls(n_repeats); + + const auto& dict_ty = internal::checked_cast(*scalar.type); + const DictionaryScalar& dict_scalar = + internal::checked_cast(scalar); + const auto& dict = internal::checked_cast::ArrayType&>( + *dict_scalar.value.dictionary); + ARROW_RETURN_NOT_OK(Reserve(n_repeats)); + switch (dict_ty.index_type()->id()) { + case Type::UINT8: + return AppendScalarImpl(dict, *dict_scalar.value.index, n_repeats); + case Type::INT8: + return AppendScalarImpl(dict, *dict_scalar.value.index, n_repeats); + case Type::UINT16: + return AppendScalarImpl(dict, *dict_scalar.value.index, n_repeats); + case Type::INT16: + return AppendScalarImpl(dict, *dict_scalar.value.index, n_repeats); + case Type::UINT32: + return AppendScalarImpl(dict, *dict_scalar.value.index, n_repeats); + case Type::INT32: + return AppendScalarImpl(dict, *dict_scalar.value.index, n_repeats); + case Type::UINT64: + return AppendScalarImpl(dict, *dict_scalar.value.index, n_repeats); + case Type::INT64: + return AppendScalarImpl(dict, *dict_scalar.value.index, n_repeats); + default: + return Status::TypeError("Invalid index type: ", dict_ty); + } + return Status::OK(); + } + + Status AppendScalars(const ScalarVector& scalars) override { + for (const auto& scalar : scalars) { + ARROW_RETURN_NOT_OK(AppendScalar(*scalar, /*n_repeats=*/1)); + } + return Status::OK(); + } + + Status AppendArraySlice(const ArrayData& array, int64_t offset, int64_t length) final { + // Visit the indices and insert the unpacked values. + const auto& dict_ty = internal::checked_cast(*array.type); + const typename TypeTraits::ArrayType dict(array.dictionary); + ARROW_RETURN_NOT_OK(Reserve(length)); + switch (dict_ty.index_type()->id()) { + case Type::UINT8: + return AppendArraySliceImpl(dict, array, offset, length); + case Type::INT8: + return AppendArraySliceImpl(dict, array, offset, length); + case Type::UINT16: + return AppendArraySliceImpl(dict, array, offset, length); + case Type::INT16: + return AppendArraySliceImpl(dict, array, offset, length); + case Type::UINT32: + return AppendArraySliceImpl(dict, array, offset, length); + case Type::INT32: + return AppendArraySliceImpl(dict, array, offset, length); + case Type::UINT64: + return AppendArraySliceImpl(dict, array, offset, length); + case Type::INT64: + return AppendArraySliceImpl(dict, array, offset, length); + default: + return Status::TypeError("Invalid index type: ", dict_ty); + } + return Status::OK(); + } + /// \brief Insert values into the dictionary's memo, but do not append any /// indices. Can be used to initialize a new builder with known dictionary /// values @@ -376,6 +455,37 @@ class DictionaryBuilderBase : public ArrayBuilder { } protected: + template + Status AppendArraySliceImpl(const typename TypeTraits::ArrayType& dict, + const ArrayData& array, int64_t offset, int64_t length) { + const c_type* values = array.GetValues(1) + offset; + return VisitBitBlocks( + array.buffers[0], array.offset + offset, length, + [&](const int64_t position) { + const int64_t index = static_cast(values[position]); + if (dict.IsValid(index)) { + return Append(dict.GetView(index)); + } + return AppendNull(); + }, + [&]() { return AppendNull(); }); + } + + template + Status AppendScalarImpl(const typename TypeTraits::ArrayType& dict, + const Scalar& index_scalar, int64_t n_repeats) { + using ScalarType = typename TypeTraits::ScalarType; + const auto index = internal::checked_cast(index_scalar).value; + if (index_scalar.is_valid && dict.IsValid(index)) { + const auto& value = dict.GetView(index); + for (int64_t i = 0; i < n_repeats; i++) { + ARROW_RETURN_NOT_OK(Append(value)); + } + return Status::OK(); + } + return AppendNulls(n_repeats); + } + Status FinishInternal(std::shared_ptr* out) override { std::shared_ptr dictionary; ARROW_RETURN_NOT_OK(FinishWithDictOffset(/*offset=*/0, out, &dictionary)); diff --git a/cpp/src/arrow/builder.cc b/cpp/src/arrow/builder.cc index 37cc9e07ad4..115a97e9389 100644 --- a/cpp/src/arrow/builder.cc +++ b/cpp/src/arrow/builder.cc @@ -41,14 +41,10 @@ struct DictionaryBuilderCase { } Status Visit(const NullType&) { return CreateFor(); } - Status Visit(const BinaryType&) { return Create(); } - Status Visit(const StringType&) { return Create(); } - Status Visit(const LargeBinaryType&) { - return Create>(); - } - Status Visit(const LargeStringType&) { - return Create>(); - } + Status Visit(const BinaryType&) { return CreateFor(); } + Status Visit(const StringType&) { return CreateFor(); } + Status Visit(const LargeBinaryType&) { return CreateFor(); } + Status Visit(const LargeStringType&) { return CreateFor(); } Status Visit(const FixedSizeBinaryType&) { return CreateFor(); } Status Visit(const Decimal128Type&) { return CreateFor(); } Status Visit(const Decimal256Type&) { return CreateFor(); } @@ -63,19 +59,50 @@ struct DictionaryBuilderCase { template Status CreateFor() { - return Create>(); - } - - template - Status Create() { - BuilderType* builder; + using AdaptiveBuilderType = DictionaryBuilder; if (dictionary != nullptr) { - builder = new BuilderType(dictionary, pool); + out->reset(new AdaptiveBuilderType(dictionary, pool)); + } else if (exact_index_type) { + switch (index_type->id()) { + case Type::UINT8: + out->reset(new internal::DictionaryBuilderBase( + value_type, pool)); + break; + case Type::INT8: + out->reset(new internal::DictionaryBuilderBase( + value_type, pool)); + break; + case Type::UINT16: + out->reset(new internal::DictionaryBuilderBase( + value_type, pool)); + break; + case Type::INT16: + out->reset(new internal::DictionaryBuilderBase( + value_type, pool)); + break; + case Type::UINT32: + out->reset(new internal::DictionaryBuilderBase( + value_type, pool)); + break; + case Type::INT32: + out->reset(new internal::DictionaryBuilderBase( + value_type, pool)); + break; + case Type::UINT64: + out->reset(new internal::DictionaryBuilderBase( + value_type, pool)); + break; + case Type::INT64: + out->reset(new internal::DictionaryBuilderBase( + value_type, pool)); + break; + default: + return Status::TypeError("MakeBuilder: invalid index type ", *index_type); + } } else { auto start_int_size = internal::GetByteWidth(*index_type); - builder = new BuilderType(start_int_size, value_type, pool); + out->reset(new AdaptiveBuilderType(start_int_size, value_type, pool)); } - out->reset(builder); return Status::OK(); } @@ -85,138 +112,130 @@ struct DictionaryBuilderCase { const std::shared_ptr& index_type; const std::shared_ptr& value_type; const std::shared_ptr& dictionary; + bool exact_index_type; std::unique_ptr* out; }; -#define BUILDER_CASE(TYPE_CLASS) \ - case TYPE_CLASS##Type::type_id: \ - out->reset(new TYPE_CLASS##Builder(type, pool)); \ +struct MakeBuilderImpl { + template + enable_if_not_nested Visit(const T&) { + out.reset(new typename TypeTraits::BuilderType(type, pool)); return Status::OK(); + } -Result>> FieldBuilders(const DataType& type, - MemoryPool* pool) { - std::vector> field_builders; + Status Visit(const DictionaryType& dict_type) { + DictionaryBuilderCase visitor = {pool, + dict_type.index_type(), + dict_type.value_type(), + /*dictionary=*/nullptr, + exact_index_type, + &out}; + return visitor.Make(); + } - for (const auto& field : type.fields()) { - std::unique_ptr builder; - RETURN_NOT_OK(MakeBuilder(pool, field->type(), &builder)); - field_builders.emplace_back(std::move(builder)); + Status Visit(const ListType& list_type) { + std::shared_ptr value_type = list_type.value_type(); + ARROW_ASSIGN_OR_RAISE(auto value_builder, ChildBuilder(value_type)); + out.reset(new ListBuilder(pool, std::move(value_builder), type)); + return Status::OK(); } - return field_builders; -} + Status Visit(const LargeListType& list_type) { + std::shared_ptr value_type = list_type.value_type(); + ARROW_ASSIGN_OR_RAISE(auto value_builder, ChildBuilder(value_type)); + out.reset(new LargeListBuilder(pool, std::move(value_builder), type)); + return Status::OK(); + } -Status MakeBuilder(MemoryPool* pool, const std::shared_ptr& type, - std::unique_ptr* out) { - switch (type->id()) { - case Type::NA: { - out->reset(new NullBuilder(pool)); - return Status::OK(); - } - BUILDER_CASE(UInt8); - BUILDER_CASE(Int8); - BUILDER_CASE(UInt16); - BUILDER_CASE(Int16); - BUILDER_CASE(UInt32); - BUILDER_CASE(Int32); - BUILDER_CASE(UInt64); - BUILDER_CASE(Int64); - BUILDER_CASE(Date32); - BUILDER_CASE(Date64); - BUILDER_CASE(Duration); - BUILDER_CASE(Time32); - BUILDER_CASE(Time64); - BUILDER_CASE(Timestamp); - BUILDER_CASE(MonthInterval); - BUILDER_CASE(DayTimeInterval); - BUILDER_CASE(MonthDayNanoInterval); - BUILDER_CASE(Boolean); - BUILDER_CASE(HalfFloat); - BUILDER_CASE(Float); - BUILDER_CASE(Double); - BUILDER_CASE(String); - BUILDER_CASE(Binary); - BUILDER_CASE(LargeString); - BUILDER_CASE(LargeBinary); - BUILDER_CASE(FixedSizeBinary); - BUILDER_CASE(Decimal128); - BUILDER_CASE(Decimal256); - - case Type::DICTIONARY: { - const auto& dict_type = static_cast(*type); - DictionaryBuilderCase visitor = {pool, dict_type.index_type(), - dict_type.value_type(), nullptr, out}; - return visitor.Make(); - } + Status Visit(const MapType& map_type) { + ARROW_ASSIGN_OR_RAISE(auto key_builder, ChildBuilder(map_type.key_type())); + ARROW_ASSIGN_OR_RAISE(auto item_builder, ChildBuilder(map_type.item_type())); + out.reset( + new MapBuilder(pool, std::move(key_builder), std::move(item_builder), type)); + return Status::OK(); + } - case Type::LIST: { - std::unique_ptr value_builder; - std::shared_ptr value_type = - internal::checked_cast(*type).value_type(); - RETURN_NOT_OK(MakeBuilder(pool, value_type, &value_builder)); - out->reset(new ListBuilder(pool, std::move(value_builder), type)); - return Status::OK(); - } + Status Visit(const FixedSizeListType& list_type) { + auto value_type = list_type.value_type(); + ARROW_ASSIGN_OR_RAISE(auto value_builder, ChildBuilder(value_type)); + out.reset(new FixedSizeListBuilder(pool, std::move(value_builder), type)); + return Status::OK(); + } - case Type::LARGE_LIST: { - std::unique_ptr value_builder; - std::shared_ptr value_type = - internal::checked_cast(*type).value_type(); - RETURN_NOT_OK(MakeBuilder(pool, value_type, &value_builder)); - out->reset(new LargeListBuilder(pool, std::move(value_builder), type)); - return Status::OK(); - } + Status Visit(const StructType& struct_type) { + ARROW_ASSIGN_OR_RAISE(auto field_builders, FieldBuilders(*type, pool)); + out.reset(new StructBuilder(type, pool, std::move(field_builders))); + return Status::OK(); + } - case Type::MAP: { - const auto& map_type = internal::checked_cast(*type); - std::unique_ptr key_builder, item_builder; - RETURN_NOT_OK(MakeBuilder(pool, map_type.key_type(), &key_builder)); - RETURN_NOT_OK(MakeBuilder(pool, map_type.item_type(), &item_builder)); - out->reset( - new MapBuilder(pool, std::move(key_builder), std::move(item_builder), type)); - return Status::OK(); - } + Status Visit(const SparseUnionType&) { + ARROW_ASSIGN_OR_RAISE(auto field_builders, FieldBuilders(*type, pool)); + out.reset(new SparseUnionBuilder(pool, std::move(field_builders), type)); + return Status::OK(); + } - case Type::FIXED_SIZE_LIST: { - const auto& list_type = internal::checked_cast(*type); - std::unique_ptr value_builder; - auto value_type = list_type.value_type(); - RETURN_NOT_OK(MakeBuilder(pool, value_type, &value_builder)); - out->reset(new FixedSizeListBuilder(pool, std::move(value_builder), type)); - return Status::OK(); - } + Status Visit(const DenseUnionType&) { + ARROW_ASSIGN_OR_RAISE(auto field_builders, FieldBuilders(*type, pool)); + out.reset(new DenseUnionBuilder(pool, std::move(field_builders), type)); + return Status::OK(); + } - case Type::STRUCT: { - ARROW_ASSIGN_OR_RAISE(auto field_builders, FieldBuilders(*type, pool)); - out->reset(new StructBuilder(type, pool, std::move(field_builders))); - return Status::OK(); - } + Status Visit(const ExtensionType&) { return NotImplemented(); } + Status Visit(const DataType&) { return NotImplemented(); } - case Type::SPARSE_UNION: { - ARROW_ASSIGN_OR_RAISE(auto field_builders, FieldBuilders(*type, pool)); - out->reset(new SparseUnionBuilder(pool, std::move(field_builders), type)); - return Status::OK(); - } + Status NotImplemented() { + return Status::NotImplemented("MakeBuilder: cannot construct builder for type ", + type->ToString()); + } - case Type::DENSE_UNION: { - ARROW_ASSIGN_OR_RAISE(auto field_builders, FieldBuilders(*type, pool)); - out->reset(new DenseUnionBuilder(pool, std::move(field_builders), type)); - return Status::OK(); - } + Result> ChildBuilder( + const std::shared_ptr& type) { + MakeBuilderImpl impl{pool, type, exact_index_type, /*out=*/nullptr}; + RETURN_NOT_OK(VisitTypeInline(*type, &impl)); + return std::move(impl.out); + } - default: - break; + Result>> FieldBuilders(const DataType& type, + MemoryPool* pool) { + std::vector> field_builders; + for (const auto& field : type.fields()) { + std::unique_ptr builder; + MakeBuilderImpl impl{pool, field->type(), exact_index_type, /*out=*/nullptr}; + RETURN_NOT_OK(VisitTypeInline(*field->type(), &impl)); + field_builders.emplace_back(std::move(impl.out)); + } + return field_builders; } - return Status::NotImplemented("MakeBuilder: cannot construct builder for type ", - type->ToString()); + + MemoryPool* pool; + const std::shared_ptr& type; + bool exact_index_type; + std::unique_ptr out; +}; + +Status MakeBuilder(MemoryPool* pool, const std::shared_ptr& type, + std::unique_ptr* out) { + MakeBuilderImpl impl{pool, type, /*exact_index_type=*/false, /*out=*/nullptr}; + RETURN_NOT_OK(VisitTypeInline(*type, &impl)); + *out = std::move(impl.out); + return Status::OK(); +} + +Status MakeBuilderExactIndex(MemoryPool* pool, const std::shared_ptr& type, + std::unique_ptr* out) { + MakeBuilderImpl impl{pool, type, /*exact_index_type=*/true, /*out=*/nullptr}; + RETURN_NOT_OK(VisitTypeInline(*type, &impl)); + *out = std::move(impl.out); + return Status::OK(); } Status MakeDictionaryBuilder(MemoryPool* pool, const std::shared_ptr& type, const std::shared_ptr& dictionary, std::unique_ptr* out) { const auto& dict_type = static_cast(*type); - DictionaryBuilderCase visitor = {pool, dict_type.index_type(), dict_type.value_type(), - dictionary, out}; + DictionaryBuilderCase visitor = { + pool, dict_type.index_type(), dict_type.value_type(), + dictionary, /*exact_index_type=*/false, out}; return visitor.Make(); } diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else.cc b/cpp/src/arrow/compute/kernels/scalar_if_else.cc index 4de04da7a81..35bb6248f23 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else.cc @@ -1222,7 +1222,6 @@ struct CaseWhenFunction : ScalarFunction { // The first function is a struct of booleans, where the number of fields in the // struct is either equal to the number of other arguments or is one less. RETURN_NOT_OK(CheckArity(*values)); - EnsureDictionaryDecoded(values); auto first_type = (*values)[0].type; if (first_type->id() != Type::STRUCT) { return Status::TypeError("case_when: first argument must be STRUCT, not ", @@ -1243,6 +1242,9 @@ struct CaseWhenFunction : ScalarFunction { } } + if (auto kernel = DispatchExactImpl(this, *values)) return kernel; + + EnsureDictionaryDecoded(values); if (auto type = CommonNumeric(values->data() + 1, values->size() - 1)) { for (auto it = values->begin() + 1; it != values->end(); it++) { it->type = type; @@ -1279,6 +1281,15 @@ Status ExecScalarCaseWhen(KernelContext* ctx, const ExecBatch& batch, Datum* out return Status::OK(); } ArrayData* output = out->mutable_array(); + if (is_dictionary_type::value) { + const Datum& dict_from = result.is_value() ? result : batch[1]; + if (dict_from.is_scalar()) { + output->dictionary = checked_cast(*dict_from.scalar()) + .value.dictionary->data(); + } else { + output->dictionary = dict_from.array()->dictionary; + } + } if (!result.is_value()) { // All conditions false, no 'else' argument result = MakeNullScalar(out->type()); @@ -1304,6 +1315,7 @@ Status ExecArrayCaseWhen(KernelContext* ctx, const ExecBatch& batch, Datum* out) static_cast(conds_array.type->num_fields()) < num_value_args; uint8_t* out_valid = output->buffers[0]->mutable_data(); uint8_t* out_values = output->buffers[1]->mutable_data(); + if (have_else_arg) { // Copy 'else' value into output CopyValues(batch.values.back(), /*in_offset=*/0, batch.length, out_valid, @@ -1472,7 +1484,7 @@ static Status ExecVarWidthArrayCaseWhenImpl( const bool have_else_arg = static_cast(conds_array.type->num_fields()) < (batch.values.size() - 1); std::unique_ptr raw_builder; - RETURN_NOT_OK(MakeBuilder(ctx->memory_pool(), out->type(), &raw_builder)); + RETURN_NOT_OK(MakeBuilderExactIndex(ctx->memory_pool(), out->type(), &raw_builder)); RETURN_NOT_OK(raw_builder->Reserve(batch.length)); RETURN_NOT_OK(reserve_data(raw_builder.get())); @@ -1701,6 +1713,24 @@ struct CaseWhenFunctor> { } }; +template <> +struct CaseWhenFunctor { + static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + if (batch[0].null_count() > 0) { + return Status::Invalid("cond struct must not have outer nulls"); + } + if (batch[0].is_scalar()) { + return ExecVarWidthScalarCaseWhen(ctx, batch, out); + } + return ExecArray(ctx, batch, out); + } + + static Status ExecArray(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + std::function reserve_data = ReserveNoData; + return ExecVarWidthArrayCaseWhen(ctx, batch, out, std::move(reserve_data)); + } +}; + struct CoalesceFunction : ScalarFunction { using ScalarFunction::ScalarFunction; @@ -2446,7 +2476,7 @@ void RegisterScalarIfElse(FunctionRegistry* registry) { } { auto func = std::make_shared( - "case_when", Arity::VarArgs(/*min_args=*/1), &case_when_doc); + "case_when", Arity::VarArgs(/*min_args=*/2), &case_when_doc); AddPrimitiveCaseWhenKernels(func, NumericTypes()); AddPrimitiveCaseWhenKernels(func, TemporalTypes()); AddPrimitiveCaseWhenKernels(func, IntervalTypes()); @@ -2464,6 +2494,7 @@ void RegisterScalarIfElse(FunctionRegistry* registry) { AddCaseWhenKernel(func, Type::STRUCT, CaseWhenFunctor::Exec); AddCaseWhenKernel(func, Type::DENSE_UNION, CaseWhenFunctor::Exec); AddCaseWhenKernel(func, Type::SPARSE_UNION, CaseWhenFunctor::Exec); + AddCaseWhenKernel(func, Type::DICTIONARY, CaseWhenFunctor::Exec); DCHECK_OK(registry->AddFunction(std::move(func))); } { diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc index b3b0f26cead..8793cac7619 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc @@ -624,6 +624,187 @@ TYPED_TEST(TestCaseWhenNumeric, ListOfType) { ArrayFromJSON(type, R"([null, null, null, [6, null]])")); } +template +class TestCaseWhenDict : public ::testing::Test {}; + +struct JsonDict { + std::shared_ptr type; + std::string value; +}; + +TYPED_TEST_SUITE(TestCaseWhenDict, IntegralArrowTypes); + +TYPED_TEST(TestCaseWhenDict, Simple) { + auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]"); + auto cond2 = ArrayFromJSON(boolean(), "[true, false, true, null]"); + for (const auto& dict : + {JsonDict{utf8(), R"(["a", null, "bc", "def"])"}, + JsonDict{int64(), "[1, null, 2, 3]"}, + JsonDict{decimal256(3, 2), R"(["1.23", null, "3.45", "6.78"])"}}) { + auto type = dictionary(default_type_instance(), dict.type); + auto values_null = DictArrayFromJSON(type, "[null, null, null, null]", dict.value); + auto values1 = DictArrayFromJSON(type, "[0, null, 3, 1]", dict.value); + auto values2 = DictArrayFromJSON(type, "[2, 1, null, 0]", dict.value); + + // Easy case: all arguments have the same dictionary + CheckDictionary("case_when", {MakeStruct({cond1, cond2}), values1, values2}); + CheckDictionary("case_when", {MakeStruct({cond1, cond2}), values1, values2, values1}); + CheckDictionary("case_when", + {MakeStruct({cond1, cond2}), values_null, values2, values1}); + } +} + +TYPED_TEST(TestCaseWhenDict, Mixed) { + auto type = dictionary(default_type_instance(), utf8()); + auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]"); + auto cond2 = ArrayFromJSON(boolean(), "[true, false, true, null]"); + auto dict = R"(["a", null, "bc", "def"])"; + auto values_null = DictArrayFromJSON(type, "[null, null, null, null]", dict); + auto values1_dict = DictArrayFromJSON(type, "[0, null, 3, 1]", dict); + auto values1_decoded = ArrayFromJSON(utf8(), R"(["a", null, "def", null])"); + auto values2_dict = DictArrayFromJSON(type, "[2, 1, null, 0]", dict); + auto values2_decoded = ArrayFromJSON(utf8(), R"(["bc", null, null, "a"])"); + + // If we have mixed dictionary/non-dictionary arguments, we decode dictionaries + CheckDictionary("case_when", + {MakeStruct({cond1, cond2}), values1_dict, values2_decoded}, + /*result_is_encoded=*/false); + CheckDictionary("case_when", + {MakeStruct({cond1, cond2}), values1_decoded, values2_dict}, + /*result_is_encoded=*/false); + CheckDictionary( + "case_when", + {MakeStruct({cond1, cond2}), values1_dict, values2_dict, values1_decoded}, + /*result_is_encoded=*/false); + CheckDictionary( + "case_when", + {MakeStruct({cond1, cond2}), values_null, values2_dict, values1_decoded}, + /*result_is_encoded=*/false); +} + +TYPED_TEST(TestCaseWhenDict, NestedSimple) { + auto make_list = [](const std::shared_ptr& indices, + const std::shared_ptr& backing_array) { + EXPECT_OK_AND_ASSIGN(auto result, ListArray::FromArrays(*indices, *backing_array)); + return result; + }; + auto index_type = default_type_instance(); + auto inner_type = dictionary(index_type, utf8()); + auto type = list(inner_type); + auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]"); + auto cond2 = ArrayFromJSON(boolean(), "[true, false, true, null]"); + auto dict = R"(["a", null, "bc", "def"])"; + auto values_null = make_list(ArrayFromJSON(int32(), "[null, null, null, null, 0]"), + DictArrayFromJSON(inner_type, "[]", dict)); + auto values1_backing = DictArrayFromJSON(inner_type, "[0, null, 3, 1]", dict); + auto values2_backing = DictArrayFromJSON(inner_type, "[2, 1, null, 0]", dict); + auto values1 = make_list(ArrayFromJSON(int32(), "[0, 2, 2, 3, 4]"), values1_backing); + auto values2 = make_list(ArrayFromJSON(int32(), "[0, 1, 2, 2, 4]"), values2_backing); + + CheckDictionary("case_when", {MakeStruct({cond1, cond2}), values1, values2}, + /*result_is_encoded=*/false); + CheckDictionary( + "case_when", + {MakeStruct({cond1, cond2}), values1, + make_list(ArrayFromJSON(int32(), "[0, 1, null, 2, 4]"), values2_backing)}, + /*result_is_encoded=*/false); + CheckDictionary( + "case_when", + {MakeStruct({cond1, cond2}), values1, + make_list(ArrayFromJSON(int32(), "[0, 1, null, 2, 4]"), values2_backing), values1}, + /*result_is_encoded=*/false); + + CheckDictionary("case_when", + { + Datum(MakeStruct({cond1, cond2})), + Datum(std::make_shared( + DictArrayFromJSON(inner_type, "[0, 1]", dict))), + Datum(std::make_shared( + DictArrayFromJSON(inner_type, "[2, 3]", dict))), + }, + /*result_is_encoded=*/false); + + CheckDictionary("case_when", + {MakeStruct({Datum(true), Datum(false)}), values1, values2}, + /*result_is_encoded=*/false); + CheckDictionary("case_when", + {MakeStruct({Datum(false), Datum(true)}), values1, values2}, + /*result_is_encoded=*/false); + CheckDictionary("case_when", {MakeStruct({Datum(false)}), values1, values2}, + /*result_is_encoded=*/false); + CheckDictionary("case_when", + {MakeStruct({Datum(false), Datum(false)}), values1, values2}, + /*result_is_encoded=*/false); +} + +TYPED_TEST(TestCaseWhenDict, DifferentDictionaries) { + auto type = dictionary(default_type_instance(), utf8()); + auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]"); + auto cond2 = ArrayFromJSON(boolean(), "[true, false, null, true]"); + auto dict1 = R"(["a", null, "bc", "def"])"; + auto dict2 = R"(["bc", "foo", null, "a"])"; + auto dict3 = R"(["def", null, "a", "bc"])"; + auto values1_null = DictArrayFromJSON(type, "[null, null, null, null]", dict1); + auto values2_null = DictArrayFromJSON(type, "[null, null, null, null]", dict2); + auto values1 = DictArrayFromJSON(type, "[null, 0, 3, 1]", dict1); + auto values2 = DictArrayFromJSON(type, "[2, 1, 0, null]", dict2); + auto values3 = DictArrayFromJSON(type, "[0, 1, 2, 3]", dict3); + + CheckDictionary("case_when", + {MakeStruct({Datum(true), Datum(false)}), values1, values2}); + CheckDictionary("case_when", + {MakeStruct({Datum(false), Datum(true)}), values1, values2}); + CheckDictionary("case_when", + {MakeStruct({Datum(false), Datum(false)}), values1, values2}); + CheckDictionary("case_when", + {MakeStruct({Datum(false), Datum(false)}), values2, values1}); + + CheckDictionary("case_when", {MakeStruct({cond1, cond2}), values1, values2}); + CheckDictionary("case_when", {MakeStruct({cond1, cond2}), values1, values2, values1}); + + CheckDictionary("case_when", + {MakeStruct({ArrayFromJSON(boolean(), "[true, true, false, false]")}), + values1, values2}); + CheckDictionary("case_when", + {MakeStruct({ArrayFromJSON(boolean(), "[true, false, false, true]")}), + values1, values2}); + CheckDictionary("case_when", + {MakeStruct({ArrayFromJSON(boolean(), "[true, true, false, false]"), + ArrayFromJSON(boolean(), "[true, false, true, false]")}), + values1, values2}); + CheckDictionary("case_when", + {MakeStruct({ArrayFromJSON(boolean(), "[false, false, false, false]"), + ArrayFromJSON(boolean(), "[true, true, true, true]")}), + values1, values3}); + CheckDictionary("case_when", + {MakeStruct({ArrayFromJSON(boolean(), "[null, null, null, true]"), + ArrayFromJSON(boolean(), "[true, true, true, true]")}), + values1, values3}); + CheckDictionary( + "case_when", + { + MakeStruct({ArrayFromJSON(boolean(), "[true, true, false, false]")}), + DictScalarFromJSON(type, "0", dict1), + DictScalarFromJSON(type, "0", dict2), + }); + CheckDictionary( + "case_when", + { + MakeStruct({ArrayFromJSON(boolean(), "[true, true, false, false]"), + ArrayFromJSON(boolean(), "[false, false, true, true]")}), + DictScalarFromJSON(type, "0", dict1), + DictScalarFromJSON(type, "0", dict2), + }); + CheckDictionary( + "case_when", + { + MakeStruct({ArrayFromJSON(boolean(), "[true, true, false, false]"), + ArrayFromJSON(boolean(), "[false, false, true, true]")}), + DictScalarFromJSON(type, "null", dict1), + DictScalarFromJSON(type, "0", dict2), + }); +} + TEST(TestCaseWhen, Null) { auto cond_true = ScalarFromJSON(boolean(), "true"); auto cond_false = ScalarFromJSON(boolean(), "false"); @@ -1489,6 +1670,18 @@ TEST(TestCaseWhen, DispatchBest) { CallFunction("case_when", {MakeStruct({ArrayFromJSON(boolean(), "[]")}), ArrayFromJSON(int64(), "[]"), ArrayFromJSON(utf8(), "[]")})); + + // Do not dictionary-decode when we have only dictionary values + CheckDispatchBest("case_when", + {struct_({field("", boolean())}), dictionary(int64(), utf8()), + dictionary(int64(), utf8())}, + {struct_({field("", boolean())}), dictionary(int64(), utf8()), + dictionary(int64(), utf8())}); + + // Dictionary-decode if we have a mix + CheckDispatchBest( + "case_when", {struct_({field("", boolean())}), dictionary(int64(), utf8()), utf8()}, + {struct_({field("", boolean())}), utf8(), utf8()}); } template diff --git a/cpp/src/arrow/compute/kernels/test_util.cc b/cpp/src/arrow/compute/kernels/test_util.cc index 4a9215101b1..cedc03698a1 100644 --- a/cpp/src/arrow/compute/kernels/test_util.cc +++ b/cpp/src/arrow/compute/kernels/test_util.cc @@ -24,6 +24,7 @@ #include "arrow/array.h" #include "arrow/array/validate.h" #include "arrow/chunked_array.h" +#include "arrow/compute/cast.h" #include "arrow/compute/exec.h" #include "arrow/compute/function.h" #include "arrow/compute/registry.h" @@ -46,13 +47,6 @@ DatumVector GetDatums(const std::vector& inputs) { return datums; } -void CheckScalarNonRecursive(const std::string& func_name, const DatumVector& inputs, - const Datum& expected, const FunctionOptions* options) { - ASSERT_OK_AND_ASSIGN(Datum out, CallFunction(func_name, inputs, options)); - ValidateOutput(out); - AssertDatumsEqual(expected, out, /*verbose=*/true); -} - template DatumVector SliceArrays(const DatumVector& inputs, SliceArgs... slice_args) { DatumVector sliced; @@ -80,6 +74,13 @@ ScalarVector GetScalars(const DatumVector& inputs, int64_t index) { } // namespace +void CheckScalarNonRecursive(const std::string& func_name, const DatumVector& inputs, + const Datum& expected, const FunctionOptions* options) { + ASSERT_OK_AND_ASSIGN(Datum out, CallFunction(func_name, inputs, options)); + ValidateOutput(out); + AssertDatumsEqual(expected, out, /*verbose=*/true); +} + void CheckScalar(std::string func_name, const ScalarVector& inputs, std::shared_ptr expected, const FunctionOptions* options) { ASSERT_OK_AND_ASSIGN(Datum out, CallFunction(func_name, GetDatums(inputs), options)); @@ -170,6 +171,83 @@ void CheckScalar(std::string func_name, const DatumVector& inputs, Datum expecte } } +Datum CheckDictionaryNonRecursive(const std::string& func_name, const DatumVector& args, + bool result_is_encoded) { + EXPECT_OK_AND_ASSIGN(Datum actual, CallFunction(func_name, args)); + ValidateOutput(actual); + + DatumVector decoded_args; + decoded_args.reserve(args.size()); + for (const auto& arg : args) { + if (arg.type()->id() == Type::DICTIONARY) { + const auto& to_type = checked_cast(*arg.type()).value_type(); + EXPECT_OK_AND_ASSIGN(auto decoded, Cast(arg, to_type)); + decoded_args.push_back(decoded); + } else { + decoded_args.push_back(arg); + } + } + EXPECT_OK_AND_ASSIGN(Datum expected, CallFunction(func_name, decoded_args)); + + if (result_is_encoded) { + EXPECT_EQ(Type::DICTIONARY, actual.type()->id()) + << "Result should have been dictionary-encoded"; + // Decode before comparison - we care about equivalent not identical results + const auto& to_type = + checked_cast(*actual.type()).value_type(); + EXPECT_OK_AND_ASSIGN(auto decoded, Cast(actual, to_type)); + AssertDatumsApproxEqual(expected, decoded, /*verbose=*/true); + } else { + AssertDatumsApproxEqual(expected, actual, /*verbose=*/true); + } + return actual; +} + +void CheckDictionary(const std::string& func_name, const DatumVector& args, + bool result_is_encoded) { + auto actual = CheckDictionaryNonRecursive(func_name, args, result_is_encoded); + + if (actual.is_scalar()) return; + ASSERT_TRUE(actual.is_array()); + ASSERT_GE(actual.length(), 0); + + // Check all scalars + for (int64_t i = 0; i < actual.length(); i++) { + CheckDictionaryNonRecursive(func_name, GetDatums(GetScalars(args, i)), + result_is_encoded); + } + + // Check slices of the input + const auto slice_length = actual.length() / 3; + if (slice_length > 0) { + CheckDictionaryNonRecursive(func_name, SliceArrays(args, 0, slice_length), + result_is_encoded); + CheckDictionaryNonRecursive(func_name, SliceArrays(args, slice_length, slice_length), + result_is_encoded); + CheckDictionaryNonRecursive(func_name, SliceArrays(args, 2 * slice_length), + result_is_encoded); + } + + // Check empty slice + CheckDictionaryNonRecursive(func_name, SliceArrays(args, 0, 0), result_is_encoded); + + // Check chunked arrays + if (slice_length > 0) { + DatumVector chunked_args; + chunked_args.reserve(args.size()); + for (const auto& arg : args) { + if (arg.is_array()) { + auto arr = arg.make_array(); + ArrayVector chunks{arr->Slice(0, slice_length), arr->Slice(slice_length)}; + chunked_args.push_back(std::make_shared(std::move(chunks))); + } else { + chunked_args.push_back(arg); + } + } + CheckDictionaryNonRecursive(func_name, chunked_args, result_is_encoded); + } +} + void CheckScalarUnary(std::string func_name, Datum input, Datum expected, const FunctionOptions* options) { std::vector input_vector = {std::move(input)}; diff --git a/cpp/src/arrow/compute/kernels/test_util.h b/cpp/src/arrow/compute/kernels/test_util.h index 79745b05552..25ea577a423 100644 --- a/cpp/src/arrow/compute/kernels/test_util.h +++ b/cpp/src/arrow/compute/kernels/test_util.h @@ -67,6 +67,8 @@ inline std::string CompareOperatorToFunctionName(CompareOperator op) { return function_names[op]; } +// Call the function with the given arguments, as well as slices of +// the arguments and scalars extracted from the arguments. void CheckScalar(std::string func_name, const ScalarVector& inputs, std::shared_ptr expected, const FunctionOptions* options = nullptr); @@ -74,6 +76,19 @@ void CheckScalar(std::string func_name, const ScalarVector& inputs, void CheckScalar(std::string func_name, const DatumVector& inputs, Datum expected, const FunctionOptions* options = nullptr); +// Like CheckScalar, but gets the expected result by +// dictionary-decoding arguments and calling the function again. +// +// result_is_encoded controls whether the result is expected to be a +// dictionary or not. +void CheckDictionary(const std::string& func_name, const DatumVector& args, + bool result_is_encoded = true); + +// Just call the function with the given arguments. +void CheckScalarNonRecursive(const std::string& func_name, const DatumVector& inputs, + const Datum& expected, + const FunctionOptions* options = nullptr); + void CheckScalarUnary(std::string func_name, std::shared_ptr in_ty, std::string json_input, std::shared_ptr out_ty, std::string json_expected, diff --git a/cpp/src/arrow/ipc/json_simple.cc b/cpp/src/arrow/ipc/json_simple.cc index 34b0f3fba59..8347b871b1f 100644 --- a/cpp/src/arrow/ipc/json_simple.cc +++ b/cpp/src/arrow/ipc/json_simple.cc @@ -969,6 +969,25 @@ Status ScalarFromJSON(const std::shared_ptr& type, return Status::OK(); } +Status DictScalarFromJSON(const std::shared_ptr& type, + util::string_view index_json, util::string_view dictionary_json, + std::shared_ptr* out) { + if (type->id() != Type::DICTIONARY) { + return Status::TypeError("DictScalarFromJSON requires dictionary type, got ", *type); + } + + const auto& dictionary_type = checked_cast(*type); + + std::shared_ptr index; + std::shared_ptr dictionary; + RETURN_NOT_OK(ScalarFromJSON(dictionary_type.index_type(), index_json, &index)); + RETURN_NOT_OK( + ArrayFromJSON(dictionary_type.value_type(), dictionary_json, &dictionary)); + + *out = DictionaryScalar::Make(std::move(index), std::move(dictionary)); + return Status::OK(); +} + } // namespace json } // namespace internal } // namespace ipc diff --git a/cpp/src/arrow/ipc/json_simple.h b/cpp/src/arrow/ipc/json_simple.h index 4dd3a664aa6..8269bd65326 100644 --- a/cpp/src/arrow/ipc/json_simple.h +++ b/cpp/src/arrow/ipc/json_simple.h @@ -55,6 +55,11 @@ ARROW_EXPORT Status ScalarFromJSON(const std::shared_ptr&, util::string_view json, std::shared_ptr* out); +ARROW_EXPORT +Status DictScalarFromJSON(const std::shared_ptr&, util::string_view index_json, + util::string_view dictionary_json, + std::shared_ptr* out); + } // namespace json } // namespace internal } // namespace ipc diff --git a/cpp/src/arrow/ipc/json_simple_test.cc b/cpp/src/arrow/ipc/json_simple_test.cc index ce2c37b7957..34c300faa95 100644 --- a/cpp/src/arrow/ipc/json_simple_test.cc +++ b/cpp/src/arrow/ipc/json_simple_test.cc @@ -1385,6 +1385,30 @@ TEST(TestScalarFromJSON, Errors) { ASSERT_RAISES(Invalid, ScalarFromJSON(boolean(), "\"true\"", &scalar)); } +TEST(TestDictScalarFromJSON, Basics) { + auto type = dictionary(int32(), utf8()); + auto dict = R"(["whiskey", "tango", "foxtrot"])"; + auto expected_dictionary = ArrayFromJSON(utf8(), dict); + + for (auto index : {"null", "2", "1", "0"}) { + auto scalar = DictScalarFromJSON(type, index, dict); + auto expected_index = ScalarFromJSON(int32(), index); + AssertScalarsEqual(*DictionaryScalar::Make(expected_index, expected_dictionary), + *scalar, /*verbose=*/true); + ASSERT_OK(scalar->ValidateFull()); + } +} + +TEST(TestDictScalarFromJSON, Errors) { + auto type = dictionary(int32(), utf8()); + std::shared_ptr scalar; + + ASSERT_RAISES(Invalid, + DictScalarFromJSON(type, "\"not a valid index\"", "[\"\"]", &scalar)); + ASSERT_RAISES(Invalid, DictScalarFromJSON(type, "0", "[1]", + &scalar)); // dict value isn't string +} + } // namespace json } // namespace internal } // namespace ipc diff --git a/cpp/src/arrow/scalar.cc b/cpp/src/arrow/scalar.cc index 60ba54f82cc..adfc50182cb 100644 --- a/cpp/src/arrow/scalar.cc +++ b/cpp/src/arrow/scalar.cc @@ -599,8 +599,9 @@ Result> DictionaryScalar::GetEncodedValue() const { std::shared_ptr DictionaryScalar::Make(std::shared_ptr index, std::shared_ptr dict) { auto type = dictionary(index->type, dict->type()); + auto is_valid = index->is_valid; return std::make_shared(ValueType{std::move(index), std::move(dict)}, - std::move(type)); + std::move(type), is_valid); } namespace { diff --git a/cpp/src/arrow/testing/gtest_util.cc b/cpp/src/arrow/testing/gtest_util.cc index 587154c1f30..24f5edcc6cb 100644 --- a/cpp/src/arrow/testing/gtest_util.cc +++ b/cpp/src/arrow/testing/gtest_util.cc @@ -446,6 +446,15 @@ std::shared_ptr ScalarFromJSON(const std::shared_ptr& type, return out; } +std::shared_ptr DictScalarFromJSON(const std::shared_ptr& type, + util::string_view index_json, + util::string_view dictionary_json) { + std::shared_ptr out; + ABORT_NOT_OK( + ipc::internal::json::DictScalarFromJSON(type, index_json, dictionary_json, &out)); + return out; +} + std::shared_ptr TableFromJSON(const std::shared_ptr& schema, const std::vector& json) { std::vector> batches; diff --git a/cpp/src/arrow/testing/gtest_util.h b/cpp/src/arrow/testing/gtest_util.h index f0021e05603..65ab33c5d1f 100644 --- a/cpp/src/arrow/testing/gtest_util.h +++ b/cpp/src/arrow/testing/gtest_util.h @@ -338,6 +338,11 @@ ARROW_TESTING_EXPORT std::shared_ptr ScalarFromJSON(const std::shared_ptr&, util::string_view json); +ARROW_TESTING_EXPORT +std::shared_ptr DictScalarFromJSON(const std::shared_ptr&, + util::string_view index_json, + util::string_view dictionary_json); + ARROW_TESTING_EXPORT std::shared_ptr
TableFromJSON(const std::shared_ptr&, const std::vector& json);