diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_dictionary.cc b/cpp/src/arrow/compute/kernels/scalar_cast_dictionary.cc index 13c0d599bf9..f13aa26d969 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_dictionary.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_dictionary.cc @@ -36,14 +36,22 @@ Status CastToDictionary(KernelContext* ctx, const ExecSpan& batch, ExecResult* o const CastOptions& options = CastState::Get(ctx); const auto& out_type = checked_cast(*out->type()); + std::shared_ptr in_array = batch[0].array.ToArrayData(); + // if out type is same as in type, return input if (out_type.Equals(*batch[0].type())) { /// XXX: This is the wrong place to do a zero-copy optimization - out->value = batch[0].array.ToArrayData(); + out->value = in_array; return Status::OK(); } - std::shared_ptr in_array = batch[0].array.ToArrayData(); + // If the input type is STRING, it is first encoded as a dictionary to facilitate + // processing. This approach allows the subsequent code to uniformly handle STRING + // inputs as if they were originally provided in dictionary format. Encoding as a + // dictionary helps in reusing the same logic for dictionary operations. + if (batch[0].type()->id() == Type::STRING) { + in_array = DictionaryEncode(in_array)->array(); + } const auto& in_type = checked_cast(*in_array->type); ArrayData* out_array = out->array_data().get(); @@ -77,17 +85,21 @@ Status CastToDictionary(KernelContext* ctx, const ExecSpan& batch, ExecResult* o return Status::OK(); } -std::vector> GetDictionaryCasts() { - auto func = std::make_shared("cast_dictionary", Type::DICTIONARY); - - AddCommonCasts(Type::DICTIONARY, kOutputTargetType, func.get()); - ScalarKernel kernel({InputType(Type::DICTIONARY)}, kOutputTargetType, CastToDictionary); +template +void AddDictionaryCast(CastFunction* func) { + ScalarKernel kernel({InputType(SrcType::type_id)}, kOutputTargetType, CastToDictionary); kernel.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE; kernel.mem_allocation = MemAllocation::NO_PREALLOCATE; + DCHECK_OK(func->AddKernel(SrcType::type_id, std::move(kernel))); +} - DCHECK_OK(func->AddKernel(Type::DICTIONARY, std::move(kernel))); +std::vector> GetDictionaryCasts() { + auto cast_dict = std::make_shared("cast_dictionary", Type::DICTIONARY); + AddCommonCasts(Type::DICTIONARY, kOutputTargetType, cast_dict.get()); + AddDictionaryCast(cast_dict.get()); + AddDictionaryCast(cast_dict.get()); - return {func}; + return {cast_dict}; } } // namespace internal diff --git a/cpp/src/arrow/dataset/partition_test.cc b/cpp/src/arrow/dataset/partition_test.cc index 7ec96929a93..1b71be15d19 100644 --- a/cpp/src/arrow/dataset/partition_test.cc +++ b/cpp/src/arrow/dataset/partition_test.cc @@ -316,7 +316,7 @@ TEST_F(TestPartitioning, DirectoryPartitioningFormatDictionary) { ArrayVector{dictionary}); written_schema_ = partitioning_->schema(); - ASSERT_OK_AND_ASSIGN(auto dict_hello, MakeScalar("hello")->CastTo(DictStr("")->type())); + ASSERT_OK_AND_ASSIGN(auto dict_hello, Cast(MakeScalar("hello"), DictStr("")->type())); AssertFormat(equal(field_ref("alpha"), literal(dict_hello)), "hello"); } @@ -329,7 +329,7 @@ TEST_F(TestPartitioning, DirectoryPartitioningFormatDictionaryCustomIndex) { schema({field("alpha", dict_type)}), ArrayVector{dictionary}); written_schema_ = partitioning_->schema(); - ASSERT_OK_AND_ASSIGN(auto dict_hello, MakeScalar("hello")->CastTo(dict_type)); + ASSERT_OK_AND_ASSIGN(auto dict_hello, Cast(MakeScalar("hello"), dict_type)); AssertFormat(equal(field_ref("alpha"), literal(dict_hello)), "hello"); } diff --git a/cpp/src/arrow/scalar_test.cc b/cpp/src/arrow/scalar_test.cc index ac740f92c85..c6b81eb46b4 100644 --- a/cpp/src/arrow/scalar_test.cc +++ b/cpp/src/arrow/scalar_test.cc @@ -1479,7 +1479,8 @@ TEST(TestDictionaryScalar, Cast) { auto alpha = dict->IsValid(i) ? MakeScalar(dict->GetString(i)) : MakeNullScalar(utf8()); // Cast string to dict(..., string) - ASSERT_OK_AND_ASSIGN(auto cast_alpha, alpha->CastTo(ty)); + ASSERT_OK_AND_ASSIGN(auto cast_alpha_datum, Cast(alpha, ty)); + const auto& cast_alpha = cast_alpha_datum.scalar(); ASSERT_OK(cast_alpha->ValidateFull()); ASSERT_OK_AND_ASSIGN( auto roundtripped_alpha,