From 06b428dbe0fa9e03a11af0d467b57145240ddabe Mon Sep 17 00:00:00 2001
From: David Li
Date: Thu, 26 Aug 2021 09:33:41 -0400
Subject: [PATCH 01/21] ARROW-13573: [C++] Add DictScalarFromJSON
---
cpp/src/arrow/ipc/json_simple.cc | 19 +++++++++++++++++++
cpp/src/arrow/ipc/json_simple.h | 5 +++++
cpp/src/arrow/ipc/json_simple_test.cc | 23 +++++++++++++++++++++++
cpp/src/arrow/testing/gtest_util.cc | 9 +++++++++
cpp/src/arrow/testing/gtest_util.h | 5 +++++
5 files changed, 61 insertions(+)
diff --git a/cpp/src/arrow/ipc/json_simple.cc b/cpp/src/arrow/ipc/json_simple.cc
index 34b0f3fba59..8347b871b1f 100644
--- a/cpp/src/arrow/ipc/json_simple.cc
+++ b/cpp/src/arrow/ipc/json_simple.cc
@@ -969,6 +969,25 @@ Status ScalarFromJSON(const std::shared_ptr& type,
return Status::OK();
}
+Status DictScalarFromJSON(const std::shared_ptr& type,
+ util::string_view index_json, util::string_view dictionary_json,
+ std::shared_ptr* out) {
+ if (type->id() != Type::DICTIONARY) {
+ return Status::TypeError("DictScalarFromJSON requires dictionary type, got ", *type);
+ }
+
+ const auto& dictionary_type = checked_cast(*type);
+
+ std::shared_ptr index;
+ std::shared_ptr dictionary;
+ RETURN_NOT_OK(ScalarFromJSON(dictionary_type.index_type(), index_json, &index));
+ RETURN_NOT_OK(
+ ArrayFromJSON(dictionary_type.value_type(), dictionary_json, &dictionary));
+
+ *out = DictionaryScalar::Make(std::move(index), std::move(dictionary));
+ return Status::OK();
+}
+
} // namespace json
} // namespace internal
} // namespace ipc
diff --git a/cpp/src/arrow/ipc/json_simple.h b/cpp/src/arrow/ipc/json_simple.h
index 4dd3a664aa6..8269bd65326 100644
--- a/cpp/src/arrow/ipc/json_simple.h
+++ b/cpp/src/arrow/ipc/json_simple.h
@@ -55,6 +55,11 @@ ARROW_EXPORT
Status ScalarFromJSON(const std::shared_ptr&, util::string_view json,
std::shared_ptr* out);
+ARROW_EXPORT
+Status DictScalarFromJSON(const std::shared_ptr&, util::string_view index_json,
+ util::string_view dictionary_json,
+ std::shared_ptr* out);
+
} // namespace json
} // namespace internal
} // namespace ipc
diff --git a/cpp/src/arrow/ipc/json_simple_test.cc b/cpp/src/arrow/ipc/json_simple_test.cc
index ce2c37b7957..372f6bf1d72 100644
--- a/cpp/src/arrow/ipc/json_simple_test.cc
+++ b/cpp/src/arrow/ipc/json_simple_test.cc
@@ -1385,6 +1385,29 @@ TEST(TestScalarFromJSON, Errors) {
ASSERT_RAISES(Invalid, ScalarFromJSON(boolean(), "\"true\"", &scalar));
}
+TEST(TestDictScalarFromJSON, Basics) {
+ auto type = dictionary(int32(), utf8());
+ auto dict = R"(["whiskey", "tango", "foxtrot"])";
+ auto expected_dictionary = ArrayFromJSON(utf8(), dict);
+
+ for (auto index : {"null", "2", "1", "0"}) {
+ auto scalar = DictScalarFromJSON(type, index, dict);
+ auto expected_index = ScalarFromJSON(int32(), index);
+ AssertScalarsEqual(*DictionaryScalar::Make(expected_index, expected_dictionary),
+ *scalar);
+ }
+}
+
+TEST(TestDictScalarFromJSON, Errors) {
+ auto type = dictionary(int32(), utf8());
+ std::shared_ptr scalar;
+
+ ASSERT_RAISES(Invalid,
+ DictScalarFromJSON(type, "\"not a valid index\"", "[\"\"]", &scalar));
+ ASSERT_RAISES(Invalid, DictScalarFromJSON(type, "0", "[1]",
+ &scalar)); // dict value isn't string
+}
+
} // namespace json
} // namespace internal
} // namespace ipc
diff --git a/cpp/src/arrow/testing/gtest_util.cc b/cpp/src/arrow/testing/gtest_util.cc
index 587154c1f30..24f5edcc6cb 100644
--- a/cpp/src/arrow/testing/gtest_util.cc
+++ b/cpp/src/arrow/testing/gtest_util.cc
@@ -446,6 +446,15 @@ std::shared_ptr ScalarFromJSON(const std::shared_ptr& type,
return out;
}
+std::shared_ptr DictScalarFromJSON(const std::shared_ptr& type,
+ util::string_view index_json,
+ util::string_view dictionary_json) {
+ std::shared_ptr out;
+ ABORT_NOT_OK(
+ ipc::internal::json::DictScalarFromJSON(type, index_json, dictionary_json, &out));
+ return out;
+}
+
std::shared_ptr TableFromJSON(const std::shared_ptr& schema,
const std::vector& json) {
std::vector> batches;
diff --git a/cpp/src/arrow/testing/gtest_util.h b/cpp/src/arrow/testing/gtest_util.h
index f0021e05603..65ab33c5d1f 100644
--- a/cpp/src/arrow/testing/gtest_util.h
+++ b/cpp/src/arrow/testing/gtest_util.h
@@ -338,6 +338,11 @@ ARROW_TESTING_EXPORT
std::shared_ptr ScalarFromJSON(const std::shared_ptr&,
util::string_view json);
+ARROW_TESTING_EXPORT
+std::shared_ptr DictScalarFromJSON(const std::shared_ptr&,
+ util::string_view index_json,
+ util::string_view dictionary_json);
+
ARROW_TESTING_EXPORT
std::shared_ptr TableFromJSON(const std::shared_ptr&,
const std::vector& json);
From 2ec113219f2d9733d0534de70eb0dd05f6124b75 Mon Sep 17 00:00:00 2001
From: David Li
Date: Thu, 26 Aug 2021 10:59:54 -0400
Subject: [PATCH 02/21] ARROW-13573: [C++] Check that dictionary array has
dictionary
---
cpp/src/arrow/array/validate.cc | 3 +++
1 file changed, 3 insertions(+)
diff --git a/cpp/src/arrow/array/validate.cc b/cpp/src/arrow/array/validate.cc
index c66c4f53b9d..1715863014c 100644
--- a/cpp/src/arrow/array/validate.cc
+++ b/cpp/src/arrow/array/validate.cc
@@ -568,6 +568,9 @@ struct ValidateArrayFullImpl {
}
Status Visit(const DictionaryType& type) {
+ if (!data.dictionary) {
+ return Status::Invalid("Dictionary array has no dictionary");
+ }
const Status indices_status =
CheckBounds(*type.index_type(), 0, data.dictionary->length - 1);
if (!indices_status.ok()) {
From e3b7f93f892bb01b735f3066bf814779b7385350 Mon Sep 17 00:00:00 2001
From: David Li
Date: Thu, 26 Aug 2021 11:00:32 -0400
Subject: [PATCH 03/21] ARROW-13573: [C++] Handle simple dictionary cases
---
.../arrow/compute/kernels/scalar_if_else.cc | 85 ++++++++++++++++++-
.../compute/kernels/scalar_if_else_test.cc | 81 ++++++++++++++++++
2 files changed, 164 insertions(+), 2 deletions(-)
diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else.cc b/cpp/src/arrow/compute/kernels/scalar_if_else.cc
index 4de04da7a81..fc69c0e2fd0 100644
--- a/cpp/src/arrow/compute/kernels/scalar_if_else.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_if_else.cc
@@ -1092,6 +1092,63 @@ struct CopyFixedWidth> {
}
};
+template
+struct CopyFixedWidth> {
+ // TODO: how are we going to deal with passing down a mapping?
+ static void CopyScalar(const Scalar& scalar, const int64_t length,
+ uint8_t* raw_out_values, const int64_t out_offset) {
+ const auto& index = *checked_cast(scalar).value.index;
+ switch (index.type->id()) {
+ case arrow::Type::INT8:
+ case arrow::Type::UINT8:
+ CopyFixedWidth::CopyScalar(index, length, raw_out_values, out_offset);
+ break;
+ case arrow::Type::INT16:
+ case arrow::Type::UINT16:
+ CopyFixedWidth::CopyScalar(index, length, raw_out_values, out_offset);
+ break;
+ case arrow::Type::INT32:
+ case arrow::Type::UINT32:
+ CopyFixedWidth::CopyScalar(index, length, raw_out_values, out_offset);
+ break;
+ case arrow::Type::INT64:
+ case arrow::Type::UINT64:
+ CopyFixedWidth::CopyScalar(index, length, raw_out_values, out_offset);
+ break;
+ default:
+ ARROW_CHECK(false) << "Invalid index type for dictionary: " << *index.type;
+ }
+ }
+ static void CopyArray(const DataType& type, const uint8_t* in_values,
+ const int64_t in_offset, const int64_t length,
+ uint8_t* raw_out_values, const int64_t out_offset) {
+ const auto& index_type = *checked_cast(type).index_type();
+ switch (index_type.id()) {
+ case arrow::Type::INT8:
+ case arrow::Type::UINT8:
+ CopyFixedWidth::CopyArray(index_type, in_values, in_offset, length,
+ raw_out_values, out_offset);
+ break;
+ case arrow::Type::INT16:
+ case arrow::Type::UINT16:
+ CopyFixedWidth::CopyArray(index_type, in_values, in_offset, length,
+ raw_out_values, out_offset);
+ break;
+ case arrow::Type::INT32:
+ case arrow::Type::UINT32:
+ CopyFixedWidth::CopyArray(index_type, in_values, in_offset, length,
+ raw_out_values, out_offset);
+ break;
+ case arrow::Type::INT64:
+ case arrow::Type::UINT64:
+ CopyFixedWidth::CopyArray(index_type, in_values, in_offset, length,
+ raw_out_values, out_offset);
+ break;
+ default:
+ ARROW_CHECK(false) << "Invalid index type for dictionary: " << index_type;
+ }
+ }
+};
template
struct CopyFixedWidth> {
static void CopyScalar(const Scalar& values, const int64_t length,
@@ -1222,7 +1279,6 @@ struct CaseWhenFunction : ScalarFunction {
// The first function is a struct of booleans, where the number of fields in the
// struct is either equal to the number of other arguments or is one less.
RETURN_NOT_OK(CheckArity(*values));
- EnsureDictionaryDecoded(values);
auto first_type = (*values)[0].type;
if (first_type->id() != Type::STRUCT) {
return Status::TypeError("case_when: first argument must be STRUCT, not ",
@@ -1243,6 +1299,9 @@ struct CaseWhenFunction : ScalarFunction {
}
}
+ if (auto kernel = DispatchExactImpl(this, *values)) return kernel;
+
+ EnsureDictionaryDecoded(values);
if (auto type = CommonNumeric(values->data() + 1, values->size() - 1)) {
for (auto it = values->begin() + 1; it != values->end(); it++) {
it->type = type;
@@ -1275,10 +1334,20 @@ Status ExecScalarCaseWhen(KernelContext* ctx, const ExecBatch& batch, Datum* out
}
}
if (out->is_scalar()) {
+ // TODO: makenullscalar here should give the expected dictionary
*out = result.is_scalar() ? result.scalar() : MakeNullScalar(out->type());
return Status::OK();
}
ArrayData* output = out->mutable_array();
+ if (is_dictionary_type::value) {
+ const Datum& dict_from = result.is_value() ? result : batch[1];
+ if (dict_from.is_scalar()) {
+ output->dictionary = checked_cast(*dict_from.scalar())
+ .value.dictionary->data();
+ } else {
+ output->dictionary = dict_from.array()->dictionary;
+ }
+ }
if (!result.is_value()) {
// All conditions false, no 'else' argument
result = MakeNullScalar(out->type());
@@ -1313,6 +1382,17 @@ Status ExecArrayCaseWhen(KernelContext* ctx, const ExecBatch& batch, Datum* out)
BitUtil::SetBitsTo(out_valid, out_offset, batch.length, false);
}
+ if (is_dictionary_type::value) {
+ // We always use the dictionary of the first argument
+ const Datum& dict_from = batch[1];
+ if (dict_from.is_scalar()) {
+ output->dictionary = checked_cast(*dict_from.scalar())
+ .value.dictionary->data();
+ } else {
+ output->dictionary = dict_from.array()->dictionary;
+ }
+ }
+
// Allocate a temporary bitmap to determine which elements still need setting.
ARROW_ASSIGN_OR_RAISE(auto mask_buffer, ctx->AllocateBitmap(batch.length));
uint8_t* mask = mask_buffer->mutable_data();
@@ -2446,7 +2526,7 @@ void RegisterScalarIfElse(FunctionRegistry* registry) {
}
{
auto func = std::make_shared(
- "case_when", Arity::VarArgs(/*min_args=*/1), &case_when_doc);
+ "case_when", Arity::VarArgs(/*min_args=*/2), &case_when_doc);
AddPrimitiveCaseWhenKernels(func, NumericTypes());
AddPrimitiveCaseWhenKernels(func, TemporalTypes());
AddPrimitiveCaseWhenKernels(func, IntervalTypes());
@@ -2464,6 +2544,7 @@ void RegisterScalarIfElse(FunctionRegistry* registry) {
AddCaseWhenKernel(func, Type::STRUCT, CaseWhenFunctor::Exec);
AddCaseWhenKernel(func, Type::DENSE_UNION, CaseWhenFunctor::Exec);
AddCaseWhenKernel(func, Type::SPARSE_UNION, CaseWhenFunctor::Exec);
+ AddCaseWhenKernel(func, Type::DICTIONARY, CaseWhenFunctor::Exec);
DCHECK_OK(registry->AddFunction(std::move(func)));
}
{
diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc
index b3b0f26cead..ca6b688eb3d 100644
--- a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc
@@ -624,6 +624,87 @@ TYPED_TEST(TestCaseWhenNumeric, ListOfType) {
ArrayFromJSON(type, R"([null, null, null, [6, null]])"));
}
+template
+class TestCaseWhenInteger : public ::testing::Test {};
+
+TYPED_TEST_SUITE(TestCaseWhenInteger, IntegralArrowTypes);
+
+TYPED_TEST(TestCaseWhenInteger, DictionaryEncodingSimple) {
+ auto type = dictionary(default_type_instance(), utf8());
+ auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]");
+ auto cond2 = ArrayFromJSON(boolean(), "[true, false, true, null]");
+ auto dict = R"(["a", null, "bc", "def"])";
+ auto values_null = DictArrayFromJSON(type, "[null, null, null, null]", dict);
+ auto values1 = DictArrayFromJSON(type, "[0, null, 3, 1]", dict);
+ auto values2 = DictArrayFromJSON(type, "[2, 1, null, 0]", dict);
+
+ // Easy case: all arguments have the same dictionary
+ // CheckScalar("case_when", {MakeStruct({Datum(false)}), DictScalarFromJSON(type, "1",
+ // dict)},
+ // DictScalarFromJSON(type, "[0, null, null, null]", dict));
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2},
+ DictArrayFromJSON(type, "[0, null, null, null]", dict));
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2, values1},
+ DictArrayFromJSON(type, "[0, null, null, 1]", dict));
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values_null, values2, values1},
+ DictArrayFromJSON(type, "[null, null, null, 1]", dict));
+}
+
+TYPED_TEST(TestCaseWhenInteger, DictionaryEncodingMixed) {
+ auto type = dictionary(default_type_instance(), utf8());
+ auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]");
+ auto cond2 = ArrayFromJSON(boolean(), "[true, false, true, null]");
+ auto dict = R"(["a", null, "bc", "def"])";
+ auto values_null = DictArrayFromJSON(type, "[null, null, null, null]", dict);
+ auto values1_dict = DictArrayFromJSON(type, "[0, null, 3, 1]", dict);
+ auto values1_decoded = ArrayFromJSON(utf8(), R"(["a", null, "def", null])");
+ auto values2_dict = DictArrayFromJSON(type, "[2, 1, null, 0]", dict);
+ auto values2_decoded = ArrayFromJSON(utf8(), R"(["bc", null, null, "a"])");
+
+ // If we have mixed dictionary/non-dictionary arguments, we decode dictionaries
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1_dict, values2_decoded},
+ ArrayFromJSON(utf8(), R"(["a", null, null, null])"));
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1_decoded, values2_dict},
+ ArrayFromJSON(utf8(), R"(["a", null, null, null])"));
+ CheckScalar("case_when",
+ {MakeStruct({cond1, cond2}), values1_dict, values2_dict, values1_decoded},
+ ArrayFromJSON(utf8(), R"(["a", null, null, null])"));
+ CheckScalar("case_when",
+ {MakeStruct({cond1, cond2}), values_null, values2_dict, values1_decoded},
+ ArrayFromJSON(utf8(), R"([null, null, null, null])"));
+}
+
+TYPED_TEST(TestCaseWhenInteger, DictionaryEncodingDifferentDictionaries) {
+ auto type = dictionary(default_type_instance(), utf8());
+ auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]");
+ auto cond2 = ArrayFromJSON(boolean(), "[true, false, true, null]");
+ auto dict1 = R"(["a", null, "bc", "def"])";
+ auto dict2 = R"(["bc", "foo", null, "a"])";
+ auto values1_null = DictArrayFromJSON(type, "[null, null, null, null]", dict1);
+ auto values2_null = DictArrayFromJSON(type, "[null, null, null, null]", dict2);
+ auto values1 = DictArrayFromJSON(type, "[0, null, 3, 1]", dict1);
+ auto values2 = DictArrayFromJSON(type, "[2, 1, null, 0]", dict2);
+
+ // For scalar conditions, we borrow the dictionary of the chosen output (or the first
+ // input when outputting null)
+ CheckScalar("case_when", {MakeStruct({Datum(true), Datum(false)}), values1, values2},
+ values1);
+ CheckScalar("case_when", {MakeStruct({Datum(false), Datum(true)}), values1, values2},
+ values2);
+ CheckScalar("case_when", {MakeStruct({Datum(false), Datum(false)}), values1, values2},
+ values1_null);
+ CheckScalar("case_when", {MakeStruct({Datum(false), Datum(false)}), values2, values1},
+ values2_null);
+
+ // For array conditions, we always borrow the dictionary of the first input
+
+ // When mixing dictionaries, we try to map other dictionaries onto the first one
+
+ // If we can't map values from a dictionary, then raise an error
+
+ // ...or optionally, raise an error
+}
+
TEST(TestCaseWhen, Null) {
auto cond_true = ScalarFromJSON(boolean(), "true");
auto cond_false = ScalarFromJSON(boolean(), "false");
From 7a57c91420222185890f1f74ec46f86388da9cd9 Mon Sep 17 00:00:00 2001
From: David Li
Date: Fri, 27 Aug 2021 11:41:38 -0400
Subject: [PATCH 04/21] ARROW-13573: [C++] Transpose dictionaries in case_when
---
.../arrow/compute/kernels/scalar_if_else.cc | 377 +++++++++++++++++-
.../compute/kernels/scalar_if_else_test.cc | 148 ++++++-
cpp/src/arrow/compute/kernels/test_util.cc | 14 +-
cpp/src/arrow/compute/kernels/test_util.h | 7 +
4 files changed, 498 insertions(+), 48 deletions(-)
diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else.cc b/cpp/src/arrow/compute/kernels/scalar_if_else.cc
index fc69c0e2fd0..6a2ca01c439 100644
--- a/cpp/src/arrow/compute/kernels/scalar_if_else.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_if_else.cc
@@ -15,10 +15,13 @@
// specific language governing permissions and limitations
// under the License.
+#include
+
#include "arrow/array/builder_nested.h"
#include "arrow/array/builder_primitive.h"
#include "arrow/array/builder_time.h"
#include "arrow/array/builder_union.h"
+#include "arrow/array/dict_internal.h"
#include "arrow/compute/api.h"
#include "arrow/compute/kernels/codegen_internal.h"
#include "arrow/compute/util_internal.h"
@@ -27,6 +30,7 @@
#include "arrow/util/bitmap.h"
#include "arrow/util/bitmap_ops.h"
#include "arrow/util/bitmap_reader.h"
+#include "arrow/util/int_util.h"
namespace arrow {
@@ -1058,6 +1062,109 @@ void AddFSBinaryIfElseKernel(const std::shared_ptr& scalar_funct
DCHECK_OK(scalar_function->AddKernel(std::move(kernel)));
}
+// Given a reference dictionary, computes indices to map dictionary values from a
+// comparison dictionary to the reference.
+class DictionaryRemapper {
+ public:
+ virtual ~DictionaryRemapper() = default;
+ virtual Status Init(const Array& dictionary) = 0;
+ virtual Result> Remap(const Array& dictionary) = 0;
+ Result> Remap(const Datum& has_dictionary) {
+ DCHECK_EQ(has_dictionary.type()->id(), Type::DICTIONARY);
+ if (has_dictionary.is_scalar()) {
+ return Remap(*checked_cast(*has_dictionary.scalar())
+ .value.dictionary);
+ } else {
+ return Remap(*MakeArray(has_dictionary.array()->dictionary));
+ }
+ }
+};
+
+template
+class DictionaryRemapperImpl : public DictionaryRemapper {
+ public:
+ using ArrayType = typename TypeTraits::ArrayType;
+ using MemoTableType = typename arrow::internal::HashTraits::MemoTableType;
+
+ explicit DictionaryRemapperImpl(MemoryPool* pool) : pool_(pool), memo_table_(pool) {}
+
+ Status Init(const Array& dictionary) override {
+ const ArrayType& values = checked_cast(dictionary);
+ if (values.length() > std::numeric_limits::max()) {
+ return Status::CapacityError("Cannot remap dictionary with more than ",
+ std::numeric_limits::max(),
+ " elements, have: ", values.length());
+ }
+ for (int32_t i = 0; i < values.length(); ++i) {
+ if (values.IsNull(i)) {
+ memo_table_.GetOrInsertNull();
+ continue;
+ }
+ int32_t unused_memo_index = 0;
+ RETURN_NOT_OK(memo_table_.GetOrInsert(values.GetView(i), &unused_memo_index));
+ }
+ return Status::OK();
+ }
+
+ Result> Remap(const Array& dictionary) override {
+ const ArrayType& values = checked_cast(dictionary);
+ std::shared_ptr valid_buffer;
+ ARROW_ASSIGN_OR_RAISE(auto indices_buffer,
+ AllocateBuffer(dictionary.length() * sizeof(int32_t), pool_));
+ int32_t* indices = reinterpret_cast(indices_buffer->mutable_data());
+ int64_t null_count = 0;
+ for (int64_t i = 0; i < values.length(); ++i) {
+ int32_t index = -1;
+ index =
+ values.IsNull(i) ? memo_table_.GetNull() : memo_table_.Get(values.GetView(i));
+ indices[i] = std::max(0, index);
+ if (index == arrow::internal::kKeyNotFound && !valid_buffer) {
+ ARROW_ASSIGN_OR_RAISE(
+ valid_buffer,
+ AllocateBuffer(BitUtil::BytesForBits(dictionary.length()), pool_));
+ std::memset(valid_buffer->mutable_data(), 0xFF, valid_buffer->size());
+ }
+ if (index == arrow::internal::kKeyNotFound) {
+ BitUtil::ClearBit(valid_buffer->mutable_data(), i);
+ null_count++;
+ }
+ }
+ return arrow::internal::make_unique(dictionary.length(),
+ std::move(indices_buffer),
+ std::move(valid_buffer), null_count);
+ }
+
+ private:
+ MemoryPool* pool_;
+ MemoTableType memo_table_;
+};
+
+struct MakeRemapper {
+ template
+ enable_if_no_memoize Visit(const T& value_type) {
+ return Status::NotImplemented("Unification of ", value_type,
+ " dictionaries is not implemented");
+ }
+
+ template
+ enable_if_memoize Visit(const T&) {
+ result_.reset(new DictionaryRemapperImpl(pool_));
+ return Status::OK();
+ }
+
+ static Result> Make(MemoryPool* pool,
+ const Array& dictionary) {
+ const auto& value_type = *dictionary.type();
+ MakeRemapper impl{pool, /*result_=*/nullptr};
+ RETURN_NOT_OK(VisitTypeInline(value_type, &impl));
+ RETURN_NOT_OK(impl.result_->Init(dictionary));
+ return std::move(impl.result_);
+ }
+
+ MemoryPool* pool_;
+ std::unique_ptr result_;
+};
+
// Helper to copy or broadcast fixed-width values between buffers.
template
struct CopyFixedWidth {};
@@ -1084,17 +1191,33 @@ struct CopyFixedWidth> {
const CType value = UnboxScalar::Unbox(scalar);
std::fill(out_values + out_offset, out_values + out_offset + length, value);
}
+ static void CopyScalar(const Scalar& scalar, const Int32Array& transpose_map,
+ const int64_t length, uint8_t* raw_out_values,
+ const int64_t out_offset) {
+ CType* out_values = reinterpret_cast(raw_out_values);
+ const CType value = UnboxScalar::Unbox(scalar);
+ const CType transposed = static_cast(transpose_map.raw_values()[value]);
+ std::fill(out_values + out_offset, out_values + out_offset + length, transposed);
+ }
static void CopyArray(const DataType&, const uint8_t* in_values,
const int64_t in_offset, const int64_t length,
uint8_t* raw_out_values, const int64_t out_offset) {
std::memcpy(raw_out_values + out_offset * sizeof(CType),
in_values + in_offset * sizeof(CType), length * sizeof(CType));
}
+ static void CopyArray(const DataType&, const Int32Array& transpose_map,
+ const uint8_t* in_values, const int64_t in_offset,
+ const int64_t length, uint8_t* raw_out_values,
+ const int64_t out_offset) {
+ arrow::internal::TransposeInts(
+ reinterpret_cast(in_values) + in_offset,
+ reinterpret_cast(raw_out_values) + out_offset, length,
+ transpose_map.raw_values());
+ }
};
template
struct CopyFixedWidth> {
- // TODO: how are we going to deal with passing down a mapping?
static void CopyScalar(const Scalar& scalar, const int64_t length,
uint8_t* raw_out_values, const int64_t out_offset) {
const auto& index = *checked_cast(scalar).value.index;
@@ -1119,6 +1242,35 @@ struct CopyFixedWidth> {
ARROW_CHECK(false) << "Invalid index type for dictionary: " << *index.type;
}
}
+ static void CopyScalar(const Scalar& scalar, const Int32Array& transpose_map,
+ const int64_t length, uint8_t* raw_out_values,
+ const int64_t out_offset) {
+ const auto& index = *checked_cast(scalar).value.index;
+ switch (index.type->id()) {
+ case arrow::Type::INT8:
+ case arrow::Type::UINT8:
+ CopyFixedWidth::CopyScalar(index, transpose_map, length,
+ raw_out_values, out_offset);
+ break;
+ case arrow::Type::INT16:
+ case arrow::Type::UINT16:
+ CopyFixedWidth::CopyScalar(index, transpose_map, length,
+ raw_out_values, out_offset);
+ break;
+ case arrow::Type::INT32:
+ case arrow::Type::UINT32:
+ CopyFixedWidth::CopyScalar(index, transpose_map, length,
+ raw_out_values, out_offset);
+ break;
+ case arrow::Type::INT64:
+ case arrow::Type::UINT64:
+ CopyFixedWidth::CopyScalar(index, transpose_map, length,
+ raw_out_values, out_offset);
+ break;
+ default:
+ ARROW_CHECK(false) << "Invalid index type for dictionary: " << *index.type;
+ }
+ }
static void CopyArray(const DataType& type, const uint8_t* in_values,
const int64_t in_offset, const int64_t length,
uint8_t* raw_out_values, const int64_t out_offset) {
@@ -1148,6 +1300,40 @@ struct CopyFixedWidth> {
ARROW_CHECK(false) << "Invalid index type for dictionary: " << index_type;
}
}
+ static void CopyArray(const DataType& type, const Int32Array& transpose_map,
+ const uint8_t* in_values, const int64_t in_offset,
+ const int64_t length, uint8_t* raw_out_values,
+ const int64_t out_offset) {
+ const auto& index_type = *checked_cast(type).index_type();
+ switch (index_type.id()) {
+ case arrow::Type::INT8:
+ case arrow::Type::UINT8:
+ CopyFixedWidth::CopyArray(index_type, transpose_map, in_values,
+ in_offset, length, raw_out_values,
+ out_offset);
+ break;
+ case arrow::Type::INT16:
+ case arrow::Type::UINT16:
+ CopyFixedWidth::CopyArray(index_type, transpose_map, in_values,
+ in_offset, length, raw_out_values,
+ out_offset);
+ break;
+ case arrow::Type::INT32:
+ case arrow::Type::UINT32:
+ CopyFixedWidth::CopyArray(index_type, transpose_map, in_values,
+ in_offset, length, raw_out_values,
+ out_offset);
+ break;
+ case arrow::Type::INT64:
+ case arrow::Type::UINT64:
+ CopyFixedWidth::CopyArray(index_type, transpose_map, in_values,
+ in_offset, length, raw_out_values,
+ out_offset);
+ break;
+ default:
+ ARROW_CHECK(false) << "Invalid index type for dictionary: " << index_type;
+ }
+ }
};
template
struct CopyFixedWidth> {
@@ -1203,8 +1389,11 @@ struct CopyFixedWidth> {
// Copy fixed-width values from a scalar/array datum into an output values buffer
template
-void CopyValues(const Datum& in_values, const int64_t in_offset, const int64_t length,
- uint8_t* out_valid, uint8_t* out_values, const int64_t out_offset) {
+enable_if_t::value> CopyValues(
+ const Datum& in_values, const int64_t in_offset, const int64_t length,
+ uint8_t* out_valid, uint8_t* out_values, const int64_t out_offset,
+ const Int32Array* transpose_map = nullptr) {
+ DCHECK(!transpose_map);
if (in_values.is_scalar()) {
const auto& scalar = *in_values.scalar();
if (out_valid) {
@@ -1234,6 +1423,123 @@ void CopyValues(const Datum& in_values, const int64_t in_offset, const int64_t l
}
}
+// Copy values, optionally transposing dictionary indices.
+template
+enable_if_dictionary CopyValues(const Datum& in_values, const int64_t in_offset,
+ const int64_t length, uint8_t* out_valid,
+ uint8_t* out_values, const int64_t out_offset,
+ const Int32Array* transpose_map) {
+ if (in_values.is_scalar()) {
+ const auto& scalar = *in_values.scalar();
+ if (out_valid) {
+ BitUtil::SetBitsTo(out_valid, out_offset, length, scalar.is_valid);
+ }
+ if (transpose_map) {
+ CopyFixedWidth::CopyScalar(scalar, *transpose_map, length, out_values,
+ out_offset);
+ } else {
+ CopyFixedWidth::CopyScalar(scalar, length, out_values, out_offset);
+ }
+ } else {
+ const ArrayData& array = *in_values.array();
+ if (out_valid) {
+ if (array.MayHaveNulls()) {
+ if (length == 1) {
+ // CopyBitmap is slow for short runs
+ BitUtil::SetBitTo(
+ out_valid, out_offset,
+ BitUtil::GetBit(array.buffers[0]->data(), array.offset + in_offset));
+ } else {
+ arrow::internal::CopyBitmap(array.buffers[0]->data(), array.offset + in_offset,
+ length, out_valid, out_offset);
+ }
+ } else {
+ BitUtil::SetBitsTo(out_valid, out_offset, length, true);
+ }
+ }
+ if (transpose_map) {
+ CopyFixedWidth::CopyArray(*array.type, *transpose_map,
+ array.buffers[1]->data(), array.offset + in_offset,
+ length, out_values, out_offset);
+ } else {
+ CopyFixedWidth::CopyArray(*array.type, array.buffers[1]->data(),
+ array.offset + in_offset, length, out_values,
+ out_offset);
+ }
+ }
+}
+
+/// Check that we can actually remap dictionary indices from one dictionary to
+/// another without losing data.
+struct CheckValidTranspositionArrayImpl {
+ template
+ enable_if_number Visit(const T&) {
+ using c_type = typename T::c_type;
+ const c_type* values = arr.GetValues(1);
+ // TODO: eventually offer the option to zero out the bitmap instead
+ return arrow::internal::VisitSetBitRuns(
+ arr.buffers[0], arr.offset + offset, length,
+ [&](int64_t position, int64_t length) {
+ for (int64_t i = 0; i < length; i++) {
+ const uint64_t idx = static_cast(values[offset + position + i]);
+ if (!BitUtil::GetBit(transpose_valid, idx)) {
+ return Status::Invalid("Cannot map dictionary index ", idx, " at position ",
+ offset + position + i, " to the common dictionary");
+ }
+ }
+ return Status::OK();
+ });
+ }
+
+ Status Visit(const DataType& ty) {
+ return Status::TypeError("Dictionary cannot have index type", ty);
+ }
+
+ const ArrayData& arr;
+ const int64_t offset;
+ const int64_t length;
+ const uint8_t* transpose_valid;
+};
+
+struct CheckValidTranspositionScalarImpl {
+ template
+ enable_if_number Visit(const T&) {
+ const uint64_t idx = static_cast(UnboxScalar::Unbox(
+ *checked_cast(scalar).value.index));
+ // TODO: eventually offer the option to zero out the bitmap instead
+ if (!BitUtil::GetBit(transpose_valid, idx)) {
+ return Status::Invalid("Cannot map dictionary index ", idx,
+ " to the common dictionary");
+ }
+ return Status::OK();
+ }
+
+ Status Visit(const DataType& ty) {
+ return Status::TypeError("Dictionary cannot have index type", ty);
+ }
+
+ const Scalar& scalar;
+ const uint8_t* transpose_valid;
+};
+
+Status CheckValidTransposition(const Datum& values, const int64_t offset,
+ const int64_t length, const Int32Array* transpose_map) {
+ // Note we assume the transpose map never has an offset
+ if (!transpose_map || transpose_map->null_count() == 0) return Status::OK();
+ DCHECK_EQ(values.type()->id(), Type::DICTIONARY);
+ if (values.is_scalar()) {
+ const Scalar& scalar = *values.scalar();
+ CheckValidTranspositionScalarImpl impl{scalar, transpose_map->null_bitmap_data()};
+ return VisitTypeInline(
+ *checked_cast(*scalar.type).index_type(), &impl);
+ }
+ const ArrayData& arr = *values.array();
+ CheckValidTranspositionArrayImpl impl{arr, offset, length,
+ transpose_map->null_bitmap_data()};
+ return VisitTypeInline(*checked_cast(*arr.type).index_type(),
+ &impl);
+}
+
// Specialized helper to copy a single value from a source array. Allows avoiding
// repeatedly calling MayHaveNulls and Buffer::data() which have internal checks that
// add up when called in a loop.
@@ -1334,7 +1640,6 @@ Status ExecScalarCaseWhen(KernelContext* ctx, const ExecBatch& batch, Datum* out
}
}
if (out->is_scalar()) {
- // TODO: makenullscalar here should give the expected dictionary
*out = result.is_scalar() ? result.scalar() : MakeNullScalar(out->type());
return Status::OK();
}
@@ -1352,9 +1657,9 @@ Status ExecScalarCaseWhen(KernelContext* ctx, const ExecBatch& batch, Datum* out
// All conditions false, no 'else' argument
result = MakeNullScalar(out->type());
}
- CopyValues(result, /*in_offset=*/0, batch.length,
- output->GetMutableValues(0, 0),
- output->GetMutableValues(1, 0), output->offset);
+ CopyValues(
+ result, /*in_offset=*/0, batch.length, output->GetMutableValues(0, 0),
+ output->GetMutableValues(1, 0), output->offset, /*transpose_map=*/nullptr);
return Status::OK();
}
@@ -1373,15 +1678,9 @@ Status ExecArrayCaseWhen(KernelContext* ctx, const ExecBatch& batch, Datum* out)
static_cast(conds_array.type->num_fields()) < num_value_args;
uint8_t* out_valid = output->buffers[0]->mutable_data();
uint8_t* out_values = output->buffers[1]->mutable_data();
- if (have_else_arg) {
- // Copy 'else' value into output
- CopyValues(batch.values.back(), /*in_offset=*/0, batch.length, out_valid,
- out_values, out_offset);
- } else {
- // There's no 'else' argument, so we should have an all-null validity bitmap
- BitUtil::SetBitsTo(out_valid, out_offset, batch.length, false);
- }
+ std::unique_ptr remapper;
+ std::unique_ptr transpose_map;
if (is_dictionary_type::value) {
// We always use the dictionary of the first argument
const Datum& dict_from = batch[1];
@@ -1391,6 +1690,20 @@ Status ExecArrayCaseWhen(KernelContext* ctx, const ExecBatch& batch, Datum* out)
} else {
output->dictionary = dict_from.array()->dictionary;
}
+ ARROW_ASSIGN_OR_RAISE(
+ remapper, MakeRemapper::Make(ctx->memory_pool(), *MakeArray(output->dictionary)));
+ }
+
+ if (have_else_arg) {
+ // Copy 'else' value into output
+ if (is_dictionary_type::value) {
+ ARROW_ASSIGN_OR_RAISE(transpose_map, remapper->Remap(batch.values.back()));
+ }
+ CopyValues(batch.values.back(), /*in_offset=*/0, batch.length, out_valid,
+ out_values, out_offset, transpose_map.get());
+ } else {
+ // There's no 'else' argument, so we should have an all-null validity bitmap
+ BitUtil::SetBitsTo(out_valid, out_offset, batch.length, false);
}
// Allocate a temporary bitmap to determine which elements still need setting.
@@ -1406,6 +1719,10 @@ Status ExecArrayCaseWhen(KernelContext* ctx, const ExecBatch& batch, Datum* out)
const Datum& values_datum = batch[i + 1];
int64_t offset = 0;
+ if (is_dictionary_type::value) {
+ ARROW_ASSIGN_OR_RAISE(transpose_map, remapper->Remap(values_datum));
+ }
+
if (cond_array.GetNullCount() == 0) {
// If no valid buffer, visit mask & cond bitmap simultaneously
BinaryBitBlockCounter counter(mask, /*start_offset=*/0, cond_values, cond_offset,
@@ -1413,15 +1730,19 @@ Status ExecArrayCaseWhen(KernelContext* ctx, const ExecBatch& batch, Datum* out)
while (offset < batch.length) {
const auto block = counter.NextAndWord();
if (block.AllSet()) {
+ RETURN_NOT_OK(CheckValidTransposition(values_datum, offset, block.length,
+ transpose_map.get()));
CopyValues(values_datum, offset, block.length, out_valid, out_values,
- out_offset + offset);
+ out_offset + offset, transpose_map.get());
BitUtil::SetBitsTo(mask, offset, block.length, false);
} else if (block.popcount) {
for (int64_t j = 0; j < block.length; ++j) {
if (BitUtil::GetBit(mask, offset + j) &&
BitUtil::GetBit(cond_values, cond_offset + offset + j)) {
+ RETURN_NOT_OK(CheckValidTransposition(values_datum, offset + j,
+ /*length=*/1, transpose_map.get()));
CopyValues(values_datum, offset + j, /*length=*/1, out_valid,
- out_values, out_offset + offset + j);
+ out_values, out_offset + offset + j, transpose_map.get());
BitUtil::SetBitTo(mask, offset + j, false);
}
}
@@ -1434,25 +1755,31 @@ Status ExecArrayCaseWhen(KernelContext* ctx, const ExecBatch& batch, Datum* out)
Bitmap bitmaps[3] = {{mask, /*offset=*/0, batch.length},
{cond_values, cond_offset, batch.length},
{cond_valid, cond_offset, batch.length}};
+ Status valid_transposition = Status::OK();
Bitmap::VisitWords(bitmaps, [&](std::array words) {
const uint64_t word = words[0] & words[1] & words[2];
const int64_t block_length = std::min(64, batch.length - offset);
if (word == std::numeric_limits::max()) {
+ valid_transposition &= CheckValidTransposition(
+ values_datum, offset, block_length, transpose_map.get());
CopyValues(values_datum, offset, block_length, out_valid, out_values,
- out_offset + offset);
+ out_offset + offset, transpose_map.get());
BitUtil::SetBitsTo(mask, offset, block_length, false);
} else if (word) {
for (int64_t j = 0; j < block_length; ++j) {
if (BitUtil::GetBit(mask, offset + j) &&
BitUtil::GetBit(cond_valid, cond_offset + offset + j) &&
BitUtil::GetBit(cond_values, cond_offset + offset + j)) {
+ valid_transposition &= CheckValidTransposition(
+ values_datum, offset + j, /*length=*/1, transpose_map.get());
CopyValues(values_datum, offset + j, /*length=*/1, out_valid,
- out_values, out_offset + offset + j);
+ out_values, out_offset + offset + j, transpose_map.get());
BitUtil::SetBitTo(mask, offset + j, false);
}
}
}
});
+ RETURN_NOT_OK(valid_transposition);
}
}
if (!have_else_arg) {
@@ -1483,6 +1810,18 @@ Status ExecArrayCaseWhen(KernelContext* ctx, const ExecBatch& batch, Datum* out)
}
offset += block.length;
}
+ } else if (is_dictionary_type::value) {
+ // Check that any 'else' slots that were not overwritten are valid transpositions.
+ arrow::internal::SetBitRunReader reader(mask, /*offset=*/0, batch.length);
+ if (is_dictionary_type::value) {
+ ARROW_ASSIGN_OR_RAISE(transpose_map, remapper->Remap(batch.values.back()));
+ }
+ while (true) {
+ const auto run = reader.NextRun();
+ if (run.length == 0) break;
+ RETURN_NOT_OK(CheckValidTransposition(batch.values.back(), run.position, run.length,
+ transpose_map.get()));
+ }
}
return Status::OK();
}
diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc
index ca6b688eb3d..c8c22f5ae76 100644
--- a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc
@@ -625,32 +625,38 @@ TYPED_TEST(TestCaseWhenNumeric, ListOfType) {
}
template
-class TestCaseWhenInteger : public ::testing::Test {};
+class TestCaseWhenDict : public ::testing::Test {};
-TYPED_TEST_SUITE(TestCaseWhenInteger, IntegralArrowTypes);
+struct JsonDict {
+ std::shared_ptr type;
+ std::string value;
+};
-TYPED_TEST(TestCaseWhenInteger, DictionaryEncodingSimple) {
- auto type = dictionary(default_type_instance(), utf8());
+TYPED_TEST_SUITE(TestCaseWhenDict, IntegralArrowTypes);
+
+TYPED_TEST(TestCaseWhenDict, Simple) {
auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]");
auto cond2 = ArrayFromJSON(boolean(), "[true, false, true, null]");
- auto dict = R"(["a", null, "bc", "def"])";
- auto values_null = DictArrayFromJSON(type, "[null, null, null, null]", dict);
- auto values1 = DictArrayFromJSON(type, "[0, null, 3, 1]", dict);
- auto values2 = DictArrayFromJSON(type, "[2, 1, null, 0]", dict);
-
- // Easy case: all arguments have the same dictionary
- // CheckScalar("case_when", {MakeStruct({Datum(false)}), DictScalarFromJSON(type, "1",
- // dict)},
- // DictScalarFromJSON(type, "[0, null, null, null]", dict));
- CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2},
- DictArrayFromJSON(type, "[0, null, null, null]", dict));
- CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2, values1},
- DictArrayFromJSON(type, "[0, null, null, 1]", dict));
- CheckScalar("case_when", {MakeStruct({cond1, cond2}), values_null, values2, values1},
- DictArrayFromJSON(type, "[null, null, null, 1]", dict));
+ for (const auto& dict :
+ {JsonDict{utf8(), R"(["a", null, "bc", "def"])"},
+ JsonDict{int64(), "[1, null, 2, 3]"},
+ JsonDict{decimal256(3, 2), R"(["1.23", null, "3.45", "6.78"])"}}) {
+ auto type = dictionary(default_type_instance(), dict.type);
+ auto values_null = DictArrayFromJSON(type, "[null, null, null, null]", dict.value);
+ auto values1 = DictArrayFromJSON(type, "[0, null, 3, 1]", dict.value);
+ auto values2 = DictArrayFromJSON(type, "[2, 1, null, 0]", dict.value);
+
+ // Easy case: all arguments have the same dictionary
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2},
+ DictArrayFromJSON(type, "[0, null, null, null]", dict.value));
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2, values1},
+ DictArrayFromJSON(type, "[0, null, null, 1]", dict.value));
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values_null, values2, values1},
+ DictArrayFromJSON(type, "[null, null, null, 1]", dict.value));
+ }
}
-TYPED_TEST(TestCaseWhenInteger, DictionaryEncodingMixed) {
+TYPED_TEST(TestCaseWhenDict, Mixed) {
auto type = dictionary(default_type_instance(), utf8());
auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]");
auto cond2 = ArrayFromJSON(boolean(), "[true, false, true, null]");
@@ -674,16 +680,18 @@ TYPED_TEST(TestCaseWhenInteger, DictionaryEncodingMixed) {
ArrayFromJSON(utf8(), R"([null, null, null, null])"));
}
-TYPED_TEST(TestCaseWhenInteger, DictionaryEncodingDifferentDictionaries) {
+TYPED_TEST(TestCaseWhenDict, DifferentDictionaries) {
auto type = dictionary(default_type_instance(), utf8());
auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]");
auto cond2 = ArrayFromJSON(boolean(), "[true, false, true, null]");
auto dict1 = R"(["a", null, "bc", "def"])";
auto dict2 = R"(["bc", "foo", null, "a"])";
+ auto dict3 = R"(["def", "a", "a", "bc"])";
auto values1_null = DictArrayFromJSON(type, "[null, null, null, null]", dict1);
auto values2_null = DictArrayFromJSON(type, "[null, null, null, null]", dict2);
auto values1 = DictArrayFromJSON(type, "[0, null, 3, 1]", dict1);
auto values2 = DictArrayFromJSON(type, "[2, 1, null, 0]", dict2);
+ auto values3 = DictArrayFromJSON(type, "[0, 1, 2, 3]", dict3);
// For scalar conditions, we borrow the dictionary of the chosen output (or the first
// input when outputting null)
@@ -697,12 +705,108 @@ TYPED_TEST(TestCaseWhenInteger, DictionaryEncodingDifferentDictionaries) {
values2_null);
// For array conditions, we always borrow the dictionary of the first input
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2},
+ DictArrayFromJSON(type, "[0, null, null, null]", dict1));
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2, values1},
+ DictArrayFromJSON(type, "[0, null, null, 1]", dict1));
// When mixing dictionaries, we try to map other dictionaries onto the first one
+ // Don't check the scalar cases since we don't remap dictionaries in that case
+ CheckScalarNonRecursive(
+ "case_when",
+ {MakeStruct({ArrayFromJSON(boolean(), "[true, true, false, false]")}), values1,
+ values2},
+ DictArrayFromJSON(type, "[0, null, null, 2]", dict1));
+ CheckScalarNonRecursive(
+ "case_when",
+ {MakeStruct({ArrayFromJSON(boolean(), "[true, true, false, false]"),
+ ArrayFromJSON(boolean(), "[true, false, true, false]")}),
+ values1, values2},
+ DictArrayFromJSON(type, "[0, null, null, null]", dict1));
+ CheckScalarNonRecursive(
+ "case_when",
+ {MakeStruct({ArrayFromJSON(boolean(), "[false, false, false, false]"),
+ ArrayFromJSON(boolean(), "[true, true, true, true]")}),
+ values1, values3},
+ DictArrayFromJSON(type, "[3, 0, 0, 2]", dict1));
+ CheckScalarNonRecursive(
+ "case_when",
+ {MakeStruct({ArrayFromJSON(boolean(), "[null, null, null, true]"),
+ ArrayFromJSON(boolean(), "[true, true, true, true]")}),
+ values1, values3},
+ DictArrayFromJSON(type, "[3, 0, 0, 1]", dict1));
+ CheckScalarNonRecursive(
+ "case_when",
+ {
+ MakeStruct({ArrayFromJSON(boolean(), "[true, true, false, false]")}),
+ DictScalarFromJSON(type, "0", dict1),
+ DictScalarFromJSON(type, "0", dict2),
+ },
+ DictArrayFromJSON(type, "[0, 0, 2, 2]", dict1));
+ CheckScalarNonRecursive(
+ "case_when",
+ {
+ MakeStruct({ArrayFromJSON(boolean(), "[true, true, false, false]"),
+ ArrayFromJSON(boolean(), "[false, false, true, true]")}),
+ DictScalarFromJSON(type, "0", dict1),
+ DictScalarFromJSON(type, "0", dict2),
+ },
+ DictArrayFromJSON(type, "[0, 0, 2, 2]", dict1));
// If we can't map values from a dictionary, then raise an error
+ // Unmappable value is in the else clause
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid,
+ ::testing::HasSubstr(
+ "Cannot map dictionary index 1 at position 1 to the common dictionary"),
+ CallFunction(
+ "case_when",
+ {MakeStruct({ArrayFromJSON(boolean(), "[false, false, false, false]")}),
+ values1, values2}));
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid,
+ ::testing::HasSubstr("Cannot map dictionary index 1 to the common dictionary"),
+ CallFunction(
+ "case_when",
+ {MakeStruct({ArrayFromJSON(boolean(), "[false, false, false, false]")}),
+ values1, DictScalarFromJSON(type, "1", dict2)}));
+ // Unmappable value is in a branch (test multiple times to ensure coverage of branches
+ // in impl)
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid,
+ ::testing::HasSubstr(
+ "Cannot map dictionary index 1 at position 1 to the common dictionary"),
+ CallFunction("case_when",
+ {MakeStruct({Datum(false),
+ ArrayFromJSON(boolean(), "[true, true, true, true]")}),
+ values1, values2}));
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid,
+ ::testing::HasSubstr(
+ "Cannot map dictionary index 1 at position 1 to the common dictionary"),
+ CallFunction("case_when",
+ {MakeStruct({Datum(false),
+ ArrayFromJSON(boolean(), "[false, true, false, false]")}),
+ values1, values2}));
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid,
+ ::testing::HasSubstr(
+ "Cannot map dictionary index 1 at position 1 to the common dictionary"),
+ CallFunction("case_when",
+ {MakeStruct({Datum(false),
+ ArrayFromJSON(boolean(), "[null, true, null, null]")}),
+ values1, values2}));
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid,
+ ::testing::HasSubstr("Cannot map dictionary index 1 to the common dictionary"),
+ CallFunction("case_when",
+ {MakeStruct({Datum(false),
+ ArrayFromJSON(boolean(), "[true, true, true, null]")}),
+ values1, DictScalarFromJSON(type, "1", dict2)}));
+
+ // ...or optionally, emit null
- // ...or optionally, raise an error
+ // TODO: this is not implemented yet
}
TEST(TestCaseWhen, Null) {
diff --git a/cpp/src/arrow/compute/kernels/test_util.cc b/cpp/src/arrow/compute/kernels/test_util.cc
index 4a9215101b1..9a779e49163 100644
--- a/cpp/src/arrow/compute/kernels/test_util.cc
+++ b/cpp/src/arrow/compute/kernels/test_util.cc
@@ -46,13 +46,6 @@ DatumVector GetDatums(const std::vector& inputs) {
return datums;
}
-void CheckScalarNonRecursive(const std::string& func_name, const DatumVector& inputs,
- const Datum& expected, const FunctionOptions* options) {
- ASSERT_OK_AND_ASSIGN(Datum out, CallFunction(func_name, inputs, options));
- ValidateOutput(out);
- AssertDatumsEqual(expected, out, /*verbose=*/true);
-}
-
template
DatumVector SliceArrays(const DatumVector& inputs, SliceArgs... slice_args) {
DatumVector sliced;
@@ -80,6 +73,13 @@ ScalarVector GetScalars(const DatumVector& inputs, int64_t index) {
} // namespace
+void CheckScalarNonRecursive(const std::string& func_name, const DatumVector& inputs,
+ const Datum& expected, const FunctionOptions* options) {
+ ASSERT_OK_AND_ASSIGN(Datum out, CallFunction(func_name, inputs, options));
+ ValidateOutput(out);
+ AssertDatumsEqual(expected, out, /*verbose=*/true);
+}
+
void CheckScalar(std::string func_name, const ScalarVector& inputs,
std::shared_ptr expected, const FunctionOptions* options) {
ASSERT_OK_AND_ASSIGN(Datum out, CallFunction(func_name, GetDatums(inputs), options));
diff --git a/cpp/src/arrow/compute/kernels/test_util.h b/cpp/src/arrow/compute/kernels/test_util.h
index 79745b05552..0931f3c77bc 100644
--- a/cpp/src/arrow/compute/kernels/test_util.h
+++ b/cpp/src/arrow/compute/kernels/test_util.h
@@ -67,6 +67,8 @@ inline std::string CompareOperatorToFunctionName(CompareOperator op) {
return function_names[op];
}
+// Call the function with the given arguments, as well as slices of
+// the arguments and scalars extracted from the arguments.
void CheckScalar(std::string func_name, const ScalarVector& inputs,
std::shared_ptr expected,
const FunctionOptions* options = nullptr);
@@ -74,6 +76,11 @@ void CheckScalar(std::string func_name, const ScalarVector& inputs,
void CheckScalar(std::string func_name, const DatumVector& inputs, Datum expected,
const FunctionOptions* options = nullptr);
+// Just call the function with the given arguments.
+void CheckScalarNonRecursive(const std::string& func_name, const DatumVector& inputs,
+ const Datum& expected,
+ const FunctionOptions* options = nullptr);
+
void CheckScalarUnary(std::string func_name, std::shared_ptr in_ty,
std::string json_input, std::shared_ptr out_ty,
std::string json_expected,
From 17230ee41b6c8080dff25f4a1e125a4120fd4ed0 Mon Sep 17 00:00:00 2001
From: David Li
Date: Mon, 30 Aug 2021 15:15:25 -0400
Subject: [PATCH 05/21] ARROW-13573: [C++] Handle nested dictionaries
---
cpp/src/arrow/array/array_test.cc | 12 +-
cpp/src/arrow/array/builder_base.cc | 10 +-
cpp/src/arrow/array/builder_base.h | 13 +-
cpp/src/arrow/array/builder_dict.cc | 38 ++-
cpp/src/arrow/array/builder_dict.h | 167 +++++++++++
cpp/src/arrow/builder.cc | 269 ++++++++++--------
.../arrow/compute/kernels/scalar_if_else.cc | 2 +-
.../compute/kernels/scalar_if_else_test.cc | 59 ++++
8 files changed, 410 insertions(+), 160 deletions(-)
diff --git a/cpp/src/arrow/array/array_test.cc b/cpp/src/arrow/array/array_test.cc
index d9617c4e603..2e3d4057094 100644
--- a/cpp/src/arrow/array/array_test.cc
+++ b/cpp/src/arrow/array/array_test.cc
@@ -456,7 +456,7 @@ TEST_F(TestArray, TestValidateNullCount) {
void AssertAppendScalar(MemoryPool* pool, const std::shared_ptr& scalar) {
std::unique_ptr builder;
auto null_scalar = MakeNullScalar(scalar->type);
- ASSERT_OK(MakeBuilder(pool, scalar->type, &builder));
+ ASSERT_OK(MakeBuilderExactIndex(pool, scalar->type, &builder));
ASSERT_OK(builder->AppendScalar(*scalar));
ASSERT_OK(builder->AppendScalar(*scalar));
ASSERT_OK(builder->AppendScalar(*null_scalar));
@@ -471,15 +471,18 @@ void AssertAppendScalar(MemoryPool* pool, const std::shared_ptr& scalar)
ASSERT_EQ(out->length(), 9);
const bool can_check_nulls = internal::HasValidityBitmap(out->type()->id());
+ // For a dictionary builder, the output dictionary won't necessarily be the same
+ const bool can_check_values = !is_dictionary(out->type()->id());
if (can_check_nulls) {
ASSERT_EQ(out->null_count(), 4);
}
+
for (const auto index : {0, 1, 3, 5, 6}) {
ASSERT_FALSE(out->IsNull(index));
ASSERT_OK_AND_ASSIGN(auto scalar_i, out->GetScalar(index));
ASSERT_OK(scalar_i->ValidateFull());
- AssertScalarsEqual(*scalar, *scalar_i, /*verbose=*/true);
+ if (can_check_values) AssertScalarsEqual(*scalar, *scalar_i, /*verbose=*/true);
}
for (const auto index : {2, 4, 7, 8}) {
ASSERT_EQ(out->IsNull(index), can_check_nulls);
@@ -575,8 +578,6 @@ TEST_F(TestArray, TestMakeArrayFromScalar) {
}
for (auto scalar : scalars) {
- // TODO(ARROW-13197): appending dictionary scalars not implemented
- if (is_dictionary(scalar->type->id())) continue;
AssertAppendScalar(pool_, scalar);
}
}
@@ -634,9 +635,6 @@ TEST_F(TestArray, TestMakeArrayFromMapScalar) {
TEST_F(TestArray, TestAppendArraySlice) {
auto scalars = GetScalars();
for (const auto& scalar : scalars) {
- // TODO(ARROW-13573): appending dictionary arrays not implemented
- if (is_dictionary(scalar->type->id())) continue;
-
ARROW_SCOPED_TRACE(*scalar->type);
ASSERT_OK_AND_ASSIGN(auto array, MakeArrayFromScalar(*scalar, 16));
ASSERT_OK_AND_ASSIGN(auto nulls, MakeArrayOfNull(scalar->type, 16));
diff --git a/cpp/src/arrow/array/builder_base.cc b/cpp/src/arrow/array/builder_base.cc
index 2f4e63b546d..117b9d37632 100644
--- a/cpp/src/arrow/array/builder_base.cc
+++ b/cpp/src/arrow/array/builder_base.cc
@@ -22,6 +22,7 @@
#include
#include "arrow/array/array_base.h"
+#include "arrow/array/builder_dict.h"
#include "arrow/array/data.h"
#include "arrow/array/util.h"
#include "arrow/buffer.h"
@@ -268,15 +269,6 @@ struct AppendScalarImpl {
} // namespace
-Status ArrayBuilder::AppendScalar(const Scalar& scalar) {
- if (!scalar.type->Equals(type())) {
- return Status::Invalid("Cannot append scalar of type ", scalar.type->ToString(),
- " to builder for type ", type()->ToString());
- }
- std::shared_ptr shared{const_cast(&scalar), [](Scalar*) {}};
- return AppendScalarImpl{&shared, &shared + 1, /*n_repeats=*/1, this}.Convert();
-}
-
Status ArrayBuilder::AppendScalar(const Scalar& scalar, int64_t n_repeats) {
if (!scalar.type->Equals(type())) {
return Status::Invalid("Cannot append scalar of type ", scalar.type->ToString(),
diff --git a/cpp/src/arrow/array/builder_base.h b/cpp/src/arrow/array/builder_base.h
index 67203e79071..87e39c3fe9f 100644
--- a/cpp/src/arrow/array/builder_base.h
+++ b/cpp/src/arrow/array/builder_base.h
@@ -119,9 +119,9 @@ class ARROW_EXPORT ArrayBuilder {
virtual Status AppendEmptyValues(int64_t length) = 0;
/// \brief Append a value from a scalar
- Status AppendScalar(const Scalar& scalar);
- Status AppendScalar(const Scalar& scalar, int64_t n_repeats);
- Status AppendScalars(const ScalarVector& scalars);
+ Status AppendScalar(const Scalar& scalar) { return AppendScalar(scalar, 1); }
+ virtual Status AppendScalar(const Scalar& scalar, int64_t n_repeats);
+ virtual Status AppendScalars(const ScalarVector& scalars);
/// \brief Append a range of values from an array.
///
@@ -282,6 +282,13 @@ ARROW_EXPORT
Status MakeBuilder(MemoryPool* pool, const std::shared_ptr& type,
std::unique_ptr* out);
+/// \brief Construct an empty ArrayBuilder corresponding to the data
+/// type, where any top-level or nested dictionary builders return the
+/// exact index type specified by the type.
+ARROW_EXPORT
+Status MakeBuilderExactIndex(MemoryPool* pool, const std::shared_ptr& type,
+ std::unique_ptr* out);
+
/// \brief Construct an empty DictionaryBuilder initialized optionally
/// with a pre-existing dictionary
/// \param[in] pool the MemoryPool to use for allocations
diff --git a/cpp/src/arrow/array/builder_dict.cc b/cpp/src/arrow/array/builder_dict.cc
index b13f6a2db34..b554c1d7099 100644
--- a/cpp/src/arrow/array/builder_dict.cc
+++ b/cpp/src/arrow/array/builder_dict.cc
@@ -159,23 +159,31 @@ DictionaryMemoTable::DictionaryMemoTable(MemoryPool* pool,
DictionaryMemoTable::~DictionaryMemoTable() = default;
-#define GET_OR_INSERT(C_TYPE) \
- Status DictionaryMemoTable::GetOrInsert( \
- const typename CTypeTraits::ArrowType*, C_TYPE value, int32_t* out) { \
- return impl_->GetOrInsert::ArrowType>(value, out); \
+#define GET_OR_INSERT(ARROW_TYPE) \
+ Status DictionaryMemoTable::GetOrInsert( \
+ const ARROW_TYPE*, typename ARROW_TYPE::c_type value, int32_t* out) { \
+ return impl_->GetOrInsert(value, out); \
}
-GET_OR_INSERT(bool)
-GET_OR_INSERT(int8_t)
-GET_OR_INSERT(int16_t)
-GET_OR_INSERT(int32_t)
-GET_OR_INSERT(int64_t)
-GET_OR_INSERT(uint8_t)
-GET_OR_INSERT(uint16_t)
-GET_OR_INSERT(uint32_t)
-GET_OR_INSERT(uint64_t)
-GET_OR_INSERT(float)
-GET_OR_INSERT(double)
+GET_OR_INSERT(BooleanType)
+GET_OR_INSERT(Int8Type)
+GET_OR_INSERT(Int16Type)
+GET_OR_INSERT(Int32Type)
+GET_OR_INSERT(Int64Type)
+GET_OR_INSERT(UInt8Type)
+GET_OR_INSERT(UInt16Type)
+GET_OR_INSERT(UInt32Type)
+GET_OR_INSERT(UInt64Type)
+GET_OR_INSERT(FloatType)
+GET_OR_INSERT(DoubleType)
+GET_OR_INSERT(DurationType);
+GET_OR_INSERT(TimestampType);
+GET_OR_INSERT(Date32Type);
+GET_OR_INSERT(Date64Type);
+GET_OR_INSERT(Time32Type);
+GET_OR_INSERT(Time64Type);
+GET_OR_INSERT(DayTimeIntervalType);
+GET_OR_INSERT(MonthIntervalType);
#undef GET_OR_INSERT
diff --git a/cpp/src/arrow/array/builder_dict.h b/cpp/src/arrow/array/builder_dict.h
index 455cb3df7b1..1b97f98f290 100644
--- a/cpp/src/arrow/array/builder_dict.h
+++ b/cpp/src/arrow/array/builder_dict.h
@@ -37,6 +37,7 @@
#include "arrow/util/decimal.h"
#include "arrow/util/macros.h"
#include "arrow/util/visibility.h"
+#include "arrow/visitor_inline.h"
namespace arrow {
@@ -97,6 +98,15 @@ class ARROW_EXPORT DictionaryMemoTable {
Status GetOrInsert(const UInt16Type*, uint16_t value, int32_t* out);
Status GetOrInsert(const UInt32Type*, uint32_t value, int32_t* out);
Status GetOrInsert(const UInt64Type*, uint64_t value, int32_t* out);
+ Status GetOrInsert(const DurationType*, int64_t value, int32_t* out);
+ Status GetOrInsert(const TimestampType*, int64_t value, int32_t* out);
+ Status GetOrInsert(const Date32Type*, int32_t value, int32_t* out);
+ Status GetOrInsert(const Date64Type*, int64_t value, int32_t* out);
+ Status GetOrInsert(const Time32Type*, int32_t value, int32_t* out);
+ Status GetOrInsert(const Time64Type*, int64_t value, int32_t* out);
+ Status GetOrInsert(const DayTimeIntervalType*,
+ DayTimeIntervalType::DayMilliseconds value, int32_t* out);
+ Status GetOrInsert(const MonthIntervalType*, int32_t value, int32_t* out);
Status GetOrInsert(const FloatType*, float value, int32_t* out);
Status GetOrInsert(const DoubleType*, double value, int32_t* out);
@@ -282,6 +292,163 @@ class DictionaryBuilderBase : public ArrayBuilder {
return indices_builder_.AppendEmptyValues(length);
}
+ Status AppendScalar(const Scalar& scalar, int64_t n_repeats) override {
+ if (!scalar.type->Equals(type())) {
+ return Status::Invalid("Cannot append scalar of type ", scalar.type->ToString(),
+ " to builder for type ", type()->ToString());
+ }
+ if (!scalar.is_valid) return AppendNulls(n_repeats);
+
+ const auto& dict_ty = internal::checked_cast(*scalar.type);
+ const DictionaryScalar& dict_scalar =
+ internal::checked_cast(scalar);
+ const auto& dict = internal::checked_cast::ArrayType&>(
+ *dict_scalar.value.dictionary);
+ switch (dict_ty.index_type()->id()) {
+ case Type::UINT8: {
+ const auto& value = dict.GetView(
+ internal::checked_cast(*dict_scalar.value.index).value);
+ for (int64_t i = 0; i < n_repeats; i++) {
+ ARROW_RETURN_NOT_OK(Append(value));
+ }
+ break;
+ }
+ case Type::INT8: {
+ const auto& value = dict.GetView(
+ internal::checked_cast(*dict_scalar.value.index).value);
+ for (int64_t i = 0; i < n_repeats; i++) {
+ ARROW_RETURN_NOT_OK(Append(value));
+ }
+ break;
+ }
+ case Type::UINT16: {
+ const auto& value = dict.GetView(
+ internal::checked_cast(*dict_scalar.value.index).value);
+ for (int64_t i = 0; i < n_repeats; i++) {
+ ARROW_RETURN_NOT_OK(Append(value));
+ }
+ break;
+ }
+ case Type::INT16: {
+ const auto& value = dict.GetView(
+ internal::checked_cast(*dict_scalar.value.index).value);
+ for (int64_t i = 0; i < n_repeats; i++) {
+ ARROW_RETURN_NOT_OK(Append(value));
+ }
+ break;
+ }
+ case Type::UINT32: {
+ const auto& value = dict.GetView(
+ internal::checked_cast(*dict_scalar.value.index).value);
+ for (int64_t i = 0; i < n_repeats; i++) {
+ ARROW_RETURN_NOT_OK(Append(value));
+ }
+ break;
+ }
+ case Type::INT32: {
+ const auto& value = dict.GetView(
+ internal::checked_cast(*dict_scalar.value.index).value);
+ for (int64_t i = 0; i < n_repeats; i++) {
+ ARROW_RETURN_NOT_OK(Append(value));
+ }
+ break;
+ }
+ case Type::UINT64: {
+ const auto& value = dict.GetView(
+ internal::checked_cast(*dict_scalar.value.index).value);
+ for (int64_t i = 0; i < n_repeats; i++) {
+ ARROW_RETURN_NOT_OK(Append(value));
+ }
+ break;
+ }
+ case Type::INT64: {
+ const auto& value = dict.GetView(
+ internal::checked_cast(*dict_scalar.value.index).value);
+ for (int64_t i = 0; i < n_repeats; i++) {
+ ARROW_RETURN_NOT_OK(Append(value));
+ }
+ break;
+ }
+ default:
+ return Status::TypeError("Invalid index type: ", dict_ty);
+ }
+ return Status::OK();
+ }
+
+ Status AppendScalars(const ScalarVector& scalars) override {
+ for (const auto& scalar : scalars) {
+ ARROW_RETURN_NOT_OK(AppendScalar(*scalar, /*n_repeats=*/1));
+ }
+ return Status::OK();
+ }
+
+ Status AppendArraySlice(const ArrayData& array, int64_t offset, int64_t length) final {
+ // Visit the indices and insert the unpacked values.
+ const auto& dict_ty = internal::checked_cast(*array.type);
+ const typename TypeTraits::ArrayType dict(array.dictionary);
+ switch (dict_ty.index_type()->id()) {
+ case Type::UINT8: {
+ const uint8_t* values = array.GetValues(1) + offset;
+ return VisitBitBlocks(
+ array.buffers[0], array.offset + offset, std::min(array.length, length),
+ [&](int64_t position) { return Append(dict.GetView(values[position])); },
+ [&]() { return AppendNull(); });
+ }
+ case Type::INT8: {
+ const int8_t* values = array.GetValues(1) + offset;
+ return VisitBitBlocks(
+ array.buffers[0], array.offset + offset, std::min(array.length, length),
+ [&](int64_t position) { return Append(dict.GetView(values[position])); },
+ [&]() { return AppendNull(); });
+ }
+ case Type::UINT16: {
+ const uint16_t* values = array.GetValues(1) + offset;
+ return VisitBitBlocks(
+ array.buffers[0], array.offset + offset, std::min(array.length, length),
+ [&](int64_t position) { return Append(dict.GetView(values[position])); },
+ [&]() { return AppendNull(); });
+ }
+ case Type::INT16: {
+ const int16_t* values = array.GetValues(1) + offset;
+ return VisitBitBlocks(
+ array.buffers[0], array.offset + offset, std::min(array.length, length),
+ [&](int64_t position) { return Append(dict.GetView(values[position])); },
+ [&]() { return AppendNull(); });
+ }
+ case Type::UINT32: {
+ const uint32_t* values = array.GetValues(1) + offset;
+ return VisitBitBlocks(
+ array.buffers[0], array.offset + offset, std::min(array.length, length),
+ [&](int64_t position) { return Append(dict.GetView(values[position])); },
+ [&]() { return AppendNull(); });
+ }
+ case Type::INT32: {
+ const int32_t* values = array.GetValues(1) + offset;
+ return VisitBitBlocks(
+ array.buffers[0], array.offset + offset, std::min(array.length, length),
+ [&](int64_t position) { return Append(dict.GetView(values[position])); },
+ [&]() { return AppendNull(); });
+ }
+ case Type::UINT64: {
+ const uint64_t* values = array.GetValues(1) + offset;
+ return VisitBitBlocks(
+ array.buffers[0], array.offset + offset, std::min(array.length, length),
+ [&](int64_t position) { return Append(dict.GetView(values[position])); },
+ [&]() { return AppendNull(); });
+ }
+ case Type::INT64: {
+ const int64_t* values = array.GetValues(1) + offset;
+ return VisitBitBlocks(
+ array.buffers[0], array.offset + offset, std::min(array.length, length),
+ [&](int64_t position) { return Append(dict.GetView(values[position])); },
+ [&]() { return AppendNull(); });
+ }
+ default:
+ return Status::TypeError("Invalid index type: ", dict_ty);
+ }
+ return Status::OK();
+ }
+
/// \brief Insert values into the dictionary's memo, but do not append any
/// indices. Can be used to initialize a new builder with known dictionary
/// values
diff --git a/cpp/src/arrow/builder.cc b/cpp/src/arrow/builder.cc
index 37cc9e07ad4..115a97e9389 100644
--- a/cpp/src/arrow/builder.cc
+++ b/cpp/src/arrow/builder.cc
@@ -41,14 +41,10 @@ struct DictionaryBuilderCase {
}
Status Visit(const NullType&) { return CreateFor(); }
- Status Visit(const BinaryType&) { return Create(); }
- Status Visit(const StringType&) { return Create(); }
- Status Visit(const LargeBinaryType&) {
- return Create>();
- }
- Status Visit(const LargeStringType&) {
- return Create>();
- }
+ Status Visit(const BinaryType&) { return CreateFor(); }
+ Status Visit(const StringType&) { return CreateFor(); }
+ Status Visit(const LargeBinaryType&) { return CreateFor(); }
+ Status Visit(const LargeStringType&) { return CreateFor(); }
Status Visit(const FixedSizeBinaryType&) { return CreateFor(); }
Status Visit(const Decimal128Type&) { return CreateFor(); }
Status Visit(const Decimal256Type&) { return CreateFor(); }
@@ -63,19 +59,50 @@ struct DictionaryBuilderCase {
template
Status CreateFor() {
- return Create>();
- }
-
- template
- Status Create() {
- BuilderType* builder;
+ using AdaptiveBuilderType = DictionaryBuilder;
if (dictionary != nullptr) {
- builder = new BuilderType(dictionary, pool);
+ out->reset(new AdaptiveBuilderType(dictionary, pool));
+ } else if (exact_index_type) {
+ switch (index_type->id()) {
+ case Type::UINT8:
+ out->reset(new internal::DictionaryBuilderBase(
+ value_type, pool));
+ break;
+ case Type::INT8:
+ out->reset(new internal::DictionaryBuilderBase(
+ value_type, pool));
+ break;
+ case Type::UINT16:
+ out->reset(new internal::DictionaryBuilderBase(
+ value_type, pool));
+ break;
+ case Type::INT16:
+ out->reset(new internal::DictionaryBuilderBase(
+ value_type, pool));
+ break;
+ case Type::UINT32:
+ out->reset(new internal::DictionaryBuilderBase(
+ value_type, pool));
+ break;
+ case Type::INT32:
+ out->reset(new internal::DictionaryBuilderBase(
+ value_type, pool));
+ break;
+ case Type::UINT64:
+ out->reset(new internal::DictionaryBuilderBase(
+ value_type, pool));
+ break;
+ case Type::INT64:
+ out->reset(new internal::DictionaryBuilderBase(
+ value_type, pool));
+ break;
+ default:
+ return Status::TypeError("MakeBuilder: invalid index type ", *index_type);
+ }
} else {
auto start_int_size = internal::GetByteWidth(*index_type);
- builder = new BuilderType(start_int_size, value_type, pool);
+ out->reset(new AdaptiveBuilderType(start_int_size, value_type, pool));
}
- out->reset(builder);
return Status::OK();
}
@@ -85,138 +112,130 @@ struct DictionaryBuilderCase {
const std::shared_ptr& index_type;
const std::shared_ptr& value_type;
const std::shared_ptr& dictionary;
+ bool exact_index_type;
std::unique_ptr* out;
};
-#define BUILDER_CASE(TYPE_CLASS) \
- case TYPE_CLASS##Type::type_id: \
- out->reset(new TYPE_CLASS##Builder(type, pool)); \
+struct MakeBuilderImpl {
+ template
+ enable_if_not_nested Visit(const T&) {
+ out.reset(new typename TypeTraits::BuilderType(type, pool));
return Status::OK();
+ }
-Result>> FieldBuilders(const DataType& type,
- MemoryPool* pool) {
- std::vector> field_builders;
+ Status Visit(const DictionaryType& dict_type) {
+ DictionaryBuilderCase visitor = {pool,
+ dict_type.index_type(),
+ dict_type.value_type(),
+ /*dictionary=*/nullptr,
+ exact_index_type,
+ &out};
+ return visitor.Make();
+ }
- for (const auto& field : type.fields()) {
- std::unique_ptr builder;
- RETURN_NOT_OK(MakeBuilder(pool, field->type(), &builder));
- field_builders.emplace_back(std::move(builder));
+ Status Visit(const ListType& list_type) {
+ std::shared_ptr value_type = list_type.value_type();
+ ARROW_ASSIGN_OR_RAISE(auto value_builder, ChildBuilder(value_type));
+ out.reset(new ListBuilder(pool, std::move(value_builder), type));
+ return Status::OK();
}
- return field_builders;
-}
+ Status Visit(const LargeListType& list_type) {
+ std::shared_ptr value_type = list_type.value_type();
+ ARROW_ASSIGN_OR_RAISE(auto value_builder, ChildBuilder(value_type));
+ out.reset(new LargeListBuilder(pool, std::move(value_builder), type));
+ return Status::OK();
+ }
-Status MakeBuilder(MemoryPool* pool, const std::shared_ptr& type,
- std::unique_ptr* out) {
- switch (type->id()) {
- case Type::NA: {
- out->reset(new NullBuilder(pool));
- return Status::OK();
- }
- BUILDER_CASE(UInt8);
- BUILDER_CASE(Int8);
- BUILDER_CASE(UInt16);
- BUILDER_CASE(Int16);
- BUILDER_CASE(UInt32);
- BUILDER_CASE(Int32);
- BUILDER_CASE(UInt64);
- BUILDER_CASE(Int64);
- BUILDER_CASE(Date32);
- BUILDER_CASE(Date64);
- BUILDER_CASE(Duration);
- BUILDER_CASE(Time32);
- BUILDER_CASE(Time64);
- BUILDER_CASE(Timestamp);
- BUILDER_CASE(MonthInterval);
- BUILDER_CASE(DayTimeInterval);
- BUILDER_CASE(MonthDayNanoInterval);
- BUILDER_CASE(Boolean);
- BUILDER_CASE(HalfFloat);
- BUILDER_CASE(Float);
- BUILDER_CASE(Double);
- BUILDER_CASE(String);
- BUILDER_CASE(Binary);
- BUILDER_CASE(LargeString);
- BUILDER_CASE(LargeBinary);
- BUILDER_CASE(FixedSizeBinary);
- BUILDER_CASE(Decimal128);
- BUILDER_CASE(Decimal256);
-
- case Type::DICTIONARY: {
- const auto& dict_type = static_cast(*type);
- DictionaryBuilderCase visitor = {pool, dict_type.index_type(),
- dict_type.value_type(), nullptr, out};
- return visitor.Make();
- }
+ Status Visit(const MapType& map_type) {
+ ARROW_ASSIGN_OR_RAISE(auto key_builder, ChildBuilder(map_type.key_type()));
+ ARROW_ASSIGN_OR_RAISE(auto item_builder, ChildBuilder(map_type.item_type()));
+ out.reset(
+ new MapBuilder(pool, std::move(key_builder), std::move(item_builder), type));
+ return Status::OK();
+ }
- case Type::LIST: {
- std::unique_ptr value_builder;
- std::shared_ptr value_type =
- internal::checked_cast(*type).value_type();
- RETURN_NOT_OK(MakeBuilder(pool, value_type, &value_builder));
- out->reset(new ListBuilder(pool, std::move(value_builder), type));
- return Status::OK();
- }
+ Status Visit(const FixedSizeListType& list_type) {
+ auto value_type = list_type.value_type();
+ ARROW_ASSIGN_OR_RAISE(auto value_builder, ChildBuilder(value_type));
+ out.reset(new FixedSizeListBuilder(pool, std::move(value_builder), type));
+ return Status::OK();
+ }
- case Type::LARGE_LIST: {
- std::unique_ptr value_builder;
- std::shared_ptr value_type =
- internal::checked_cast(*type).value_type();
- RETURN_NOT_OK(MakeBuilder(pool, value_type, &value_builder));
- out->reset(new LargeListBuilder(pool, std::move(value_builder), type));
- return Status::OK();
- }
+ Status Visit(const StructType& struct_type) {
+ ARROW_ASSIGN_OR_RAISE(auto field_builders, FieldBuilders(*type, pool));
+ out.reset(new StructBuilder(type, pool, std::move(field_builders)));
+ return Status::OK();
+ }
- case Type::MAP: {
- const auto& map_type = internal::checked_cast(*type);
- std::unique_ptr key_builder, item_builder;
- RETURN_NOT_OK(MakeBuilder(pool, map_type.key_type(), &key_builder));
- RETURN_NOT_OK(MakeBuilder(pool, map_type.item_type(), &item_builder));
- out->reset(
- new MapBuilder(pool, std::move(key_builder), std::move(item_builder), type));
- return Status::OK();
- }
+ Status Visit(const SparseUnionType&) {
+ ARROW_ASSIGN_OR_RAISE(auto field_builders, FieldBuilders(*type, pool));
+ out.reset(new SparseUnionBuilder(pool, std::move(field_builders), type));
+ return Status::OK();
+ }
- case Type::FIXED_SIZE_LIST: {
- const auto& list_type = internal::checked_cast(*type);
- std::unique_ptr value_builder;
- auto value_type = list_type.value_type();
- RETURN_NOT_OK(MakeBuilder(pool, value_type, &value_builder));
- out->reset(new FixedSizeListBuilder(pool, std::move(value_builder), type));
- return Status::OK();
- }
+ Status Visit(const DenseUnionType&) {
+ ARROW_ASSIGN_OR_RAISE(auto field_builders, FieldBuilders(*type, pool));
+ out.reset(new DenseUnionBuilder(pool, std::move(field_builders), type));
+ return Status::OK();
+ }
- case Type::STRUCT: {
- ARROW_ASSIGN_OR_RAISE(auto field_builders, FieldBuilders(*type, pool));
- out->reset(new StructBuilder(type, pool, std::move(field_builders)));
- return Status::OK();
- }
+ Status Visit(const ExtensionType&) { return NotImplemented(); }
+ Status Visit(const DataType&) { return NotImplemented(); }
- case Type::SPARSE_UNION: {
- ARROW_ASSIGN_OR_RAISE(auto field_builders, FieldBuilders(*type, pool));
- out->reset(new SparseUnionBuilder(pool, std::move(field_builders), type));
- return Status::OK();
- }
+ Status NotImplemented() {
+ return Status::NotImplemented("MakeBuilder: cannot construct builder for type ",
+ type->ToString());
+ }
- case Type::DENSE_UNION: {
- ARROW_ASSIGN_OR_RAISE(auto field_builders, FieldBuilders(*type, pool));
- out->reset(new DenseUnionBuilder(pool, std::move(field_builders), type));
- return Status::OK();
- }
+ Result> ChildBuilder(
+ const std::shared_ptr& type) {
+ MakeBuilderImpl impl{pool, type, exact_index_type, /*out=*/nullptr};
+ RETURN_NOT_OK(VisitTypeInline(*type, &impl));
+ return std::move(impl.out);
+ }
- default:
- break;
+ Result>> FieldBuilders(const DataType& type,
+ MemoryPool* pool) {
+ std::vector> field_builders;
+ for (const auto& field : type.fields()) {
+ std::unique_ptr builder;
+ MakeBuilderImpl impl{pool, field->type(), exact_index_type, /*out=*/nullptr};
+ RETURN_NOT_OK(VisitTypeInline(*field->type(), &impl));
+ field_builders.emplace_back(std::move(impl.out));
+ }
+ return field_builders;
}
- return Status::NotImplemented("MakeBuilder: cannot construct builder for type ",
- type->ToString());
+
+ MemoryPool* pool;
+ const std::shared_ptr& type;
+ bool exact_index_type;
+ std::unique_ptr out;
+};
+
+Status MakeBuilder(MemoryPool* pool, const std::shared_ptr