Skip to content
1 change: 1 addition & 0 deletions cpp/src/arrow/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,7 @@ if(ARROW_COMPUTE)
compute/kernels/scalar_arithmetic.cc
compute/kernels/scalar_boolean.cc
compute/kernels/scalar_cast_boolean.cc
compute/kernels/scalar_cast_dictionary.cc
compute/kernels/scalar_cast_internal.cc
compute/kernels/scalar_cast_nested.cc
compute/kernels/scalar_cast_numeric.cc
Expand Down
1 change: 1 addition & 0 deletions cpp/src/arrow/compute/cast.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ void InitCastTable() {
AddCastFunctions(GetNestedCasts());
AddCastFunctions(GetNumericCasts());
AddCastFunctions(GetTemporalCasts());
AddCastFunctions(GetDictionaryCasts());
}

void EnsureInitCastTable() { std::call_once(cast_table_initialized, InitCastTable); }
Expand Down
1 change: 1 addition & 0 deletions cpp/src/arrow/compute/cast_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ std::vector<std::shared_ptr<CastFunction>> GetNumericCasts();
std::vector<std::shared_ptr<CastFunction>> GetTemporalCasts();
std::vector<std::shared_ptr<CastFunction>> GetBinaryLikeCasts();
std::vector<std::shared_ptr<CastFunction>> GetNestedCasts();
std::vector<std::shared_ptr<CastFunction>> GetDictionaryCasts();

} // namespace internal
} // namespace compute
Expand Down
17 changes: 8 additions & 9 deletions cpp/src/arrow/compute/exec/expression_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -939,17 +939,16 @@ TEST(Expression, ReplaceFieldsWithKnownValues) {
ExpectReplacesTo(is_valid(field_ref("str")), i32_valid_str_null,
is_valid(null_literal(utf8())));

ASSERT_OK_AND_ASSIGN(auto expr, field_ref("dict_str").Bind(*kBoringSchema));
Datum dict_i32{
DictionaryScalar::Make(MakeScalar<int32_t>(0), ArrayFromJSON(int32(), R"([3])"))};
// Unsupported cast dictionary(int32(), int32()) -> dictionary(int32(), utf8())
ASSERT_RAISES(NotImplemented, ReplaceFieldsWithKnownValues(
KnownFieldValues{{{"dict_str", dict_i32}}}, expr));
// Unsupported cast dictionary(int8(), utf8()) -> dictionary(int32(), utf8())
dict_str = Datum{
DictionaryScalar::Make(MakeScalar<int8_t>(0), ArrayFromJSON(utf8(), R"(["a"])"))};
ASSERT_RAISES(NotImplemented, ReplaceFieldsWithKnownValues(
KnownFieldValues{{{"dict_str", dict_str}}}, expr));
// cast dictionary(int32(), int32()) -> dictionary(int32(), utf8())
ExpectReplacesTo(field_ref("dict_str"), {{"dict_str", dict_i32}}, literal(dict_str));

// cast dictionary(int8(), utf8()) -> dictionary(int32(), utf8())
auto dict_int8_str = Datum{
DictionaryScalar::Make(MakeScalar<int8_t>(0), ArrayFromJSON(utf8(), R"(["3"])"))};
ExpectReplacesTo(field_ref("dict_str"), {{"dict_str", dict_int8_str}},
literal(dict_str));
}

struct {
Expand Down
126 changes: 126 additions & 0 deletions cpp/src/arrow/compute/kernels/scalar_cast_dictionary.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

// Implementation of casting to dictionary type

#include <arrow/util/bitmap_ops.h>
#include <arrow/util/checked_cast.h>

#include "arrow/array/builder_primitive.h"
#include "arrow/compute/cast_internal.h"
#include "arrow/compute/kernels/scalar_cast_internal.h"
#include "arrow/compute/kernels/util_internal.h"
#include "arrow/util/int_util.h"

namespace arrow {
using internal::CopyBitmap;

namespace compute {
namespace internal {

Status CastDictionary(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
const CastOptions& options = CastState::Get(ctx);
auto out_type = std::static_pointer_cast<DictionaryType>(out->type());

// if out type is same as in type, return input
if (out_type->Equals(batch[0].type())) {
*out = batch[0];
return Status::OK();
}

if (batch[0].is_scalar()) { // if input is scalar
auto in_scalar = checked_cast<const DictionaryScalar&>(*batch[0].scalar());

// if invalid scalar, return null scalar
if (!in_scalar.is_valid) {
*out = MakeNullScalar(out_type);
return Status::OK();
}

Datum casted_index, casted_dict;
if (in_scalar.value.index->type->Equals(out_type->index_type())) {
casted_index = in_scalar.value.index;
} else {
ARROW_ASSIGN_OR_RAISE(casted_index,
Cast(in_scalar.value.index, out_type->index_type(), options,
ctx->exec_context()));
}

if (in_scalar.value.dictionary->type()->Equals(out_type->value_type())) {
casted_dict = in_scalar.value.dictionary;
} else {
ARROW_ASSIGN_OR_RAISE(
casted_dict, Cast(in_scalar.value.dictionary, out_type->value_type(), options,
ctx->exec_context()));
}

*out = std::static_pointer_cast<Scalar>(
DictionaryScalar::Make(casted_index.scalar(), casted_dict.make_array()));

return Status::OK();
}

// if input is array
const std::shared_ptr<ArrayData>& in_array = batch[0].array();
const auto& in_type = checked_cast<const DictionaryType&>(*in_array->type);

ArrayData* out_array = out->mutable_array();

if (in_type.index_type()->Equals(out_type->index_type())) {
out_array->buffers[0] = in_array->buffers[0];
out_array->buffers[1] = in_array->buffers[1];
out_array->null_count = in_array->GetNullCount();
out_array->offset = in_array->offset;
} else {
// for indices, create a dummy ArrayData with index_type()
const std::shared_ptr<ArrayData>& indices_arr =
ArrayData::Make(in_type.index_type(), in_array->length, in_array->buffers,
in_array->GetNullCount(), in_array->offset);
ARROW_ASSIGN_OR_RAISE(auto casted_indices, Cast(indices_arr, out_type->index_type(),
options, ctx->exec_context()));
out_array->buffers[0] = std::move(casted_indices.array()->buffers[0]);
out_array->buffers[1] = std::move(casted_indices.array()->buffers[1]);
}

// data (dict)
if (in_type.value_type()->Equals(out_type->value_type())) {
out_array->dictionary = in_array->dictionary;
} else {
const std::shared_ptr<Array>& dict_arr = MakeArray(in_array->dictionary);
ARROW_ASSIGN_OR_RAISE(auto casted_data, Cast(dict_arr, out_type->value_type(),
options, ctx->exec_context()));
out_array->dictionary = casted_data.array();
}
return Status::OK();
}

std::vector<std::shared_ptr<CastFunction>> GetDictionaryCasts() {
auto func = std::make_shared<CastFunction>("cast_dictionary", Type::DICTIONARY);

AddCommonCasts(Type::DICTIONARY, kOutputTargetType, func.get());
ScalarKernel kernel({InputType(Type::DICTIONARY)}, kOutputTargetType, CastDictionary);
kernel.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE;
kernel.mem_allocation = MemAllocation::NO_PREALLOCATE;

DCHECK_OK(func->AddKernel(Type::DICTIONARY, std::move(kernel)));

return {func};
}

} // namespace internal
} // namespace compute
} // namespace arrow
34 changes: 34 additions & 0 deletions cpp/src/arrow/compute/kernels/scalar_cast_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1911,5 +1911,39 @@ TEST(Cast, ExtensionTypeToIntDowncast) {
}
}

TEST(Cast, DictTypeToAnotherDict) {
auto check_cast = [&](const std::shared_ptr<DataType>& in_type,
const std::shared_ptr<DataType>& out_type,
const std::string& json_str,
const CastOptions& options = CastOptions()) {
auto arr = ArrayFromJSON(in_type, json_str);
auto exp = in_type->Equals(out_type) ? arr : ArrayFromJSON(out_type, json_str);
// this checks for scalars as well
CheckCast(arr, exp, options);
};

// check same type passed on to casting
check_cast(dictionary(int8(), int16()), dictionary(int8(), int16()),
"[1, 2, 3, 1, null, 3]");
check_cast(dictionary(int8(), int16()), dictionary(int32(), int64()),
"[1, 2, 3, 1, null, 3]");
check_cast(dictionary(int8(), int16()), dictionary(int32(), float64()),
"[1, 2, 3, 1, null, 3]");
check_cast(dictionary(int32(), utf8()), dictionary(int8(), utf8()),
R"(["a", "b", "a", null])");

auto arr = ArrayFromJSON(dictionary(int32(), int32()), "[1, 1000]");
// check casting unsafe values (checking for unsafe indices is unnecessary, because it
// would create an invalid index array which results in a ValidateOutput failure)
ASSERT_OK_AND_ASSIGN(auto casted,
Cast(arr, dictionary(int8(), int8()), CastOptions::Unsafe()));
ValidateOutput(casted);

// check safe casting values
EXPECT_RAISES_WITH_MESSAGE_THAT(
Invalid, testing::HasSubstr("Integer value 1000 not in range"),
Cast(arr, dictionary(int8(), int8()), CastOptions::Safe()));
}

} // namespace compute
} // namespace arrow
24 changes: 19 additions & 5 deletions python/pyarrow/tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1524,12 +1524,26 @@ def test_cast_string_to_number_roundtrip():


def test_cast_dictionary():
arr = pa.DictionaryArray.from_arrays(
pa.array([0, 1, None], type=pa.int32()),
pa.array(["foo", "bar"]))
assert arr.cast(pa.string()).equals(pa.array(["foo", "bar", None]))
# cast to the value type
arr = pa.array(
["foo", "bar", None],
type=pa.dictionary(pa.int64(), pa.string())
)
expected = pa.array(["foo", "bar", None])
assert arr.type == pa.dictionary(pa.int64(), pa.string())
assert arr.cast(pa.string()) == expected

# cast to a different key type
for key_type in [pa.int8(), pa.int16(), pa.int32()]:
typ = pa.dictionary(key_type, pa.string())
expected = pa.array(
["foo", "bar", None],
type=pa.dictionary(key_type, pa.string())
)
assert arr.cast(typ) == expected

# shouldn't crash (ARROW-7077)
with pytest.raises(pa.ArrowInvalid):
# Shouldn't crash (ARROW-7077)
arr.cast(pa.int32())


Expand Down