From 9695c78ce0a1bff9bbdd083d97d06ed3b6ff0680 Mon Sep 17 00:00:00 2001 From: Wes McKinney Date: Mon, 20 Jun 2022 19:45:52 -0500 Subject: [PATCH 1/9] Various porting toward removing ValueDescr from kernel APIs More refactoring More refactoring More refactoring Checkpoint, still some ValueDescr to remove Clean up and delete some more scalar output code More refactoring More refactoring More refactoring, code cleaning more cleaning More cleaning checkpoint Get everything compiling again More refactoring exec.cc refactoring, get compiling again Handle scalar -> array span promotions Work to make all scalars more well formed Make union scalars more 'well formed' All things compiling again Fix some more stuff Fix some more tedious errors Fix some more things Fixed MapBuilder::AppendArraySlice bug checkpoint Fix more bugs C++ tests passing again, restore cumulative_sum scalar tests --- .../arrow/compute_register_example.cc | 3 +- cpp/examples/arrow/udf_example.cc | 6 +- cpp/gdb_arrow.py | 7 +- cpp/src/arrow/array/array_base.cc | 23 +- cpp/src/arrow/array/array_test.cc | 12 +- cpp/src/arrow/array/builder_base.cc | 95 ++-- cpp/src/arrow/array/builder_nested.h | 6 +- cpp/src/arrow/array/data.cc | 221 ++++++-- cpp/src/arrow/array/data.h | 21 +- cpp/src/arrow/array/util.cc | 16 +- cpp/src/arrow/compare.cc | 11 +- cpp/src/arrow/compute/api_vector.cc | 4 +- cpp/src/arrow/compute/cast.cc | 68 +-- cpp/src/arrow/compute/cast.h | 55 +- cpp/src/arrow/compute/cast_internal.h | 29 ++ cpp/src/arrow/compute/exec.cc | 337 ++++++------ cpp/src/arrow/compute/exec.h | 87 +--- cpp/src/arrow/compute/exec/aggregate.cc | 60 +-- cpp/src/arrow/compute/exec/aggregate.h | 6 +- cpp/src/arrow/compute/exec/aggregate_node.cc | 36 +- cpp/src/arrow/compute/exec/expression.cc | 119 ++--- cpp/src/arrow/compute/exec/expression.h | 15 +- .../arrow/compute/exec/expression_internal.h | 21 +- cpp/src/arrow/compute/exec/expression_test.cc | 26 +- cpp/src/arrow/compute/exec/hash_join.cc | 6 +- cpp/src/arrow/compute/exec/hash_join_dict.cc | 18 +- .../arrow/compute/exec/hash_join_node_test.cc | 14 +- cpp/src/arrow/compute/exec/plan_test.cc | 29 +- cpp/src/arrow/compute/exec/project_node.cc | 4 +- cpp/src/arrow/compute/exec/test_util.cc | 23 +- cpp/src/arrow/compute/exec/test_util.h | 8 +- cpp/src/arrow/compute/exec_internal.h | 15 +- cpp/src/arrow/compute/exec_test.cc | 48 +- cpp/src/arrow/compute/function.cc | 155 +++--- cpp/src/arrow/compute/function.h | 24 +- cpp/src/arrow/compute/function_benchmark.cc | 49 +- cpp/src/arrow/compute/function_internal.h | 10 + cpp/src/arrow/compute/function_test.cc | 22 +- cpp/src/arrow/compute/kernel.cc | 90 ++-- cpp/src/arrow/compute/kernel.h | 153 ++---- cpp/src/arrow/compute/kernel_test.cc | 284 ++++------- .../arrow/compute/kernels/aggregate_basic.cc | 84 ++- .../compute/kernels/aggregate_basic_avx2.cc | 8 +- .../compute/kernels/aggregate_basic_avx512.cc | 8 +- .../kernels/aggregate_basic_internal.h | 11 +- .../arrow/compute/kernels/aggregate_mode.cc | 12 +- .../compute/kernels/aggregate_quantile.cc | 10 +- .../compute/kernels/aggregate_tdigest.cc | 9 +- .../arrow/compute/kernels/codegen_internal.cc | 177 ++++--- .../arrow/compute/kernels/codegen_internal.h | 139 ++--- .../compute/kernels/codegen_internal_test.cc | 139 ++--- .../arrow/compute/kernels/hash_aggregate.cc | 142 +++--- .../compute/kernels/hash_aggregate_test.cc | 96 ++-- cpp/src/arrow/compute/kernels/row_encoder.cc | 25 +- cpp/src/arrow/compute/kernels/row_encoder.h | 2 +- .../compute/kernels/scalar_arithmetic.cc | 264 +++++----- .../arrow/compute/kernels/scalar_boolean.cc | 89 +--- .../compute/kernels/scalar_cast_dictionary.cc | 34 -- .../compute/kernels/scalar_cast_internal.cc | 142 ++---- .../compute/kernels/scalar_cast_internal.h | 17 +- .../compute/kernels/scalar_cast_nested.cc | 35 -- .../compute/kernels/scalar_cast_numeric.cc | 58 +-- .../compute/kernels/scalar_cast_string.cc | 67 +-- .../compute/kernels/scalar_cast_temporal.cc | 3 +- .../arrow/compute/kernels/scalar_cast_test.cc | 12 +- .../arrow/compute/kernels/scalar_compare.cc | 149 ++---- .../arrow/compute/kernels/scalar_if_else.cc | 338 +++++------- .../compute/kernels/scalar_if_else_test.cc | 66 ++- .../arrow/compute/kernels/scalar_nested.cc | 374 ++++---------- .../compute/kernels/scalar_nested_test.cc | 12 +- .../arrow/compute/kernels/scalar_random.cc | 4 +- .../compute/kernels/scalar_set_lookup.cc | 25 +- .../compute/kernels/scalar_string_ascii.cc | 479 +++++------------- .../compute/kernels/scalar_string_internal.h | 111 +--- .../compute/kernels/scalar_string_utf8.cc | 24 +- .../compute/kernels/scalar_temporal_unary.cc | 104 +--- .../arrow/compute/kernels/scalar_validity.cc | 212 +++----- cpp/src/arrow/compute/kernels/test_util.cc | 18 +- cpp/src/arrow/compute/kernels/test_util.h | 10 +- .../arrow/compute/kernels/util_internal.cc | 64 --- cpp/src/arrow/compute/kernels/util_internal.h | 11 - .../compute/kernels/vector_array_sort.cc | 17 +- .../compute/kernels/vector_cumulative_ops.cc | 20 +- .../kernels/vector_cumulative_ops_test.cc | 48 +- cpp/src/arrow/compute/kernels/vector_hash.cc | 37 +- .../arrow/compute/kernels/vector_nested.cc | 9 +- .../arrow/compute/kernels/vector_replace.cc | 21 +- .../arrow/compute/kernels/vector_selection.cc | 118 ++--- cpp/src/arrow/compute/row/grouper.cc | 55 +- cpp/src/arrow/compute/row/grouper.h | 2 +- cpp/src/arrow/compute/type_fwd.h | 2 +- cpp/src/arrow/dataset/partition.cc | 2 +- cpp/src/arrow/dataset/scanner.cc | 4 +- cpp/src/arrow/datum.cc | 73 --- cpp/src/arrow/datum.h | 68 --- cpp/src/arrow/datum_test.cc | 26 - cpp/src/arrow/ipc/json_simple.cc | 3 +- cpp/src/arrow/python/gdb.cc | 36 +- cpp/src/arrow/python/udf.cc | 57 +-- cpp/src/arrow/scalar.cc | 328 +++++++----- cpp/src/arrow/scalar.h | 88 ++-- cpp/src/arrow/scalar_test.cc | 180 +++++-- cpp/src/arrow/type.cc | 25 +- cpp/src/arrow/type.h | 61 ++- cpp/src/arrow/type_fwd.h | 40 +- cpp/src/arrow/type_traits.h | 13 + 106 files changed, 3008 insertions(+), 4145 deletions(-) diff --git a/cpp/examples/arrow/compute_register_example.cc b/cpp/examples/arrow/compute_register_example.cc index 13d80b29631..113dfd0faf3 100644 --- a/cpp/examples/arrow/compute_register_example.cc +++ b/cpp/examples/arrow/compute_register_example.cc @@ -127,8 +127,7 @@ const cp::FunctionDoc func_doc{ int main(int argc, char** argv) { const std::string name = "compute_register_example"; auto func = std::make_shared(name, cp::Arity::Unary(), func_doc); - cp::ScalarKernel kernel({cp::InputType::Array(arrow::int64())}, arrow::int64(), - ExampleFunctionImpl); + cp::ScalarKernel kernel({arrow::int64()}, arrow::int64(), ExampleFunctionImpl); kernel.mem_allocation = cp::MemAllocation::NO_PREALLOCATE; ABORT_ON_FAILURE(func->AddKernel(std::move(kernel))); diff --git a/cpp/examples/arrow/udf_example.cc b/cpp/examples/arrow/udf_example.cc index 47c45411477..ccd804339a2 100644 --- a/cpp/examples/arrow/udf_example.cc +++ b/cpp/examples/arrow/udf_example.cc @@ -75,10 +75,8 @@ arrow::Status SampleFunction(cp::KernelContext* ctx, const cp::ExecSpan& batch, arrow::Status Execute() { const std::string name = "add_three"; auto func = std::make_shared(name, cp::Arity::Ternary(), func_doc); - cp::ScalarKernel kernel( - {cp::InputType::Array(arrow::int64()), cp::InputType::Array(arrow::int64()), - cp::InputType::Array(arrow::int64())}, - arrow::int64(), SampleFunction); + cp::ScalarKernel kernel({arrow::int64(), arrow::int64(), arrow::int64()}, + arrow::int64(), SampleFunction); kernel.mem_allocation = cp::MemAllocation::PREALLOCATE; kernel.null_handling = cp::NullHandling::INTERSECTION; diff --git a/cpp/gdb_arrow.py b/cpp/gdb_arrow.py index cd687ec8b2e..2237da4cc98 100644 --- a/cpp/gdb_arrow.py +++ b/cpp/gdb_arrow.py @@ -1406,13 +1406,12 @@ class FixedSizeBinaryScalarPrinter(BaseBinaryScalarPrinter): def to_string(self): size = self.type['byte_width_'] - if not self.is_valid: - return f"{self._format_type()} of size {size}, null value" bufptr = BufferPtr(SharedPtr(self.val['value']).get()) if bufptr.data is None: return f"{self._format_type()} of size {size}, " - return (f"{self._format_type()} of size {size}, " - f"value {self._format_buf(bufptr)}") + nullness = 'non-null' if self.is_valid else 'null' + return (f"{self._format_type()} {nullness} of size {size}, " + f"value buffer {self._format_buf(bufptr)}") class DictionaryScalarPrinter(ScalarPrinter): diff --git a/cpp/src/arrow/array/array_base.cc b/cpp/src/arrow/array/array_base.cc index b36fb0fb94a..5d27b2aedfb 100644 --- a/cpp/src/arrow/array/array_base.cc +++ b/cpp/src/arrow/array/array_base.cc @@ -104,16 +104,15 @@ struct ScalarFromArraySlotImpl { } Status Visit(const SparseUnionArray& a) { - const auto type_code = a.type_code(index_); - // child array which stores the actual value - const auto arr = a.field(a.child_id(index_)); - // no need to adjust the index - ARROW_ASSIGN_OR_RAISE(auto value, arr->GetScalar(index_)); - if (value->is_valid) { - out_ = std::shared_ptr(new SparseUnionScalar(value, type_code, a.type())); - } else { - out_ = std::shared_ptr(new SparseUnionScalar(type_code, a.type())); + int8_t type_code = a.type_code(index_); + + ScalarVector children; + for (int i = 0; i < a.type()->num_fields(); ++i) { + children.emplace_back(); + ARROW_ASSIGN_OR_RAISE(children.back(), a.field(i)->GetScalar(index_)); } + + out_ = std::make_shared(std::move(children), type_code, a.type()); return Status::OK(); } @@ -124,11 +123,7 @@ struct ScalarFromArraySlotImpl { // need to look up the value based on offsets auto offset = a.value_offset(index_); ARROW_ASSIGN_OR_RAISE(auto value, arr->GetScalar(offset)); - if (value->is_valid) { - out_ = std::shared_ptr(new DenseUnionScalar(value, type_code, a.type())); - } else { - out_ = std::shared_ptr(new DenseUnionScalar(type_code, a.type())); - } + out_ = std::make_shared(value, type_code, a.type()); return Status::OK(); } diff --git a/cpp/src/arrow/array/array_test.cc b/cpp/src/arrow/array/array_test.cc index 0d9afba6ece..d438557a330 100644 --- a/cpp/src/arrow/array/array_test.cc +++ b/cpp/src/arrow/array/array_test.cc @@ -561,16 +561,16 @@ static ScalarVector GetScalars() { }, struct_({field("min", int32()), field("max", int32())})), // Same values, different union type codes - std::make_shared(std::make_shared(100), 6, - sparse_union_ty), - std::make_shared(std::make_shared(100), 42, - sparse_union_ty), - std::make_shared(42, sparse_union_ty), + SparseUnionScalar::FromValue(std::make_shared(100), 1, + sparse_union_ty), + SparseUnionScalar::FromValue(std::make_shared(100), 2, + sparse_union_ty), + SparseUnionScalar::FromValue(MakeNullScalar(int32()), 2, sparse_union_ty), std::make_shared(std::make_shared(101), 6, dense_union_ty), std::make_shared(std::make_shared(101), 42, dense_union_ty), - std::make_shared(42, dense_union_ty), + std::make_shared(MakeNullScalar(int32()), 42, dense_union_ty), DictionaryScalar::Make(ScalarFromJSON(int8(), "1"), ArrayFromJSON(utf8(), R"(["foo", "bar"])")), DictionaryScalar::Make(ScalarFromJSON(uint8(), "1"), diff --git a/cpp/src/arrow/array/builder_base.cc b/cpp/src/arrow/array/builder_base.cc index 49abd8e0234..deadef061df 100644 --- a/cpp/src/arrow/array/builder_base.cc +++ b/cpp/src/arrow/array/builder_base.cc @@ -34,6 +34,8 @@ namespace arrow { +using internal::checked_cast; + Status ArrayBuilder::CheckArrayType(const std::shared_ptr& expected_type, const Array& array, const char* message) { if (!expected_type->Equals(*array.type())) { @@ -105,14 +107,13 @@ struct AppendScalarImpl { is_fixed_size_binary_type::value, Status> Visit(const T&) { - auto builder = internal::checked_cast::BuilderType*>(builder_); + auto builder = checked_cast::BuilderType*>(builder_); RETURN_NOT_OK(builder->Reserve(n_repeats_ * (scalars_end_ - scalars_begin_))); for (int64_t i = 0; i < n_repeats_; i++) { for (const std::shared_ptr* raw = scalars_begin_; raw != scalars_end_; raw++) { - auto scalar = - internal::checked_cast::ScalarType*>(raw->get()); + auto scalar = checked_cast::ScalarType*>(raw->get()); if (scalar->is_valid) { builder->UnsafeAppend(scalar->value); } else { @@ -128,22 +129,20 @@ struct AppendScalarImpl { int64_t data_size = 0; for (const std::shared_ptr* raw = scalars_begin_; raw != scalars_end_; raw++) { - auto scalar = - internal::checked_cast::ScalarType*>(raw->get()); + auto scalar = checked_cast::ScalarType*>(raw->get()); if (scalar->is_valid) { data_size += scalar->value->size(); } } - auto builder = internal::checked_cast::BuilderType*>(builder_); + auto builder = checked_cast::BuilderType*>(builder_); RETURN_NOT_OK(builder->Reserve(n_repeats_ * (scalars_end_ - scalars_begin_))); RETURN_NOT_OK(builder->ReserveData(n_repeats_ * data_size)); for (int64_t i = 0; i < n_repeats_; i++) { for (const std::shared_ptr* raw = scalars_begin_; raw != scalars_end_; raw++) { - auto scalar = - internal::checked_cast::ScalarType*>(raw->get()); + auto scalar = checked_cast::ScalarType*>(raw->get()); if (scalar->is_valid) { builder->UnsafeAppend(util::string_view{*scalar->value}); } else { @@ -156,13 +155,12 @@ struct AppendScalarImpl { template enable_if_list_like Visit(const T&) { - auto builder = internal::checked_cast::BuilderType*>(builder_); + auto builder = checked_cast::BuilderType*>(builder_); int64_t num_children = 0; for (const std::shared_ptr* scalar = scalars_begin_; scalar != scalars_end_; scalar++) { if (!(*scalar)->is_valid) continue; - num_children += - internal::checked_cast(**scalar).value->length(); + num_children += checked_cast(**scalar).value->length(); } RETURN_NOT_OK(builder->value_builder()->Reserve(num_children * n_repeats_)); @@ -171,8 +169,7 @@ struct AppendScalarImpl { scalar++) { if ((*scalar)->is_valid) { RETURN_NOT_OK(builder->Append()); - const Array& list = - *internal::checked_cast(**scalar).value; + const Array& list = *checked_cast(**scalar).value; for (int64_t i = 0; i < list.length(); i++) { ARROW_ASSIGN_OR_RAISE(auto scalar, list.GetScalar(i)); RETURN_NOT_OK(builder->value_builder()->AppendScalar(*scalar)); @@ -186,7 +183,7 @@ struct AppendScalarImpl { } Status Visit(const StructType& type) { - auto* builder = internal::checked_cast(builder_); + auto* builder = checked_cast(builder_); auto count = n_repeats_ * (scalars_end_ - scalars_begin_); RETURN_NOT_OK(builder->Reserve(count)); for (int field_index = 0; field_index < type.num_fields(); ++field_index) { @@ -194,7 +191,7 @@ struct AppendScalarImpl { } for (int64_t i = 0; i < n_repeats_; i++) { for (const std::shared_ptr* s = scalars_begin_; s != scalars_end_; s++) { - const auto& scalar = internal::checked_cast(**s); + const auto& scalar = checked_cast(**s); for (int field_index = 0; field_index < type.num_fields(); ++field_index) { if (!scalar.is_valid || !scalar.value[field_index]) { RETURN_NOT_OK(builder->field_builder(field_index)->AppendNull()); @@ -213,12 +210,55 @@ struct AppendScalarImpl { Status Visit(const DenseUnionType& type) { return MakeUnionArray(type); } + template ::BuilderType> + Status AppendUnionScalar(const T& type, const Scalar& s, BuilderType* builder) { + const auto& scalar = checked_cast(s); + const auto scalar_field_index = type.child_ids()[scalar.type_code]; + RETURN_NOT_OK(builder->Append(scalar.type_code)); + + for (int field_index = 0; field_index < type.num_fields(); ++field_index) { + auto* child_builder = builder->child_builder(field_index).get(); + if (field_index == scalar_field_index) { + if (scalar.is_valid) { + RETURN_NOT_OK(child_builder->AppendScalar(*scalar.value)); + } else { + RETURN_NOT_OK(child_builder->AppendNull()); + } + } + } + return Status::OK(); + } + + template <> + Status AppendUnionScalar(const SparseUnionType& type, const Scalar& s, + SparseUnionBuilder* builder) { + // For each scalar, + // 1. append the type code, + // 2. append the value to the corresponding child, + // 3. append null to the other children. + const auto& scalar = checked_cast(s); + RETURN_NOT_OK(builder->Append(scalar.type_code)); + + for (int field_index = 0; field_index < type.num_fields(); ++field_index) { + auto* child_builder = builder->child_builder(field_index).get(); + if (field_index == scalar.child_id) { + if (scalar.is_valid) { + RETURN_NOT_OK(child_builder->AppendScalar(*scalar.value[field_index])); + } else { + RETURN_NOT_OK(child_builder->AppendNull()); + } + } else { + RETURN_NOT_OK(child_builder->AppendNull()); + } + } + return Status::OK(); + } + template Status MakeUnionArray(const T& type) { using BuilderType = typename TypeTraits::BuilderType; - constexpr bool is_dense = std::is_same::value; - auto* builder = internal::checked_cast(builder_); + auto* builder = checked_cast(builder_); const auto count = n_repeats_ * (scalars_end_ - scalars_begin_); RETURN_NOT_OK(builder->Reserve(count)); @@ -230,26 +270,7 @@ struct AppendScalarImpl { for (int64_t i = 0; i < n_repeats_; i++) { for (const std::shared_ptr* s = scalars_begin_; s != scalars_end_; s++) { - // For each scalar, - // 1. append the type code, - // 2. append the value to the corresponding child, - // 3. if the union is sparse, append null to the other children. - const auto& scalar = internal::checked_cast(**s); - const auto scalar_field_index = type.child_ids()[scalar.type_code]; - RETURN_NOT_OK(builder->Append(scalar.type_code)); - - for (int field_index = 0; field_index < type.num_fields(); ++field_index) { - auto* child_builder = builder->child_builder(field_index).get(); - if (field_index == scalar_field_index) { - if (scalar.is_valid) { - RETURN_NOT_OK(child_builder->AppendScalar(*scalar.value)); - } else { - RETURN_NOT_OK(child_builder->AppendNull()); - } - } else if (!is_dense) { - RETURN_NOT_OK(child_builder->AppendNull()); - } - } + RETURN_NOT_OK(AppendUnionScalar(type, **s, builder)); } } return Status::OK(); diff --git a/cpp/src/arrow/array/builder_nested.h b/cpp/src/arrow/array/builder_nested.h index 3d36cb5f65e..306d861b09f 100644 --- a/cpp/src/arrow/array/builder_nested.h +++ b/cpp/src/arrow/array/builder_nested.h @@ -304,10 +304,12 @@ class ARROW_EXPORT MapBuilder : public ArrayBuilder { if (!validity || bit_util::GetBit(validity, array.offset + row)) { ARROW_RETURN_NOT_OK(Append()); const int64_t slot_length = offsets[row + 1] - offsets[row]; + // Add together the inner StructArray offset to the Map/List offset + int64_t key_value_offset = array.child_data[0].offset + offsets[row]; ARROW_RETURN_NOT_OK(key_builder_->AppendArraySlice( - array.child_data[0].child_data[0], offsets[row], slot_length)); + array.child_data[0].child_data[0], key_value_offset, slot_length)); ARROW_RETURN_NOT_OK(item_builder_->AppendArraySlice( - array.child_data[0].child_data[1], offsets[row], slot_length)); + array.child_data[0].child_data[1], key_value_offset, slot_length)); } else { ARROW_RETURN_NOT_OK(AppendNull()); } diff --git a/cpp/src/arrow/array/data.cc b/cpp/src/arrow/array/data.cc index 37db8ccb775..970bcaaaeb2 100644 --- a/cpp/src/arrow/array/data.cc +++ b/cpp/src/arrow/array/data.cc @@ -38,6 +38,7 @@ namespace arrow { +using internal::checked_cast; using internal::CountSetBits; static inline void AdjustNonNullable(Type::type type_id, int64_t length, @@ -174,27 +175,197 @@ void ArraySpan::SetMembers(const ArrayData& data) { } } +template +void SetOffsetsForScalar(ArraySpan* span, uint8_t* buffer, int64_t value_size, + int buffer_index = 1) { + auto offsets = reinterpret_cast(buffer); + offsets[0] = 0; + offsets[1] = static_cast(value_size); + span->buffers[buffer_index].data = buffer; + span->buffers[buffer_index].size = 2 * sizeof(offset_type); +} + +int GetNumBuffers(const DataType& type) { + switch (type.id()) { + case Type::NA: + case Type::STRUCT: + case Type::FIXED_SIZE_LIST: + return 1; + case Type::BINARY: + case Type::LARGE_BINARY: + case Type::STRING: + case Type::LARGE_STRING: + case Type::DENSE_UNION: + return 3; + case Type::EXTENSION: + // The number of buffers depends on the storage type + return GetNumBuffers( + *internal::checked_cast(type).storage_type()); + default: + // Everything else has 2 buffers + return 2; + } +} + +namespace internal { + +void FillZeroLengthArray(const DataType* type, ArraySpan* span) { + memset(span->scratch_space, 0x00, 16); + + span->type = type; + span->length = 0; + int num_buffers = GetNumBuffers(*type); + for (int i = 0; i < num_buffers; ++i) { + span->buffers[i].data = span->scratch_space; + span->buffers[i].size = 0; + } + + for (int i = num_buffers; i < 3; ++i) { + span->ClearBuffer(i); + } + + // Fill children + span->child_data.resize(type->num_fields()); + for (int i = 0; i < type->num_fields(); ++i) { + FillZeroLengthArray(type->field(i)->type().get(), &span->child_data[i]); + } +} + +} // namespace internal + void ArraySpan::FillFromScalar(const Scalar& value) { - static const uint8_t kValidByte = 0x01; - static const uint8_t kNullByte = 0x00; + static uint8_t kTrueBit = 0x01; + static uint8_t kFalseBit = 0x00; this->type = value.type.get(); this->length = 1; - // Populate null count and validity bitmap + Type::type type_id = value.type->id(); + + // Populate null count and validity bitmap (only for non-union types) this->null_count = value.is_valid ? 0 : 1; - this->buffers[0].data = const_cast(value.is_valid ? &kValidByte : &kNullByte); - this->buffers[0].size = 1; + if (!is_union(type_id)) { + this->buffers[0].data = value.is_valid ? &kTrueBit : &kFalseBit; + this->buffers[0].size = 1; + } - if (is_primitive(value.type->id())) { - const auto& scalar = - internal::checked_cast(value); + if (type_id == Type::BOOL) { + const auto& scalar = checked_cast(value); + this->buffers[1].data = scalar.value ? &kTrueBit : &kFalseBit; + this->buffers[1].size = 1; + } else if (is_primitive(type_id) || is_decimal(type_id) || + type_id == Type::DICTIONARY) { + const auto& scalar = checked_cast(value); const uint8_t* scalar_data = reinterpret_cast(scalar.view().data()); this->buffers[1].data = const_cast(scalar_data); this->buffers[1].size = scalar.type->byte_width(); + if (type_id == Type::DICTIONARY) { + // Populate dictionary data + const auto& dict_scalar = checked_cast(value); + this->child_data.resize(1); + this->child_data[0].SetMembers(*dict_scalar.value.dictionary->data()); + } + } else if (is_base_binary_like(type_id)) { + const auto& scalar = checked_cast(value); + this->buffers[1].data = this->scratch_space; + const uint8_t* data_buffer = nullptr; + int64_t data_size = 0; + if (scalar.is_valid) { + data_buffer = scalar.value->data(); + data_size = scalar.value->size(); + } + if (is_binary_like(type_id)) { + SetOffsetsForScalar(this, this->scratch_space, data_size); + } else { + // is_large_binary_like + SetOffsetsForScalar(this, this->scratch_space, data_size); + } + this->buffers[2].data = const_cast(data_buffer); + this->buffers[2].size = data_size; + } else if (type_id == Type::FIXED_SIZE_BINARY) { + const auto& scalar = checked_cast(value); + this->buffers[1].data = const_cast(scalar.value->data()); + this->buffers[1].size = scalar.value->size(); + } else if (is_list_like(type_id)) { + const auto& scalar = checked_cast(value); + + int64_t value_length = 0; + this->child_data.resize(1); + if (scalar.value != nullptr) { + // When the scalar is null, scalar.value can also be null + this->child_data[0].SetMembers(*scalar.value->data()); + value_length = scalar.value->length(); + } else { + // Even when the value is null, we still must populate the + // child_data to yield a valid array. Tedious + internal::FillZeroLengthArray(this->type->field(0)->type().get(), + &this->child_data[0]); + } + + if (type_id == Type::LIST || type_id == Type::MAP) { + SetOffsetsForScalar(this, this->scratch_space, value_length); + } else if (type_id == Type::LARGE_LIST) { + SetOffsetsForScalar(this, this->scratch_space, value_length); + } else { + // FIXED_SIZE_LIST: does not have a second buffer + this->buffers[1].data = nullptr; + this->buffers[1].size = 0; + } + } else if (type_id == Type::STRUCT) { + const auto& scalar = checked_cast(value); + this->child_data.resize(this->type->num_fields()); + DCHECK_EQ(this->type->num_fields(), static_cast(scalar.value.size())); + for (size_t i = 0; i < scalar.value.size(); ++i) { + this->child_data[i].FillFromScalar(*scalar.value[i]); + } + } else if (is_union(type_id)) { + // First buffer is kept null since unions have no validity vector + this->buffers[0].data = nullptr; + this->buffers[0].size = 0; + + this->buffers[1].data = this->scratch_space; + this->buffers[1].size = 1; + int8_t* type_codes = reinterpret_cast(this->scratch_space); + type_codes[0] = checked_cast(value).type_code; + + this->child_data.resize(this->type->num_fields()); + if (type_id == Type::DENSE_UNION) { + const auto& scalar = checked_cast(value); + // Has offset; start 4 bytes in so it's aligned to a 32-bit boundaries + SetOffsetsForScalar(this, this->scratch_space + sizeof(int32_t), 1, + /*buffer_index=*/2); + // We can't "see" the other arrays in the union, but we put the "active" + // union array in the right place and fill zero-length arrays for the + // others + const std::vector& child_ids = + static_cast(this->type)->child_ids(); + DCHECK_GE(scalar.type_code, 0); + DCHECK_LT(scalar.type_code, static_cast(child_ids.size())); + for (int i = 0; i < static_cast(this->child_data.size()); ++i) { + if (i == child_ids[scalar.type_code]) { + this->child_data[i].FillFromScalar(*scalar.value); + } else { + internal::FillZeroLengthArray(this->type->field(i)->type().get(), + &this->child_data[i]); + } + } + } else { + const auto& scalar = checked_cast(value); + // Sparse union scalars have a full complement of child values even + // though only one of them is relevant, so we just fill them in here + for (int i = 0; i < static_cast(this->child_data.size()); ++i) { + this->child_data[i].FillFromScalar(*scalar.value[i]); + } + } + } else if (type_id == Type::EXTENSION) { + // Pass through storage + const auto& scalar = checked_cast(value); + FillFromScalar(*scalar.value); + + // Restore the extension type + this->type = value.type.get(); } else { - // TODO(wesm): implement for other types - DCHECK(false) << "need to implement for other types"; + DCHECK_EQ(Type::NA, type_id) << "should be unreachable: " << *value.type; } } @@ -212,40 +383,14 @@ int64_t ArraySpan::GetNullCount() const { return precomputed; } -int GetNumBuffers(const DataType& type) { - switch (type.id()) { - case Type::NA: - case Type::STRUCT: - case Type::FIXED_SIZE_LIST: - return 1; - case Type::BINARY: - case Type::LARGE_BINARY: - case Type::STRING: - case Type::LARGE_STRING: - case Type::DENSE_UNION: - return 3; - case Type::EXTENSION: - // The number of buffers depends on the storage type - return GetNumBuffers( - *internal::checked_cast(type).storage_type()); - default: - // Everything else has 2 buffers - return 2; - } -} - int ArraySpan::num_buffers() const { return GetNumBuffers(*this->type); } std::shared_ptr ArraySpan::ToArrayData() const { - auto result = std::make_shared(this->type->Copy(), this->length, + auto result = std::make_shared(this->type->GetSharedPtr(), this->length, this->null_count, this->offset); for (int i = 0; i < this->num_buffers(); ++i) { - if (this->buffers[i].owner) { - result->buffers.emplace_back(this->GetBuffer(i)); - } else { - result->buffers.push_back(nullptr); - } + result->buffers.emplace_back(this->GetBuffer(i)); } if (this->type->id() == Type::NA) { diff --git a/cpp/src/arrow/array/data.h b/cpp/src/arrow/array/data.h index df547aedfaf..b76ab597107 100644 --- a/cpp/src/arrow/array/data.h +++ b/cpp/src/arrow/array/data.h @@ -266,6 +266,11 @@ struct ARROW_EXPORT ArraySpan { int64_t offset = 0; BufferSpan buffers[3]; + // 16 bytes of scratch space to enable this ArraySpan to be a view onto + // scalar values including binary scalars (where we need to create a buffer + // that looks like two 32-bit or 64-bit offsets) + uint8_t scratch_space[16]; + ArraySpan() = default; explicit ArraySpan(const DataType* type, int64_t length) : type(type), length(length) {} @@ -273,9 +278,7 @@ struct ARROW_EXPORT ArraySpan { ArraySpan(const ArrayData& data) { // NOLINT implicit conversion SetMembers(data); } - ArraySpan(const Scalar& data) { // NOLINT implicit converstion - FillFromScalar(data); - } + explicit ArraySpan(const Scalar& data) { FillFromScalar(data); } /// If dictionary-encoded, put dictionary in the first entry std::vector child_data; @@ -343,10 +346,14 @@ struct ARROW_EXPORT ArraySpan { std::shared_ptr ToArray() const; std::shared_ptr GetBuffer(int index) const { - if (this->buffers[index].owner == NULLPTR) { - return NULLPTR; + const BufferSpan& buf = this->buffers[index]; + if (buf.owner) { + return *buf.owner; + } else if (buf.data != NULLPTR) { + // Buffer points to some memory without an owning buffer + return std::make_shared(buf.data, buf.size); } else { - return *this->buffers[index].owner; + return NULLPTR; } } @@ -372,6 +379,8 @@ struct ARROW_EXPORT ArraySpan { namespace internal { +void FillZeroLengthArray(const DataType* type, ArraySpan* span); + /// Construct a zero-copy view of this ArrayData with the given type. /// /// This method checks if the types are layout-compatible. diff --git a/cpp/src/arrow/array/util.cc b/cpp/src/arrow/array/util.cc index e5b4ab39493..c0cdcab730c 100644 --- a/cpp/src/arrow/array/util.cc +++ b/cpp/src/arrow/array/util.cc @@ -664,22 +664,20 @@ class RepeatedArrayFactory { } Status Visit(const SparseUnionType& type) { - const auto& union_scalar = checked_cast(scalar_); - const auto& union_type = checked_cast(*scalar_.type); + const auto& union_scalar = checked_cast(scalar_); const auto scalar_type_code = union_scalar.type_code; - const auto scalar_child_id = union_type.child_ids()[scalar_type_code]; // Create child arrays: most of them are all-null, except for the child array // for the given type code (if the scalar is valid). ArrayVector fields; for (int i = 0; i < type.num_fields(); ++i) { fields.emplace_back(); - if (i == scalar_child_id && scalar_.is_valid) { - ARROW_ASSIGN_OR_RAISE(fields.back(), - MakeArrayFromScalar(*union_scalar.value, length_, pool_)); - } else { + if (i == union_scalar.child_id && scalar_.is_valid) { ARROW_ASSIGN_OR_RAISE( - fields.back(), MakeArrayOfNull(union_type.field(i)->type(), length_, pool_)); + fields.back(), MakeArrayFromScalar(*union_scalar.value[i], length_, pool_)); + } else { + ARROW_ASSIGN_OR_RAISE(fields.back(), + MakeArrayOfNull(type.field(i)->type(), length_, pool_)); } } @@ -691,7 +689,7 @@ class RepeatedArrayFactory { } Status Visit(const DenseUnionType& type) { - const auto& union_scalar = checked_cast(scalar_); + const auto& union_scalar = checked_cast(scalar_); const auto& union_type = checked_cast(*scalar_.type); const auto scalar_type_code = union_scalar.type_code; const auto scalar_child_id = union_type.child_ids()[scalar_type_code]; diff --git a/cpp/src/arrow/compare.cc b/cpp/src/arrow/compare.cc index 8af319ed9ea..c5406ee583f 100644 --- a/cpp/src/arrow/compare.cc +++ b/cpp/src/arrow/compare.cc @@ -796,12 +796,19 @@ class ScalarEqualsVisitor { return Status::OK(); } - Status Visit(const UnionScalar& left) { - const auto& right = checked_cast(right_); + Status Visit(const DenseUnionScalar& left) { + const auto& right = checked_cast(right_); result_ = ScalarEquals(*left.value, *right.value, options_, floating_approximate_); return Status::OK(); } + Status Visit(const SparseUnionScalar& left) { + const auto& right = checked_cast(right_); + result_ = ScalarEquals(*left.value[left.child_id], *right.value[right.child_id], + options_, floating_approximate_); + return Status::OK(); + } + Status Visit(const DictionaryScalar& left) { const auto& right = checked_cast(right_); result_ = ScalarEquals(*left.value.index, *right.value.index, options_, diff --git a/cpp/src/arrow/compute/api_vector.cc b/cpp/src/arrow/compute/api_vector.cc index 4ebdecf5e78..ff1d6619905 100644 --- a/cpp/src/arrow/compute/api_vector.cc +++ b/cpp/src/arrow/compute/api_vector.cc @@ -347,11 +347,11 @@ Result Filter(const Datum& values, const Datum& filter, return CallFunction("filter", {values, filter}, &options, ctx); } -Result Take(const Datum& values, const Datum& filter, const TakeOptions& options, +Result Take(const Datum& values, const Datum& indices, const TakeOptions& options, ExecContext* ctx) { // Invoke metafunction which deals with Datum kinds other than just Array, // ChunkedArray. - return CallFunction("take", {values, filter}, &options, ctx); + return CallFunction("take", {values, indices}, &options, ctx); } Result> Take(const Array& values, const Array& indices, diff --git a/cpp/src/arrow/compute/cast.cc b/cpp/src/arrow/compute/cast.cc index bd49041b4f3..21257e05602 100644 --- a/cpp/src/arrow/compute/cast.cc +++ b/cpp/src/arrow/compute/cast.cc @@ -69,9 +69,9 @@ void EnsureInitCastTable() { std::call_once(cast_table_initialized, InitCastTabl // Private version of GetCastFunction with better error reporting // if the input type is known. Result> GetCastFunctionInternal( - const std::shared_ptr& to_type, const DataType* from_type = nullptr) { + const TypeHolder& to_type, const DataType* from_type = nullptr) { internal::EnsureInitCastTable(); - auto it = internal::g_cast_table.find(static_cast(to_type->id())); + auto it = internal::g_cast_table.find(static_cast(to_type.id())); if (it == internal::g_cast_table.end()) { if (from_type != nullptr) { return Status::NotImplemented("Unsupported cast from ", *from_type, " to ", @@ -139,18 +139,6 @@ void RegisterScalarCast(FunctionRegistry* registry) { DCHECK_OK(registry->AddFunction(std::make_shared())); DCHECK_OK(registry->AddFunctionOptionsType(kCastOptionsType)); } -} // namespace internal - -CastOptions::CastOptions(bool safe) - : FunctionOptions(internal::kCastOptionsType), - allow_int_overflow(!safe), - allow_time_truncate(!safe), - allow_time_overflow(!safe), - allow_decimal_truncate(!safe), - allow_float_truncate(!safe), - allow_invalid_utf8(!safe) {} - -constexpr char CastOptions::kTypeName[]; CastFunction::CastFunction(std::string name, Type::type out_type_id) : ScalarFunction(std::move(name), Arity::Unary(), FunctionDoc::Empty()), @@ -177,18 +165,18 @@ Status CastFunction::AddKernel(Type::type in_type_id, std::vector in_ } Result CastFunction::DispatchExact( - const std::vector& values) const { - RETURN_NOT_OK(CheckArity(values)); + const std::vector& types) const { + RETURN_NOT_OK(CheckArity(types.size())); std::vector candidate_kernels; for (const auto& kernel : kernels_) { - if (kernel.signature->MatchesInputs(values)) { + if (kernel.signature->MatchesInputs(types)) { candidate_kernels.push_back(&kernel); } } if (candidate_kernels.size() == 0) { - return Status::NotImplemented("Unsupported cast from ", values[0].type->ToString(), + return Status::NotImplemented("Unsupported cast from ", types[0].type->ToString(), " to ", ToTypeName(out_type_id_), " using function ", this->name()); } @@ -213,28 +201,40 @@ Result CastFunction::DispatchExact( return candidate_kernels[0]; } +Result> GetCastFunction(const TypeHolder& to_type) { + return internal::GetCastFunctionInternal(to_type); +} + +} // namespace internal + +CastOptions::CastOptions(bool safe) + : FunctionOptions(internal::kCastOptionsType), + allow_int_overflow(!safe), + allow_time_truncate(!safe), + allow_time_overflow(!safe), + allow_decimal_truncate(!safe), + allow_float_truncate(!safe), + allow_invalid_utf8(!safe) {} + +constexpr char CastOptions::kTypeName[]; + Result Cast(const Datum& value, const CastOptions& options, ExecContext* ctx) { return CallFunction("cast", {value}, &options, ctx); } -Result Cast(const Datum& value, std::shared_ptr to_type, +Result Cast(const Datum& value, const TypeHolder& to_type, const CastOptions& options, ExecContext* ctx) { CastOptions options_with_to_type = options; options_with_to_type.to_type = to_type; return Cast(value, options_with_to_type, ctx); } -Result> Cast(const Array& value, std::shared_ptr to_type, +Result> Cast(const Array& value, const TypeHolder& to_type, const CastOptions& options, ExecContext* ctx) { ARROW_ASSIGN_OR_RAISE(Datum result, Cast(Datum(value), to_type, options, ctx)); return result.make_array(); } -Result> GetCastFunction( - const std::shared_ptr& to_type) { - return internal::GetCastFunctionInternal(to_type); -} - bool CanCast(const DataType& from_type, const DataType& to_type) { internal::EnsureInitCastTable(); auto it = internal::g_cast_table.find(static_cast(to_type.id())); @@ -242,7 +242,7 @@ bool CanCast(const DataType& from_type, const DataType& to_type) { return false; } - const CastFunction* function = it->second.get(); + const internal::CastFunction* function = it->second.get(); DCHECK_EQ(function->out_type_id(), to_type.id()); for (auto from_id : function->in_type_ids()) { @@ -253,21 +253,5 @@ bool CanCast(const DataType& from_type, const DataType& to_type) { return false; } -Result> Cast(std::vector datums, std::vector descrs, - ExecContext* ctx) { - for (size_t i = 0; i != datums.size(); ++i) { - if (descrs[i] != datums[i].descr()) { - if (descrs[i].shape != datums[i].shape()) { - return Status::NotImplemented("casting between Datum shapes"); - } - - ARROW_ASSIGN_OR_RAISE(datums[i], - Cast(datums[i], CastOptions::Safe(descrs[i].type), ctx)); - } - } - - return datums; -} - } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/cast.h b/cpp/src/arrow/compute/cast.h index e9c3cf55da9..7432933a124 100644 --- a/cpp/src/arrow/compute/cast.h +++ b/cpp/src/arrow/compute/cast.h @@ -22,8 +22,7 @@ #include #include "arrow/compute/function.h" -#include "arrow/compute/kernel.h" -#include "arrow/datum.h" +#include "arrow/compute/type_fwd.h" #include "arrow/result.h" #include "arrow/status.h" #include "arrow/type.h" @@ -46,13 +45,13 @@ class ARROW_EXPORT CastOptions : public FunctionOptions { explicit CastOptions(bool safe = true); static constexpr char const kTypeName[] = "CastOptions"; - static CastOptions Safe(std::shared_ptr to_type = NULLPTR) { + static CastOptions Safe(TypeHolder to_type = {}) { CastOptions safe(true); safe.to_type = std::move(to_type); return safe; } - static CastOptions Unsafe(std::shared_ptr to_type = NULLPTR) { + static CastOptions Unsafe(TypeHolder to_type = {}) { CastOptions unsafe(false); unsafe.to_type = std::move(to_type); return unsafe; @@ -60,7 +59,7 @@ class ARROW_EXPORT CastOptions : public FunctionOptions { // Type being casted to. May be passed separate to eager function // compute::Cast - std::shared_ptr to_type; + TypeHolder to_type; bool allow_int_overflow; bool allow_time_truncate; @@ -74,36 +73,6 @@ class ARROW_EXPORT CastOptions : public FunctionOptions { /// @} -// Cast functions are _not_ registered in the FunctionRegistry, though they use -// the same execution machinery -class CastFunction : public ScalarFunction { - public: - CastFunction(std::string name, Type::type out_type_id); - - Type::type out_type_id() const { return out_type_id_; } - const std::vector& in_type_ids() const { return in_type_ids_; } - - Status AddKernel(Type::type in_type_id, std::vector in_types, - OutputType out_type, ArrayKernelExec exec, - NullHandling::type = NullHandling::INTERSECTION, - MemAllocation::type = MemAllocation::PREALLOCATE); - - // Note, this function toggles off memory allocation and sets the init - // function to CastInit - Status AddKernel(Type::type in_type_id, ScalarKernel kernel); - - Result DispatchExact( - const std::vector& values) const override; - - private: - std::vector in_type_ids_; - const Type::type out_type_id_; -}; - -ARROW_EXPORT -Result> GetCastFunction( - const std::shared_ptr& to_type); - /// \brief Return true if a cast function is defined ARROW_EXPORT bool CanCast(const DataType& from_type, const DataType& to_type); @@ -121,7 +90,7 @@ bool CanCast(const DataType& from_type, const DataType& to_type); /// \since 1.0.0 /// \note API not yet finalized ARROW_EXPORT -Result> Cast(const Array& value, std::shared_ptr to_type, +Result> Cast(const Array& value, const TypeHolder& to_type, const CastOptions& options = CastOptions::Safe(), ExecContext* ctx = NULLPTR); @@ -147,21 +116,9 @@ Result Cast(const Datum& value, const CastOptions& options, /// \since 1.0.0 /// \note API not yet finalized ARROW_EXPORT -Result Cast(const Datum& value, std::shared_ptr to_type, +Result Cast(const Datum& value, const TypeHolder& to_type, const CastOptions& options = CastOptions::Safe(), ExecContext* ctx = NULLPTR); -/// \brief Cast several values simultaneously. Safe cast options are used. -/// \param[in] values datums to cast -/// \param[in] descrs ValueDescrs to cast to -/// \param[in] ctx the function execution context, optional -/// \return the resulting datums -/// -/// \since 4.0.0 -/// \note API not yet finalized -ARROW_EXPORT -Result> Cast(std::vector values, std::vector descrs, - ExecContext* ctx = NULLPTR); - } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/cast_internal.h b/cpp/src/arrow/compute/cast_internal.h index 0105d08a573..bfa2a110cd7 100644 --- a/cpp/src/arrow/compute/cast_internal.h +++ b/cpp/src/arrow/compute/cast_internal.h @@ -30,6 +30,32 @@ namespace internal { using CastState = OptionsWrapper; +// Cast functions are _not_ registered in the FunctionRegistry, though they use +// the same execution machinery +class CastFunction : public ScalarFunction { + public: + CastFunction(std::string name, Type::type out_type_id); + + Type::type out_type_id() const { return out_type_id_; } + const std::vector& in_type_ids() const { return in_type_ids_; } + + Status AddKernel(Type::type in_type_id, std::vector in_types, + OutputType out_type, ArrayKernelExec exec, + NullHandling::type = NullHandling::INTERSECTION, + MemAllocation::type = MemAllocation::PREALLOCATE); + + // Note, this function toggles off memory allocation and sets the init + // function to CastInit + Status AddKernel(Type::type in_type_id, ScalarKernel kernel); + + Result DispatchExact( + const std::vector& types) const override; + + private: + std::vector in_type_ids_; + const Type::type out_type_id_; +}; + // See kernels/scalar_cast_*.cc for these std::vector> GetBooleanCasts(); std::vector> GetNumericCasts(); @@ -38,6 +64,9 @@ std::vector> GetBinaryLikeCasts(); std::vector> GetNestedCasts(); std::vector> GetDictionaryCasts(); +ARROW_EXPORT +Result> GetCastFunction(const TypeHolder& to_type); + } // namespace internal } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/exec.cc b/cpp/src/arrow/compute/exec.cc index a612a83e7a8..c5b1dfaca0e 100644 --- a/cpp/src/arrow/compute/exec.cc +++ b/cpp/src/arrow/compute/exec.cc @@ -219,16 +219,6 @@ void ComputeDataPreallocate(const DataType& type, namespace detail { -Status CheckAllValues(const std::vector& values) { - for (const auto& value : values) { - if (!value.is_value()) { - return Status::Invalid("Tried executing function with non-value type: ", - value.ToString()); - } - } - return Status::OK(); -} - ExecBatchIterator::ExecBatchIterator(std::vector args, int64_t length, int64_t max_chunksize) : args_(std::move(args)), @@ -249,9 +239,7 @@ Result> ExecBatchIterator::Make( } } - // If the arguments are all scalars, then the length is 1 - int64_t length = 1; - + int64_t length = -1; bool length_set = false; for (auto& arg : args) { if (arg.is_scalar()) { @@ -267,6 +255,11 @@ Result> ExecBatchIterator::Make( } } + if (!length_set) { + // All scalar case, to be removed soon + length = 1; + } + max_chunksize = std::min(length, max_chunksize); return std::unique_ptr( @@ -328,8 +321,17 @@ bool ExecBatchIterator::Next(ExecBatch* batch) { // ---------------------------------------------------------------------- // ExecSpanIterator; to eventually replace ExecBatchIterator -Status ExecSpanIterator::Init(const ExecBatch& batch, ValueDescr::Shape output_shape, - int64_t max_chunksize) { +bool CheckIfAllScalar(const ExecBatch& batch) { + for (const Datum& value : batch.values) { + if (!value.is_scalar()) { + DCHECK(value.is_arraylike()); + return false; + } + } + return batch.num_values() > 0; +} + +Status ExecSpanIterator::Init(const ExecBatch& batch, int64_t max_chunksize) { if (batch.num_values() > 0) { // Validate arguments bool all_args_same_length = false; @@ -343,8 +345,9 @@ Status ExecSpanIterator::Init(const ExecBatch& batch, ValueDescr::Shape output_s } args_ = &batch.values; initialized_ = have_chunked_arrays_ = false; + have_all_scalars_ = CheckIfAllScalar(batch); position_ = 0; - length_ = output_shape == ValueDescr::SCALAR ? 1 : batch.length; + length_ = batch.length; chunk_indexes_.clear(); chunk_indexes_.resize(args_->size(), 0); value_positions_.clear(); @@ -358,8 +361,7 @@ Status ExecSpanIterator::Init(const ExecBatch& batch, ValueDescr::Shape output_s int64_t ExecSpanIterator::GetNextChunkSpan(int64_t iteration_size, ExecSpan* span) { for (size_t i = 0; i < args_->size() && iteration_size > 0; ++i) { // If the argument is not a chunked array, it's either a Scalar or Array, - // in which case it doesn't influence the size of this span. Note that if - // the args are all scalars the span length is 1 + // in which case it doesn't influence the size of this span if (!args_->at(i).is_chunked_array()) { continue; } @@ -385,13 +387,20 @@ int64_t ExecSpanIterator::GetNextChunkSpan(int64_t iteration_size, ExecSpan* spa return iteration_size; } -bool ExecSpanIterator::Next(ExecSpan* span) { - if (position_ == length_) { - // This also protects from degenerate cases like ChunkedArrays - // without any chunks - return false; +void PromoteExecSpanScalars(ExecSpan* span) { + // In the "all scalar" case, we "promote" the scalars to ArraySpans of + // length 1, since the kernel implementations do not handle the all + // scalar case + for (int i = 0; i < span->num_values(); ++i) { + ExecValue* value = &span->values[i]; + if (value->is_scalar()) { + value->array.FillFromScalar(*value->scalar); + value->scalar = nullptr; + } } +} +bool ExecSpanIterator::Next(ExecSpan* span) { if (!initialized_) { span->length = 0; @@ -402,25 +411,36 @@ bool ExecSpanIterator::Next(ExecSpan* span) { // iteration span->values.resize(args_->size()); for (size_t i = 0; i < args_->size(); ++i) { - if (args_->at(i).is_scalar()) { - span->values[i].SetScalar(args_->at(i).scalar().get()); - } else if (args_->at(i).is_array()) { - const ArrayData& arr = *args_->at(i).array(); + const Datum& arg = (*args_)[i]; + if (arg.is_scalar()) { + span->values[i].SetScalar(arg.scalar().get()); + } else if (arg.is_array()) { + const ArrayData& arr = *arg.array(); span->values[i].SetArray(arr); value_offsets_[i] = arr.offset; } else { // Populate members from the first chunk - const Array* first_chunk = args_->at(i).chunked_array()->chunk(0).get(); - const ArrayData& arr = *first_chunk->data(); - span->values[i].SetArray(arr); - value_offsets_[i] = arr.offset; + const ChunkedArray& carr = *arg.chunked_array(); + if (carr.num_chunks() > 0) { + const ArrayData& arr = *carr.chunk(0)->data(); + span->values[i].SetArray(arr); + value_offsets_[i] = arr.offset; + } else { + // Fill as zero-length array + internal::FillZeroLengthArray(carr.type().get(), &span->values[i].array); + span->values[i].scalar = nullptr; + } have_chunked_arrays_ = true; } } - initialized_ = true; - } - if (position_ == length_) { + if (have_all_scalars_) { + PromoteExecSpanScalars(span); + } + + initialized_ = true; + } else if (position_ == length_) { + // We've emitted at least one span and we're at the end so we are done return false; } @@ -441,6 +461,7 @@ bool ExecSpanIterator::Next(ExecSpan* span) { value_positions_[i] += iteration_size; } } + position_ += iteration_size; DCHECK_LE(position_, length_); return true; @@ -662,7 +683,7 @@ class NullPropagator { }; std::shared_ptr ToChunkedArray(const std::vector& values, - const std::shared_ptr& type) { + const TypeHolder& type) { std::vector> arrays; arrays.reserve(values.size()); for (const Datum& val : values) { @@ -672,7 +693,7 @@ std::shared_ptr ToChunkedArray(const std::vector& values, } arrays.emplace_back(val.make_array()); } - return std::make_shared(std::move(arrays), type); + return std::make_shared(std::move(arrays), type.GetSharedPtr()); } bool HaveChunkedArray(const std::vector& values) { @@ -691,9 +712,9 @@ class KernelExecutorImpl : public KernelExecutor { kernel_ctx_ = kernel_ctx; kernel_ = static_cast(args.kernel); - // Resolve the output descriptor for this kernel + // Resolve the output type for this kernel ARROW_ASSIGN_OR_RAISE( - output_descr_, kernel_->signature->out_type().Resolve(kernel_ctx_, args.inputs)); + output_type_, kernel_->signature->out_type().Resolve(kernel_ctx_, args.inputs)); return Status::OK(); } @@ -703,7 +724,7 @@ class KernelExecutorImpl : public KernelExecutor { // Kernel::mem_allocation is not MemAllocation::PREALLOCATE, then no // data buffers will be set Result> PrepareOutput(int64_t length) { - auto out = std::make_shared(output_descr_.type, length); + auto out = std::make_shared(output_type_.GetSharedPtr(), length); out->buffers.resize(output_num_buffers_); if (validity_preallocated_) { @@ -726,10 +747,10 @@ class KernelExecutorImpl : public KernelExecutor { Status CheckResultType(const Datum& out, const char* function_name) override { const auto& type = out.type(); - if (type != nullptr && !type->Equals(output_descr_.type)) { + if (type != nullptr && !type->Equals(*output_type_.type)) { return Status::TypeError( "kernel type result mismatch for function '", function_name, "': declared as ", - output_descr_.type->ToString(), ", actual is ", type->ToString()); + output_type_.type->ToString(), ", actual is ", type->ToString()); } return Status::OK(); } @@ -741,7 +762,7 @@ class KernelExecutorImpl : public KernelExecutor { KernelContext* kernel_ctx_; const KernelType* kernel_; - ValueDescr output_descr_; + TypeHolder output_type_; int output_num_buffers_; @@ -757,18 +778,23 @@ class KernelExecutorImpl : public KernelExecutor { class ScalarExecutor : public KernelExecutorImpl { public: Status Execute(const ExecBatch& batch, ExecListener* listener) override { - RETURN_NOT_OK(span_iterator_.Init(batch, output_descr_.shape, - exec_context()->exec_chunksize())); + RETURN_NOT_OK(span_iterator_.Init(batch, exec_context()->exec_chunksize())); - // TODO(wesm): remove if with ARROW-16757 - if (output_descr_.shape != ValueDescr::SCALAR) { - // If the executor is configured to produce a single large Array output for - // kernels supporting preallocation, then we do so up front and then - // iterate over slices of that large array. Otherwise, we preallocate prior - // to processing each span emitted from the ExecSpanIterator - RETURN_NOT_OK(SetupPreallocation(span_iterator_.length(), batch.values)); + if (batch.length == 0) { + // For zero-length batches, we do nothing except return a zero-length + // array of the correct output type + ARROW_ASSIGN_OR_RAISE(std::shared_ptr result, + MakeArrayOfNull(output_type_.GetSharedPtr(), /*length=*/0, + exec_context()->memory_pool())); + return EmitResult(result->data(), listener); } + // If the executor is configured to produce a single large Array output for + // kernels supporting preallocation, then we do so up front and then + // iterate over slices of that large array. Otherwise, we preallocate prior + // to processing each span emitted from the ExecSpanIterator + RETURN_NOT_OK(SetupPreallocation(span_iterator_.length(), batch.values)); + // ARROW-16756: Here we have to accommodate the distinct cases // // * Fully-preallocated contiguous output @@ -784,30 +810,28 @@ class ScalarExecutor : public KernelExecutorImpl { Datum WrapResults(const std::vector& inputs, const std::vector& outputs) override { - if (output_descr_.shape == ValueDescr::SCALAR) { - // TODO(wesm): to remove, see ARROW-16757 - DCHECK_EQ(outputs.size(), 1); - // Return as SCALAR - return outputs[0]; + // If execution yielded multiple chunks (because large arrays were split + // based on the ExecContext parameters, then the result is a ChunkedArray + if (HaveChunkedArray(inputs) || outputs.size() > 1) { + return ToChunkedArray(outputs, output_type_); } else { - // If execution yielded multiple chunks (because large arrays were split - // based on the ExecContext parameters, then the result is a ChunkedArray - if (HaveChunkedArray(inputs) || outputs.size() > 1) { - return ToChunkedArray(outputs, output_descr_.type); - } else if (outputs.size() == 1) { - // Outputs have just one element - return outputs[0]; - } else { - // XXX: In the case where no outputs are omitted, is returning a 0-length - // array always the correct move? - return MakeArrayOfNull(output_descr_.type, /*length=*/0, - exec_context()->memory_pool()) - .ValueOrDie(); - } + // Outputs have just one element + return outputs[0]; } } protected: + Status EmitResult(std::shared_ptr out, ExecListener* listener) { + if (span_iterator_.have_all_scalars()) { + // ARROW-16757 We boxed scalar inputs as ArraySpan, so now we have to + // unbox the output as a scalar + ARROW_ASSIGN_OR_RAISE(std::shared_ptr scalar, MakeArray(out)->GetScalar(0)); + return listener->OnResult(std::move(scalar)); + } else { + return listener->OnResult(std::move(out)); + } + } + Status ExecuteSpans(ExecListener* listener) { // We put the preallocation in an ArraySpan to be passed to the // kernel which is expecting to receive that. More @@ -817,6 +841,7 @@ class ScalarExecutor : public KernelExecutorImpl { ExecSpan input; ExecResult output; ArraySpan* output_span = output.array_span(); + if (preallocate_contiguous_) { // Make one big output allocation ARROW_ASSIGN_OR_RAISE(preallocation, PrepareOutput(span_iterator_.length())); @@ -832,7 +857,7 @@ class ScalarExecutor : public KernelExecutorImpl { } // Kernel execution is complete; emit result - RETURN_NOT_OK(listener->OnResult(std::move(preallocation))); + return EmitResult(std::move(preallocation), listener); } else { // Fully preallocating, but not contiguously // We preallocate (maybe) only for the output of processing the current @@ -842,15 +867,15 @@ class ScalarExecutor : public KernelExecutorImpl { output_span->SetMembers(*preallocation); RETURN_NOT_OK(ExecuteSingleSpan(input, &output)); // Emit the result for this chunk - RETURN_NOT_OK(listener->OnResult(std::move(preallocation))); + RETURN_NOT_OK(EmitResult(std::move(preallocation), listener)); } + return Status::OK(); } - return Status::OK(); } Status ExecuteSingleSpan(const ExecSpan& input, ExecResult* out) { ArraySpan* result_span = out->array_span(); - if (output_descr_.type->id() == Type::NA) { + if (output_type_.type->id() == Type::NA) { result_span->null_count = result_span->length; } else if (kernel_->null_handling == NullHandling::INTERSECTION) { if (!elide_validity_bitmap_) { @@ -859,7 +884,10 @@ class ScalarExecutor : public KernelExecutorImpl { } else if (kernel_->null_handling == NullHandling::OUTPUT_NOT_NULL) { result_span->null_count = 0; } - return kernel_->exec(kernel_ctx_, input, out); + RETURN_NOT_OK(kernel_->exec(kernel_ctx_, input, out)); + // Output type didn't change + DCHECK(out->is_array_span()); + return Status::OK(); } Status ExecuteNonSpans(ExecListener* listener) { @@ -873,60 +901,32 @@ class ScalarExecutor : public KernelExecutorImpl { ExecSpan input; ExecResult output; while (span_iterator_.Next(&input)) { - if (output_descr_.shape == ValueDescr::ARRAY) { - ARROW_ASSIGN_OR_RAISE(output.value, PrepareOutput(input.length)); - DCHECK(output.is_array_data()); - } else { - // For scalar outputs, we set a null scalar of the correct type to - // communicate the output type to the kernel if needed - // - // XXX: Is there some way to avoid this step? - // TODO: Remove this path in ARROW-16757 - output.value = MakeNullScalar(output_descr_.type); - } + ARROW_ASSIGN_OR_RAISE(output.value, PrepareOutput(input.length)); + DCHECK(output.is_array_data()); - if (output_descr_.shape == ValueDescr::ARRAY) { - ArrayData* out_arr = output.array_data().get(); - if (output_descr_.type->id() == Type::NA) { - out_arr->null_count = out_arr->length; - } else if (kernel_->null_handling == NullHandling::INTERSECTION) { - RETURN_NOT_OK(PropagateNulls(kernel_ctx_, input, out_arr)); - } else if (kernel_->null_handling == NullHandling::OUTPUT_NOT_NULL) { - out_arr->null_count = 0; - } - } else { - // TODO(wesm): to remove, see ARROW-16757 - if (kernel_->null_handling == NullHandling::INTERSECTION) { - // set scalar validity - output.scalar()->is_valid = - std::all_of(input.values.begin(), input.values.end(), - [](const ExecValue& input) { return input.scalar->is_valid; }); - } else if (kernel_->null_handling == NullHandling::OUTPUT_NOT_NULL) { - output.scalar()->is_valid = true; - } + ArrayData* out_arr = output.array_data().get(); + if (output_type_.type->id() == Type::NA) { + out_arr->null_count = out_arr->length; + } else if (kernel_->null_handling == NullHandling::INTERSECTION) { + RETURN_NOT_OK(PropagateNulls(kernel_ctx_, input, out_arr)); + } else if (kernel_->null_handling == NullHandling::OUTPUT_NOT_NULL) { + out_arr->null_count = 0; } RETURN_NOT_OK(kernel_->exec(kernel_ctx_, input, &output)); - // Assert that the kernel did not alter the shape of the output - // type. After ARROW-16577 delete this since ValueDescr::SCALAR will not - // exist anymore - DCHECK(((output_descr_.shape == ValueDescr::ARRAY) && output.is_array_data()) || - ((output_descr_.shape == ValueDescr::SCALAR) && output.is_scalar())); + // Output type didn't change + DCHECK(output.is_array_data()); // Emit a result for each chunk - if (output_descr_.shape == ValueDescr::ARRAY) { - RETURN_NOT_OK(listener->OnResult(output.array_data())); - } else { - RETURN_NOT_OK(listener->OnResult(output.scalar())); - } + RETURN_NOT_OK(EmitResult(std::move(output.array_data()), listener)); } return Status::OK(); } Status SetupPreallocation(int64_t total_length, const std::vector& args) { - output_num_buffers_ = static_cast(output_descr_.type->layout().buffers.size()); - auto out_type_id = output_descr_.type->id(); + output_num_buffers_ = static_cast(output_type_.type->layout().buffers.size()); + auto out_type_id = output_type_.type->id(); // Default to no validity pre-allocation for following cases: // - Output Array is NullArray // - kernel_->null_handling is COMPUTED_NO_PREALLOCATE or OUTPUT_NOT_NULL @@ -950,7 +950,7 @@ class ScalarExecutor : public KernelExecutorImpl { } } if (kernel_->mem_allocation == MemAllocation::PREALLOCATE) { - ComputeDataPreallocate(*output_descr_.type, &data_preallocated_); + ComputeDataPreallocate(*output_type_.type, &data_preallocated_); } // Validity bitmap either preallocated or elided, and all data @@ -995,14 +995,24 @@ class ScalarExecutor : public KernelExecutorImpl { ExecSpanIterator span_iterator_; }; +Status CheckCanExecuteChunked(const VectorKernel* kernel) { + if (kernel->exec_chunked == nullptr) { + return Status::Invalid( + "Vector kernel cannot execute chunkwise and no " + "chunked exec function was defined"); + } + + if (kernel->null_handling == NullHandling::INTERSECTION) { + return Status::Invalid( + "Null pre-propagation is unsupported for ChunkedArray " + "execution in vector kernels"); + } + return Status::OK(); +} + class VectorExecutor : public KernelExecutorImpl { public: Status Execute(const ExecBatch& batch, ExecListener* listener) override { - // TODO(wesm): remove in ARROW-16577 - if (output_descr_.shape == ValueDescr::SCALAR) { - return Status::Invalid("VectorExecutor only supports array output types"); - } - // Some vector kernels have a separate code path for handling // chunked arrays (VectorKernel::exec_chunked) so we check if we // have any chunked arrays. If we do and an exec_chunked function @@ -1012,19 +1022,18 @@ class VectorExecutor : public KernelExecutorImpl { if (arg.is_chunked_array()) have_chunked_arrays = true; } - output_num_buffers_ = static_cast(output_descr_.type->layout().buffers.size()); + output_num_buffers_ = static_cast(output_type_.type->layout().buffers.size()); // Decide if we need to preallocate memory for this kernel validity_preallocated_ = (kernel_->null_handling != NullHandling::COMPUTED_NO_PREALLOCATE && kernel_->null_handling != NullHandling::OUTPUT_NOT_NULL); if (kernel_->mem_allocation == MemAllocation::PREALLOCATE) { - ComputeDataPreallocate(*output_descr_.type, &data_preallocated_); + ComputeDataPreallocate(*output_type_.type, &data_preallocated_); } if (kernel_->can_execute_chunkwise) { - RETURN_NOT_OK(span_iterator_.Init(batch, output_descr_.shape, - exec_context()->exec_chunksize())); + RETURN_NOT_OK(span_iterator_.Init(batch, exec_context()->exec_chunksize())); ExecSpan span; while (span_iterator_.Next(&span)) { RETURN_NOT_OK(Exec(span, listener)); @@ -1038,7 +1047,11 @@ class VectorExecutor : public KernelExecutorImpl { } else { // No chunked arrays. We pack the args into an ExecSpan and // call the regular exec code path - RETURN_NOT_OK(Exec(ExecSpan(batch), listener)); + ExecSpan span(batch); + if (CheckIfAllScalar(batch)) { + PromoteExecSpanScalars(&span); + } + RETURN_NOT_OK(Exec(span, listener)); } } @@ -1058,63 +1071,46 @@ class VectorExecutor : public KernelExecutorImpl { // If execution yielded multiple chunks (because large arrays were split // based on the ExecContext parameters, then the result is a ChunkedArray if (kernel_->output_chunked && (HaveChunkedArray(inputs) || outputs.size() > 1)) { - return ToChunkedArray(outputs, output_descr_.type); - } else if (outputs.size() == 1) { + return ToChunkedArray(outputs, output_type_.GetSharedPtr()); + } else { // Outputs have just one element return outputs[0]; - } else { - // XXX: In the case where no outputs are omitted, is returning a 0-length - // array always the correct move? - return MakeArrayOfNull(output_descr_.type, /*length=*/0).ValueOrDie(); } } protected: - Status Exec(const ExecSpan& span, ExecListener* listener) { - ExecResult out; - - // We preallocate (maybe) only for the output of processing the current - // batch, but create an output ArrayData instance regardless - ARROW_ASSIGN_OR_RAISE(out.value, PrepareOutput(span.length)); - - if (kernel_->null_handling == NullHandling::INTERSECTION) { - RETURN_NOT_OK(PropagateNulls(kernel_ctx_, span, out.array_data().get())); - } - RETURN_NOT_OK(kernel_->exec(kernel_ctx_, span, &out)); + Status EmitResult(Datum result, ExecListener* listener) { if (!kernel_->finalize) { // If there is no result finalizer (e.g. for hash-based functions, we can // emit the processed batch right away rather than waiting - RETURN_NOT_OK(listener->OnResult(out.array_data())); + RETURN_NOT_OK(listener->OnResult(std::move(result))); } else { - results_.emplace_back(out.array_data()); + results_.emplace_back(std::move(result)); } return Status::OK(); } - Status ExecChunked(const ExecBatch& batch, ExecListener* listener) { - if (kernel_->exec_chunked == nullptr) { - return Status::Invalid( - "Vector kernel cannot execute chunkwise and no " - "chunked exec function was defined"); - } - + Status Exec(const ExecSpan& span, ExecListener* listener) { + ExecResult out; + ARROW_ASSIGN_OR_RAISE(out.value, PrepareOutput(span.length)); if (kernel_->null_handling == NullHandling::INTERSECTION) { - return Status::Invalid( - "Null pre-propagation is unsupported for ChunkedArray " - "execution in vector kernels"); + RETURN_NOT_OK(PropagateNulls(kernel_ctx_, span, out.array_data().get())); } + RETURN_NOT_OK(kernel_->exec(kernel_ctx_, span, &out)); + return EmitResult(std::move(out.array_data()), listener); + } + Status ExecChunked(const ExecBatch& batch, ExecListener* listener) { + RETURN_NOT_OK(CheckCanExecuteChunked(kernel_)); Datum out; ARROW_ASSIGN_OR_RAISE(out.value, PrepareOutput(batch.length)); RETURN_NOT_OK(kernel_->exec_chunked(kernel_ctx_, batch, &out)); - if (!kernel_->finalize) { - // If there is no result finalizer (e.g. for hash-based functions, we can - // emit the processed batch right away rather than waiting - RETURN_NOT_OK(listener->OnResult(std::move(out))); + if (out.is_array()) { + return EmitResult(std::move(out.array()), listener); } else { - results_.emplace_back(std::move(out)); + DCHECK(out.is_chunked_array()); + return EmitResult(std::move(out.chunked_array()), listener); } - return Status::OK(); } ExecSpanIterator span_iterator_; @@ -1124,7 +1120,7 @@ class VectorExecutor : public KernelExecutorImpl { class ScalarAggExecutor : public KernelExecutorImpl { public: Status Init(KernelContext* ctx, KernelInitArgs args) override { - input_descrs_ = &args.inputs; + input_types_ = &args.inputs; options_ = args.options; return KernelExecutorImpl::Init(ctx, args); } @@ -1160,9 +1156,8 @@ class ScalarAggExecutor : public KernelExecutorImpl { private: Status Consume(const ExecBatch& batch) { // FIXME(ARROW-11840) don't merge *any* aggegates for every batch - ARROW_ASSIGN_OR_RAISE( - auto batch_state, - kernel_->init(kernel_ctx_, {kernel_, *input_descrs_, options_})); + ARROW_ASSIGN_OR_RAISE(auto batch_state, + kernel_->init(kernel_ctx_, {kernel_, *input_types_, options_})); if (batch_state == nullptr) { return Status::Invalid("ScalarAggregation requires non-null kernel state"); @@ -1177,7 +1172,7 @@ class ScalarAggExecutor : public KernelExecutorImpl { } std::unique_ptr batch_iterator_; - const std::vector* input_descrs_; + const std::vector* input_types_; const FunctionOptions* options_; }; @@ -1358,8 +1353,7 @@ Result> SelectionVector::FromMask( Result CallFunction(const std::string& func_name, const std::vector& args, const FunctionOptions* options, ExecContext* ctx) { if (ctx == nullptr) { - ExecContext default_ctx; - return CallFunction(func_name, args, options, &default_ctx); + ctx = default_exec_context(); } ARROW_ASSIGN_OR_RAISE(std::shared_ptr func, ctx->func_registry()->GetFunction(func_name)); @@ -1374,8 +1368,7 @@ Result CallFunction(const std::string& func_name, const std::vector CallFunction(const std::string& func_name, const ExecBatch& batch, const FunctionOptions* options, ExecContext* ctx) { if (ctx == nullptr) { - ExecContext default_ctx; - return CallFunction(func_name, batch, options, &default_ctx); + ctx = default_exec_context(); } ARROW_ASSIGN_OR_RAISE(std::shared_ptr func, ctx->func_registry()->GetFunction(func_name)); diff --git a/cpp/src/arrow/compute/exec.h b/cpp/src/arrow/compute/exec.h index 8fd938ce299..f0b951dccb8 100644 --- a/cpp/src/arrow/compute/exec.h +++ b/cpp/src/arrow/compute/exec.h @@ -235,12 +235,11 @@ struct ARROW_EXPORT ExecBatch { ExecBatch Slice(int64_t offset, int64_t length) const; - /// \brief A convenience for returning the ValueDescr objects (types and - /// shapes) from the batch. - std::vector GetDescriptors() const { - std::vector result; + /// \brief A convenience for returning the types from the batch. + std::vector GetTypes() const { + std::vector result; for (const auto& value : this->values) { - result.emplace_back(value.descr()); + result.emplace_back(value.type()); } return result; } @@ -254,19 +253,16 @@ inline bool operator==(const ExecBatch& l, const ExecBatch& r) { return l.Equals inline bool operator!=(const ExecBatch& l, const ExecBatch& r) { return !l.Equals(r); } struct ExecValue { - enum Kind { ARRAY, SCALAR }; - Kind kind = ARRAY; ArraySpan array; - const Scalar* scalar; + const Scalar* scalar = NULLPTR; ExecValue(Scalar* scalar) // NOLINT implicit conversion - : kind(SCALAR), scalar(scalar) {} + : scalar(scalar) {} ExecValue(ArraySpan array) // NOLINT implicit conversion - : kind(ARRAY), array(std::move(array)) {} + : array(std::move(array)) {} - ExecValue(const ArrayData& array) // NOLINT implicit conversion - : kind(ARRAY) { + ExecValue(const ArrayData& array) { // NOLINT implicit conversion this->array.SetMembers(array); } @@ -278,31 +274,21 @@ struct ExecValue { int64_t length() const { return this->is_array() ? this->array.length : 1; } - bool is_array() const { return this->kind == ARRAY; } - bool is_scalar() const { return this->kind == SCALAR; } + bool is_array() const { return this->scalar == NULLPTR; } + bool is_scalar() const { return !this->is_array(); } void SetArray(const ArrayData& array) { - this->kind = ARRAY; this->array.SetMembers(array); + this->scalar = NULLPTR; } - void SetScalar(const Scalar* scalar) { - this->kind = SCALAR; - this->scalar = scalar; - } + void SetScalar(const Scalar* scalar) { this->scalar = scalar; } template const ExactType& scalar_as() const { return ::arrow::internal::checked_cast(*this->scalar); } - /// XXX: here only temporarily until type resolution can be cleaned - /// up to not use ValueDescr - ValueDescr descr() const { - ValueDescr::Shape shape = this->is_array() ? ValueDescr::ARRAY : ValueDescr::SCALAR; - return ValueDescr(const_cast(this->type())->shared_from_this(), shape); - } - /// XXX: here temporarily for compatibility with datum, see /// e.g. MakeStructExec in scalar_nested.cc int64_t null_count() const { @@ -314,7 +300,7 @@ struct ExecValue { } const DataType* type() const { - if (this->kind == ARRAY) { + if (this->is_array()) { return array.type; } else { return scalar->type.get(); @@ -324,29 +310,21 @@ struct ExecValue { struct ARROW_EXPORT ExecResult { // The default value of the variant is ArraySpan - // TODO(wesm): remove Scalar output modality in ARROW-16577 - util::Variant, std::shared_ptr> value; + util::Variant> value; int64_t length() const { if (this->is_array_span()) { return this->array_span()->length; - } else if (this->is_array_data()) { - return this->array_data()->length; } else { - // Should not reach here - return 1; + return this->array_data()->length; } } const DataType* type() const { - switch (this->value.index()) { - case 0: - return this->array_span()->type; - case 1: - return this->array_data()->type.get(); - default: - // scalar - return this->scalar()->type.get(); + if (this->is_array_span()) { + return this->array_span()->type; + } else { + return this->array_data()->type.get(); } } @@ -360,12 +338,6 @@ struct ARROW_EXPORT ExecResult { } bool is_array_data() const { return this->value.index() == 1; } - - const std::shared_ptr& scalar() const { - return util::get>(this->value); - } - - bool is_scalar() const { return this->value.index() == 2; } }; /// \brief A "lightweight" column batch object which contains no @@ -395,15 +367,6 @@ struct ARROW_EXPORT ExecSpan { } } - bool is_all_scalar() const { - for (const ExecValue& value : this->values) { - if (value.is_array()) { - return false; - } - } - return true; - } - /// \brief Return the value at the i-th index template inline const ExecValue& operator[](index_type i) const { @@ -412,7 +375,7 @@ struct ARROW_EXPORT ExecSpan { void AddOffset(int64_t offset) { for (ExecValue& value : values) { - if (value.kind == ExecValue::ARRAY) { + if (value.is_array()) { value.array.AddOffset(offset); } } @@ -420,7 +383,7 @@ struct ARROW_EXPORT ExecSpan { void SetOffset(int64_t offset) { for (ExecValue& value : values) { - if (value.kind == ExecValue::ARRAY) { + if (value.is_array()) { value.array.SetOffset(offset); } } @@ -429,12 +392,10 @@ struct ARROW_EXPORT ExecSpan { /// \brief A convenience for the number of values / arguments. int num_values() const { return static_cast(values.size()); } - // XXX: eliminate the need for ValueDescr; copied temporarily from - // ExecBatch - std::vector GetDescriptors() const { - std::vector result; + std::vector GetTypes() const { + std::vector result; for (const auto& value : this->values) { - result.emplace_back(value.descr()); + result.emplace_back(value.type()); } return result; } diff --git a/cpp/src/arrow/compute/exec/aggregate.cc b/cpp/src/arrow/compute/exec/aggregate.cc index 41b5bb75b66..5cb9a9c5633 100644 --- a/cpp/src/arrow/compute/exec/aggregate.cc +++ b/cpp/src/arrow/compute/exec/aggregate.cc @@ -31,20 +31,19 @@ namespace internal { Result> GetKernels( ExecContext* ctx, const std::vector& aggregates, - const std::vector& in_descrs) { - if (aggregates.size() != in_descrs.size()) { + const std::vector& in_types) { + if (aggregates.size() != in_types.size()) { return Status::Invalid(aggregates.size(), " aggregate functions were specified but ", - in_descrs.size(), " arguments were provided."); + in_types.size(), " arguments were provided."); } - std::vector kernels(in_descrs.size()); + std::vector kernels(in_types.size()); for (size_t i = 0; i < aggregates.size(); ++i) { ARROW_ASSIGN_OR_RAISE(auto function, ctx->func_registry()->GetFunction(aggregates[i].function)); - ARROW_ASSIGN_OR_RAISE( - const Kernel* kernel, - function->DispatchExact({in_descrs[i], ValueDescr::Array(uint32())})); + ARROW_ASSIGN_OR_RAISE(const Kernel* kernel, + function->DispatchExact({in_types[i], uint32()})); kernels[i] = static_cast(kernel); } return kernels; @@ -52,7 +51,7 @@ Result> GetKernels( Result>> InitKernels( const std::vector& kernels, ExecContext* ctx, - const std::vector& aggregates, const std::vector& in_descrs) { + const std::vector& aggregates, const std::vector& in_types) { std::vector> states(kernels.size()); for (size_t i = 0; i < aggregates.size(); ++i) { @@ -69,14 +68,13 @@ Result>> InitKernels( } KernelContext kernel_ctx{ctx}; - ARROW_ASSIGN_OR_RAISE( - states[i], - kernels[i]->init(&kernel_ctx, KernelInitArgs{kernels[i], - { - in_descrs[i], - ValueDescr::Array(uint32()), - }, - options})); + ARROW_ASSIGN_OR_RAISE(states[i], + kernels[i]->init(&kernel_ctx, KernelInitArgs{kernels[i], + { + in_types[i], + uint32(), + }, + options})); } return std::move(states); @@ -86,19 +84,16 @@ Result ResolveKernels( const std::vector& aggregates, const std::vector& kernels, const std::vector>& states, ExecContext* ctx, - const std::vector& descrs) { - FieldVector fields(descrs.size()); + const std::vector& types) { + FieldVector fields(types.size()); for (size_t i = 0; i < kernels.size(); ++i) { KernelContext kernel_ctx{ctx}; kernel_ctx.SetState(states[i].get()); - ARROW_ASSIGN_OR_RAISE(auto descr, kernels[i]->signature->out_type().Resolve( - &kernel_ctx, { - descrs[i], - ValueDescr::Array(uint32()), - })); - fields[i] = field(aggregates[i].function, std::move(descr.type)); + ARROW_ASSIGN_OR_RAISE(auto type, kernels[i]->signature->out_type().Resolve( + &kernel_ctx, {types[i], uint32()})); + fields[i] = field(aggregates[i].function, type.GetSharedPtr()); } return fields; } @@ -122,18 +117,17 @@ Result GroupBy(const std::vector& arguments, const std::vectorparallelism()); for (auto& state : states) { - ARROW_ASSIGN_OR_RAISE(state, - InitKernels(kernels, ctx, aggregates, argument_descrs)); + ARROW_ASSIGN_OR_RAISE(state, InitKernels(kernels, ctx, aggregates, argument_types)); } ARROW_ASSIGN_OR_RAISE( - out_fields, ResolveKernels(aggregates, kernels, states[0], ctx, argument_descrs)); + out_fields, ResolveKernels(aggregates, kernels, states[0], ctx, argument_types)); ARROW_ASSIGN_OR_RAISE( argument_batch_iterator, @@ -142,19 +136,19 @@ Result GroupBy(const std::vector& arguments, const std::vector> groupers(task_group->parallelism()); for (auto& grouper : groupers) { - ARROW_ASSIGN_OR_RAISE(grouper, Grouper::Make(key_descrs, ctx)); + ARROW_ASSIGN_OR_RAISE(grouper, Grouper::Make(key_types, ctx)); } std::mutex mutex; std::unordered_map thread_ids; int i = 0; - for (ValueDescr& key_descr : key_descrs) { - out_fields.push_back(field("key_" + std::to_string(i++), std::move(key_descr.type))); + for (const TypeHolder& key_type : key_types) { + out_fields.push_back(field("key_" + std::to_string(i++), key_type.GetSharedPtr())); } ARROW_ASSIGN_OR_RAISE( diff --git a/cpp/src/arrow/compute/exec/aggregate.h b/cpp/src/arrow/compute/exec/aggregate.h index 753b0a8c47e..72990f3b6e7 100644 --- a/cpp/src/arrow/compute/exec/aggregate.h +++ b/cpp/src/arrow/compute/exec/aggregate.h @@ -42,17 +42,17 @@ Result GroupBy(const std::vector& arguments, const std::vector> GetKernels( ExecContext* ctx, const std::vector& aggregates, - const std::vector& in_descrs); + const std::vector& in_types); Result>> InitKernels( const std::vector& kernels, ExecContext* ctx, - const std::vector& aggregates, const std::vector& in_descrs); + const std::vector& aggregates, const std::vector& in_types); Result ResolveKernels( const std::vector& aggregates, const std::vector& kernels, const std::vector>& states, ExecContext* ctx, - const std::vector& descrs); + const std::vector& in_types); } // namespace internal } // namespace compute diff --git a/cpp/src/arrow/compute/exec/aggregate_node.cc b/cpp/src/arrow/compute/exec/aggregate_node.cc index 8c7899c41ec..0131319be3b 100644 --- a/cpp/src/arrow/compute/exec/aggregate_node.cc +++ b/cpp/src/arrow/compute/exec/aggregate_node.cc @@ -104,8 +104,7 @@ class ScalarAggregateNode : public ExecNode { aggregates[i].function); } - auto in_type = ValueDescr::Array(input_schema.field(target_field_ids[i])->type()); - + TypeHolder in_type(input_schema.field(target_field_ids[i])->type().get()); ARROW_ASSIGN_OR_RAISE(const Kernel* kernel, function->DispatchExact({in_type})); kernels[i] = static_cast(kernel); @@ -125,10 +124,10 @@ class ScalarAggregateNode : public ExecNode { // pick one to resolve the kernel signature kernel_ctx.SetState(states[i][0].get()); - ARROW_ASSIGN_OR_RAISE( - auto descr, kernels[i]->signature->out_type().Resolve(&kernel_ctx, {in_type})); + ARROW_ASSIGN_OR_RAISE(auto out_type, kernels[i]->signature->out_type().Resolve( + &kernel_ctx, {in_type})); - fields[i] = field(aggregate_options.aggregates[i].name, std::move(descr.type)); + fields[i] = field(aggregate_options.aggregates[i].name, out_type.GetSharedPtr()); } return plan->EmplaceNode( @@ -313,25 +312,24 @@ class GroupByNode : public ExecNode { } // Build vector of aggregate source field data types - std::vector agg_src_descrs(aggs.size()); + std::vector agg_src_types(aggs.size()); for (size_t i = 0; i < aggs.size(); ++i) { auto agg_src_field_id = agg_src_field_ids[i]; - agg_src_descrs[i] = - ValueDescr(input_schema->field(agg_src_field_id)->type(), ValueDescr::ARRAY); + agg_src_types[i] = input_schema->field(agg_src_field_id)->type().get(); } auto ctx = input->plan()->exec_context(); // Construct aggregates ARROW_ASSIGN_OR_RAISE(auto agg_kernels, - internal::GetKernels(ctx, aggs, agg_src_descrs)); + internal::GetKernels(ctx, aggs, agg_src_types)); ARROW_ASSIGN_OR_RAISE(auto agg_states, - internal::InitKernels(agg_kernels, ctx, aggs, agg_src_descrs)); + internal::InitKernels(agg_kernels, ctx, aggs, agg_src_types)); ARROW_ASSIGN_OR_RAISE( FieldVector agg_result_fields, - internal::ResolveKernels(aggs, agg_kernels, agg_states, ctx, agg_src_descrs)); + internal::ResolveKernels(aggs, agg_kernels, agg_states, ctx, agg_src_types)); // Build field vector for output schema FieldVector output_fields{keys.size() + aggs.size()}; @@ -621,26 +619,24 @@ class GroupByNode : public ExecNode { if (state->grouper != nullptr) return Status::OK(); // Build vector of key field data types - std::vector key_descrs(key_field_ids_.size()); + std::vector key_types(key_field_ids_.size()); for (size_t i = 0; i < key_field_ids_.size(); ++i) { auto key_field_id = key_field_ids_[i]; - key_descrs[i] = ValueDescr(input_schema->field(key_field_id)->type()); + key_types[i] = input_schema->field(key_field_id)->type().get(); } // Construct grouper - ARROW_ASSIGN_OR_RAISE(state->grouper, Grouper::Make(key_descrs, ctx_)); + ARROW_ASSIGN_OR_RAISE(state->grouper, Grouper::Make(key_types, ctx_)); // Build vector of aggregate source field data types - std::vector agg_src_descrs(agg_kernels_.size()); + std::vector agg_src_types(agg_kernels_.size()); for (size_t i = 0; i < agg_kernels_.size(); ++i) { auto agg_src_field_id = agg_src_field_ids_[i]; - agg_src_descrs[i] = - ValueDescr(input_schema->field(agg_src_field_id)->type(), ValueDescr::ARRAY); + agg_src_types[i] = input_schema->field(agg_src_field_id)->type().get(); } - ARROW_ASSIGN_OR_RAISE( - state->agg_states, - internal::InitKernels(agg_kernels_, ctx_, aggs_, agg_src_descrs)); + ARROW_ASSIGN_OR_RAISE(state->agg_states, internal::InitKernels(agg_kernels_, ctx_, + aggs_, agg_src_types)); return Status::OK(); } diff --git a/cpp/src/arrow/compute/exec/expression.cc b/cpp/src/arrow/compute/exec/expression.cc index b796f5cda3b..c890b3c5935 100644 --- a/cpp/src/arrow/compute/exec/expression.cc +++ b/cpp/src/arrow/compute/exec/expression.cc @@ -64,7 +64,7 @@ Expression::Expression(Parameter parameter) Expression literal(Datum lit) { return Expression(std::move(lit)); } Expression field_ref(FieldRef ref) { - return Expression(Expression::Parameter{std::move(ref), ValueDescr{}, {-1}}); + return Expression(Expression::Parameter{std::move(ref), TypeHolder{}, {-1}}); } Expression call(std::string function, std::vector arguments, @@ -93,36 +93,18 @@ const Expression::Call* Expression::call() const { return util::get_if(impl_.get()); } -ValueDescr Expression::descr() const { - if (impl_ == nullptr) return {}; +const DataType* Expression::type() const { + if (impl_ == nullptr) return nullptr; - if (auto lit = literal()) { - return lit->descr(); - } - - if (auto parameter = this->parameter()) { - return parameter->descr; - } - - return CallNotNull(*this)->descr; -} - -// This is a module-global singleton to avoid synchronization costs of a -// function-static singleton. -static const std::shared_ptr kNoType; - -const std::shared_ptr& Expression::type() const { - if (impl_ == nullptr) return kNoType; - - if (auto lit = literal()) { - return lit->type(); + if (const Datum* lit = literal()) { + return lit->type().get(); } - if (auto parameter = this->parameter()) { - return parameter->descr.type; + if (const Parameter* parameter = this->parameter()) { + return parameter->type.type; } - return CallNotNull(*this)->descr.type; + return CallNotNull(*this)->type.type; } namespace { @@ -276,7 +258,7 @@ size_t Expression::hash() const { bool Expression::IsBound() const { if (type() == nullptr) return false; - if (auto call = this->call()) { + if (const Call* call = this->call()) { if (call->kernel == nullptr) return false; for (const Expression& arg : call->arguments) { @@ -338,7 +320,7 @@ util::optional GetNullHandling( } // namespace bool Expression::IsSatisfiable() const { - if (!type()) return true; + if (type() == nullptr) return true; if (type()->id() != Type::BOOL) return true; if (auto lit = literal()) { @@ -382,25 +364,20 @@ Result BindNonRecursive(Expression::Call call, bool insert_implicit_ DCHECK(std::all_of(call.arguments.begin(), call.arguments.end(), [](const Expression& argument) { return argument.IsBound(); })); - auto descrs = GetDescriptors(call.arguments); + std::vector types = GetTypes(call.arguments); ARROW_ASSIGN_OR_RAISE(call.function, GetFunction(call, exec_context)); if (!insert_implicit_casts) { - ARROW_ASSIGN_OR_RAISE(call.kernel, call.function->DispatchExact(descrs)); + ARROW_ASSIGN_OR_RAISE(call.kernel, call.function->DispatchExact(types)); } else { - ARROW_ASSIGN_OR_RAISE(call.kernel, call.function->DispatchBest(&descrs)); + ARROW_ASSIGN_OR_RAISE(call.kernel, call.function->DispatchBest(&types)); - for (size_t i = 0; i < descrs.size(); ++i) { - if (descrs[i] == call.arguments[i].descr()) continue; + for (size_t i = 0; i < types.size(); ++i) { + if (types[i] == call.arguments[i].type()) continue; - if (descrs[i].shape != call.arguments[i].descr().shape) { - return Status::NotImplemented( - "Automatic broadcasting of scalars arguments to arrays in ", - Expression(std::move(call)).ToString()); - } - - if (auto lit = call.arguments[i].literal()) { - ARROW_ASSIGN_OR_RAISE(Datum new_lit, compute::Cast(*lit, descrs[i].type)); + if (const Datum* lit = call.arguments[i].literal()) { + ARROW_ASSIGN_OR_RAISE(Datum new_lit, + compute::Cast(*lit, types[i].GetSharedPtr())); call.arguments[i] = literal(std::move(new_lit)); continue; } @@ -409,8 +386,10 @@ Result BindNonRecursive(Expression::Call call, bool insert_implicit_ Expression::Call implicit_cast; implicit_cast.function_name = "cast"; implicit_cast.arguments = {std::move(call.arguments[i])}; + + // TODO(wesm): Use TypeHolder in options implicit_cast.options = std::make_shared( - compute::CastOptions::Safe(descrs[i].type)); + compute::CastOptions::Safe(types[i].GetSharedPtr())); ARROW_ASSIGN_OR_RAISE( call.arguments[i], @@ -425,43 +404,41 @@ Result BindNonRecursive(Expression::Call call, bool insert_implicit_ call.options ? call.options.get() : call.function->default_options(); ARROW_ASSIGN_OR_RAISE( call.kernel_state, - call.kernel->init(&kernel_context, {call.kernel, descrs, options})); + call.kernel->init(&kernel_context, {call.kernel, types, options})); kernel_context.SetState(call.kernel_state.get()); } ARROW_ASSIGN_OR_RAISE( - call.descr, call.kernel->signature->out_type().Resolve(&kernel_context, descrs)); + call.type, call.kernel->signature->out_type().Resolve(&kernel_context, types)); return Expression(std::move(call)); } template Result BindImpl(Expression expr, const TypeOrSchema& in, - ValueDescr::Shape shape, compute::ExecContext* exec_context) { + compute::ExecContext* exec_context) { if (exec_context == nullptr) { compute::ExecContext exec_context; - return BindImpl(std::move(expr), in, shape, &exec_context); + return BindImpl(std::move(expr), in, &exec_context); } if (expr.literal()) return expr; - if (auto ref = expr.field_ref()) { - ARROW_ASSIGN_OR_RAISE(auto path, ref->FindOne(in)); + if (const FieldRef* ref = expr.field_ref()) { + ARROW_ASSIGN_OR_RAISE(FieldPath path, ref->FindOne(in)); - auto bound = *expr.parameter(); - bound.indices.resize(path.indices().size()); - std::copy(path.indices().begin(), path.indices().end(), bound.indices.begin()); + Expression::Parameter param = *expr.parameter(); + param.indices.resize(path.indices().size()); + std::copy(path.indices().begin(), path.indices().end(), param.indices.begin()); ARROW_ASSIGN_OR_RAISE(auto field, path.Get(in)); - bound.descr.type = field->type(); - bound.descr.shape = shape; - return Expression{std::move(bound)}; + param.type = field->type(); + return Expression{std::move(param)}; } auto call = *CallNotNull(expr); for (auto& argument : call.arguments) { - ARROW_ASSIGN_OR_RAISE(argument, - BindImpl(std::move(argument), in, shape, exec_context)); + ARROW_ASSIGN_OR_RAISE(argument, BindImpl(std::move(argument), in, exec_context)); } return BindNonRecursive(std::move(call), /*insert_implicit_casts=*/true, exec_context); @@ -469,14 +446,14 @@ Result BindImpl(Expression expr, const TypeOrSchema& in, } // namespace -Result Expression::Bind(const ValueDescr& in, +Result Expression::Bind(const TypeHolder& in, compute::ExecContext* exec_context) const { - return BindImpl(*this, *in.type, in.shape, exec_context); + return BindImpl(*this, *in.type, exec_context); } Result Expression::Bind(const Schema& in_schema, compute::ExecContext* exec_context) const { - return BindImpl(*this, in_schema, ValueDescr::ARRAY, exec_context); + return BindImpl(*this, in_schema, exec_context); } Result MakeExecBatch(const Schema& full_schema, const Datum& partial) { @@ -558,7 +535,7 @@ Result ExecuteScalarExpression(const Expression& expr, const ExecBatch& i if (auto lit = expr.literal()) return *lit; if (auto param = expr.parameter()) { - if (param->descr.type->id() == Type::NA) { + if (param->type.id() == Type::NA) { return MakeNullScalar(null()); } @@ -569,10 +546,10 @@ Result ExecuteScalarExpression(const Expression& expr, const ExecBatch& i ARROW_ASSIGN_OR_RAISE( field, compute::CallFunction("struct_field", {std::move(field)}, &options)); } - if (!field.type()->Equals(param->descr.type)) { + if (!field.type()->Equals(*param->type.type)) { return Status::Invalid("Referenced field ", expr.ToString(), " was ", field.type()->ToString(), " but should have been ", - param->descr.type->ToString()); + param->type.ToString()); } return field; @@ -596,10 +573,10 @@ Result ExecuteScalarExpression(const Expression& expr, const ExecBatch& i compute::KernelContext kernel_context(exec_context, call->kernel); kernel_context.SetState(call->kernel_state.get()); - auto kernel = call->kernel; - auto descrs = GetDescriptors(arguments); + const Kernel* kernel = call->kernel; + std::vector types = GetTypes(arguments); auto options = call->options.get(); - RETURN_NOT_OK(executor->Init(&kernel_context, {kernel, descrs, options})); + RETURN_NOT_OK(executor->Init(&kernel_context, {kernel, types, options})); compute::detail::DatumAccumulator listener; RETURN_NOT_OK(executor->Execute( @@ -683,16 +660,16 @@ Result FoldConstants(Expression expr) { if (GetNullHandling(*call) == compute::NullHandling::INTERSECTION) { // kernels which always produce intersected validity can be resolved // to null *now* if any of their inputs is a null literal - if (!call->descr.type) { + if (!call->type.type) { return Status::Invalid("Cannot fold constants for unbound expression ", expr.ToString()); } - for (const auto& argument : call->arguments) { + for (const Expression& argument : call->arguments) { if (argument.IsNullLiteral()) { - if (argument.type()->Equals(*call->descr.type)) { + if (argument.type()->Equals(*call->type.type)) { return argument; } else { - return literal(MakeNullScalar(call->descr.type)); + return literal(MakeNullScalar(call->type.GetSharedPtr())); } } } @@ -815,7 +792,7 @@ Result ReplaceFieldsWithKnownValues(const KnownFieldValues& known_va auto it = known_values.map.find(*ref); if (it != known_values.map.end()) { Datum lit = it->second; - if (lit.descr() == expr.descr()) return literal(std::move(lit)); + if (lit.type()->Equals(*expr.type())) return literal(std::move(lit)); // type mismatch, try casting the known value to the correct type if (expr.type()->id() == Type::DICTIONARY && @@ -836,7 +813,7 @@ Result ReplaceFieldsWithKnownValues(const KnownFieldValues& known_va } } - ARROW_ASSIGN_OR_RAISE(lit, compute::Cast(lit, expr.type())); + ARROW_ASSIGN_OR_RAISE(lit, compute::Cast(lit, expr.type()->GetSharedPtr())); return literal(std::move(lit)); } } diff --git a/cpp/src/arrow/compute/exec/expression.h b/cpp/src/arrow/compute/exec/expression.h index a1765d0fcca..e9026961aa9 100644 --- a/cpp/src/arrow/compute/exec/expression.h +++ b/cpp/src/arrow/compute/exec/expression.h @@ -55,7 +55,7 @@ class ARROW_EXPORT Expression { std::shared_ptr function; const Kernel* kernel = NULLPTR; std::shared_ptr kernel_state; - ValueDescr descr; + TypeHolder type; void ComputeHash(); }; @@ -70,7 +70,7 @@ class ARROW_EXPORT Expression { /// Bind this expression to the given input type, looking up Kernels and field types. /// Some expression simplification may be performed and implicit casts will be inserted. /// Any state necessary for execution will be initialized and returned. - Result Bind(const ValueDescr& in, ExecContext* = NULLPTR) const; + Result Bind(const TypeHolder& in, ExecContext* = NULLPTR) const; Result Bind(const Schema& in_schema, ExecContext* = NULLPTR) const; // XXX someday @@ -82,8 +82,8 @@ class ARROW_EXPORT Expression { // Result CloneState() const; // Status SetState(ExpressionState); - /// Return true if all an expression's field references have explicit ValueDescr and all - /// of its functions' kernels are looked up. + /// Return true if all an expression's field references have explicit types + /// and all of its functions' kernels are looked up. bool IsBound() const; /// Return true if this expression is composed only of Scalar literals, field @@ -107,9 +107,8 @@ class ARROW_EXPORT Expression { /// Access a FieldRef or return nullptr if this expression is not a field_ref const FieldRef* field_ref() const; - /// The type and shape to which this expression will evaluate - ValueDescr descr() const; - const std::shared_ptr& type() const; + /// The type to which this expression will evaluate + const DataType* type() const; // XXX someday // NullGeneralization::type nullable() const; @@ -117,7 +116,7 @@ class ARROW_EXPORT Expression { FieldRef ref; // post-bind properties - ValueDescr descr; + TypeHolder type; ::arrow::internal::SmallVector indices; }; const Parameter* parameter() const; diff --git a/cpp/src/arrow/compute/exec/expression_internal.h b/cpp/src/arrow/compute/exec/expression_internal.h index f8c686d2c81..7490d116c54 100644 --- a/cpp/src/arrow/compute/exec/expression_internal.h +++ b/cpp/src/arrow/compute/exec/expression_internal.h @@ -23,6 +23,7 @@ #include "arrow/compute/api_scalar.h" #include "arrow/compute/cast.h" +#include "arrow/compute/cast_internal.h" #include "arrow/compute/registry.h" #include "arrow/record_batch.h" #include "arrow/table.h" @@ -31,6 +32,8 @@ namespace arrow { namespace compute { +using internal::GetCastFunction; + struct KnownFieldValues { std::unordered_map map; }; @@ -41,21 +44,21 @@ inline const Expression::Call* CallNotNull(const Expression& expr) { return call; } -inline std::vector GetDescriptors(const std::vector& exprs) { - std::vector descrs(exprs.size()); +inline std::vector GetTypes(const std::vector& exprs) { + std::vector types(exprs.size()); for (size_t i = 0; i < exprs.size(); ++i) { DCHECK(exprs[i].IsBound()); - descrs[i] = exprs[i].descr(); + types[i] = exprs[i].type(); } - return descrs; + return types; } -inline std::vector GetDescriptors(const std::vector& values) { - std::vector descrs(values.size()); +inline std::vector GetTypes(const std::vector& values) { + std::vector types(values.size()); for (size_t i = 0; i < values.size(); ++i) { - descrs[i] = values[i].descr(); + types[i] = values[i].type(); } - return descrs; + return types; } struct Comparison { @@ -281,7 +284,7 @@ inline Result> GetFunction( // XXX this special case is strange; why not make "cast" a ScalarFunction? const auto& to_type = ::arrow::internal::checked_cast(*call.options).to_type; - return compute::GetCastFunction(to_type); + return GetCastFunction(to_type); } /// Modify an Expression with pre-order and post-order visitation. diff --git a/cpp/src/arrow/compute/exec/expression_test.cc b/cpp/src/arrow/compute/exec/expression_test.cc index 95adb1652eb..b4466d827eb 100644 --- a/cpp/src/arrow/compute/exec/expression_test.cc +++ b/cpp/src/arrow/compute/exec/expression_test.cc @@ -493,8 +493,8 @@ TEST(Expression, BindLiteral) { Datum(ArrayFromJSON(int32(), "[1,2,3]")), }) { // literals are always considered bound - auto expr = literal(dat); - EXPECT_EQ(expr.descr(), dat.descr()); + Expression expr = literal(dat); + EXPECT_TRUE(dat.type()->Equals(*expr.type())); EXPECT_TRUE(expr.IsBound()); } } @@ -518,13 +518,13 @@ void ExpectBindsTo(Expression expr, util::optional expected, } TEST(Expression, BindFieldRef) { - // an unbound field_ref does not have the output ValueDescr set + // an unbound field_ref does not have the output type set auto expr = field_ref("alpha"); - EXPECT_EQ(expr.descr(), ValueDescr{}); + EXPECT_EQ(expr.type(), nullptr); EXPECT_FALSE(expr.IsBound()); ExpectBindsTo(field_ref("i32"), no_change, &expr); - EXPECT_EQ(expr.descr(), ValueDescr::Array(int32())); + EXPECT_TRUE(expr.type()->Equals(*int32())); // if the field is not found, an error will be raised ASSERT_RAISES(Invalid, field_ref("no such field").Bind(*kBoringSchema)); @@ -541,11 +541,11 @@ TEST(Expression, BindNestedFieldRef) { ExpectBindsTo(field_ref(FieldRef("a", "b")), no_change, &expr, schema); EXPECT_TRUE(expr.IsBound()); - EXPECT_EQ(expr.descr(), ValueDescr::Array(int32())); + EXPECT_TRUE(expr.type()->Equals(*int32())); ExpectBindsTo(field_ref(FieldRef(FieldPath({0, 0}))), no_change, &expr, schema); EXPECT_TRUE(expr.IsBound()); - EXPECT_EQ(expr.descr(), ValueDescr::Array(int32())); + EXPECT_TRUE(expr.type()->Equals(*int32())); ASSERT_RAISES(Invalid, field_ref(FieldPath({0, 1})).Bind(schema)); ASSERT_RAISES(Invalid, field_ref(FieldRef("a", "b")) @@ -558,7 +558,7 @@ TEST(Expression, BindCall) { EXPECT_FALSE(expr.IsBound()); ExpectBindsTo(expr, no_change, &expr); - EXPECT_EQ(expr.descr(), ValueDescr::Array(int32())); + EXPECT_TRUE(expr.type()->Equals(*int32())); ExpectBindsTo(call("add", {field_ref("f32"), literal(3)}), call("add", {field_ref("f32"), literal(3.0F)})); @@ -607,7 +607,7 @@ TEST(Expression, BindNestedCall) { ASSERT_OK_AND_ASSIGN(expr, expr.Bind(Schema({field("a", int32()), field("b", int32()), field("c", int32()), field("d", int32())}))); - EXPECT_EQ(expr.descr(), ValueDescr::Array(int32())); + EXPECT_TRUE(expr.type()->Equals(*int32())); EXPECT_TRUE(expr.IsBound()); } @@ -615,7 +615,7 @@ TEST(Expression, ExecuteFieldRef) { auto ExpectRefIs = [](FieldRef ref, Datum in, Datum expected) { auto expr = field_ref(ref); - ASSERT_OK_AND_ASSIGN(expr, expr.Bind(in.descr())); + ASSERT_OK_AND_ASSIGN(expr, expr.Bind(in.type())); ASSERT_OK_AND_ASSIGN(Datum actual, ExecuteScalarExpression(expr, Schema(in.type()->fields()), in)); @@ -716,8 +716,8 @@ Result NaiveExecuteScalarExpression(const Expression& expr, const Datum& compute::ExecContext exec_context; ARROW_ASSIGN_OR_RAISE(auto function, GetFunction(*call, &exec_context)); - auto descrs = GetDescriptors(call->arguments); - ARROW_ASSIGN_OR_RAISE(auto expected_kernel, function->DispatchExact(descrs)); + std::vector types = GetTypes(call->arguments); + ARROW_ASSIGN_OR_RAISE(auto expected_kernel, function->DispatchExact(types)); EXPECT_EQ(call->kernel, expected_kernel); return function->Execute(arguments, call->options.get(), &exec_context); @@ -726,7 +726,7 @@ Result NaiveExecuteScalarExpression(const Expression& expr, const Datum& void ExpectExecute(Expression expr, Datum in, Datum* actual_out = NULLPTR) { std::shared_ptr schm; if (in.is_value()) { - ASSERT_OK_AND_ASSIGN(expr, expr.Bind(in.descr())); + ASSERT_OK_AND_ASSIGN(expr, expr.Bind(in.type())); schm = schema(in.type()->fields()); } else { ASSERT_OK_AND_ASSIGN(expr, expr.Bind(*in.schema())); diff --git a/cpp/src/arrow/compute/exec/hash_join.cc b/cpp/src/arrow/compute/exec/hash_join.cc index a145863e597..a376fb5f57b 100644 --- a/cpp/src/arrow/compute/exec/hash_join.cc +++ b/cpp/src/arrow/compute/exec/hash_join.cc @@ -84,13 +84,11 @@ class HashJoinBasicImpl : public HashJoinImpl { private: void InitEncoder(int side, HashJoinProjection projection_handle, RowEncoder* encoder) { - std::vector data_types; + std::vector data_types; int num_cols = schema_mgr_->proj_maps[side].num_cols(projection_handle); data_types.resize(num_cols); for (int icol = 0; icol < num_cols; ++icol) { - data_types[icol] = - ValueDescr(schema_mgr_->proj_maps[side].data_type(projection_handle, icol), - ValueDescr::ARRAY); + data_types[icol] = schema_mgr_->proj_maps[side].data_type(projection_handle, icol); } encoder->Init(data_types, ctx_); encoder->Clear(); diff --git a/cpp/src/arrow/compute/exec/hash_join_dict.cc b/cpp/src/arrow/compute/exec/hash_join_dict.cc index 731a5662d7d..560b0ea8d4d 100644 --- a/cpp/src/arrow/compute/exec/hash_join_dict.cc +++ b/cpp/src/arrow/compute/exec/hash_join_dict.cc @@ -224,8 +224,8 @@ Status HashJoinDictBuild::Init(ExecContext* ctx, std::shared_ptr dictiona // Initialize encoder internal::RowEncoder encoder; - std::vector encoder_types; - encoder_types.emplace_back(value_type_, ValueDescr::ARRAY); + std::vector encoder_types; + encoder_types.emplace_back(value_type_); encoder.Init(encoder_types, ctx); // Encode all dictionary values @@ -285,8 +285,7 @@ Result> HashJoinDictBuild::RemapInputValues( // Initialize encoder // internal::RowEncoder encoder; - std::vector encoder_types; - encoder_types.emplace_back(value_type_, ValueDescr::ARRAY); + std::vector encoder_types = {value_type_}; encoder.Init(encoder_types, ctx); // Encode all @@ -422,8 +421,7 @@ Result> HashJoinDictProbe::RemapInput( remapped_ids_, opt_build_side->RemapInputValues(ctx, Datum(dict->data()), dict->length())); } else { - std::vector encoder_types; - encoder_types.emplace_back(dict_type.value_type(), ValueDescr::ARRAY); + std::vector encoder_types = {dict_type.value_type()}; encoder_.Init(encoder_types, ctx); RETURN_NOT_OK( encoder_.EncodeAndAppend(ExecSpan({*dict->data()}, dict->length()))); @@ -516,14 +514,14 @@ void HashJoinDictBuildMulti::InitEncoder( const SchemaProjectionMaps& proj_map, RowEncoder* encoder, ExecContext* ctx) { int num_cols = proj_map.num_cols(HashJoinProjection::KEY); - std::vector data_types(num_cols); + std::vector data_types(num_cols); for (int icol = 0; icol < num_cols; ++icol) { std::shared_ptr data_type = proj_map.data_type(HashJoinProjection::KEY, icol); if (HashJoinDictBuild::KeyNeedsProcessing(data_type)) { data_type = HashJoinDictBuild::DataTypeAfterRemapping(); } - data_types[icol] = ValueDescr(data_type, ValueDescr::ARRAY); + data_types[icol] = data_type; } encoder->Init(data_types, ctx); } @@ -610,7 +608,7 @@ void HashJoinDictProbeMulti::InitEncoder( const SchemaProjectionMaps& proj_map_build, RowEncoder* encoder, ExecContext* ctx) { int num_cols = proj_map_probe.num_cols(HashJoinProjection::KEY); - std::vector data_types(num_cols); + std::vector data_types(num_cols); for (int icol = 0; icol < num_cols; ++icol) { std::shared_ptr data_type = proj_map_probe.data_type(HashJoinProjection::KEY, icol); @@ -619,7 +617,7 @@ void HashJoinDictProbeMulti::InitEncoder( if (HashJoinDictProbe::KeyNeedsProcessing(data_type, build_data_type)) { data_type = HashJoinDictProbe::DataTypeAfterRemapping(build_data_type); } - data_types[icol] = ValueDescr(data_type, ValueDescr::ARRAY); + data_types[icol] = data_type; } encoder->Init(data_types, ctx); } diff --git a/cpp/src/arrow/compute/exec/hash_join_node_test.cc b/cpp/src/arrow/compute/exec/hash_join_node_test.cc index 46600a96da3..9a3c7342788 100644 --- a/cpp/src/arrow/compute/exec/hash_join_node_test.cc +++ b/cpp/src/arrow/compute/exec/hash_join_node_test.cc @@ -44,13 +44,13 @@ BatchesWithSchema GenerateBatchesFromString( const std::vector& json_strings, int multiplicity = 1) { BatchesWithSchema out_batches{{}, schema}; - std::vector descrs; + std::vector types; for (auto&& field : schema->fields()) { - descrs.emplace_back(field->type()); + types.emplace_back(field->type()); } for (auto&& s : json_strings) { - out_batches.batches.push_back(ExecBatchFromJSON(descrs, s)); + out_batches.batches.push_back(ExecBatchFromJSON(types, s)); } size_t batch_count = out_batches.batches.size(); @@ -473,7 +473,7 @@ void TakeUsingVector(ExecContext* ctx, const std::vector> } } -// Generate random arrays given list of data type descriptions and null probabilities. +// Generate random arrays given list of data types and null probabilities. // Make sure that all generated records are unique. // The actual number of generated records may be lower than desired because duplicates // will be removed without replacement. @@ -485,12 +485,12 @@ std::vector> GenRandomUniqueRecords( GenRandomRecords(rng, data_types.data_types, num_desired); ExecContext* ctx = default_exec_context(); - std::vector val_descrs; + std::vector val_types; for (size_t i = 0; i < result.size(); ++i) { - val_descrs.push_back(ValueDescr(result[i]->type(), ValueDescr::ARRAY)); + val_types.push_back(result[i]->type()); } internal::RowEncoder encoder; - encoder.Init(val_descrs, ctx); + encoder.Init(val_types, ctx); ExecBatch batch({}, num_desired); batch.values.resize(result.size()); for (size_t i = 0; i < result.size(); ++i) { diff --git a/cpp/src/arrow/compute/exec/plan_test.cc b/cpp/src/arrow/compute/exec/plan_test.cc index 9efa6623e5a..f67d541e1ea 100644 --- a/cpp/src/arrow/compute/exec/plan_test.cc +++ b/cpp/src/arrow/compute/exec/plan_test.cc @@ -1133,12 +1133,11 @@ TEST(ExecPlanExecution, SourceScalarAggSink) { }) .AddToPlan(plan.get())); - ASSERT_THAT( - StartAndCollect(plan.get(), sink_gen), - Finishes(ResultWith(UnorderedElementsAreArray({ - ExecBatchFromJSON({ValueDescr::Scalar(int64()), ValueDescr::Scalar(boolean())}, - "[[22, true]]"), - })))); + ASSERT_THAT(StartAndCollect(plan.get(), sink_gen), + Finishes(ResultWith(UnorderedElementsAreArray({ + ExecBatchFromJSON({int64(), boolean()}, + {ArgShape::SCALAR, ArgShape::SCALAR}, "[[22, true]]"), + })))); } TEST(ExecPlanExecution, AggregationPreservesOptions) { @@ -1168,7 +1167,7 @@ TEST(ExecPlanExecution, AggregationPreservesOptions) { ASSERT_THAT(StartAndCollect(plan.get(), sink_gen), Finishes(ResultWith(UnorderedElementsAreArray({ - ExecBatchFromJSON({ValueDescr::Array(float64())}, "[[5.5]]"), + ExecBatchFromJSON({float64()}, "[[5.5]]"), })))); } { @@ -1209,7 +1208,7 @@ TEST(ExecPlanExecution, ScalarSourceScalarAggSink) { BatchesWithSchema scalar_data; scalar_data.batches = { - ExecBatchFromJSON({ValueDescr::Scalar(int32()), ValueDescr::Scalar(boolean())}, + ExecBatchFromJSON({int32(), boolean()}, {ArgShape::SCALAR, ArgShape::SCALAR}, "[[5, false], [5, false], [5, false]]"), ExecBatchFromJSON({int32(), boolean()}, "[[5, true], [6, false], [7, true]]")}; scalar_data.schema = schema({field("a", int32()), field("b", boolean())}); @@ -1239,11 +1238,11 @@ TEST(ExecPlanExecution, ScalarSourceScalarAggSink) { StartAndCollect(plan.get(), sink_gen), Finishes(ResultWith(UnorderedElementsAreArray({ ExecBatchFromJSON( - {ValueDescr::Scalar(boolean()), ValueDescr::Scalar(boolean()), - ValueDescr::Scalar(int64()), ValueDescr::Scalar(float64()), - ValueDescr::Scalar(int64()), ValueDescr::Scalar(float64()), - ValueDescr::Scalar(int64()), ValueDescr::Array(float64()), - ValueDescr::Scalar(float64())}, + {boolean(), boolean(), int64(), float64(), int64(), float64(), int64(), + float64(), float64()}, + {ArgShape::SCALAR, ArgShape::SCALAR, ArgShape::SCALAR, ArgShape::SCALAR, + ArgShape::SCALAR, ArgShape::SCALAR, ArgShape::SCALAR, ArgShape::ARRAY, + ArgShape::SCALAR}, R"([[false, true, 6, 5.5, 26250, 0.7637626158259734, 33, 5.0, 0.5833333333333334]])"), })))); } @@ -1255,9 +1254,9 @@ TEST(ExecPlanExecution, ScalarSourceGroupedSum) { BatchesWithSchema scalar_data; scalar_data.batches = { - ExecBatchFromJSON({int32(), ValueDescr::Scalar(boolean())}, + ExecBatchFromJSON({int32(), boolean()}, {ArgShape::ARRAY, ArgShape::SCALAR}, "[[5, false], [6, false], [7, false]]"), - ExecBatchFromJSON({int32(), ValueDescr::Scalar(boolean())}, + ExecBatchFromJSON({int32(), boolean()}, {ArgShape::ARRAY, ArgShape::SCALAR}, "[[1, true], [2, true], [3, true]]"), }; scalar_data.schema = schema({field("a", int32()), field("b", boolean())}); diff --git a/cpp/src/arrow/compute/exec/project_node.cc b/cpp/src/arrow/compute/exec/project_node.cc index cad8d7c45ae..de01899b485 100644 --- a/cpp/src/arrow/compute/exec/project_node.cc +++ b/cpp/src/arrow/compute/exec/project_node.cc @@ -67,7 +67,7 @@ class ProjectNode : public MapNode { ARROW_ASSIGN_OR_RAISE( expr, expr.Bind(*inputs[0]->output_schema(), plan->exec_context())); } - fields[i] = field(std::move(names[i]), expr.type()); + fields[i] = field(std::move(names[i]), expr.type()->GetSharedPtr()); ++i; } return plan->EmplaceNode(plan, std::move(inputs), @@ -82,7 +82,7 @@ class ProjectNode : public MapNode { for (size_t i = 0; i < exprs_.size(); ++i) { util::tracing::Span span; START_COMPUTE_SPAN(span, "Project", - {{"project.descr", exprs_[i].descr().ToString()}, + {{"project.type", exprs_[i].type().ToString()}, {"project.length", target.length}, {"project.expression", exprs_[i].ToString()}}); ARROW_ASSIGN_OR_RAISE(Expression simplified_expr, diff --git a/cpp/src/arrow/compute/exec/test_util.cc b/cpp/src/arrow/compute/exec/test_util.cc index 1e09cb742fa..330ee471126 100644 --- a/cpp/src/arrow/compute/exec/test_util.cc +++ b/cpp/src/arrow/compute/exec/test_util.cc @@ -143,16 +143,25 @@ ExecNode* MakeDummyNode(ExecPlan* plan, std::string label, std::vector& descrs, +ExecBatch ExecBatchFromJSON(const std::vector& types, util::string_view json) { auto fields = ::arrow::internal::MapVector( - [](const ValueDescr& descr) { return field("", descr.type); }, descrs); + [](const TypeHolder& th) { return field("", th.GetSharedPtr()); }, types); ExecBatch batch{*RecordBatchFromJSON(schema(std::move(fields)), json)}; + return batch; +} + +ExecBatch ExecBatchFromJSON(const std::vector& types, + const std::vector& shapes, util::string_view json) { + DCHECK_EQ(types.size(), shapes.size()); + + ExecBatch batch = ExecBatchFromJSON(types, json); + auto value_it = batch.values.begin(); - for (const auto& descr : descrs) { - if (descr.shape == ValueDescr::SCALAR) { + for (ArgShape shape : shapes) { + if (shape == ArgShape::SCALAR) { if (batch.length == 0) { *value_it = MakeNullScalar(value_it->type()); } else { @@ -232,13 +241,13 @@ BatchesWithSchema MakeBatchesFromString( const std::vector& json_strings, int multiplicity) { BatchesWithSchema out_batches{{}, schema}; - std::vector descrs; + std::vector types; for (auto&& field : schema->fields()) { - descrs.emplace_back(field->type()); + types.emplace_back(field->type()); } for (auto&& s : json_strings) { - out_batches.batches.push_back(ExecBatchFromJSON(descrs, s)); + out_batches.batches.push_back(ExecBatchFromJSON(types, s)); } size_t batch_count = out_batches.batches.size(); diff --git a/cpp/src/arrow/compute/exec/test_util.h b/cpp/src/arrow/compute/exec/test_util.h index ba7e4bb3411..64f725deafd 100644 --- a/cpp/src/arrow/compute/exec/test_util.h +++ b/cpp/src/arrow/compute/exec/test_util.h @@ -27,6 +27,7 @@ #include "arrow/compute/exec.h" #include "arrow/compute/exec/exec_plan.h" +#include "arrow/compute/kernel.h" #include "arrow/testing/visibility.h" #include "arrow/util/async_generator.h" #include "arrow/util/pcg_random.h" @@ -44,8 +45,11 @@ ExecNode* MakeDummyNode(ExecPlan* plan, std::string label, std::vector& descrs, - util::string_view json); +ExecBatch ExecBatchFromJSON(const std::vector& types, util::string_view json); + +ARROW_TESTING_EXPORT +ExecBatch ExecBatchFromJSON(const std::vector& types, + const std::vector& shapes, util::string_view json); struct BatchesWithSchema { std::vector batches; diff --git a/cpp/src/arrow/compute/exec_internal.h b/cpp/src/arrow/compute/exec_internal.h index c475a61c1ba..afca289c20e 100644 --- a/cpp/src/arrow/compute/exec_internal.h +++ b/cpp/src/arrow/compute/exec_internal.h @@ -84,8 +84,7 @@ class ARROW_EXPORT ExecSpanIterator { /// \param[in] batch the input ExecBatch /// \param[in] max_chunksize the maximum length of each ExecSpan. Depending /// on the chunk layout of ChunkedArray. - Status Init(const ExecBatch& batch, ValueDescr::Shape output_shape = ValueDescr::ARRAY, - int64_t max_chunksize = kDefaultMaxChunksize); + Status Init(const ExecBatch& batch, int64_t max_chunksize = kDefaultMaxChunksize); /// \brief Compute the next span by updating the state of the /// previous span object. You must keep passing in the previous @@ -101,6 +100,8 @@ class ARROW_EXPORT ExecSpanIterator { int64_t length() const { return length_; } int64_t position() const { return position_; } + bool have_all_scalars() const { return have_all_scalars_; } + private: ExecSpanIterator(const std::vector& args, int64_t length, int64_t max_chunksize); @@ -108,6 +109,7 @@ class ARROW_EXPORT ExecSpanIterator { bool initialized_ = false; bool have_chunked_arrays_ = false; + bool have_all_scalars_ = false; const std::vector* args_; std::vector chunk_indexes_; std::vector value_positions_; @@ -117,8 +119,8 @@ class ARROW_EXPORT ExecSpanIterator { // from the relative position within each chunk (which is in // value_positions_) std::vector value_offsets_; - int64_t position_; - int64_t length_; + int64_t position_ = 0; + int64_t length_ = 0; int64_t max_chunksize_; }; @@ -147,11 +149,6 @@ class DatumAccumulator : public ExecListener { std::vector values_; }; -/// \brief Check that each Datum is of a "value" type, which means either -/// SCALAR, ARRAY, or CHUNKED_ARRAY. If there are chunked inputs, then these -/// inputs will be split into non-chunked ExecBatch values for execution -Status CheckAllValues(const std::vector& values); - class ARROW_EXPORT KernelExecutor { public: virtual ~KernelExecutor() = default; diff --git a/cpp/src/arrow/compute/exec_test.cc b/cpp/src/arrow/compute/exec_test.cc index bd344fb2297..573f4aee4a0 100644 --- a/cpp/src/arrow/compute/exec_test.cc +++ b/cpp/src/arrow/compute/exec_test.cc @@ -728,10 +728,10 @@ TEST_F(TestExecBatchIterator, Basics) { ASSERT_EQ(3, batch.num_values()); ASSERT_EQ(length, batch.length); - std::vector descrs = batch.GetDescriptors(); - ASSERT_EQ(ValueDescr::Array(int32()), descrs[0]); - ASSERT_EQ(ValueDescr::Array(float64()), descrs[1]); - ASSERT_EQ(ValueDescr::Scalar(int32()), descrs[2]); + std::vector types = batch.GetTypes(); + ASSERT_EQ(types[0], int32()); + ASSERT_EQ(types[1], float64()); + ASSERT_EQ(types[2], int32()); AssertArraysEqual(*args[0].make_array(), *batch[0].make_array()); AssertArraysEqual(*args[1].make_array(), *batch[1].make_array()); @@ -795,13 +795,12 @@ TEST_F(TestExecBatchIterator, ZeroLengthInputs) { class TestExecSpanIterator : public TestComputeInternals { public: void SetupIterator(const ExecBatch& batch, - ValueDescr::Shape output_shape = ValueDescr::ARRAY, int64_t max_chunksize = kDefaultMaxChunksize) { - ASSERT_OK(iterator_.Init(batch, output_shape, max_chunksize)); + ASSERT_OK(iterator_.Init(batch, max_chunksize)); } void CheckIteration(const ExecBatch& input, int chunksize, const std::vector& ex_batch_sizes) { - SetupIterator(input, ValueDescr::ARRAY, chunksize); + SetupIterator(input, chunksize); ExecSpan batch; int64_t position = 0; for (size_t i = 0; i < ex_batch_sizes.size(); ++i) { @@ -902,8 +901,10 @@ TEST_F(TestExecSpanIterator, ZeroLengthInputs) { auto CheckArgs = [&](const ExecBatch& batch) { ExecSpanIterator iterator; - ASSERT_OK(iterator.Init(batch, ValueDescr::ARRAY)); + ASSERT_OK(iterator.Init(batch)); ExecSpan iter_span; + ASSERT_TRUE(iterator.Next(&iter_span)); + ASSERT_EQ(0, iter_span.length); ASSERT_FALSE(iterator.Next(&iter_span)); }; @@ -1045,11 +1046,13 @@ Status ExecStateful(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) return Status::OK(); } -// TODO: remove this / refactor it in ARROW-16577 Status ExecAddInt32(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { - const Int32Scalar& arg0 = batch[0].scalar_as(); - const Int32Scalar& arg1 = batch[1].scalar_as(); - out->value = std::make_shared(arg0.value + arg1.value); + const int32_t* left_data = batch[0].array.GetValues(1); + const int32_t* right_data = batch[1].array.GetValues(1); + int32_t* out_data = out->array_span()->GetValues(1); + for (int64_t i = 0; i < batch.length; ++i) { + *out_data++ = *left_data++ + *right_data++; + } return Status::OK(); } @@ -1078,16 +1081,15 @@ class TestCallScalarFunction : public TestComputeInternals { /*doc=*/FunctionDoc::Empty()); // Add a few kernels. Our implementation only accepts arrays - ASSERT_OK(func->AddKernel({InputType::Array(uint8())}, uint8(), ExecCopyArraySpan)); - ASSERT_OK(func->AddKernel({InputType::Array(int32())}, int32(), ExecCopyArraySpan)); - ASSERT_OK( - func->AddKernel({InputType::Array(float64())}, float64(), ExecCopyArraySpan)); + ASSERT_OK(func->AddKernel({uint8()}, uint8(), ExecCopyArraySpan)); + ASSERT_OK(func->AddKernel({int32()}, int32(), ExecCopyArraySpan)); + ASSERT_OK(func->AddKernel({float64()}, float64(), ExecCopyArraySpan)); ASSERT_OK(registry->AddFunction(func)); // A version which doesn't want the executor to call PropagateNulls auto func2 = std::make_shared( "test_copy_computed_bitmap", Arity::Unary(), /*doc=*/FunctionDoc::Empty()); - ScalarKernel kernel({InputType::Array(uint8())}, uint8(), ExecComputedBitmap); + ScalarKernel kernel({uint8()}, uint8(), ExecComputedBitmap); kernel.null_handling = NullHandling::COMPUTED_PREALLOCATE; ASSERT_OK(func2->AddKernel(kernel)); ASSERT_OK(registry->AddFunction(func2)); @@ -1103,7 +1105,7 @@ class TestCallScalarFunction : public TestComputeInternals { auto f2 = std::make_shared( "test_nopre_validity_or_data", Arity::Unary(), /*doc=*/FunctionDoc::Empty()); - ScalarKernel kernel({InputType::Array(uint8())}, uint8(), ExecNoPreallocatedData); + ScalarKernel kernel({uint8()}, uint8(), ExecNoPreallocatedData); kernel.mem_allocation = MemAllocation::NO_PREALLOCATE; ASSERT_OK(f1->AddKernel(kernel)); @@ -1123,7 +1125,7 @@ class TestCallScalarFunction : public TestComputeInternals { auto func = std::make_shared("test_stateful", Arity::Unary(), /*doc=*/FunctionDoc::Empty()); - ScalarKernel kernel({InputType::Array(int32())}, int32(), ExecStateful, InitStateful); + ScalarKernel kernel({int32()}, int32(), ExecStateful, InitStateful); ASSERT_OK(func->AddKernel(kernel)); ASSERT_OK(registry->AddFunction(func)); } @@ -1133,8 +1135,7 @@ class TestCallScalarFunction : public TestComputeInternals { auto func = std::make_shared("test_scalar_add_int32", Arity::Binary(), /*doc=*/FunctionDoc::Empty()); - ASSERT_OK(func->AddKernel({InputType::Scalar(int32()), InputType::Scalar(int32())}, - int32(), ExecAddInt32)); + ASSERT_OK(func->AddKernel({int32(), int32()}, int32(), ExecAddInt32)); ASSERT_OK(registry->AddFunction(func)); } }; @@ -1154,8 +1155,9 @@ TEST_F(TestCallScalarFunction, ArgumentValidation) { ASSERT_RAISES(Invalid, CallFunction("test_copy", args)); // Cannot do scalar - args = {Datum(std::make_shared(5))}; - ASSERT_RAISES(NotImplemented, CallFunction("test_copy", args)); + Datum d1_scalar(std::make_shared(5)); + ASSERT_OK_AND_ASSIGN(auto result, CallFunction("test_copy", {d1})); + ASSERT_OK_AND_ASSIGN(result, CallFunction("test_copy", {d1_scalar})); } TEST_F(TestCallScalarFunction, PreallocationCases) { diff --git a/cpp/src/arrow/compute/function.cc b/cpp/src/arrow/compute/function.cc index b5ebc67d180..dd67de023e8 100644 --- a/cpp/src/arrow/compute/function.cc +++ b/cpp/src/arrow/compute/function.cc @@ -79,51 +79,35 @@ static const FunctionDoc kEmptyFunctionDoc{}; const FunctionDoc& FunctionDoc::Empty() { return kEmptyFunctionDoc; } -static Status CheckArityImpl(const Function& function, int passed_num_args, - const char* passed_num_args_label) { - if (function.arity().is_varargs && passed_num_args < function.arity().num_args) { - return Status::Invalid("VarArgs function '", function.name(), "' needs at least ", - function.arity().num_args, " arguments but ", - passed_num_args_label, " only ", passed_num_args); +static Status CheckArityImpl(const Function& func, int num_args) { + if (func.arity().is_varargs && num_args < func.arity().num_args) { + return Status::Invalid("VarArgs function '", func.name(), "' needs at least ", + func.arity().num_args, " arguments but only ", num_args, + " passed"); } - if (!function.arity().is_varargs && passed_num_args != function.arity().num_args) { - return Status::Invalid("Function '", function.name(), "' accepts ", - function.arity().num_args, " arguments but ", - passed_num_args_label, " ", passed_num_args); + if (!func.arity().is_varargs && num_args != func.arity().num_args) { + return Status::Invalid("Function '", func.name(), "' accepts ", func.arity().num_args, + " arguments but ", num_args, " passed"); } - return Status::OK(); } -Status Function::CheckArity(const std::vector& in_types) const { - return CheckArityImpl(*this, static_cast(in_types.size()), "kernel accepts"); -} - -Status Function::CheckArity(const std::vector& descrs) const { - return CheckArityImpl(*this, static_cast(descrs.size()), - "attempted to look up kernel(s) with"); -} - -static Status CheckOptions(const Function& function, const FunctionOptions* options) { - if (options == nullptr && function.doc().options_required) { - return Status::Invalid("Function '", function.name(), - "' cannot be called without options"); - } - return Status::OK(); +Status Function::CheckArity(size_t num_args) const { + return CheckArityImpl(*this, static_cast(num_args)); } namespace detail { -Status NoMatchingKernel(const Function* func, const std::vector& descrs) { +Status NoMatchingKernel(const Function* func, const std::vector& types) { return Status::NotImplemented("Function '", func->name(), "' has no kernel matching input types ", - ValueDescr::ToString(descrs)); + TypeHolder::ToString(types)); } template const KernelType* DispatchExactImpl(const std::vector& kernels, - const std::vector& values) { + const std::vector& values) { const KernelType* kernel_matches[SimdLevel::MAX] = {nullptr}; // Validate arity @@ -159,7 +143,7 @@ const KernelType* DispatchExactImpl(const std::vector& kernels, } const Kernel* DispatchExactImpl(const Function* func, - const std::vector& values) { + const std::vector& values) { if (func->kind() == Function::SCALAR) { return DispatchExactImpl(checked_cast(func)->kernels(), values); @@ -186,11 +170,11 @@ const Kernel* DispatchExactImpl(const Function* func, } // namespace detail Result Function::DispatchExact( - const std::vector& values) const { + const std::vector& values) const { if (kind_ == Function::META) { return Status::NotImplemented("Dispatch for a MetaFunction's Kernels"); } - RETURN_NOT_OK(CheckArity(values)); + RETURN_NOT_OK(CheckArity(values.size())); if (auto kernel = detail::DispatchExactImpl(this, values)) { return kernel; @@ -198,32 +182,44 @@ Result Function::DispatchExact( return detail::NoMatchingKernel(this, values); } -Result Function::DispatchBest(std::vector* values) const { +Result Function::DispatchBest(std::vector* values) const { // TODO(ARROW-11508) permit generic conversions here return DispatchExact(*values); } -Result Function::Execute(const std::vector& args, - const FunctionOptions* options, ExecContext* ctx) const { - return ExecuteInternal(args, /*passed_length=*/-1, options, ctx); +namespace { + +/// \brief Check that each Datum is of a "value" type, which means either +/// SCALAR, ARRAY, or CHUNKED_ARRAY. +Status CheckAllValues(const std::vector& values) { + for (const auto& value : values) { + if (!value.is_value()) { + return Status::Invalid("Tried executing function with non-value type: ", + value.ToString()); + } + } + return Status::OK(); } -Result Function::Execute(const ExecBatch& batch, const FunctionOptions* options, - ExecContext* ctx) const { - return ExecuteInternal(batch.values, batch.length, options, ctx); +Status CheckOptions(const Function& function, const FunctionOptions* options) { + if (options == nullptr && function.doc().options_required) { + return Status::Invalid("Function '", function.name(), + "' cannot be called without options"); + } + return Status::OK(); } -Result Function::ExecuteInternal(const std::vector& args, - int64_t passed_length, - const FunctionOptions* options, - ExecContext* ctx) const { +Result ExecuteInternal(const Function& func, std::vector args, + int64_t passed_length, const FunctionOptions* options, + ExecContext* ctx) { + std::unique_ptr default_ctx; if (options == nullptr) { - RETURN_NOT_OK(CheckOptions(*this, options)); - options = default_options(); + RETURN_NOT_OK(CheckOptions(func, options)); + options = func.default_options(); } if (ctx == nullptr) { - ExecContext default_ctx; - return ExecuteInternal(args, passed_length, options, &default_ctx); + default_ctx.reset(new ExecContext()); + ctx = default_ctx.get(); } util::tracing::Span span; @@ -235,38 +231,45 @@ Result Function::ExecuteInternal(const std::vector& args, // type-check Datum arguments here. Really we'd like to avoid this as much as // possible - RETURN_NOT_OK(detail::CheckAllValues(args)); - std::vector inputs(args.size()); + RETURN_NOT_OK(CheckAllValues(args)); + std::vector in_types(args.size()); for (size_t i = 0; i != args.size(); ++i) { - inputs[i] = args[i].descr(); + in_types[i] = args[i].type().get(); } std::unique_ptr executor; - if (kind() == Function::SCALAR) { + if (func.kind() == Function::SCALAR) { executor = detail::KernelExecutor::MakeScalar(); - } else if (kind() == Function::VECTOR) { + } else if (func.kind() == Function::VECTOR) { executor = detail::KernelExecutor::MakeVector(); - } else if (kind() == Function::SCALAR_AGGREGATE) { + } else if (func.kind() == Function::SCALAR_AGGREGATE) { executor = detail::KernelExecutor::MakeScalarAggregate(); } else { return Status::NotImplemented("Direct execution of HASH_AGGREGATE functions"); } - ARROW_ASSIGN_OR_RAISE(const Kernel* kernel, DispatchBest(&inputs)); - ARROW_ASSIGN_OR_RAISE(std::vector args_with_casts, Cast(args, inputs, ctx)); + ARROW_ASSIGN_OR_RAISE(const Kernel* kernel, func.DispatchBest(&in_types)); + + // Cast arguments if necessary + for (size_t i = 0; i != args.size(); ++i) { + if (in_types[i] != args[i].type()) { + ARROW_ASSIGN_OR_RAISE(args[i], Cast(args[i], CastOptions::Safe(in_types[i]), ctx)); + } + } - std::unique_ptr state; KernelContext kernel_ctx{ctx, kernel}; + + std::unique_ptr state; if (kernel->init) { - ARROW_ASSIGN_OR_RAISE(state, kernel->init(&kernel_ctx, {kernel, inputs, options})); + ARROW_ASSIGN_OR_RAISE(state, kernel->init(&kernel_ctx, {kernel, in_types, options})); kernel_ctx.SetState(state.get()); } - RETURN_NOT_OK(executor->Init(&kernel_ctx, {kernel, inputs, options})); + RETURN_NOT_OK(executor->Init(&kernel_ctx, {kernel, in_types, options})); detail::DatumAccumulator listener; - ExecBatch input(std::move(args_with_casts), /*length=*/0); + ExecBatch input(std::move(args), /*length=*/0); if (input.num_values() == 0) { if (passed_length != -1) { input.length = passed_length; @@ -275,9 +278,9 @@ Result Function::ExecuteInternal(const std::vector& args, bool all_same_length = false; int64_t inferred_length = detail::InferBatchLength(input.values, &all_same_length); input.length = inferred_length; - if (kind() == Function::SCALAR) { + if (func.kind() == Function::SCALAR) { DCHECK(passed_length == -1 || passed_length == inferred_length); - } else if (kind() == Function::VECTOR) { + } else if (func.kind() == Function::VECTOR) { auto vkernel = static_cast(kernel); if (!(all_same_length || !vkernel->can_execute_chunkwise)) { return Status::Invalid("Vector kernel arguments must all be the same length"); @@ -287,12 +290,25 @@ Result Function::ExecuteInternal(const std::vector& args, RETURN_NOT_OK(executor->Execute(input, &listener)); const auto out = executor->WrapResults(input.values, listener.values()); #ifndef NDEBUG - DCHECK_OK(executor->CheckResultType(out, name_.c_str())); + DCHECK_OK(executor->CheckResultType(out, func.name().c_str())); #endif return out; } +} // namespace + +Result Function::Execute(const std::vector& args, + const FunctionOptions* options, ExecContext* ctx) const { + return ExecuteInternal(*this, args, /*passed_length=*/-1, options, ctx); +} + +Result Function::Execute(const ExecBatch& batch, const FunctionOptions* options, + ExecContext* ctx) const { + return ExecuteInternal(*this, batch.values, batch.length, options, ctx); +} + namespace { + Status ValidateFunctionSummary(const std::string& s) { if (s.find('\n') != s.npos) { return Status::Invalid("summary contains a newline"); @@ -347,7 +363,7 @@ Status Function::Validate() const { Status ScalarFunction::AddKernel(std::vector in_types, OutputType out_type, ArrayKernelExec exec, KernelInit init) { - RETURN_NOT_OK(CheckArity(in_types)); + RETURN_NOT_OK(CheckArity(in_types.size())); if (arity_.is_varargs && in_types.size() != 1) { return Status::Invalid("VarArgs signatures must have exactly one input type"); @@ -359,7 +375,7 @@ Status ScalarFunction::AddKernel(std::vector in_types, OutputType out } Status ScalarFunction::AddKernel(ScalarKernel kernel) { - RETURN_NOT_OK(CheckArity(kernel.signature->in_types())); + RETURN_NOT_OK(CheckArity(kernel.signature->in_types().size())); if (arity_.is_varargs && !kernel.signature->is_varargs()) { return Status::Invalid("Function accepts varargs but kernel signature does not"); } @@ -369,7 +385,7 @@ Status ScalarFunction::AddKernel(ScalarKernel kernel) { Status VectorFunction::AddKernel(std::vector in_types, OutputType out_type, ArrayKernelExec exec, KernelInit init) { - RETURN_NOT_OK(CheckArity(in_types)); + RETURN_NOT_OK(CheckArity(in_types.size())); if (arity_.is_varargs && in_types.size() != 1) { return Status::Invalid("VarArgs signatures must have exactly one input type"); @@ -381,7 +397,7 @@ Status VectorFunction::AddKernel(std::vector in_types, OutputType out } Status VectorFunction::AddKernel(VectorKernel kernel) { - RETURN_NOT_OK(CheckArity(kernel.signature->in_types())); + RETURN_NOT_OK(CheckArity(kernel.signature->in_types().size())); if (arity_.is_varargs && !kernel.signature->is_varargs()) { return Status::Invalid("Function accepts varargs but kernel signature does not"); } @@ -390,7 +406,7 @@ Status VectorFunction::AddKernel(VectorKernel kernel) { } Status ScalarAggregateFunction::AddKernel(ScalarAggregateKernel kernel) { - RETURN_NOT_OK(CheckArity(kernel.signature->in_types())); + RETURN_NOT_OK(CheckArity(kernel.signature->in_types().size())); if (arity_.is_varargs && !kernel.signature->is_varargs()) { return Status::Invalid("Function accepts varargs but kernel signature does not"); } @@ -399,7 +415,7 @@ Status ScalarAggregateFunction::AddKernel(ScalarAggregateKernel kernel) { } Status HashAggregateFunction::AddKernel(HashAggregateKernel kernel) { - RETURN_NOT_OK(CheckArity(kernel.signature->in_types())); + RETURN_NOT_OK(CheckArity(kernel.signature->in_types().size())); if (arity_.is_varargs && !kernel.signature->is_varargs()) { return Status::Invalid("Function accepts varargs but kernel signature does not"); } @@ -410,8 +426,7 @@ Status HashAggregateFunction::AddKernel(HashAggregateKernel kernel) { Result MetaFunction::Execute(const std::vector& args, const FunctionOptions* options, ExecContext* ctx) const { - RETURN_NOT_OK( - CheckArityImpl(*this, static_cast(args.size()), "attempted to Execute with")); + RETURN_NOT_OK(CheckArityImpl(*this, static_cast(args.size()))); RETURN_NOT_OK(CheckOptions(*this, options)); if (options == nullptr) { diff --git a/cpp/src/arrow/compute/function.h b/cpp/src/arrow/compute/function.h index c32c8766a91..7f2fba68caf 100644 --- a/cpp/src/arrow/compute/function.h +++ b/cpp/src/arrow/compute/function.h @@ -211,19 +211,19 @@ class ARROW_EXPORT Function { virtual int num_kernels() const = 0; /// \brief Return a kernel that can execute the function given the exact - /// argument types (without implicit type casts or scalar->array promotions). + /// argument types (without implicit type casts). /// /// NB: This function is overridden in CastFunction. - virtual Result DispatchExact( - const std::vector& values) const; + virtual Result DispatchExact(const std::vector& types) const; /// \brief Return a best-match kernel that can execute the function given the argument /// types, after implicit casts are applied. /// - /// \param[in,out] values Argument types. An element may be modified to indicate that - /// the returned kernel only approximately matches the input value descriptors; callers - /// are responsible for casting inputs to the type and shape required by the kernel. - virtual Result DispatchBest(std::vector* values) const; + /// \param[in,out] values Argument types. An element may be modified to + /// indicate that the returned kernel only approximately matches the input + /// value descriptors; callers are responsible for casting inputs to the type + /// required by the kernel. + virtual Result DispatchBest(std::vector* values) const; /// \brief Execute the function eagerly with the passed input arguments with /// kernel dispatch, batch iteration, and memory allocation details taken @@ -255,11 +255,7 @@ class ARROW_EXPORT Function { doc_(std::move(doc)), default_options_(default_options) {} - Result ExecuteInternal(const std::vector& args, int64_t passed_length, - const FunctionOptions* options, ExecContext* ctx) const; - - Status CheckArity(const std::vector&) const; - Status CheckArity(const std::vector&) const; + Status CheckArity(size_t num_args) const; std::string name_; Function::Kind kind_; @@ -294,11 +290,11 @@ class FunctionImpl : public Function { /// \brief Look up a kernel in a function. If no Kernel is found, nullptr is returned. ARROW_EXPORT -const Kernel* DispatchExactImpl(const Function* func, const std::vector&); +const Kernel* DispatchExactImpl(const Function* func, const std::vector&); /// \brief Return an error message if no Kernel is found. ARROW_EXPORT -Status NoMatchingKernel(const Function* func, const std::vector&); +Status NoMatchingKernel(const Function* func, const std::vector&); } // namespace detail diff --git a/cpp/src/arrow/compute/function_benchmark.cc b/cpp/src/arrow/compute/function_benchmark.cc index b508ad047fb..bdd0bb6e986 100644 --- a/cpp/src/arrow/compute/function_benchmark.cc +++ b/cpp/src/arrow/compute/function_benchmark.cc @@ -19,6 +19,7 @@ #include "arrow/array/array_base.h" #include "arrow/compute/api.h" +#include "arrow/compute/cast_internal.h" #include "arrow/compute/exec_internal.h" #include "arrow/memory_pool.h" #include "arrow/scalar.h" @@ -67,14 +68,13 @@ void BM_CastDispatchBaseline(benchmark::State& state) { // Repeatedly invoke a trivial Cast with all dispatch outside the hot loop random::RandomArrayGenerator rag(kSeed); - auto int_scalars = ToScalars(rag.Int64(kScalarCount, 0, 1 << 20)); - + auto int_array = rag.Int64(1, 0, 1 << 20); auto double_type = float64(); CastOptions cast_options; cast_options.to_type = double_type; - ASSERT_OK_AND_ASSIGN(auto cast_function, GetCastFunction(double_type)); + ASSERT_OK_AND_ASSIGN(auto cast_function, internal::GetCastFunction(double_type)); ASSERT_OK_AND_ASSIGN(auto cast_kernel, - cast_function->DispatchExact({int_scalars[0]->type})); + cast_function->DispatchExact({int_array->type()})); const auto& exec = static_cast(cast_kernel)->exec; ExecContext exec_context; @@ -85,15 +85,13 @@ void BM_CastDispatchBaseline(benchmark::State& state) { .ValueOrDie(); kernel_context.SetState(cast_state.get()); - ExecSpan input; - input.length = 1; + ExecSpan input({ExecValue(*int_array->data())}, 1); + ExecResult result; + ASSERT_OK_AND_ASSIGN(std::shared_ptr result_space, + MakeArrayOfNull(double_type, 1)); + result.array_span()->SetMembers(*result_space->data()); for (auto _ : state) { - ExecResult result; - result.value = MakeNullScalar(double_type); - for (const std::shared_ptr& int_scalar : int_scalars) { - input.values = {ExecValue(int_scalar.get())}; - ABORT_NOT_OK(exec(&kernel_context, input, &result)); - } + ABORT_NOT_OK(exec(&kernel_context, input, &result)); } state.SetItemsProcessed(state.iterations() * kScalarCount); @@ -153,31 +151,26 @@ void BM_ExecuteScalarFunctionOnScalar(benchmark::State& state) { void BM_ExecuteScalarKernelOnScalar(benchmark::State& state) { // Execute a trivial function, with argument dispatch outside the hot path - const int64_t N = 10000; - auto function = *GetFunctionRegistry()->GetFunction("is_valid"); - auto kernel = *function->DispatchExact({ValueDescr::Scalar(int64())}); + auto kernel = *function->DispatchExact({int64()}); const auto& exec = static_cast(*kernel).exec; - const auto scalars = MakeScalarsForIsValid(N); - ExecContext exec_context; KernelContext kernel_context(&exec_context); - ExecSpan input; - input.length = 1; + ASSERT_OK_AND_ASSIGN(std::shared_ptr input_arr, MakeArrayOfNull(int64(), 1)); + ExecSpan input({*input_arr->data()}, 1); + + ExecResult output; + ASSERT_OK_AND_ASSIGN(std::shared_ptr output_arr, MakeArrayOfNull(int64(), 1)); + output.array_span()->SetMembers(*output_arr->data()); + + const int64_t N = 10000; for (auto _ : state) { - int64_t total = 0; - for (const std::shared_ptr& scalar : scalars) { - ExecResult result; - result.value = MakeNullScalar(int64()); - input.values = {scalar.get()}; - ABORT_NOT_OK(exec(&kernel_context, input, &result)); - total += result.scalar()->is_valid; + for (int i = 0; i < N; ++i) { + ABORT_NOT_OK(exec(&kernel_context, input, &output)); } - benchmark::DoNotOptimize(total); } - state.SetItemsProcessed(state.iterations() * N); } diff --git a/cpp/src/arrow/compute/function_internal.h b/cpp/src/arrow/compute/function_internal.h index f2303b87d90..17261332619 100644 --- a/cpp/src/arrow/compute/function_internal.h +++ b/cpp/src/arrow/compute/function_internal.h @@ -345,6 +345,10 @@ static inline Result> GenericToScalar( return MakeNullScalar(value); } +static inline Result> GenericToScalar(const TypeHolder& value) { + return GenericToScalar(value.GetSharedPtr()); +} + static inline Result> GenericToScalar( const std::shared_ptr& value) { return value; @@ -430,6 +434,12 @@ static inline enable_if_same_result> GenericFromSca return value->type; } +template +static inline enable_if_same_result GenericFromScalar( + const std::shared_ptr& value) { + return value->type; +} + template static inline enable_if_same_result> GenericFromScalar( const std::shared_ptr& value) { diff --git a/cpp/src/arrow/compute/function_test.cc b/cpp/src/arrow/compute/function_test.cc index f06f225f5b9..94daa6baa96 100644 --- a/cpp/src/arrow/compute/function_test.cc +++ b/cpp/src/arrow/compute/function_test.cc @@ -230,9 +230,9 @@ void CheckAddDispatch(FunctionType* func, ExecType exec) { // Duplicate sig is okay ASSERT_OK(func->AddKernel(in_types1, out_type1, exec)); - // Add given a descr - KernelType descr({float64(), float64()}, float64(), exec); - ASSERT_OK(func->AddKernel(descr)); + // Add a kernel + KernelType kernel({float64(), float64()}, float64(), exec); + ASSERT_OK(func->AddKernel(kernel)); ASSERT_EQ(4, func->num_kernels()); ASSERT_EQ(4, func->kernels().size()); @@ -249,9 +249,9 @@ void CheckAddDispatch(FunctionType* func, ExecType exec) { KernelType invalid_kernel({boolean()}, boolean(), exec); ASSERT_RAISES(Invalid, func->AddKernel(invalid_kernel)); - ASSERT_OK_AND_ASSIGN(const Kernel* kernel, func->DispatchExact({int32(), int32()})); + ASSERT_OK_AND_ASSIGN(const Kernel* dispatched, func->DispatchExact({int32(), int32()})); KernelSignature expected_sig(in_types1, out_type1); - ASSERT_TRUE(kernel->signature->Equals(expected_sig)); + ASSERT_TRUE(dispatched->signature->Equals(expected_sig)); // No kernel available ASSERT_RAISES(NotImplemented, func->DispatchExact({utf8(), utf8()})); @@ -288,7 +288,7 @@ TEST(ArrayFunction, VarArgs) { ScalarKernel non_va_kernel(std::make_shared(va_args, int8()), ExecNYI); ASSERT_RAISES(Invalid, va_func.AddKernel(non_va_kernel)); - std::vector args = {ValueDescr::Scalar(int8()), int8(), int8()}; + std::vector args = {int8(), int8(), int8()}; ASSERT_OK_AND_ASSIGN(const Kernel* kernel, va_func.DispatchExact(args)); ASSERT_TRUE(kernel->signature->MatchesInputs(args)); @@ -319,7 +319,7 @@ Status NoopFinalize(KernelContext*, Datum*) { return Status::OK(); } TEST(ScalarAggregateFunction, DispatchExact) { ScalarAggregateFunction func("agg_test", Arity::Unary(), FunctionDoc::Empty()); - std::vector in_args = {ValueDescr::Array(int8())}; + std::vector in_args = {int8()}; ScalarAggregateKernel kernel(std::move(in_args), int64(), NoopInit, NoopConsume, NoopMerge, NoopFinalize); ASSERT_OK(func.AddKernel(kernel)); @@ -341,18 +341,14 @@ TEST(ScalarAggregateFunction, DispatchExact) { kernel.signature = std::make_shared(in_args, float64()); ASSERT_RAISES(Invalid, func.AddKernel(kernel)); - std::vector dispatch_args = {ValueDescr::Array(int8())}; + std::vector dispatch_args = {int8()}; ASSERT_OK_AND_ASSIGN(const Kernel* selected_kernel, func.DispatchExact(dispatch_args)); ASSERT_EQ(func.kernels()[0], selected_kernel); ASSERT_TRUE(selected_kernel->signature->MatchesInputs(dispatch_args)); - // We declared that only arrays are accepted - dispatch_args[0] = {ValueDescr::Scalar(int8())}; - ASSERT_RAISES(NotImplemented, func.DispatchExact(dispatch_args)); - // Didn't qualify the float64() kernel so this actually dispatches (even // though that may not be what you want) - dispatch_args[0] = {ValueDescr::Scalar(float64())}; + dispatch_args[0] = {float64()}; ASSERT_OK_AND_ASSIGN(selected_kernel, func.DispatchExact(dispatch_args)); ASSERT_TRUE(selected_kernel->signature->MatchesInputs(dispatch_args)); } diff --git a/cpp/src/arrow/compute/kernel.cc b/cpp/src/arrow/compute/kernel.cc index 909c2399c8e..1e3303473ef 100644 --- a/cpp/src/arrow/compute/kernel.cc +++ b/cpp/src/arrow/compute/kernel.cc @@ -87,7 +87,9 @@ class SameTypeIdMatcher : public TypeMatcher { public: explicit SameTypeIdMatcher(Type::type accepted_id) : accepted_id_(accepted_id) {} - bool Matches(const DataType& type) const override { return type.id() == accepted_id_; } + bool Matches(const TypeHolder& type) const override { + return type.id() == accepted_id_; + } std::string ToString() const override { std::stringstream ss; @@ -122,11 +124,11 @@ class TimeUnitMatcher : public TypeMatcher { explicit TimeUnitMatcher(TimeUnit::type accepted_unit) : accepted_unit_(accepted_unit) {} - bool Matches(const DataType& type) const override { + bool Matches(const TypeHolder& type) const override { if (type.id() != ArrowType::type_id) { return false; } - const auto& time_type = checked_cast(type); + const auto& time_type = checked_cast(*type.type); return time_type.unit() == accepted_unit_; } @@ -177,7 +179,7 @@ class IntegerMatcher : public TypeMatcher { public: IntegerMatcher() {} - bool Matches(const DataType& type) const override { return is_integer(type.id()); } + bool Matches(const TypeHolder& type) const override { return is_integer(type.id()); } bool Equals(const TypeMatcher& other) const override { if (this == &other) { @@ -196,7 +198,7 @@ class PrimitiveMatcher : public TypeMatcher { public: PrimitiveMatcher() {} - bool Matches(const DataType& type) const override { return is_primitive(type.id()); } + bool Matches(const TypeHolder& type) const override { return is_primitive(type.id()); } bool Equals(const TypeMatcher& other) const override { if (this == &other) { @@ -215,7 +217,9 @@ class BinaryLikeMatcher : public TypeMatcher { public: BinaryLikeMatcher() {} - bool Matches(const DataType& type) const override { return is_binary_like(type.id()); } + bool Matches(const TypeHolder& type) const override { + return is_binary_like(type.id()); + } bool Equals(const TypeMatcher& other) const override { if (this == &other) { @@ -235,7 +239,7 @@ class LargeBinaryLikeMatcher : public TypeMatcher { public: LargeBinaryLikeMatcher() {} - bool Matches(const DataType& type) const override { + bool Matches(const TypeHolder& type) const override { return is_large_binary_like(type.id()); } @@ -253,7 +257,7 @@ class FixedSizeBinaryLikeMatcher : public TypeMatcher { public: FixedSizeBinaryLikeMatcher() {} - bool Matches(const DataType& type) const override { + bool Matches(const TypeHolder& type) const override { return is_fixed_size_binary(type.id()); } @@ -282,7 +286,6 @@ std::shared_ptr FixedSizeBinaryLike() { size_t InputType::Hash() const { size_t result = kHashSeed; - hash_combine(result, static_cast(shape_)); hash_combine(result, static_cast(kind_)); switch (kind_) { case InputType::EXACT_TYPE: @@ -296,21 +299,6 @@ size_t InputType::Hash() const { std::string InputType::ToString() const { std::stringstream ss; - switch (shape_) { - case ValueDescr::ANY: - ss << "any"; - break; - case ValueDescr::ARRAY: - ss << "array"; - break; - case ValueDescr::SCALAR: - ss << "scalar"; - break; - default: - DCHECK(false); - break; - } - ss << "["; switch (kind_) { case InputType::ANY_TYPE: ss << "any"; @@ -325,7 +313,6 @@ std::string InputType::ToString() const { DCHECK(false); break; } - ss << "]"; return ss.str(); } @@ -333,7 +320,7 @@ bool InputType::Equals(const InputType& other) const { if (this == &other) { return true; } - if (kind_ != other.kind_ || shape_ != other.shape_) { + if (kind_ != other.kind_) { return false; } switch (kind_) { @@ -348,22 +335,30 @@ bool InputType::Equals(const InputType& other) const { } } -bool InputType::Matches(const ValueDescr& descr) const { - if (shape_ != ValueDescr::ANY && descr.shape != shape_) { - return false; - } +bool InputType::Matches(const TypeHolder& type) const { switch (kind_) { case InputType::EXACT_TYPE: - return type_->Equals(*descr.type); + return type_->Equals(*type.type); case InputType::USE_TYPE_MATCHER: - return type_matcher_->Matches(*descr.type); + return type_matcher_->Matches(type); default: // ANY_TYPE return true; } } -bool InputType::Matches(const Datum& value) const { return Matches(value.descr()); } +bool InputType::Matches(const Datum& value) const { + switch (value.kind()) { + case Datum::ARRAY: + case Datum::CHUNKED_ARRAY: + case Datum::SCALAR: + break; + default: + DCHECK(false); + return false; + } + return Matches(value.type().get()); +} const std::shared_ptr& InputType::type() const { DCHECK_EQ(InputType::EXACT_TYPE, kind_); @@ -378,21 +373,12 @@ const TypeMatcher& InputType::type_matcher() const { // ---------------------------------------------------------------------- // OutputType -OutputType::OutputType(ValueDescr descr) : OutputType(descr.type) { - shape_ = descr.shape; -} - -Result OutputType::Resolve(KernelContext* ctx, - const std::vector& args) const { - ValueDescr::Shape broadcasted_shape = GetBroadcastShape(args); +Result OutputType::Resolve(KernelContext* ctx, + const std::vector& types) const { if (kind_ == OutputType::FIXED) { - return ValueDescr(type_, shape_ == ValueDescr::ANY ? broadcasted_shape : shape_); + return type_.get(); } else { - ARROW_ASSIGN_OR_RAISE(ValueDescr resolved_descr, resolver_(ctx, args)); - if (resolved_descr.shape == ValueDescr::ANY) { - resolved_descr.shape = broadcasted_shape; - } - return resolved_descr; + return resolver_(ctx, types); } } @@ -448,19 +434,19 @@ bool KernelSignature::Equals(const KernelSignature& other) const { return true; } -bool KernelSignature::MatchesInputs(const std::vector& args) const { +bool KernelSignature::MatchesInputs(const std::vector& types) const { if (is_varargs_) { - for (size_t i = 0; i < args.size(); ++i) { - if (!in_types_[std::min(i, in_types_.size() - 1)].Matches(args[i])) { + for (size_t i = 0; i < types.size(); ++i) { + if (!in_types_[std::min(i, in_types_.size() - 1)].Matches(types[i])) { return false; } } } else { - if (args.size() != in_types_.size()) { + if (types.size() != in_types_.size()) { return false; } for (size_t i = 0; i < in_types_.size(); ++i) { - if (!in_types_[i].Matches(args[i])) { + if (!in_types_[i].Matches(types[i])) { return false; } } @@ -495,7 +481,7 @@ std::string KernelSignature::ToString() const { ss << in_types_[i].ToString(); } if (is_varargs_) { - ss << "]"; + ss << "*]"; } else { ss << ")"; } diff --git a/cpp/src/arrow/compute/kernel.h b/cpp/src/arrow/compute/kernel.h index 93a1c605a99..1b412af525e 100644 --- a/cpp/src/arrow/compute/kernel.h +++ b/cpp/src/arrow/compute/kernel.h @@ -101,7 +101,7 @@ struct ARROW_EXPORT TypeMatcher { virtual ~TypeMatcher() = default; /// \brief Return true if this matcher accepts the data type. - virtual bool Matches(const DataType& type) const = 0; + virtual bool Matches(const TypeHolder& type) const = 0; /// \brief A human-interpretable string representation of what the type /// matcher checks for, usable when printing KernelSignature or formatting @@ -143,10 +143,14 @@ ARROW_EXPORT std::shared_ptr Primitive(); } // namespace match -/// \brief An object used for type- and shape-checking arguments to be passed -/// to a kernel and stored in a KernelSignature. Distinguishes between ARRAY -/// and SCALAR arguments using ValueDescr::Shape. The type-checking rule can be -/// supplied either with an exact DataType instance or a custom TypeMatcher. +/// \brief Shape qualifier for value types. In certain instances +/// (e.g. "map_lookup" kernel), an argument may only be a scalar, where in +/// other kernels arguments can be arrays or scalars +enum class ArgShape { ANY, ARRAY, SCALAR }; + +/// \brief An object used for type-checking arguments to be passed to a kernel +/// and stored in a KernelSignature. The type-checking rule can be supplied +/// either with an exact DataType instance or a custom TypeMatcher. class ARROW_EXPORT InputType { public: /// \brief The kind of type-checking rule that the InputType contains. @@ -163,29 +167,21 @@ class ARROW_EXPORT InputType { USE_TYPE_MATCHER }; - /// \brief Accept any value type but with a specific shape (e.g. any Array or - /// any Scalar). - InputType(ValueDescr::Shape shape = ValueDescr::ANY) // NOLINT implicit construction - : kind_(ANY_TYPE), shape_(shape) {} + /// \brief Accept any value type + InputType() : kind_(ANY_TYPE) {} /// \brief Accept an exact value type. - InputType(std::shared_ptr type, // NOLINT implicit construction - ValueDescr::Shape shape = ValueDescr::ANY) - : kind_(EXACT_TYPE), shape_(shape), type_(std::move(type)) {} - - /// \brief Accept an exact value type and shape provided by a ValueDescr. - InputType(const ValueDescr& descr) // NOLINT implicit construction - : InputType(descr.type, descr.shape) {} + InputType(std::shared_ptr type) // NOLINT implicit construction + : kind_(EXACT_TYPE), type_(std::move(type)) {} /// \brief Use the passed TypeMatcher to type check. - InputType(std::shared_ptr type_matcher, // NOLINT implicit construction - ValueDescr::Shape shape = ValueDescr::ANY) - : kind_(USE_TYPE_MATCHER), shape_(shape), type_matcher_(std::move(type_matcher)) {} + InputType(std::shared_ptr type_matcher) // NOLINT implicit construction + : kind_(USE_TYPE_MATCHER), type_matcher_(std::move(type_matcher)) {} /// \brief Match any type with the given Type::type. Uses a TypeMatcher for /// its implementation. - explicit InputType(Type::type type_id, ValueDescr::Shape shape = ValueDescr::ANY) - : InputType(match::SameTypeId(type_id), shape) {} + InputType(Type::type type_id) // NOLINT implicit construction + : InputType(match::SameTypeId(type_id)) {} InputType(const InputType& other) { CopyInto(other); } @@ -195,23 +191,8 @@ class ARROW_EXPORT InputType { void operator=(InputType&& other) { MoveInto(std::forward(other)); } - // \brief Match an array with the given exact type. Convenience constructor. - static InputType Array(std::shared_ptr type) { - return InputType(std::move(type), ValueDescr::ARRAY); - } - - // \brief Match a scalar with the given exact type. Convenience constructor. - static InputType Scalar(std::shared_ptr type) { - return InputType(std::move(type), ValueDescr::SCALAR); - } - - // \brief Match an array with the given Type::type id. Convenience - // constructor. - static InputType Array(Type::type id) { return InputType(id, ValueDescr::ARRAY); } - - // \brief Match a scalar with the given Type::type id. Convenience - // constructor. - static InputType Scalar(Type::type id) { return InputType(id, ValueDescr::SCALAR); } + // \brief Match any input (array, scalar of any type) + static InputType Any() { return InputType(); } /// \brief Return true if this input type matches the same type cases as the /// other. @@ -227,21 +208,16 @@ class ARROW_EXPORT InputType { /// \brief Render a human-readable string representation. std::string ToString() const; - /// \brief Return true if the value matches this argument kind in type - /// and shape. + /// \brief Return true if the Datum matches this argument kind in + /// type (and only allows scalar or array-like Datums). bool Matches(const Datum& value) const; - /// \brief Return true if the value descriptor matches this argument kind in - /// type and shape. - bool Matches(const ValueDescr& value) const; + /// \brief Return true if the type matches this InputType + bool Matches(const TypeHolder& type) const; /// \brief The type matching rule that this InputType uses. Kind kind() const { return kind_; } - /// \brief Indicates whether this InputType matches Array (ValueDescr::ARRAY), - /// Scalar (ValueDescr::SCALAR) values, or both (ValueDescr::ANY). - ValueDescr::Shape shape() const { return shape_; } - /// \brief For InputType::EXACT_TYPE kind, the exact type that this InputType /// must match. Otherwise this function should not be used and will assert in /// debug builds. @@ -255,22 +231,18 @@ class ARROW_EXPORT InputType { private: void CopyInto(const InputType& other) { this->kind_ = other.kind_; - this->shape_ = other.shape_; this->type_ = other.type_; this->type_matcher_ = other.type_matcher_; } void MoveInto(InputType&& other) { this->kind_ = other.kind_; - this->shape_ = other.shape_; this->type_ = std::move(other.type_); this->type_matcher_ = std::move(other.type_matcher_); } Kind kind_; - ValueDescr::Shape shape_ = ValueDescr::ANY; - // For EXACT_TYPE Kind std::shared_ptr type_; @@ -279,43 +251,30 @@ class ARROW_EXPORT InputType { }; /// \brief Container to capture both exact and input-dependent output types. -/// -/// The value shape returned by Resolve will be determined by broadcasting the -/// shapes of the input arguments, otherwise this is handled by the -/// user-defined resolver function: -/// -/// * Any ARRAY shape -> output shape is ARRAY -/// * All SCALAR shapes -> output shape is SCALAR class ARROW_EXPORT OutputType { public: /// \brief An enum indicating whether the value type is an invariant fixed /// value or one that's computed by a kernel-defined resolver function. enum ResolveKind { FIXED, COMPUTED }; - /// Type resolution function. Given input types and shapes, return output - /// type and shape. This function MAY may use the kernel state to decide - /// the output type based on the functionoptions. + /// Type resolution function. Given input types, return output type. This + /// function MAY may use the kernel state to decide the output type based on + /// the FunctionOptions. /// /// This function SHOULD _not_ be used to check for arity, that is to be /// performed one or more layers above. - using Resolver = - std::function(KernelContext*, const std::vector&)>; + typedef Result (*Resolver)(KernelContext*, const std::vector&); - /// \brief Output an exact type, but with shape determined by promoting the - /// shapes of the inputs (any ARRAY argument yields ARRAY). + /// \brief Output an exact type OutputType(std::shared_ptr type) // NOLINT implicit construction : kind_(FIXED), type_(std::move(type)) {} - /// \brief Output the exact type and shape provided by a ValueDescr - OutputType(ValueDescr descr); // NOLINT implicit construction - /// \brief Output a computed type depending on actual input types OutputType(Resolver resolver) // NOLINT implicit construction : kind_(COMPUTED), resolver_(std::move(resolver)) {} OutputType(const OutputType& other) { this->kind_ = other.kind_; - this->shape_ = other.shape_; this->type_ = other.type_; this->resolver_ = other.resolver_; } @@ -323,19 +282,17 @@ class ARROW_EXPORT OutputType { OutputType(OutputType&& other) { this->kind_ = other.kind_; this->type_ = std::move(other.type_); - this->shape_ = other.shape_; this->resolver_ = other.resolver_; } OutputType& operator=(const OutputType&) = default; OutputType& operator=(OutputType&&) = default; - /// \brief Return the shape and type of the expected output value of the - /// kernel given the value descriptors (shapes and types) of the input - /// arguments. The resolver may make use of state information kept in the - /// KernelContext. - Result Resolve(KernelContext* ctx, - const std::vector& args) const; + /// \brief Return the type of the expected output value of the kernel given + /// the input argument types. The resolver may make use of state information + /// kept in the KernelContext. + Result Resolve(KernelContext* ctx, + const std::vector& args) const; /// \brief The exact output value type for the FIXED kind. const std::shared_ptr& type() const; @@ -352,20 +309,12 @@ class ARROW_EXPORT OutputType { /// fixed/invariant or computed by a resolver. ResolveKind kind() const { return kind_; } - /// \brief If the shape is ANY, then Resolve will compute the shape based on - /// the input arguments. - ValueDescr::Shape shape() const { return shape_; } - private: ResolveKind kind_; // For FIXED resolution std::shared_ptr type_; - /// \brief The shape of the output type to return when using Resolve. If ANY - /// will promote the input shapes. - ValueDescr::Shape shape_ = ValueDescr::ANY; - // For COMPUTED resolution Resolver resolver_; }; @@ -388,7 +337,7 @@ class ARROW_EXPORT KernelSignature { /// \brief Return true if the signature if compatible with the list of input /// value descriptors. - bool MatchesInputs(const std::vector& descriptors) const; + bool MatchesInputs(const std::vector& types) const; /// \brief Returns true if the input types of each signature are /// equal. Well-formed functions should have a deterministic output type @@ -408,9 +357,10 @@ class ARROW_EXPORT KernelSignature { /// function arguments. const std::vector& in_types() const { return in_types_; } - /// \brief The output type for the kernel. Use Resolve to return the exact - /// output given input argument ValueDescrs, since many kernels' output types - /// depend on their input types (or their type metadata). + /// \brief The output type for the kernel. Use Resolve to return the + /// exact output given input argument types, since many kernels' + /// output types depend on their input types (or their type + /// metadata). const OutputType& out_type() const { return out_type_; } /// \brief Render a human-readable string representation @@ -493,12 +443,9 @@ struct KernelInitArgs { /// depend on the kernel's KernelSignature or other data contained there. const Kernel* kernel; - /// \brief The types and shapes of the input arguments that the kernel is + /// \brief The types of the input arguments that the kernel is /// about to be executed against. - /// - /// TODO: should this be const std::vector*? const-ref is being - /// used to avoid the cost of copying the struct into the args struct. - const std::vector& inputs; + const std::vector& inputs; /// \brief Opaque options specific to this kernel. May be nullptr for functions /// that do not require options. @@ -523,7 +470,7 @@ struct Kernel { std::move(init)) {} /// \brief The "signature" of the kernel containing the InputType input - /// argument validators and OutputType output type and shape resolver. + /// argument validators and OutputType output type resolver. std::shared_ptr signature; /// \brief Create a new KernelState for invocations of this kernel, e.g. to @@ -546,6 +493,9 @@ struct Kernel { /// contain multiple kernels with the same signature but different levels of SIMD, /// so that the most optimized kernel supported on a host's processor can be chosen. SimdLevel::type simd_level = SimdLevel::NONE; + + // Additional kernel-specific data + std::shared_ptr data; }; /// \brief The scalar kernel execution API that must be implemented for SCALAR @@ -555,8 +505,7 @@ struct Kernel { /// endeavor to write into pre-allocated memory if they are able, though for /// some kernels (e.g. in cases when a builder like StringBuilder) must be /// employed this may not be possible. -using ArrayKernelExec = - std::function; +typedef Status (*ArrayKernelExec)(KernelContext*, const ExecSpan&, ExecResult*); /// \brief Kernel data structure for implementations of ScalarFunction. In /// addition to the members found in Kernel, contains the null handling @@ -566,12 +515,11 @@ struct ScalarKernel : public Kernel { ScalarKernel(std::shared_ptr sig, ArrayKernelExec exec, KernelInit init = NULLPTR) - : Kernel(std::move(sig), init), exec(std::move(exec)) {} + : Kernel(std::move(sig), init), exec(exec) {} ScalarKernel(std::vector in_types, OutputType out_type, ArrayKernelExec exec, KernelInit init = NULLPTR) - : Kernel(std::move(in_types), std::move(out_type), std::move(init)), - exec(std::move(exec)) {} + : Kernel(std::move(in_types), std::move(out_type), std::move(init)), exec(exec) {} /// \brief Perform a single invocation of this kernel. Depending on the /// implementation, it may only write into preallocated memory, while in some @@ -590,9 +538,6 @@ struct ScalarKernel : public Kernel { // bitmaps is a reasonable default NullHandling::type null_handling = NullHandling::INTERSECTION; MemAllocation::type mem_allocation = MemAllocation::PREALLOCATE; - - // Additional kernel-specific data - std::shared_ptr data; }; // ---------------------------------------------------------------------- @@ -615,13 +560,13 @@ struct VectorKernel : public Kernel { VectorKernel(std::vector in_types, OutputType out_type, ArrayKernelExec exec, KernelInit init = NULLPTR, FinalizeFunc finalize = NULLPTR) : Kernel(std::move(in_types), std::move(out_type), std::move(init)), - exec(std::move(exec)), + exec(exec), finalize(std::move(finalize)) {} VectorKernel(std::shared_ptr sig, ArrayKernelExec exec, KernelInit init = NULLPTR, FinalizeFunc finalize = NULLPTR) : Kernel(std::move(sig), std::move(init)), - exec(std::move(exec)), + exec(exec), finalize(std::move(finalize)) {} /// \brief Perform a single invocation of this kernel. Any required state is diff --git a/cpp/src/arrow/compute/kernel_test.cc b/cpp/src/arrow/compute/kernel_test.cc index 2d427374426..d995cca354c 100644 --- a/cpp/src/arrow/compute/kernel_test.cc +++ b/cpp/src/arrow/compute/kernel_test.cc @@ -21,6 +21,7 @@ #include +#include "arrow/array/util.h" #include "arrow/compute/kernel.h" #include "arrow/status.h" #include "arrow/testing/gtest_util.h" @@ -35,8 +36,8 @@ namespace compute { TEST(TypeMatcher, SameTypeId) { std::shared_ptr matcher = match::SameTypeId(Type::DECIMAL); - ASSERT_TRUE(matcher->Matches(*decimal(12, 2))); - ASSERT_FALSE(matcher->Matches(*int8())); + ASSERT_TRUE(matcher->Matches(decimal(12, 2))); + ASSERT_FALSE(matcher->Matches(int8())); ASSERT_EQ("Type::DECIMAL128", matcher->ToString()); @@ -49,11 +50,11 @@ TEST(TypeMatcher, TimestampTypeUnit) { auto matcher = match::TimestampTypeUnit(TimeUnit::MILLI); auto matcher2 = match::Time32TypeUnit(TimeUnit::MILLI); - ASSERT_TRUE(matcher->Matches(*timestamp(TimeUnit::MILLI))); - ASSERT_TRUE(matcher->Matches(*timestamp(TimeUnit::MILLI, "utc"))); - ASSERT_FALSE(matcher->Matches(*timestamp(TimeUnit::SECOND))); - ASSERT_FALSE(matcher->Matches(*time32(TimeUnit::MILLI))); - ASSERT_TRUE(matcher2->Matches(*time32(TimeUnit::MILLI))); + ASSERT_TRUE(matcher->Matches(timestamp(TimeUnit::MILLI))); + ASSERT_TRUE(matcher->Matches(timestamp(TimeUnit::MILLI, "utc"))); + ASSERT_FALSE(matcher->Matches(timestamp(TimeUnit::SECOND))); + ASSERT_FALSE(matcher->Matches(time32(TimeUnit::MILLI))); + ASSERT_TRUE(matcher2->Matches(time32(TimeUnit::MILLI))); // Check ToString representation ASSERT_EQ("timestamp(s)", match::TimestampTypeUnit(TimeUnit::SECOND)->ToString()); @@ -75,43 +76,23 @@ TEST(InputType, AnyTypeConstructor) { // Check the ANY_TYPE ctors InputType ty; ASSERT_EQ(InputType::ANY_TYPE, ty.kind()); - ASSERT_EQ(ValueDescr::ANY, ty.shape()); - - ty = InputType(ValueDescr::SCALAR); - ASSERT_EQ(ValueDescr::SCALAR, ty.shape()); - - ty = InputType(ValueDescr::ARRAY); - ASSERT_EQ(ValueDescr::ARRAY, ty.shape()); } TEST(InputType, Constructors) { // Exact type constructor InputType ty1(int8()); ASSERT_EQ(InputType::EXACT_TYPE, ty1.kind()); - ASSERT_EQ(ValueDescr::ANY, ty1.shape()); AssertTypeEqual(*int8(), *ty1.type()); InputType ty1_implicit = int8(); ASSERT_TRUE(ty1.Equals(ty1_implicit)); - InputType ty1_array(int8(), ValueDescr::ARRAY); - ASSERT_EQ(ValueDescr::ARRAY, ty1_array.shape()); - - InputType ty1_scalar(int8(), ValueDescr::SCALAR); - ASSERT_EQ(ValueDescr::SCALAR, ty1_scalar.shape()); - // Same type id constructor InputType ty2(Type::DECIMAL); ASSERT_EQ(InputType::USE_TYPE_MATCHER, ty2.kind()); - ASSERT_EQ("any[Type::DECIMAL128]", ty2.ToString()); - ASSERT_TRUE(ty2.type_matcher().Matches(*decimal(12, 2))); - ASSERT_FALSE(ty2.type_matcher().Matches(*int16())); - - InputType ty2_array(Type::DECIMAL, ValueDescr::ARRAY); - ASSERT_EQ(ValueDescr::ARRAY, ty2_array.shape()); - - InputType ty2_scalar(Type::DECIMAL, ValueDescr::SCALAR); - ASSERT_EQ(ValueDescr::SCALAR, ty2_scalar.shape()); + ASSERT_EQ("Type::DECIMAL128", ty2.ToString()); + ASSERT_TRUE(ty2.type_matcher().Matches(decimal(12, 2))); + ASSERT_FALSE(ty2.type_matcher().Matches(int16())); // Implicit construction in a vector std::vector types = {int8(), InputType(Type::DECIMAL)}; @@ -131,69 +112,33 @@ TEST(InputType, Constructors) { ASSERT_TRUE(ty6.Equals(ty2)); // ToString - ASSERT_EQ("any[int8]", ty1.ToString()); - ASSERT_EQ("array[int8]", ty1_array.ToString()); - ASSERT_EQ("scalar[int8]", ty1_scalar.ToString()); - - ASSERT_EQ("any[Type::DECIMAL128]", ty2.ToString()); - ASSERT_EQ("array[Type::DECIMAL128]", ty2_array.ToString()); - ASSERT_EQ("scalar[Type::DECIMAL128]", ty2_scalar.ToString()); + ASSERT_EQ("int8", ty1.ToString()); + ASSERT_EQ("Type::DECIMAL128", ty2.ToString()); InputType ty7(match::TimestampTypeUnit(TimeUnit::MICRO)); - ASSERT_EQ("any[timestamp(us)]", ty7.ToString()); + ASSERT_EQ("timestamp(us)", ty7.ToString()); InputType ty8; - InputType ty9(ValueDescr::ANY); - InputType ty10(ValueDescr::ARRAY); - InputType ty11(ValueDescr::SCALAR); - ASSERT_EQ("any[any]", ty8.ToString()); - ASSERT_EQ("any[any]", ty9.ToString()); - ASSERT_EQ("array[any]", ty10.ToString()); - ASSERT_EQ("scalar[any]", ty11.ToString()); + ASSERT_EQ("any", ty8.ToString()); } TEST(InputType, Equals) { InputType t1 = int8(); InputType t2 = int8(); - InputType t3(int8(), ValueDescr::ARRAY); - InputType t3_i32(int32(), ValueDescr::ARRAY); - InputType t3_scalar(int8(), ValueDescr::SCALAR); - InputType t4(int8(), ValueDescr::ARRAY); - InputType t4_i32(int32(), ValueDescr::ARRAY); + InputType t3 = int32(); InputType t5(Type::DECIMAL); InputType t6(Type::DECIMAL); - InputType t7(Type::DECIMAL, ValueDescr::SCALAR); - InputType t7_i32(Type::INT32, ValueDescr::SCALAR); - InputType t8(Type::DECIMAL, ValueDescr::SCALAR); - InputType t8_i32(Type::INT32, ValueDescr::SCALAR); ASSERT_TRUE(t1.Equals(t2)); ASSERT_EQ(t1, t2); - - // ANY vs SCALAR ASSERT_NE(t1, t3); - ASSERT_EQ(t3, t4); - - // both ARRAY, but different type - ASSERT_NE(t3, t3_i32); - - // ARRAY vs SCALAR - ASSERT_NE(t3, t3_scalar); - - ASSERT_EQ(t3_i32, t4_i32); - ASSERT_FALSE(t1.Equals(t5)); ASSERT_NE(t1, t5); ASSERT_EQ(t5, t5); ASSERT_EQ(t5, t6); - ASSERT_NE(t5, t7); - ASSERT_EQ(t7, t8); - ASSERT_EQ(t7, t8); - ASSERT_NE(t7, t7_i32); - ASSERT_EQ(t7_i32, t8_i32); // NOTE: For the time being, we treat int32() and Type::INT32 as being // different. This could obviously be fixed later to make these equivalent @@ -208,9 +153,6 @@ TEST(InputType, Equals) { TEST(InputType, Hash) { InputType t0; - InputType t0_scalar(ValueDescr::SCALAR); - InputType t0_array(ValueDescr::ARRAY); - InputType t1 = int8(); InputType t2(Type::DECIMAL); @@ -218,36 +160,32 @@ TEST(InputType, Hash) { // same value, and whether the elements of the type are all incorporated into // the Hash ASSERT_EQ(t0.Hash(), t0.Hash()); - ASSERT_NE(t0.Hash(), t0_scalar.Hash()); - ASSERT_NE(t0.Hash(), t0_array.Hash()); - ASSERT_NE(t0_scalar.Hash(), t0_array.Hash()); - ASSERT_EQ(t1.Hash(), t1.Hash()); ASSERT_EQ(t2.Hash(), t2.Hash()); - ASSERT_NE(t0.Hash(), t1.Hash()); ASSERT_NE(t0.Hash(), t2.Hash()); ASSERT_NE(t1.Hash(), t2.Hash()); } TEST(InputType, Matches) { - InputType ty1 = int8(); - - ASSERT_TRUE(ty1.Matches(ValueDescr::Scalar(int8()))); - ASSERT_TRUE(ty1.Matches(ValueDescr::Array(int8()))); - ASSERT_TRUE(ty1.Matches(ValueDescr::Any(int8()))); - ASSERT_FALSE(ty1.Matches(ValueDescr::Any(int16()))); - - InputType ty2(Type::DECIMAL); - ASSERT_TRUE(ty2.Matches(ValueDescr::Scalar(decimal(12, 2)))); - ASSERT_TRUE(ty2.Matches(ValueDescr::Array(decimal(12, 2)))); - ASSERT_FALSE(ty2.Matches(ValueDescr::Any(float64()))); - - InputType ty3(int64(), ValueDescr::SCALAR); - ASSERT_FALSE(ty3.Matches(ValueDescr::Array(int64()))); - ASSERT_TRUE(ty3.Matches(ValueDescr::Scalar(int64()))); - ASSERT_FALSE(ty3.Matches(ValueDescr::Scalar(int32()))); - ASSERT_FALSE(ty3.Matches(ValueDescr::Any(int64()))); + InputType input1 = int8(); + + ASSERT_TRUE(input1.Matches(int8())); + ASSERT_TRUE(input1.Matches(int8())); + ASSERT_FALSE(input1.Matches(int16())); + + InputType input2(Type::DECIMAL); + ASSERT_TRUE(input2.Matches(decimal(12, 2))); + + auto ty2 = decimal(12, 2); + auto ty3 = float64(); + ASSERT_OK_AND_ASSIGN(std::shared_ptr arr2, MakeArrayOfNull(ty2, 1)); + ASSERT_OK_AND_ASSIGN(std::shared_ptr arr3, MakeArrayOfNull(ty3, 1)); + ASSERT_OK_AND_ASSIGN(std::shared_ptr scalar2, arr2->GetScalar(0)); + ASSERT_TRUE(input2.Matches(Datum(arr2))); + ASSERT_TRUE(input2.Matches(Datum(scalar2))); + ASSERT_FALSE(input2.Matches(ty3)); + ASSERT_FALSE(input2.Matches(arr3)); } // ---------------------------------------------------------------------- @@ -259,14 +197,14 @@ TEST(OutputType, Constructors) { AssertTypeEqual(*int8(), *ty1.type()); auto DummyResolver = [](KernelContext*, - const std::vector& args) -> Result { - return ValueDescr(int32(), GetBroadcastShape(args)); + const std::vector& args) -> Result { + return int32(); }; OutputType ty2(DummyResolver); ASSERT_EQ(OutputType::COMPUTED, ty2.kind()); - ASSERT_OK_AND_ASSIGN(ValueDescr out_descr2, ty2.Resolve(nullptr, {})); - ASSERT_EQ(ValueDescr::Array(int32()), out_descr2); + ASSERT_OK_AND_ASSIGN(TypeHolder out_type2, ty2.Resolve(nullptr, {})); + ASSERT_EQ(out_type2, int32()); // Copy constructor OutputType ty3 = ty1; @@ -275,8 +213,8 @@ TEST(OutputType, Constructors) { OutputType ty4 = ty2; ASSERT_EQ(OutputType::COMPUTED, ty4.kind()); - ASSERT_OK_AND_ASSIGN(ValueDescr out_descr4, ty4.Resolve(nullptr, {})); - ASSERT_EQ(ValueDescr::Array(int32()), out_descr4); + ASSERT_OK_AND_ASSIGN(TypeHolder out_type4, ty4.Resolve(nullptr, {})); + ASSERT_EQ(out_type4, int32()); // Move constructor OutputType ty5 = std::move(ty1); @@ -285,8 +223,8 @@ TEST(OutputType, Constructors) { OutputType ty6 = std::move(ty4); ASSERT_EQ(OutputType::COMPUTED, ty6.kind()); - ASSERT_OK_AND_ASSIGN(ValueDescr out_descr6, ty6.Resolve(nullptr, {})); - ASSERT_EQ(ValueDescr::Array(int32()), out_descr6); + ASSERT_OK_AND_ASSIGN(TypeHolder out_type6, ty6.Resolve(nullptr, {})); + ASSERT_EQ(out_type6, int32()); // ToString @@ -296,89 +234,63 @@ TEST(OutputType, Constructors) { } TEST(OutputType, Resolve) { - // Check shape promotion rules for FIXED kind OutputType ty1(int32()); - ASSERT_OK_AND_ASSIGN(ValueDescr descr, ty1.Resolve(nullptr, {})); - ASSERT_EQ(ValueDescr::Array(int32()), descr); + ASSERT_OK_AND_ASSIGN(TypeHolder result, ty1.Resolve(nullptr, {})); + ASSERT_EQ(result, int32()); - ASSERT_OK_AND_ASSIGN(descr, - ty1.Resolve(nullptr, {ValueDescr(int8(), ValueDescr::SCALAR)})); - ASSERT_EQ(ValueDescr::Scalar(int32()), descr); + ASSERT_OK_AND_ASSIGN(result, ty1.Resolve(nullptr, {int8()})); + ASSERT_EQ(result, int32()); - ASSERT_OK_AND_ASSIGN(descr, - ty1.Resolve(nullptr, {ValueDescr(int8(), ValueDescr::SCALAR), - ValueDescr(int8(), ValueDescr::ARRAY)})); - ASSERT_EQ(ValueDescr::Array(int32()), descr); + ASSERT_OK_AND_ASSIGN(result, ty1.Resolve(nullptr, {int8(), int8()})); + ASSERT_EQ(result, int32()); - OutputType ty2([](KernelContext*, const std::vector& args) { - return ValueDescr(args[0].type, GetBroadcastShape(args)); - }); + auto resolver = [](KernelContext*, + const std::vector& args) -> Result { + return args[0]; + }; + OutputType ty2(resolver); - ASSERT_OK_AND_ASSIGN(descr, ty2.Resolve(nullptr, {ValueDescr::Array(utf8())})); - ASSERT_EQ(ValueDescr::Array(utf8()), descr); + ASSERT_OK_AND_ASSIGN(result, ty2.Resolve(nullptr, {utf8()})); + ASSERT_EQ(result, utf8()); // Type resolver that returns an error OutputType ty3( - [](KernelContext* ctx, const std::vector& args) -> Result { + [](KernelContext* ctx, const std::vector& types) -> Result { // NB: checking the value types versus the function arity should be // validated elsewhere, so this is just for illustration purposes - if (args.size() == 0) { + if (types.size() == 0) { return Status::Invalid("Need at least one argument"); } - return ValueDescr(args[0]); + return types[0]; }); ASSERT_RAISES(Invalid, ty3.Resolve(nullptr, {})); - // Type resolver that returns ValueDescr::ANY and needs type promotion + // Type resolver that returns a fixed value OutputType ty4( - [](KernelContext* ctx, const std::vector& args) -> Result { + [](KernelContext* ctx, const std::vector& types) -> Result { return int32(); }); - ASSERT_OK_AND_ASSIGN(descr, ty4.Resolve(nullptr, {ValueDescr::Array(int8())})); - ASSERT_EQ(ValueDescr::Array(int32()), descr); - ASSERT_OK_AND_ASSIGN(descr, ty4.Resolve(nullptr, {ValueDescr::Scalar(int8())})); - ASSERT_EQ(ValueDescr::Scalar(int32()), descr); -} - -TEST(OutputType, ResolveDescr) { - ValueDescr d1 = ValueDescr::Scalar(int32()); - ValueDescr d2 = ValueDescr::Array(int32()); - - OutputType ty1(d1); - OutputType ty2(d2); - - ASSERT_EQ(ValueDescr::SCALAR, ty1.shape()); - ASSERT_EQ(ValueDescr::ARRAY, ty2.shape()); - - { - ASSERT_OK_AND_ASSIGN(ValueDescr descr, ty1.Resolve(nullptr, {})); - ASSERT_EQ(d1, descr); - } - - { - ASSERT_OK_AND_ASSIGN(ValueDescr descr, ty2.Resolve(nullptr, {})); - ASSERT_EQ(d2, descr); - } + ASSERT_OK_AND_ASSIGN(result, ty4.Resolve(nullptr, {int8()})); + ASSERT_EQ(result, int32()); + ASSERT_OK_AND_ASSIGN(result, ty4.Resolve(nullptr, {int8()})); + ASSERT_EQ(result, int32()); } // ---------------------------------------------------------------------- // KernelSignature TEST(KernelSignature, Basics) { - // (any[int8], scalar[decimal]) -> utf8 - std::vector in_types({int8(), InputType(Type::DECIMAL, ValueDescr::SCALAR)}); + // (int8, decimal) -> utf8 + std::vector in_types({int8(), InputType(Type::DECIMAL)}); OutputType out_type(utf8()); KernelSignature sig(in_types, out_type); ASSERT_EQ(2, sig.in_types().size()); ASSERT_TRUE(sig.in_types()[0].type()->Equals(*int8())); - ASSERT_TRUE(sig.in_types()[0].Matches(ValueDescr::Scalar(int8()))); - ASSERT_TRUE(sig.in_types()[0].Matches(ValueDescr::Array(int8()))); - - ASSERT_TRUE(sig.in_types()[1].Matches(ValueDescr::Scalar(decimal(12, 2)))); - ASSERT_FALSE(sig.in_types()[1].Matches(ValueDescr::Array(decimal(12, 2)))); + ASSERT_TRUE(sig.in_types()[0].Matches(int8())); + ASSERT_TRUE(sig.in_types()[1].Matches(decimal(12, 2))); } TEST(KernelSignature, Equals) { @@ -393,10 +305,6 @@ TEST(KernelSignature, Equals) { KernelSignature sig4_copy({int8(), int16()}, utf8()); KernelSignature sig5({int8(), int16(), int32()}, utf8()); - // Differ in shape - KernelSignature sig6({ValueDescr::Scalar(int8())}, utf8()); - KernelSignature sig7({ValueDescr::Array(int8())}, utf8()); - ASSERT_EQ(sig1, sig1); ASSERT_EQ(sig2, sig3); @@ -408,8 +316,6 @@ TEST(KernelSignature, Equals) { // Match first 2 args, but not third ASSERT_NE(sig4, sig5); - - ASSERT_NE(sig6, sig7); } TEST(KernelSignature, VarArgsEquals) { @@ -441,40 +347,32 @@ TEST(KernelSignature, MatchesInputs) { ASSERT_TRUE(sig1.MatchesInputs({})); ASSERT_FALSE(sig1.MatchesInputs({int8()})); - // (any[int8], any[decimal]) -> boolean + // (int8, decimal) -> boolean KernelSignature sig2({int8(), InputType(Type::DECIMAL)}, boolean()); ASSERT_FALSE(sig2.MatchesInputs({})); ASSERT_FALSE(sig2.MatchesInputs({int8()})); ASSERT_TRUE(sig2.MatchesInputs({int8(), decimal(12, 2)})); - ASSERT_TRUE(sig2.MatchesInputs( - {ValueDescr::Scalar(int8()), ValueDescr::Scalar(decimal(12, 2))})); - ASSERT_TRUE( - sig2.MatchesInputs({ValueDescr::Array(int8()), ValueDescr::Array(decimal(12, 2))})); - // (scalar[int8], array[int32]) -> boolean - KernelSignature sig3({ValueDescr::Scalar(int8()), ValueDescr::Array(int32())}, - boolean()); + // (int8, int32) -> boolean + KernelSignature sig3({int8(), int32()}, boolean()); ASSERT_FALSE(sig3.MatchesInputs({})); // Unqualified, these are ANY type and do not match because the kernel // requires a scalar and an array - ASSERT_FALSE(sig3.MatchesInputs({int8(), int32()})); - ASSERT_TRUE( - sig3.MatchesInputs({ValueDescr::Scalar(int8()), ValueDescr::Array(int32())})); - ASSERT_FALSE( - sig3.MatchesInputs({ValueDescr::Array(int8()), ValueDescr::Array(int32())})); + ASSERT_TRUE(sig3.MatchesInputs({int8(), int32()})); + ASSERT_FALSE(sig3.MatchesInputs({int8(), int16()})); } TEST(KernelSignature, VarArgsMatchesInputs) { { KernelSignature sig({int8()}, utf8(), /*is_varargs=*/true); - std::vector args = {int8()}; + std::vector args = {int8()}; ASSERT_TRUE(sig.MatchesInputs(args)); - args.push_back(ValueDescr::Scalar(int8())); - args.push_back(ValueDescr::Array(int8())); + args.push_back(int8()); + args.push_back(int8()); ASSERT_TRUE(sig.MatchesInputs(args)); args.push_back(int32()); ASSERT_FALSE(sig.MatchesInputs(args)); @@ -482,10 +380,10 @@ TEST(KernelSignature, VarArgsMatchesInputs) { { KernelSignature sig({int8(), utf8()}, utf8(), /*is_varargs=*/true); - std::vector args = {int8()}; + std::vector args = {int8()}; ASSERT_TRUE(sig.MatchesInputs(args)); - args.push_back(ValueDescr::Scalar(utf8())); - args.push_back(ValueDescr::Array(utf8())); + args.push_back(utf8()); + args.push_back(utf8()); ASSERT_TRUE(sig.MatchesInputs(args)); args.push_back(int32()); ASSERT_FALSE(sig.MatchesInputs(args)); @@ -493,23 +391,25 @@ TEST(KernelSignature, VarArgsMatchesInputs) { } TEST(KernelSignature, ToString) { - std::vector in_types = {InputType(int8(), ValueDescr::SCALAR), - InputType(Type::DECIMAL, ValueDescr::ARRAY), + std::vector in_types = {InputType(int8()), InputType(Type::DECIMAL), InputType(utf8())}; KernelSignature sig(in_types, utf8()); - ASSERT_EQ("(scalar[int8], array[Type::DECIMAL128], any[string]) -> string", - sig.ToString()); - - OutputType out_type([](KernelContext*, const std::vector& args) { - return Status::Invalid("NYI"); - }); - KernelSignature sig2({int8(), InputType(Type::DECIMAL)}, out_type); - ASSERT_EQ("(any[int8], any[Type::DECIMAL128]) -> computed", sig2.ToString()); + ASSERT_EQ("(int8, Type::DECIMAL128, string) -> string", sig.ToString()); + + OutputType out_type( + [](KernelContext*, const std::vector& args) -> Result { + return Status::Invalid("NYI"); + }); + KernelSignature sig2({int8(), Type::DECIMAL}, out_type); + ASSERT_EQ("(int8, Type::DECIMAL128) -> computed", sig2.ToString()); } TEST(KernelSignature, VarArgsToString) { KernelSignature sig({int8()}, utf8(), /*is_varargs=*/true); - ASSERT_EQ("varargs[any[int8]] -> string", sig.ToString()); + ASSERT_EQ("varargs[int8*] -> string", sig.ToString()); + + KernelSignature sig2({utf8(), int8()}, utf8(), /*is_varargs=*/true); + ASSERT_EQ("varargs[string, int8*] -> string", sig2.ToString()); } } // namespace compute diff --git a/cpp/src/arrow/compute/kernels/aggregate_basic.cc b/cpp/src/arrow/compute/kernels/aggregate_basic.cc index 661b6a4edb1..57cee87f00d 100644 --- a/cpp/src/arrow/compute/kernels/aggregate_basic.cc +++ b/cpp/src/arrow/compute/kernels/aggregate_basic.cc @@ -195,7 +195,7 @@ Result> CountDistinctInit(KernelContext* ctx, template void AddCountDistinctKernel(InputType type, ScalarAggregateFunction* func) { - AddAggKernel(KernelSignature::Make({type}, ValueDescr::Scalar(int64())), + AddAggKernel(KernelSignature::Make({type}, int64()), CountDistinctInit, func); } @@ -252,7 +252,7 @@ struct MeanImplDefault : public MeanImpl { Result> SumInit(KernelContext* ctx, const KernelInitArgs& args) { SumLikeInit visitor( - ctx, args.inputs[0].type, + ctx, args.inputs[0].GetSharedPtr(), static_cast(*args.options)); return visitor.Create(); } @@ -260,7 +260,7 @@ Result> SumInit(KernelContext* ctx, Result> MeanInit(KernelContext* ctx, const KernelInitArgs& args) { MeanKernelInit visitor( - ctx, args.inputs[0].type, + ctx, args.inputs[0].GetSharedPtr(), static_cast(*args.options)); return visitor.Create(); } @@ -277,7 +277,7 @@ struct ProductImpl : public ScalarAggregator { using ProductType = typename TypeTraits::CType; using OutputType = typename TypeTraits::ScalarType; - explicit ProductImpl(const std::shared_ptr& out_type, + explicit ProductImpl(std::shared_ptr out_type, const ScalarAggregateOptions& options) : out_type(out_type), options(options), @@ -356,10 +356,10 @@ struct NullProductImpl : public NullImpl { struct ProductInit { std::unique_ptr state; KernelContext* ctx; - const std::shared_ptr& type; + std::shared_ptr type; const ScalarAggregateOptions& options; - ProductInit(KernelContext* ctx, const std::shared_ptr& type, + ProductInit(KernelContext* ctx, std::shared_ptr type, const ScalarAggregateOptions& options) : ctx(ctx), type(type), options(options) {} @@ -402,7 +402,7 @@ struct ProductInit { static Result> Init(KernelContext* ctx, const KernelInitArgs& args) { - ProductInit visitor(ctx, args.inputs[0].type, + ProductInit visitor(ctx, args.inputs[0].GetSharedPtr(), static_cast(*args.options)); return visitor.Create(); } @@ -413,10 +413,10 @@ struct ProductInit { Result> MinMaxInit(KernelContext* ctx, const KernelInitArgs& args) { - ARROW_ASSIGN_OR_RAISE(auto out_type, + ARROW_ASSIGN_OR_RAISE(TypeHolder out_type, args.kernel->signature->out_type().Resolve(ctx, args.inputs)); MinMaxInitState visitor( - ctx, *args.inputs[0].type, std::move(out_type.type), + ctx, *args.inputs[0], out_type.GetSharedPtr(), static_cast(*args.options)); return visitor.Create(); } @@ -425,14 +425,7 @@ Result> MinMaxInit(KernelContext* ctx, template void AddMinOrMaxAggKernel(ScalarAggregateFunction* func, ScalarAggregateFunction* min_max_func) { - auto sig = KernelSignature::Make( - {InputType(ValueDescr::ANY)}, - OutputType([](KernelContext*, - const std::vector& descrs) -> Result { - // any[T] -> scalar[T] - return ValueDescr::Scalar(descrs.front().type); - })); - + auto sig = KernelSignature::Make({InputType::Any()}, FirstType); auto init = [min_max_func]( KernelContext* ctx, const KernelInitArgs& args) -> Result> { @@ -775,8 +768,7 @@ void AddBasicAggKernels(KernelInit init, SimdLevel::type simd_level) { for (const auto& ty : types) { // array[InT] -> scalar[OutT] - auto sig = - KernelSignature::Make({InputType::Array(ty->id())}, ValueDescr::Scalar(out_ty)); + auto sig = KernelSignature::Make({ty->id()}, out_ty); AddAggKernel(std::move(sig), init, func, simd_level); } } @@ -786,9 +778,7 @@ void AddScalarAggKernels(KernelInit init, std::shared_ptr out_ty, ScalarAggregateFunction* func) { for (const auto& ty : types) { - // scalar[InT] -> scalar[OutT] - auto sig = - KernelSignature::Make({InputType::Scalar(ty->id())}, ValueDescr::Scalar(out_ty)); + auto sig = KernelSignature::Make({ty->id()}, out_ty); AddAggKernel(std::move(sig), init, func, SimdLevel::NONE); } } @@ -804,17 +794,17 @@ void AddArrayScalarAggKernels(KernelInit init, namespace { -Result MinMaxType(KernelContext*, const std::vector& descrs) { - // any[T] -> scalar[struct] - auto ty = descrs.front().type; - return ValueDescr::Scalar(struct_({field("min", ty), field("max", ty)})); +Result MinMaxType(KernelContext*, const std::vector& types) { + // T -> struct + auto ty = types.front().GetSharedPtr(); + return struct_({field("min", ty), field("max", ty)}); } } // namespace void AddMinMaxKernel(KernelInit init, internal::detail::GetTypeId get_id, ScalarAggregateFunction* func, SimdLevel::type simd_level) { - auto sig = KernelSignature::Make({InputType(get_id.id)}, OutputType(MinMaxType)); + auto sig = KernelSignature::Make({InputType(get_id.id)}, MinMaxType); AddAggKernel(std::move(sig), init, func, simd_level); } @@ -828,13 +818,6 @@ void AddMinMaxKernels(KernelInit init, namespace { -Result ScalarFirstType(KernelContext*, - const std::vector& descrs) { - ValueDescr result = descrs.front(); - result.shape = ValueDescr::SCALAR; - return result; -} - const FunctionDoc count_doc{"Count the number of null / non-null values", ("By default, only non-null values are counted.\n" "This can be changed through CountOptions."), @@ -922,8 +905,7 @@ void RegisterScalarAggregateBasic(FunctionRegistry* registry) { // Takes any input, outputs int64 scalar InputType any_input; - AddAggKernel(KernelSignature::Make({any_input}, ValueDescr::Scalar(int64())), CountInit, - func.get()); + AddAggKernel(KernelSignature::Make({any_input}, int64()), CountInit, func.get()); DCHECK_OK(registry->AddFunction(std::move(func))); func = std::make_shared( @@ -935,12 +917,10 @@ void RegisterScalarAggregateBasic(FunctionRegistry* registry) { func = std::make_shared("sum", Arity::Unary(), sum_doc, &default_scalar_aggregate_options); AddArrayScalarAggKernels(SumInit, {boolean()}, uint64(), func.get()); - AddAggKernel( - KernelSignature::Make({InputType(Type::DECIMAL128)}, OutputType(ScalarFirstType)), - SumInit, func.get(), SimdLevel::NONE); - AddAggKernel( - KernelSignature::Make({InputType(Type::DECIMAL256)}, OutputType(ScalarFirstType)), - SumInit, func.get(), SimdLevel::NONE); + AddAggKernel(KernelSignature::Make({Type::DECIMAL128}, FirstType), SumInit, func.get(), + SimdLevel::NONE); + AddAggKernel(KernelSignature::Make({Type::DECIMAL256}, FirstType), SumInit, func.get(), + SimdLevel::NONE); AddArrayScalarAggKernels(SumInit, SignedIntTypes(), int64(), func.get()); AddArrayScalarAggKernels(SumInit, UnsignedIntTypes(), uint64(), func.get()); AddArrayScalarAggKernels(SumInit, FloatingPointTypes(), float64(), func.get()); @@ -965,12 +945,10 @@ void RegisterScalarAggregateBasic(FunctionRegistry* registry) { &default_scalar_aggregate_options); AddArrayScalarAggKernels(MeanInit, {boolean()}, float64(), func.get()); AddArrayScalarAggKernels(MeanInit, NumericTypes(), float64(), func.get()); - AddAggKernel( - KernelSignature::Make({InputType(Type::DECIMAL128)}, OutputType(ScalarFirstType)), - MeanInit, func.get(), SimdLevel::NONE); - AddAggKernel( - KernelSignature::Make({InputType(Type::DECIMAL256)}, OutputType(ScalarFirstType)), - MeanInit, func.get(), SimdLevel::NONE); + AddAggKernel(KernelSignature::Make({Type::DECIMAL128}, FirstType), MeanInit, func.get(), + SimdLevel::NONE); + AddAggKernel(KernelSignature::Make({Type::DECIMAL256}, FirstType), MeanInit, func.get(), + SimdLevel::NONE); AddArrayScalarAggKernels(MeanInit, {null()}, float64(), func.get()); // Add the SIMD variants for mean #if defined(ARROW_HAVE_RUNTIME_AVX2) @@ -1028,12 +1006,10 @@ void RegisterScalarAggregateBasic(FunctionRegistry* registry) { AddArrayScalarAggKernels(ProductInit::Init, UnsignedIntTypes(), uint64(), func.get()); AddArrayScalarAggKernels(ProductInit::Init, FloatingPointTypes(), float64(), func.get()); - AddAggKernel( - KernelSignature::Make({InputType(Type::DECIMAL128)}, OutputType(ScalarFirstType)), - ProductInit::Init, func.get(), SimdLevel::NONE); - AddAggKernel( - KernelSignature::Make({InputType(Type::DECIMAL256)}, OutputType(ScalarFirstType)), - ProductInit::Init, func.get(), SimdLevel::NONE); + AddAggKernel(KernelSignature::Make({Type::DECIMAL128}, FirstType), ProductInit::Init, + func.get(), SimdLevel::NONE); + AddAggKernel(KernelSignature::Make({Type::DECIMAL256}, FirstType), ProductInit::Init, + func.get(), SimdLevel::NONE); AddArrayScalarAggKernels(ProductInit::Init, {null()}, int64(), func.get()); DCHECK_OK(registry->AddFunction(std::move(func))); diff --git a/cpp/src/arrow/compute/kernels/aggregate_basic_avx2.cc b/cpp/src/arrow/compute/kernels/aggregate_basic_avx2.cc index 00e3e2e5fd4..03b45107eec 100644 --- a/cpp/src/arrow/compute/kernels/aggregate_basic_avx2.cc +++ b/cpp/src/arrow/compute/kernels/aggregate_basic_avx2.cc @@ -37,7 +37,7 @@ struct MeanImplAvx2 : public MeanImpl { Result> SumInitAvx2(KernelContext* ctx, const KernelInitArgs& args) { SumLikeInit visitor( - ctx, args.inputs[0].type, + ctx, args.inputs[0].GetSharedPtr(), static_cast(*args.options)); return visitor.Create(); } @@ -45,7 +45,7 @@ Result> SumInitAvx2(KernelContext* ctx, Result> MeanInitAvx2(KernelContext* ctx, const KernelInitArgs& args) { SumLikeInit visitor( - ctx, args.inputs[0].type, + ctx, args.inputs[0].GetSharedPtr(), static_cast(*args.options)); return visitor.Create(); } @@ -55,10 +55,10 @@ Result> MeanInitAvx2(KernelContext* ctx, Result> MinMaxInitAvx2(KernelContext* ctx, const KernelInitArgs& args) { - ARROW_ASSIGN_OR_RAISE(auto out_type, + ARROW_ASSIGN_OR_RAISE(TypeHolder out_type, args.kernel->signature->out_type().Resolve(ctx, args.inputs)); MinMaxInitState visitor( - ctx, *args.inputs[0].type, std::move(out_type.type), + ctx, *args.inputs[0], out_type.GetSharedPtr(), static_cast(*args.options)); return visitor.Create(); } diff --git a/cpp/src/arrow/compute/kernels/aggregate_basic_avx512.cc b/cpp/src/arrow/compute/kernels/aggregate_basic_avx512.cc index 8c10eb19b07..0d66ed2ec3e 100644 --- a/cpp/src/arrow/compute/kernels/aggregate_basic_avx512.cc +++ b/cpp/src/arrow/compute/kernels/aggregate_basic_avx512.cc @@ -37,7 +37,7 @@ struct MeanImplAvx512 : public MeanImpl { Result> SumInitAvx512(KernelContext* ctx, const KernelInitArgs& args) { SumLikeInit visitor( - ctx, args.inputs[0].type, + ctx, args.inputs[0].GetSharedPtr(), static_cast(*args.options)); return visitor.Create(); } @@ -45,7 +45,7 @@ Result> SumInitAvx512(KernelContext* ctx, Result> MeanInitAvx512(KernelContext* ctx, const KernelInitArgs& args) { SumLikeInit visitor( - ctx, args.inputs[0].type, + ctx, args.inputs[0].GetSharedPtr(), static_cast(*args.options)); return visitor.Create(); } @@ -55,10 +55,10 @@ Result> MeanInitAvx512(KernelContext* ctx, Result> MinMaxInitAvx512(KernelContext* ctx, const KernelInitArgs& args) { - ARROW_ASSIGN_OR_RAISE(auto out_type, + ARROW_ASSIGN_OR_RAISE(TypeHolder out_type, args.kernel->signature->out_type().Resolve(ctx, args.inputs)); MinMaxInitState visitor( - ctx, *args.inputs[0].type, std::move(out_type.type), + ctx, *args.inputs[0], out_type.GetSharedPtr(), static_cast(*args.options)); return visitor.Create(); } diff --git a/cpp/src/arrow/compute/kernels/aggregate_basic_internal.h b/cpp/src/arrow/compute/kernels/aggregate_basic_internal.h index a5b473793a9..6645e1a76bc 100644 --- a/cpp/src/arrow/compute/kernels/aggregate_basic_internal.h +++ b/cpp/src/arrow/compute/kernels/aggregate_basic_internal.h @@ -65,8 +65,7 @@ struct SumImpl : public ScalarAggregator { using SumCType = typename TypeTraits::CType; using OutputType = typename TypeTraits::ScalarType; - SumImpl(const std::shared_ptr& out_type, - const ScalarAggregateOptions& options_) + SumImpl(std::shared_ptr out_type, const ScalarAggregateOptions& options_) : out_type(out_type), options(options_) {} Status Consume(KernelContext*, const ExecBatch& batch) override { @@ -216,10 +215,10 @@ template