From ba98a234f254a850b28b9773e07d9862d8aef162 Mon Sep 17 00:00:00 2001 From: Wes McKinney Date: Mon, 19 Aug 2019 19:29:16 -0500 Subject: [PATCH] Support view from one dictionary type to another in Array::View --- cpp/src/arrow/array.cc | 28 +++++++++++++++------------- cpp/src/arrow/array_view_test.cc | 31 +++++++++++++++++++++++++++---- 2 files changed, 42 insertions(+), 17 deletions(-) diff --git a/cpp/src/arrow/array.cc b/cpp/src/arrow/array.cc index 72997030e48..cfcc65a6743 100644 --- a/cpp/src/arrow/array.cc +++ b/cpp/src/arrow/array.cc @@ -1044,15 +1044,6 @@ struct ViewDataImpl { return Status::OK(); } - Status CheckInputHasNoDictionaries() { - for (const auto& layout : in_layouts) { - if (layout.has_dictionary) { - return InvalidView("input has dictionary"); - } - } - return Status::OK(); - } - Status CheckInputAtZeroOffset() { for (const auto& data : in_data) { if (data->offset != 0) { @@ -1062,18 +1053,28 @@ struct ViewDataImpl { return Status::OK(); } + Status GetDictionaryView(const DataType& out_type, std::shared_ptr* out) { + if (in_data[in_layout_idx]->type->id() != Type::DICTIONARY) { + return InvalidView("Cannot get view as dictionary type"); + } + const auto& dict_out_type = static_cast(out_type); + return in_data[in_layout_idx]->dictionary->View(dict_out_type.value_type(), out); + } + Status MakeDataView(const std::shared_ptr& out_field, std::shared_ptr* out) { const auto out_type = out_field->type(); const auto out_layout = out_type->layout(); - if (out_layout.has_dictionary) { - return InvalidView("view type requires dictionary"); - } AdjustInputPointer(); int64_t out_length = in_data_length; int64_t out_null_count; + std::shared_ptr dictionary; + if (out_type->id() == Type::DICTIONARY) { + RETURN_NOT_OK(GetDictionaryView(*out_type, &dictionary)); + } + // No type has a purely empty layout DCHECK_GT(out_layout.bit_widths.size(), 0); @@ -1143,6 +1144,8 @@ struct ViewDataImpl { std::shared_ptr out_data = ArrayData::Make(out_type, out_length, std::move(out_buffers), out_null_count); + out_data->dictionary = dictionary; + // Process children recursively, depth-first for (const auto& child_field : out_type->children()) { std::shared_ptr child_data; @@ -1166,7 +1169,6 @@ Status Array::View(const std::shared_ptr& out_type, impl.in_data_length = data_->length; std::shared_ptr out_data; - RETURN_NOT_OK(impl.CheckInputHasNoDictionaries()); RETURN_NOT_OK(impl.CheckInputAtZeroOffset()); // Dummy field for output type auto out_field = field("", out_type); diff --git a/cpp/src/arrow/array_view_test.cc b/cpp/src/arrow/array_view_test.cc index 8d94242d8f3..4d4c5704b0b 100644 --- a/cpp/src/arrow/array_view_test.cc +++ b/cpp/src/arrow/array_view_test.cc @@ -355,14 +355,37 @@ TEST(TestArrayView, DecimalRoundTrip) { } TEST(TestArrayView, Dictionaries) { - // Can't view dictionaries + // ARROW-6049 auto ty1 = dictionary(int8(), float32()); + auto ty2 = dictionary(int8(), int32()); + auto indices = ArrayFromJSON(int8(), "[0, 2, null, 1]"); auto values = ArrayFromJSON(float32(), "[0.0, 1.5, -2.5]"); - std::shared_ptr arr; + + std::shared_ptr arr, expected, expected_dict; + ASSERT_OK(values->View(int32(), &expected_dict)); ASSERT_OK(DictionaryArray::FromArrays(ty1, indices, values, &arr)); - CheckViewFails(arr, int8()); - CheckViewFails(indices, ty1); + ASSERT_OK(DictionaryArray::FromArrays(ty2, indices, expected_dict, &expected)); + + CheckView(arr, expected); + CheckView(expected, arr); + + // Incompatible index type + auto ty3 = dictionary(int16(), int32()); + CheckViewFails(arr, ty3); + + // Incompatible dictionary type + auto ty4 = dictionary(int16(), float64()); + CheckViewFails(arr, ty4); + + // Check dictionary-encoded child + auto offsets = ArrayFromJSON(int32(), "[0, 2, 2, 4]"); + std::shared_ptr list_arr, expected_list_arr; + ASSERT_OK(ListArray::FromArrays(*offsets, *arr, default_memory_pool(), &list_arr)); + ASSERT_OK(ListArray::FromArrays(*offsets, *expected, default_memory_pool(), + &expected_list_arr)); + CheckView(list_arr, expected_list_arr); + CheckView(expected_list_arr, list_arr); } TEST(TestArrayView, ExtensionType) {