From 523c3b8f3b3a72ff9989075252dfa3b12d35cfbb Mon Sep 17 00:00:00 2001 From: niranda perera Date: Wed, 14 Jul 2021 17:28:23 -0400 Subject: [PATCH 01/14] add initial impl --- cpp/src/arrow/CMakeLists.txt | 3 +- cpp/src/arrow/compute/cast.cc | 1 + cpp/src/arrow/compute/cast_internal.h | 1 + .../compute/kernels/scalar_cast_dictionary.cc | 89 +++++++++++++++++++ .../arrow/compute/kernels/scalar_cast_test.cc | 11 +++ 5 files changed, 104 insertions(+), 1 deletion(-) create mode 100644 cpp/src/arrow/compute/kernels/scalar_cast_dictionary.cc diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index 88a92d8c2c9..d8650734d9d 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -385,7 +385,8 @@ if(ARROW_COMPUTE) compute/kernels/scalar_arithmetic.cc compute/kernels/scalar_boolean.cc compute/kernels/scalar_cast_boolean.cc - compute/kernels/scalar_cast_internal.cc + compute/kernels/scalar_cast_dictionary.cc + compute/kernels/scalar_cast_internal.cc compute/kernels/scalar_cast_nested.cc compute/kernels/scalar_cast_numeric.cc compute/kernels/scalar_cast_string.cc diff --git a/cpp/src/arrow/compute/cast.cc b/cpp/src/arrow/compute/cast.cc index 521f217213d..4de68ba8d90 100644 --- a/cpp/src/arrow/compute/cast.cc +++ b/cpp/src/arrow/compute/cast.cc @@ -61,6 +61,7 @@ void InitCastTable() { AddCastFunctions(GetNestedCasts()); AddCastFunctions(GetNumericCasts()); AddCastFunctions(GetTemporalCasts()); + AddCastFunctions(GetDictionaryCasts()); } void EnsureInitCastTable() { std::call_once(cast_table_initialized, InitCastTable); } diff --git a/cpp/src/arrow/compute/cast_internal.h b/cpp/src/arrow/compute/cast_internal.h index c152d10bd86..0105d08a573 100644 --- a/cpp/src/arrow/compute/cast_internal.h +++ b/cpp/src/arrow/compute/cast_internal.h @@ -36,6 +36,7 @@ std::vector> GetNumericCasts(); std::vector> GetTemporalCasts(); std::vector> GetBinaryLikeCasts(); std::vector> GetNestedCasts(); +std::vector> GetDictionaryCasts(); } // namespace internal } // namespace compute diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_dictionary.cc b/cpp/src/arrow/compute/kernels/scalar_cast_dictionary.cc new file mode 100644 index 00000000000..db6997d6e8d --- /dev/null +++ b/cpp/src/arrow/compute/kernels/scalar_cast_dictionary.cc @@ -0,0 +1,89 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Implementation of casting to dictionary type + +#include + +#include "arrow/array/builder_primitive.h" +#include "arrow/compute/cast_internal.h" +#include "arrow/compute/kernels/scalar_cast_internal.h" +#include "arrow/compute/kernels/util_internal.h" +#include "arrow/util/int_util.h" + +namespace arrow { +using internal::CopyBitmap; + +namespace compute { +namespace internal { + +Status CastDictionary(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + const CastOptions& options = CastState::Get(ctx); + auto out_type = std::static_pointer_cast(out->type()); + + if (batch[0].is_scalar()) { + auto in_scalar = std::static_pointer_cast(batch[0].scalar()); + + ARROW_ASSIGN_OR_RAISE(auto casted_index, + Cast(in_scalar->value.index, out_type->index_type(), options, + ctx->exec_context())); + ARROW_ASSIGN_OR_RAISE(auto casted_dict, + Cast(in_scalar->value.dictionary, out_type->value_type(), + options, ctx->exec_context())); + + *out = std::static_pointer_cast( + DictionaryScalar::Make(casted_index.scalar(), casted_dict.make_array())); + } else { + const std::shared_ptr& in_array = batch[0].array(); + auto in_type = std::static_pointer_cast(in_array->type); + + ArrayData* out_array = out->mutable_array(); + + // for indices, create a dummy ArrayData with index_type() + const std::shared_ptr& indices_arr = + ArrayData::Make(in_type->index_type(), in_array->length, in_array->buffers, + in_array->GetNullCount(), in_array->offset); + ARROW_ASSIGN_OR_RAISE(auto casted_indices, Cast(indices_arr, out_type->index_type(), + options, ctx->exec_context())); + out_array->buffers[0] = std::move(casted_indices.array()->buffers[0]); + out_array->buffers[1] = std::move(casted_indices.array()->buffers[1]); + + // data + const std::shared_ptr& dict_arr = MakeArray(in_array->dictionary); + ARROW_ASSIGN_OR_RAISE(auto casted_data, Cast(dict_arr, out_type->value_type(), + options, ctx->exec_context())); + out_array->dictionary = casted_data.array(); + } + return Status::OK(); +} + +std::vector> GetDictionaryCasts() { + auto func = std::make_shared("cast_dictionary", Type::DICTIONARY); + + AddCommonCasts(Type::DICTIONARY, kOutputTargetType, func.get()); + ScalarKernel kernel({InputType(Type::DICTIONARY)}, kOutputTargetType, CastDictionary); + kernel.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE; + kernel.mem_allocation = MemAllocation::NO_PREALLOCATE; + + DCHECK_OK(func->AddKernel(Type::DICTIONARY, std::move(kernel))); + + return {func}; +} + +} // namespace internal +} // namespace compute +} // namespace arrow \ No newline at end of file diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_test.cc b/cpp/src/arrow/compute/kernels/scalar_cast_test.cc index 494b15dfbc8..fea9e41f687 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_test.cc @@ -1911,5 +1911,16 @@ TEST(Cast, ExtensionTypeToIntDowncast) { } } +TEST(Cast, DictTypeToAnotherDict) { + const std::shared_ptr& arr = + ArrayFromJSON(dictionary(int8(), int8()), "[1, 2, 3, 1, null, 3]"); + + ASSERT_OK_AND_ASSIGN(auto casted, Cast(arr, dictionary(int32(), int32()))); + + std::cout << "input: " << arr->ToString() << std::endl; + std::cout << "res: " << casted.make_array()->ToString() << std::endl; + ValidateOutput(casted); +} + } // namespace compute } // namespace arrow From d450567157a0a939b7a9f4c8633b04954ec76701 Mon Sep 17 00:00:00 2001 From: niranda perera Date: Wed, 14 Jul 2021 17:35:11 -0400 Subject: [PATCH 02/14] fixing non-int index_type --- cpp/src/arrow/compute/kernels/scalar_cast_dictionary.cc | 4 ++++ cpp/src/arrow/compute/kernels/scalar_cast_test.cc | 4 ++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_dictionary.cc b/cpp/src/arrow/compute/kernels/scalar_cast_dictionary.cc index db6997d6e8d..f5ba6785bd9 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_dictionary.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_dictionary.cc @@ -35,6 +35,10 @@ Status CastDictionary(KernelContext* ctx, const ExecBatch& batch, Datum* out) { const CastOptions& options = CastState::Get(ctx); auto out_type = std::static_pointer_cast(out->type()); + if (is_integer(out_type->index_type()->id())){ + return Status::Invalid("non-integer type used for DictionaryType::index_type"); + } + if (batch[0].is_scalar()) { auto in_scalar = std::static_pointer_cast(batch[0].scalar()); diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_test.cc b/cpp/src/arrow/compute/kernels/scalar_cast_test.cc index fea9e41f687..73b282f9c1a 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_test.cc @@ -1917,8 +1917,8 @@ TEST(Cast, DictTypeToAnotherDict) { ASSERT_OK_AND_ASSIGN(auto casted, Cast(arr, dictionary(int32(), int32()))); - std::cout << "input: " << arr->ToString() << std::endl; - std::cout << "res: " << casted.make_array()->ToString() << std::endl; +// std::cout << "input: " << arr->ToString() << std::endl; +// std::cout << "res: " << casted.make_array()->ToString() << std::endl; ValidateOutput(casted); } From 35ab8e8279939bc65075036f117afb9b1e0d37a8 Mon Sep 17 00:00:00 2001 From: niranda perera Date: Wed, 14 Jul 2021 18:01:24 -0400 Subject: [PATCH 03/14] cast only when types differ --- .../compute/kernels/scalar_cast_dictionary.cc | 54 ++++++++++++++----- .../arrow/compute/kernels/scalar_cast_test.cc | 4 +- 2 files changed, 43 insertions(+), 15 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_dictionary.cc b/cpp/src/arrow/compute/kernels/scalar_cast_dictionary.cc index f5ba6785bd9..d68e97711df 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_dictionary.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_dictionary.cc @@ -35,28 +35,52 @@ Status CastDictionary(KernelContext* ctx, const ExecBatch& batch, Datum* out) { const CastOptions& options = CastState::Get(ctx); auto out_type = std::static_pointer_cast(out->type()); - if (is_integer(out_type->index_type()->id())){ + if (!is_integer(out_type->index_type()->id())) { return Status::Invalid("non-integer type used for DictionaryType::index_type"); } - if (batch[0].is_scalar()) { + // if out type is same as in type, return input + if (out_type->Equals(batch[0].type())) { + *out = batch[0]; + return Status::OK(); + } + + if (batch[0].is_scalar()) { // if input is scalar auto in_scalar = std::static_pointer_cast(batch[0].scalar()); - ARROW_ASSIGN_OR_RAISE(auto casted_index, - Cast(in_scalar->value.index, out_type->index_type(), options, - ctx->exec_context())); - ARROW_ASSIGN_OR_RAISE(auto casted_dict, - Cast(in_scalar->value.dictionary, out_type->value_type(), - options, ctx->exec_context())); + Datum casted_index, casted_dict; + if (in_scalar->value.index->type->Equals(out_type->index_type())) { + casted_index = in_scalar->value.index; + } else { + ARROW_ASSIGN_OR_RAISE(casted_index, + Cast(in_scalar->value.index, out_type->index_type(), options, + ctx->exec_context())); + } + + if (in_scalar->value.dictionary->type()->Equals(out_type->value_type())) { + casted_dict = in_scalar->value.dictionary; + } else { + ARROW_ASSIGN_OR_RAISE( + casted_dict, Cast(in_scalar->value.dictionary, out_type->value_type(), options, + ctx->exec_context())); + } *out = std::static_pointer_cast( DictionaryScalar::Make(casted_index.scalar(), casted_dict.make_array())); - } else { - const std::shared_ptr& in_array = batch[0].array(); - auto in_type = std::static_pointer_cast(in_array->type); - ArrayData* out_array = out->mutable_array(); + return Status::OK(); + } + + // if input is array + const std::shared_ptr& in_array = batch[0].array(); + auto in_type = std::static_pointer_cast(in_array->type); + ArrayData* out_array = out->mutable_array(); + + if (in_type->index_type()->Equals(out_type->index_type())) { + out_array->buffers[0] = in_array->buffers[0]; + out_array->buffers[1] = in_array->buffers[1]; + } else { // for indices, create a dummy ArrayData with index_type() const std::shared_ptr& indices_arr = ArrayData::Make(in_type->index_type(), in_array->length, in_array->buffers, @@ -65,8 +89,12 @@ Status CastDictionary(KernelContext* ctx, const ExecBatch& batch, Datum* out) { options, ctx->exec_context())); out_array->buffers[0] = std::move(casted_indices.array()->buffers[0]); out_array->buffers[1] = std::move(casted_indices.array()->buffers[1]); + } - // data + // data (dict) + if (in_type->index_type()->Equals(out_type->index_type())) { + out_array->dictionary = in_array->dictionary; + } else { const std::shared_ptr& dict_arr = MakeArray(in_array->dictionary); ARROW_ASSIGN_OR_RAISE(auto casted_data, Cast(dict_arr, out_type->value_type(), options, ctx->exec_context())); diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_test.cc b/cpp/src/arrow/compute/kernels/scalar_cast_test.cc index 73b282f9c1a..2593bd7a358 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_test.cc @@ -1913,9 +1913,9 @@ TEST(Cast, ExtensionTypeToIntDowncast) { TEST(Cast, DictTypeToAnotherDict) { const std::shared_ptr& arr = - ArrayFromJSON(dictionary(int8(), int8()), "[1, 2, 3, 1, null, 3]"); + ArrayFromJSON(dictionary(int8(), int16()), "[1, 2, 3, 1, null, 3]"); - ASSERT_OK_AND_ASSIGN(auto casted, Cast(arr, dictionary(int32(), int32()))); + ASSERT_OK_AND_ASSIGN(auto casted, Cast(arr, dictionary(int32(), int64()))); // std::cout << "input: " << arr->ToString() << std::endl; // std::cout << "res: " << casted.make_array()->ToString() << std::endl; From 110e0d8b5b016bd609b1b3dd23d3513aec7f4e7c Mon Sep 17 00:00:00 2001 From: niranda perera Date: Wed, 14 Jul 2021 18:53:56 -0400 Subject: [PATCH 04/14] extending test cases --- .../compute/kernels/scalar_cast_dictionary.cc | 4 --- .../arrow/compute/kernels/scalar_cast_test.cc | 32 +++++++++++++++---- 2 files changed, 26 insertions(+), 10 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_dictionary.cc b/cpp/src/arrow/compute/kernels/scalar_cast_dictionary.cc index d68e97711df..55c7bb0a3f4 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_dictionary.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_dictionary.cc @@ -35,10 +35,6 @@ Status CastDictionary(KernelContext* ctx, const ExecBatch& batch, Datum* out) { const CastOptions& options = CastState::Get(ctx); auto out_type = std::static_pointer_cast(out->type()); - if (!is_integer(out_type->index_type()->id())) { - return Status::Invalid("non-integer type used for DictionaryType::index_type"); - } - // if out type is same as in type, return input if (out_type->Equals(batch[0].type())) { *out = batch[0]; diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_test.cc b/cpp/src/arrow/compute/kernels/scalar_cast_test.cc index 2593bd7a358..9c3c53e5691 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_test.cc @@ -1912,14 +1912,34 @@ TEST(Cast, ExtensionTypeToIntDowncast) { } TEST(Cast, DictTypeToAnotherDict) { - const std::shared_ptr& arr = - ArrayFromJSON(dictionary(int8(), int16()), "[1, 2, 3, 1, null, 3]"); - - ASSERT_OK_AND_ASSIGN(auto casted, Cast(arr, dictionary(int32(), int64()))); + auto check_cast = [&](const std::shared_ptr& in_type, + const std::shared_ptr& out_type, + const std::string& json_str, + const CastOptions& options = CastOptions()) { + auto arr = ArrayFromJSON(in_type, json_str); + auto exp = in_type->Equals(out_type) ? arr : ArrayFromJSON(out_type, json_str); + ASSERT_OK_AND_ASSIGN(auto casted, Cast(arr, out_type, options)); + ValidateOutput(casted); + AssertArraysEqual(*exp, *casted.make_array(), /*verbose=*/true); + }; -// std::cout << "input: " << arr->ToString() << std::endl; -// std::cout << "res: " << casted.make_array()->ToString() << std::endl; + // check same type passed on to casting + check_cast(dictionary(int8(), int16()), dictionary(int8(), int16()), + "[1, 2, 3, 1, null, 3]"); + check_cast(dictionary(int8(), int16()), dictionary(int32(), int64()), + "[1, 2, 3, 1, null, 3]"); + check_cast(dictionary(int32(), utf8()), dictionary(int8(), utf8()), + R"(["a", "b", "a", null])"); + + auto arr = ArrayFromJSON(dictionary(int32(), int32()), "[1, 1000]"); + // check unsafe + ASSERT_OK_AND_ASSIGN(auto casted, + Cast(arr, dictionary(int8(), int8()), CastOptions::Unsafe())); ValidateOutput(casted); + // check safe + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, testing::HasSubstr("Integer value 1000 not in range"), + Cast(arr, dictionary(int8(), int8()), CastOptions::Safe())); } } // namespace compute From 0a6255c760038385732fd6c0d2dd14bdd7a26dc3 Mon Sep 17 00:00:00 2001 From: niranda perera Date: Wed, 14 Jul 2021 19:26:42 -0400 Subject: [PATCH 05/14] fixing null scalar issue --- cpp/src/arrow/compute/kernels/scalar_cast_dictionary.cc | 6 ++++++ cpp/src/arrow/compute/kernels/scalar_cast_test.cc | 7 +++---- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_dictionary.cc b/cpp/src/arrow/compute/kernels/scalar_cast_dictionary.cc index 55c7bb0a3f4..38a9536b3f8 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_dictionary.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_dictionary.cc @@ -44,6 +44,12 @@ Status CastDictionary(KernelContext* ctx, const ExecBatch& batch, Datum* out) { if (batch[0].is_scalar()) { // if input is scalar auto in_scalar = std::static_pointer_cast(batch[0].scalar()); + // if invalid scalar, return null scalar + if (!in_scalar->is_valid) { + *out = MakeNullScalar(out_type); + return Status::OK(); + } + Datum casted_index, casted_dict; if (in_scalar->value.index->type->Equals(out_type->index_type())) { casted_index = in_scalar->value.index; diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_test.cc b/cpp/src/arrow/compute/kernels/scalar_cast_test.cc index 9c3c53e5691..c93d347be2c 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_test.cc @@ -1918,12 +1918,11 @@ TEST(Cast, DictTypeToAnotherDict) { const CastOptions& options = CastOptions()) { auto arr = ArrayFromJSON(in_type, json_str); auto exp = in_type->Equals(out_type) ? arr : ArrayFromJSON(out_type, json_str); - ASSERT_OK_AND_ASSIGN(auto casted, Cast(arr, out_type, options)); - ValidateOutput(casted); - AssertArraysEqual(*exp, *casted.make_array(), /*verbose=*/true); + // this checks for scalars as well + CheckCast(arr, exp, options); }; - // check same type passed on to casting + // check same type passed on to casting check_cast(dictionary(int8(), int16()), dictionary(int8(), int16()), "[1, 2, 3, 1, null, 3]"); check_cast(dictionary(int8(), int16()), dictionary(int32(), int64()), From 37a9bfe8b136c4dd9b5d127ad5ae68ffe3e31cbe Mon Sep 17 00:00:00 2001 From: niranda perera Date: Thu, 15 Jul 2021 09:52:46 -0400 Subject: [PATCH 06/14] adding PR comments --- cpp/src/arrow/CMakeLists.txt | 4 +-- .../compute/kernels/scalar_cast_dictionary.cc | 27 ++++++++++--------- .../arrow/compute/kernels/scalar_cast_test.cc | 19 +++++++++++-- 3 files changed, 34 insertions(+), 16 deletions(-) diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index d8650734d9d..d2f80ce7213 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -385,8 +385,8 @@ if(ARROW_COMPUTE) compute/kernels/scalar_arithmetic.cc compute/kernels/scalar_boolean.cc compute/kernels/scalar_cast_boolean.cc - compute/kernels/scalar_cast_dictionary.cc - compute/kernels/scalar_cast_internal.cc + compute/kernels/scalar_cast_dictionary.cc + compute/kernels/scalar_cast_internal.cc compute/kernels/scalar_cast_nested.cc compute/kernels/scalar_cast_numeric.cc compute/kernels/scalar_cast_string.cc diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_dictionary.cc b/cpp/src/arrow/compute/kernels/scalar_cast_dictionary.cc index 38a9536b3f8..d1a4312062c 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_dictionary.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_dictionary.cc @@ -18,6 +18,7 @@ // Implementation of casting to dictionary type #include +#include #include "arrow/array/builder_primitive.h" #include "arrow/compute/cast_internal.h" @@ -42,28 +43,28 @@ Status CastDictionary(KernelContext* ctx, const ExecBatch& batch, Datum* out) { } if (batch[0].is_scalar()) { // if input is scalar - auto in_scalar = std::static_pointer_cast(batch[0].scalar()); + auto in_scalar = checked_cast(*batch[0].scalar()); // if invalid scalar, return null scalar - if (!in_scalar->is_valid) { + if (!in_scalar.is_valid) { *out = MakeNullScalar(out_type); return Status::OK(); } Datum casted_index, casted_dict; - if (in_scalar->value.index->type->Equals(out_type->index_type())) { - casted_index = in_scalar->value.index; + if (in_scalar.value.index->type->Equals(out_type->index_type())) { + casted_index = in_scalar.value.index; } else { ARROW_ASSIGN_OR_RAISE(casted_index, - Cast(in_scalar->value.index, out_type->index_type(), options, + Cast(in_scalar.value.index, out_type->index_type(), options, ctx->exec_context())); } - if (in_scalar->value.dictionary->type()->Equals(out_type->value_type())) { - casted_dict = in_scalar->value.dictionary; + if (in_scalar.value.dictionary->type()->Equals(out_type->value_type())) { + casted_dict = in_scalar.value.dictionary; } else { ARROW_ASSIGN_OR_RAISE( - casted_dict, Cast(in_scalar->value.dictionary, out_type->value_type(), options, + casted_dict, Cast(in_scalar.value.dictionary, out_type->value_type(), options, ctx->exec_context())); } @@ -75,17 +76,19 @@ Status CastDictionary(KernelContext* ctx, const ExecBatch& batch, Datum* out) { // if input is array const std::shared_ptr& in_array = batch[0].array(); - auto in_type = std::static_pointer_cast(in_array->type); + const auto& in_type = checked_cast(*in_array->type); ArrayData* out_array = out->mutable_array(); - if (in_type->index_type()->Equals(out_type->index_type())) { + if (in_type.index_type()->Equals(out_type->index_type())) { out_array->buffers[0] = in_array->buffers[0]; out_array->buffers[1] = in_array->buffers[1]; + out_array->null_count = in_array->GetNullCount(); + out_array->offset = in_array->offset; } else { // for indices, create a dummy ArrayData with index_type() const std::shared_ptr& indices_arr = - ArrayData::Make(in_type->index_type(), in_array->length, in_array->buffers, + ArrayData::Make(in_type.index_type(), in_array->length, in_array->buffers, in_array->GetNullCount(), in_array->offset); ARROW_ASSIGN_OR_RAISE(auto casted_indices, Cast(indices_arr, out_type->index_type(), options, ctx->exec_context())); @@ -94,7 +97,7 @@ Status CastDictionary(KernelContext* ctx, const ExecBatch& batch, Datum* out) { } // data (dict) - if (in_type->index_type()->Equals(out_type->index_type())) { + if (in_type.value_type()->Equals(out_type->value_type())) { out_array->dictionary = in_array->dictionary; } else { const std::shared_ptr& dict_arr = MakeArray(in_array->dictionary); diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_test.cc b/cpp/src/arrow/compute/kernels/scalar_cast_test.cc index c93d347be2c..8864e1d6590 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_test.cc @@ -1931,14 +1931,29 @@ TEST(Cast, DictTypeToAnotherDict) { R"(["a", "b", "a", null])"); auto arr = ArrayFromJSON(dictionary(int32(), int32()), "[1, 1000]"); - // check unsafe + // check unsafe values ASSERT_OK_AND_ASSIGN(auto casted, Cast(arr, dictionary(int8(), int8()), CastOptions::Unsafe())); ValidateOutput(casted); - // check safe + // check safe values EXPECT_RAISES_WITH_MESSAGE_THAT( Invalid, testing::HasSubstr("Integer value 1000 not in range"), Cast(arr, dictionary(int8(), int8()), CastOptions::Safe())); + + // check unsafe indices + random::RandomArrayGenerator rand(/*seed=*/0); + int64_t len = 1000; + auto val_arr = rand.ArrayOf(int32(), len, /*null_probability=*/0.01); + ASSERT_OK_AND_ASSIGN(auto arr2, DictionaryEncode(val_arr)); + // check unsafe indices. Cannot validate this array because ValidateOutput throws an + // out of bounds error + ASSERT_OK_AND_ASSIGN(auto casted2, Cast(arr2.make_array(), dictionary(int8(), int8()), + CastOptions::Unsafe())); + + // check safe indices + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, testing::HasSubstr("not in range"), + Cast(arr, dictionary(int8(), int8()), CastOptions::Safe())); } } // namespace compute From ee9ddb3a424ff06dcd1d1a8ed18cf6baffac9faa Mon Sep 17 00:00:00 2001 From: niranda perera Date: Thu, 15 Jul 2021 11:20:09 -0400 Subject: [PATCH 07/14] lint fix --- cpp/src/arrow/compute/kernels/scalar_cast_dictionary.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_dictionary.cc b/cpp/src/arrow/compute/kernels/scalar_cast_dictionary.cc index d1a4312062c..b1e1164fd34 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_dictionary.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_dictionary.cc @@ -123,4 +123,4 @@ std::vector> GetDictionaryCasts() { } // namespace internal } // namespace compute -} // namespace arrow \ No newline at end of file +} // namespace arrow From 4d1a26b3dac34fc99a8e46232b019e6d26e7c87d Mon Sep 17 00:00:00 2001 From: niranda perera Date: Thu, 15 Jul 2021 14:29:08 -0400 Subject: [PATCH 08/14] test fix --- cpp/src/arrow/compute/exec/expression_test.cc | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/cpp/src/arrow/compute/exec/expression_test.cc b/cpp/src/arrow/compute/exec/expression_test.cc index 86909f4eb64..5784a7cd98c 100644 --- a/cpp/src/arrow/compute/exec/expression_test.cc +++ b/cpp/src/arrow/compute/exec/expression_test.cc @@ -942,14 +942,14 @@ TEST(Expression, ReplaceFieldsWithKnownValues) { ASSERT_OK_AND_ASSIGN(auto expr, field_ref("dict_str").Bind(*kBoringSchema)); Datum dict_i32{ DictionaryScalar::Make(MakeScalar(0), ArrayFromJSON(int32(), R"([3])"))}; - // Unsupported cast dictionary(int32(), int32()) -> dictionary(int32(), utf8()) - ASSERT_RAISES(NotImplemented, ReplaceFieldsWithKnownValues( - KnownFieldValues{{{"dict_str", dict_i32}}}, expr)); - // Unsupported cast dictionary(int8(), utf8()) -> dictionary(int32(), utf8()) - dict_str = Datum{ - DictionaryScalar::Make(MakeScalar(0), ArrayFromJSON(utf8(), R"(["a"])"))}; - ASSERT_RAISES(NotImplemented, ReplaceFieldsWithKnownValues( - KnownFieldValues{{{"dict_str", dict_str}}}, expr)); + // cast dictionary(int32(), int32()) -> dictionary(int32(), utf8()) + ExpectReplacesTo(field_ref("dict_str"), {{"dict_str", dict_i32}}, literal(dict_str)); + + // cast dictionary(int8(), utf8()) -> dictionary(int32(), utf8()) + auto dict_int8_str = Datum{ + DictionaryScalar::Make(MakeScalar(0), ArrayFromJSON(utf8(), R"(["3"])"))}; + ExpectReplacesTo(field_ref("dict_str"), {{"dict_str", dict_int8_str}}, + literal(dict_str)); } struct { From 3f10dbaba896d8569d09afd2f49bd779d62b51da Mon Sep 17 00:00:00 2001 From: niranda perera Date: Thu, 15 Jul 2021 17:43:34 -0400 Subject: [PATCH 09/14] removing unused var --- cpp/src/arrow/compute/exec/expression_test.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/cpp/src/arrow/compute/exec/expression_test.cc b/cpp/src/arrow/compute/exec/expression_test.cc index 5784a7cd98c..b59f8762818 100644 --- a/cpp/src/arrow/compute/exec/expression_test.cc +++ b/cpp/src/arrow/compute/exec/expression_test.cc @@ -939,7 +939,6 @@ TEST(Expression, ReplaceFieldsWithKnownValues) { ExpectReplacesTo(is_valid(field_ref("str")), i32_valid_str_null, is_valid(null_literal(utf8()))); - ASSERT_OK_AND_ASSIGN(auto expr, field_ref("dict_str").Bind(*kBoringSchema)); Datum dict_i32{ DictionaryScalar::Make(MakeScalar(0), ArrayFromJSON(int32(), R"([3])"))}; // cast dictionary(int32(), int32()) -> dictionary(int32(), utf8()) From ec58508d7b80610b23791fb9edb407cbafec1953 Mon Sep 17 00:00:00 2001 From: niranda perera Date: Sat, 17 Jul 2021 23:03:33 -0400 Subject: [PATCH 10/14] adding float test case --- .../arrow/compute/kernels/scalar_cast_test.cc | 28 ++++++++----------- 1 file changed, 11 insertions(+), 17 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_test.cc b/cpp/src/arrow/compute/kernels/scalar_cast_test.cc index 8864e1d6590..1c1a413f53d 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_test.cc @@ -1930,29 +1930,23 @@ TEST(Cast, DictTypeToAnotherDict) { check_cast(dictionary(int32(), utf8()), dictionary(int8(), utf8()), R"(["a", "b", "a", null])"); + // check float types (NOTE: ArrayFromJSON doesnt work for float value dictionary types) + auto arr_int8_int16 = + ArrayFromJSON(dictionary(int8(), int16()), "[1, 2, 3, 1, null, 3]"); + auto arr_float64 = ArrayFromJSON(float64(), "[1, 2, 3, 1, null, 3]"); + ASSERT_OK_AND_ASSIGN(auto arr_int32_float64, DictionaryEncode(arr_float64)); + CheckCast(arr_int8_int16, arr_int32_float64.make_array(), CastOptions::Safe()); + auto arr = ArrayFromJSON(dictionary(int32(), int32()), "[1, 1000]"); - // check unsafe values + // check casting unsafe values (checking for unsafe indices is unnecessary, because it + // would create an invalid index array which results in a ValidateOutput failure) ASSERT_OK_AND_ASSIGN(auto casted, Cast(arr, dictionary(int8(), int8()), CastOptions::Unsafe())); ValidateOutput(casted); - // check safe values - EXPECT_RAISES_WITH_MESSAGE_THAT( - Invalid, testing::HasSubstr("Integer value 1000 not in range"), - Cast(arr, dictionary(int8(), int8()), CastOptions::Safe())); - // check unsafe indices - random::RandomArrayGenerator rand(/*seed=*/0); - int64_t len = 1000; - auto val_arr = rand.ArrayOf(int32(), len, /*null_probability=*/0.01); - ASSERT_OK_AND_ASSIGN(auto arr2, DictionaryEncode(val_arr)); - // check unsafe indices. Cannot validate this array because ValidateOutput throws an - // out of bounds error - ASSERT_OK_AND_ASSIGN(auto casted2, Cast(arr2.make_array(), dictionary(int8(), int8()), - CastOptions::Unsafe())); - - // check safe indices + // check safe casting values EXPECT_RAISES_WITH_MESSAGE_THAT( - Invalid, testing::HasSubstr("not in range"), + Invalid, testing::HasSubstr("Integer value 1000 not in range"), Cast(arr, dictionary(int8(), int8()), CastOptions::Safe())); } From 45c26a650700372ed3a72aefcecfb8a7a53e7340 Mon Sep 17 00:00:00 2001 From: niranda perera Date: Mon, 19 Jul 2021 10:38:43 -0400 Subject: [PATCH 11/14] updating with JIRA --- cpp/src/arrow/compute/kernels/scalar_cast_test.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_test.cc b/cpp/src/arrow/compute/kernels/scalar_cast_test.cc index 1c1a413f53d..fc79338367c 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_test.cc @@ -1930,7 +1930,7 @@ TEST(Cast, DictTypeToAnotherDict) { check_cast(dictionary(int32(), utf8()), dictionary(int8(), utf8()), R"(["a", "b", "a", null])"); - // check float types (NOTE: ArrayFromJSON doesnt work for float value dictionary types) + // check float types (TODO: ARROW-13381 ArrayFromJSON doesnt work for float value dictionary types) auto arr_int8_int16 = ArrayFromJSON(dictionary(int8(), int16()), "[1, 2, 3, 1, null, 3]"); auto arr_float64 = ArrayFromJSON(float64(), "[1, 2, 3, 1, null, 3]"); From 93c018df30a88e0bf7599dad9a8a90282e875d13 Mon Sep 17 00:00:00 2001 From: niranda perera Date: Mon, 19 Jul 2021 11:55:04 -0400 Subject: [PATCH 12/14] lint fix --- cpp/src/arrow/compute/kernels/scalar_cast_test.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_test.cc b/cpp/src/arrow/compute/kernels/scalar_cast_test.cc index fc79338367c..9ad08150896 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_test.cc @@ -1930,7 +1930,8 @@ TEST(Cast, DictTypeToAnotherDict) { check_cast(dictionary(int32(), utf8()), dictionary(int8(), utf8()), R"(["a", "b", "a", null])"); - // check float types (TODO: ARROW-13381 ArrayFromJSON doesnt work for float value dictionary types) + // check float types + // TODO(ARROW-13381): ArrayFromJSON doesnt work for float value dictionary types auto arr_int8_int16 = ArrayFromJSON(dictionary(int8(), int16()), "[1, 2, 3, 1, null, 3]"); auto arr_float64 = ArrayFromJSON(float64(), "[1, 2, 3, 1, null, 3]"); From 26ffdb3558841c23a65b90ea076a7304ec9cb5bd Mon Sep 17 00:00:00 2001 From: niranda perera Date: Mon, 19 Jul 2021 17:22:18 -0400 Subject: [PATCH 13/14] rebasing and removing todo --- cpp/src/arrow/compute/kernels/scalar_cast_test.cc | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_test.cc b/cpp/src/arrow/compute/kernels/scalar_cast_test.cc index 9ad08150896..fce8518dd3b 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_test.cc @@ -1927,17 +1927,11 @@ TEST(Cast, DictTypeToAnotherDict) { "[1, 2, 3, 1, null, 3]"); check_cast(dictionary(int8(), int16()), dictionary(int32(), int64()), "[1, 2, 3, 1, null, 3]"); + check_cast(dictionary(int8(), int16()), dictionary(int32(), float64()), + "[1, 2, 3, 1, null, 3]"); check_cast(dictionary(int32(), utf8()), dictionary(int8(), utf8()), R"(["a", "b", "a", null])"); - // check float types - // TODO(ARROW-13381): ArrayFromJSON doesnt work for float value dictionary types - auto arr_int8_int16 = - ArrayFromJSON(dictionary(int8(), int16()), "[1, 2, 3, 1, null, 3]"); - auto arr_float64 = ArrayFromJSON(float64(), "[1, 2, 3, 1, null, 3]"); - ASSERT_OK_AND_ASSIGN(auto arr_int32_float64, DictionaryEncode(arr_float64)); - CheckCast(arr_int8_int16, arr_int32_float64.make_array(), CastOptions::Safe()); - auto arr = ArrayFromJSON(dictionary(int32(), int32()), "[1, 1000]"); // check casting unsafe values (checking for unsafe indices is unnecessary, because it // would create an invalid index array which results in a ValidateOutput failure) From 1bad25f7edf9c59f027f3e04bf2ff84a6212b4a6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kriszti=C3=A1n=20Sz=C5=B1cs?= Date: Tue, 20 Jul 2021 16:20:04 +0200 Subject: [PATCH 14/14] Python test case --- python/pyarrow/tests/test_array.py | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/python/pyarrow/tests/test_array.py b/python/pyarrow/tests/test_array.py index 57224ef4ebe..0f137383378 100644 --- a/python/pyarrow/tests/test_array.py +++ b/python/pyarrow/tests/test_array.py @@ -1524,12 +1524,26 @@ def test_cast_string_to_number_roundtrip(): def test_cast_dictionary(): - arr = pa.DictionaryArray.from_arrays( - pa.array([0, 1, None], type=pa.int32()), - pa.array(["foo", "bar"])) - assert arr.cast(pa.string()).equals(pa.array(["foo", "bar", None])) + # cast to the value type + arr = pa.array( + ["foo", "bar", None], + type=pa.dictionary(pa.int64(), pa.string()) + ) + expected = pa.array(["foo", "bar", None]) + assert arr.type == pa.dictionary(pa.int64(), pa.string()) + assert arr.cast(pa.string()) == expected + + # cast to a different key type + for key_type in [pa.int8(), pa.int16(), pa.int32()]: + typ = pa.dictionary(key_type, pa.string()) + expected = pa.array( + ["foo", "bar", None], + type=pa.dictionary(key_type, pa.string()) + ) + assert arr.cast(typ) == expected + + # shouldn't crash (ARROW-7077) with pytest.raises(pa.ArrowInvalid): - # Shouldn't crash (ARROW-7077) arr.cast(pa.int32())