Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 21 additions & 9 deletions cpp/src/arrow/compute/kernels/scalar_cast_dictionary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,22 @@ Status CastToDictionary(KernelContext* ctx, const ExecSpan& batch, ExecResult* o
const CastOptions& options = CastState::Get(ctx);
const auto& out_type = checked_cast<const DictionaryType&>(*out->type());

std::shared_ptr<ArrayData> in_array = batch[0].array.ToArrayData();

// if out type is same as in type, return input
if (out_type.Equals(*batch[0].type())) {
/// XXX: This is the wrong place to do a zero-copy optimization
out->value = batch[0].array.ToArrayData();
out->value = in_array;
return Status::OK();
}

std::shared_ptr<ArrayData> in_array = batch[0].array.ToArrayData();
// If the input type is STRING, it is first encoded as a dictionary to facilitate
// processing. This approach allows the subsequent code to uniformly handle STRING
// inputs as if they were originally provided in dictionary format. Encoding as a
// dictionary helps in reusing the same logic for dictionary operations.
if (batch[0].type()->id() == Type::STRING) {
in_array = DictionaryEncode(in_array)->array();
}
const auto& in_type = checked_cast<const DictionaryType&>(*in_array->type);

ArrayData* out_array = out->array_data().get();
Expand Down Expand Up @@ -77,17 +85,21 @@ Status CastToDictionary(KernelContext* ctx, const ExecSpan& batch, ExecResult* o
return Status::OK();
}

std::vector<std::shared_ptr<CastFunction>> GetDictionaryCasts() {
auto func = std::make_shared<CastFunction>("cast_dictionary", Type::DICTIONARY);

AddCommonCasts(Type::DICTIONARY, kOutputTargetType, func.get());
ScalarKernel kernel({InputType(Type::DICTIONARY)}, kOutputTargetType, CastToDictionary);
template <typename SrcType>
void AddDictionaryCast(CastFunction* func) {
ScalarKernel kernel({InputType(SrcType::type_id)}, kOutputTargetType, CastToDictionary);
kernel.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE;
kernel.mem_allocation = MemAllocation::NO_PREALLOCATE;
DCHECK_OK(func->AddKernel(SrcType::type_id, std::move(kernel)));
}

DCHECK_OK(func->AddKernel(Type::DICTIONARY, std::move(kernel)));
std::vector<std::shared_ptr<CastFunction>> GetDictionaryCasts() {
auto cast_dict = std::make_shared<CastFunction>("cast_dictionary", Type::DICTIONARY);
AddCommonCasts(Type::DICTIONARY, kOutputTargetType, cast_dict.get());
AddDictionaryCast<DictionaryType>(cast_dict.get());
AddDictionaryCast<StringType>(cast_dict.get());

return {func};
return {cast_dict};
}

} // namespace internal
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/arrow/dataset/partition_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ TEST_F(TestPartitioning, DirectoryPartitioningFormatDictionary) {
ArrayVector{dictionary});
written_schema_ = partitioning_->schema();

ASSERT_OK_AND_ASSIGN(auto dict_hello, MakeScalar("hello")->CastTo(DictStr("")->type()));
ASSERT_OK_AND_ASSIGN(auto dict_hello, Cast(MakeScalar("hello"), DictStr("")->type()));
AssertFormat(equal(field_ref("alpha"), literal(dict_hello)), "hello");
}

Expand All @@ -329,7 +329,7 @@ TEST_F(TestPartitioning, DirectoryPartitioningFormatDictionaryCustomIndex) {
schema({field("alpha", dict_type)}), ArrayVector{dictionary});
written_schema_ = partitioning_->schema();

ASSERT_OK_AND_ASSIGN(auto dict_hello, MakeScalar("hello")->CastTo(dict_type));
ASSERT_OK_AND_ASSIGN(auto dict_hello, Cast(MakeScalar("hello"), dict_type));
AssertFormat(equal(field_ref("alpha"), literal(dict_hello)), "hello");
}

Expand Down
3 changes: 2 additions & 1 deletion cpp/src/arrow/scalar_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1479,7 +1479,8 @@ TEST(TestDictionaryScalar, Cast) {
auto alpha =
dict->IsValid(i) ? MakeScalar(dict->GetString(i)) : MakeNullScalar(utf8());
// Cast string to dict(..., string)
ASSERT_OK_AND_ASSIGN(auto cast_alpha, alpha->CastTo(ty));
ASSERT_OK_AND_ASSIGN(auto cast_alpha_datum, Cast(alpha, ty));
const auto& cast_alpha = cast_alpha_datum.scalar();
ASSERT_OK(cast_alpha->ValidateFull());
ASSERT_OK_AND_ASSIGN(
auto roundtripped_alpha,
Expand Down