diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index 88a92d8c2c9..d2f80ce7213 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -385,6 +385,7 @@ if(ARROW_COMPUTE) compute/kernels/scalar_arithmetic.cc compute/kernels/scalar_boolean.cc compute/kernels/scalar_cast_boolean.cc + compute/kernels/scalar_cast_dictionary.cc compute/kernels/scalar_cast_internal.cc compute/kernels/scalar_cast_nested.cc compute/kernels/scalar_cast_numeric.cc diff --git a/cpp/src/arrow/compute/cast.cc b/cpp/src/arrow/compute/cast.cc index 521f217213d..4de68ba8d90 100644 --- a/cpp/src/arrow/compute/cast.cc +++ b/cpp/src/arrow/compute/cast.cc @@ -61,6 +61,7 @@ void InitCastTable() { AddCastFunctions(GetNestedCasts()); AddCastFunctions(GetNumericCasts()); AddCastFunctions(GetTemporalCasts()); + AddCastFunctions(GetDictionaryCasts()); } void EnsureInitCastTable() { std::call_once(cast_table_initialized, InitCastTable); } diff --git a/cpp/src/arrow/compute/cast_internal.h b/cpp/src/arrow/compute/cast_internal.h index c152d10bd86..0105d08a573 100644 --- a/cpp/src/arrow/compute/cast_internal.h +++ b/cpp/src/arrow/compute/cast_internal.h @@ -36,6 +36,7 @@ std::vector> GetNumericCasts(); std::vector> GetTemporalCasts(); std::vector> GetBinaryLikeCasts(); std::vector> GetNestedCasts(); +std::vector> GetDictionaryCasts(); } // namespace internal } // namespace compute diff --git a/cpp/src/arrow/compute/exec/expression_test.cc b/cpp/src/arrow/compute/exec/expression_test.cc index 86909f4eb64..b59f8762818 100644 --- a/cpp/src/arrow/compute/exec/expression_test.cc +++ b/cpp/src/arrow/compute/exec/expression_test.cc @@ -939,17 +939,16 @@ TEST(Expression, ReplaceFieldsWithKnownValues) { ExpectReplacesTo(is_valid(field_ref("str")), i32_valid_str_null, is_valid(null_literal(utf8()))); - ASSERT_OK_AND_ASSIGN(auto expr, field_ref("dict_str").Bind(*kBoringSchema)); Datum dict_i32{ DictionaryScalar::Make(MakeScalar(0), ArrayFromJSON(int32(), R"([3])"))}; - // Unsupported cast dictionary(int32(), int32()) -> dictionary(int32(), utf8()) - ASSERT_RAISES(NotImplemented, ReplaceFieldsWithKnownValues( - KnownFieldValues{{{"dict_str", dict_i32}}}, expr)); - // Unsupported cast dictionary(int8(), utf8()) -> dictionary(int32(), utf8()) - dict_str = Datum{ - DictionaryScalar::Make(MakeScalar(0), ArrayFromJSON(utf8(), R"(["a"])"))}; - ASSERT_RAISES(NotImplemented, ReplaceFieldsWithKnownValues( - KnownFieldValues{{{"dict_str", dict_str}}}, expr)); + // cast dictionary(int32(), int32()) -> dictionary(int32(), utf8()) + ExpectReplacesTo(field_ref("dict_str"), {{"dict_str", dict_i32}}, literal(dict_str)); + + // cast dictionary(int8(), utf8()) -> dictionary(int32(), utf8()) + auto dict_int8_str = Datum{ + DictionaryScalar::Make(MakeScalar(0), ArrayFromJSON(utf8(), R"(["3"])"))}; + ExpectReplacesTo(field_ref("dict_str"), {{"dict_str", dict_int8_str}}, + literal(dict_str)); } struct { diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_dictionary.cc b/cpp/src/arrow/compute/kernels/scalar_cast_dictionary.cc new file mode 100644 index 00000000000..b1e1164fd34 --- /dev/null +++ b/cpp/src/arrow/compute/kernels/scalar_cast_dictionary.cc @@ -0,0 +1,126 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Implementation of casting to dictionary type + +#include +#include + +#include "arrow/array/builder_primitive.h" +#include "arrow/compute/cast_internal.h" +#include "arrow/compute/kernels/scalar_cast_internal.h" +#include "arrow/compute/kernels/util_internal.h" +#include "arrow/util/int_util.h" + +namespace arrow { +using internal::CopyBitmap; + +namespace compute { +namespace internal { + +Status CastDictionary(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + const CastOptions& options = CastState::Get(ctx); + auto out_type = std::static_pointer_cast(out->type()); + + // if out type is same as in type, return input + if (out_type->Equals(batch[0].type())) { + *out = batch[0]; + return Status::OK(); + } + + if (batch[0].is_scalar()) { // if input is scalar + auto in_scalar = checked_cast(*batch[0].scalar()); + + // if invalid scalar, return null scalar + if (!in_scalar.is_valid) { + *out = MakeNullScalar(out_type); + return Status::OK(); + } + + Datum casted_index, casted_dict; + if (in_scalar.value.index->type->Equals(out_type->index_type())) { + casted_index = in_scalar.value.index; + } else { + ARROW_ASSIGN_OR_RAISE(casted_index, + Cast(in_scalar.value.index, out_type->index_type(), options, + ctx->exec_context())); + } + + if (in_scalar.value.dictionary->type()->Equals(out_type->value_type())) { + casted_dict = in_scalar.value.dictionary; + } else { + ARROW_ASSIGN_OR_RAISE( + casted_dict, Cast(in_scalar.value.dictionary, out_type->value_type(), options, + ctx->exec_context())); + } + + *out = std::static_pointer_cast( + DictionaryScalar::Make(casted_index.scalar(), casted_dict.make_array())); + + return Status::OK(); + } + + // if input is array + const std::shared_ptr& in_array = batch[0].array(); + const auto& in_type = checked_cast(*in_array->type); + + ArrayData* out_array = out->mutable_array(); + + if (in_type.index_type()->Equals(out_type->index_type())) { + out_array->buffers[0] = in_array->buffers[0]; + out_array->buffers[1] = in_array->buffers[1]; + out_array->null_count = in_array->GetNullCount(); + out_array->offset = in_array->offset; + } else { + // for indices, create a dummy ArrayData with index_type() + const std::shared_ptr& indices_arr = + ArrayData::Make(in_type.index_type(), in_array->length, in_array->buffers, + in_array->GetNullCount(), in_array->offset); + ARROW_ASSIGN_OR_RAISE(auto casted_indices, Cast(indices_arr, out_type->index_type(), + options, ctx->exec_context())); + out_array->buffers[0] = std::move(casted_indices.array()->buffers[0]); + out_array->buffers[1] = std::move(casted_indices.array()->buffers[1]); + } + + // data (dict) + if (in_type.value_type()->Equals(out_type->value_type())) { + out_array->dictionary = in_array->dictionary; + } else { + const std::shared_ptr& dict_arr = MakeArray(in_array->dictionary); + ARROW_ASSIGN_OR_RAISE(auto casted_data, Cast(dict_arr, out_type->value_type(), + options, ctx->exec_context())); + out_array->dictionary = casted_data.array(); + } + return Status::OK(); +} + +std::vector> GetDictionaryCasts() { + auto func = std::make_shared("cast_dictionary", Type::DICTIONARY); + + AddCommonCasts(Type::DICTIONARY, kOutputTargetType, func.get()); + ScalarKernel kernel({InputType(Type::DICTIONARY)}, kOutputTargetType, CastDictionary); + kernel.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE; + kernel.mem_allocation = MemAllocation::NO_PREALLOCATE; + + DCHECK_OK(func->AddKernel(Type::DICTIONARY, std::move(kernel))); + + return {func}; +} + +} // namespace internal +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_test.cc b/cpp/src/arrow/compute/kernels/scalar_cast_test.cc index 494b15dfbc8..fce8518dd3b 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_test.cc @@ -1911,5 +1911,39 @@ TEST(Cast, ExtensionTypeToIntDowncast) { } } +TEST(Cast, DictTypeToAnotherDict) { + auto check_cast = [&](const std::shared_ptr& in_type, + const std::shared_ptr& out_type, + const std::string& json_str, + const CastOptions& options = CastOptions()) { + auto arr = ArrayFromJSON(in_type, json_str); + auto exp = in_type->Equals(out_type) ? arr : ArrayFromJSON(out_type, json_str); + // this checks for scalars as well + CheckCast(arr, exp, options); + }; + + // check same type passed on to casting + check_cast(dictionary(int8(), int16()), dictionary(int8(), int16()), + "[1, 2, 3, 1, null, 3]"); + check_cast(dictionary(int8(), int16()), dictionary(int32(), int64()), + "[1, 2, 3, 1, null, 3]"); + check_cast(dictionary(int8(), int16()), dictionary(int32(), float64()), + "[1, 2, 3, 1, null, 3]"); + check_cast(dictionary(int32(), utf8()), dictionary(int8(), utf8()), + R"(["a", "b", "a", null])"); + + auto arr = ArrayFromJSON(dictionary(int32(), int32()), "[1, 1000]"); + // check casting unsafe values (checking for unsafe indices is unnecessary, because it + // would create an invalid index array which results in a ValidateOutput failure) + ASSERT_OK_AND_ASSIGN(auto casted, + Cast(arr, dictionary(int8(), int8()), CastOptions::Unsafe())); + ValidateOutput(casted); + + // check safe casting values + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, testing::HasSubstr("Integer value 1000 not in range"), + Cast(arr, dictionary(int8(), int8()), CastOptions::Safe())); +} + } // namespace compute } // namespace arrow diff --git a/python/pyarrow/tests/test_array.py b/python/pyarrow/tests/test_array.py index 57224ef4ebe..0f137383378 100644 --- a/python/pyarrow/tests/test_array.py +++ b/python/pyarrow/tests/test_array.py @@ -1524,12 +1524,26 @@ def test_cast_string_to_number_roundtrip(): def test_cast_dictionary(): - arr = pa.DictionaryArray.from_arrays( - pa.array([0, 1, None], type=pa.int32()), - pa.array(["foo", "bar"])) - assert arr.cast(pa.string()).equals(pa.array(["foo", "bar", None])) + # cast to the value type + arr = pa.array( + ["foo", "bar", None], + type=pa.dictionary(pa.int64(), pa.string()) + ) + expected = pa.array(["foo", "bar", None]) + assert arr.type == pa.dictionary(pa.int64(), pa.string()) + assert arr.cast(pa.string()) == expected + + # cast to a different key type + for key_type in [pa.int8(), pa.int16(), pa.int32()]: + typ = pa.dictionary(key_type, pa.string()) + expected = pa.array( + ["foo", "bar", None], + type=pa.dictionary(key_type, pa.string()) + ) + assert arr.cast(typ) == expected + + # shouldn't crash (ARROW-7077) with pytest.raises(pa.ArrowInvalid): - # Shouldn't crash (ARROW-7077) arr.cast(pa.int32())