diff --git a/c_glib/arrow-glib/compute.cpp b/c_glib/arrow-glib/compute.cpp index e08a4d28d7b..581a4c679f0 100644 --- a/c_glib/arrow-glib/compute.cpp +++ b/c_glib/arrow-glib/compute.cpp @@ -40,16 +40,11 @@ garrow_numeric_array_sum(GArrowArrayType array, typename ArrowType::c_type default_value) { auto arrow_array = garrow_array_get_raw(GARROW_ARRAY(array)); - auto memory_pool = arrow::default_memory_pool(); - arrow::compute::FunctionContext context(memory_pool); - arrow::compute::Datum sum_datum; - auto status = arrow::compute::Sum(&context, - arrow_array, - &sum_datum); - if (garrow_error_check(error, status, tag)) { + auto arrow_sum_datum = arrow::compute::Sum(arrow_array); + if (garrow::check(error, arrow_sum_datum, tag)) { using ScalarType = typename arrow::TypeTraits::ScalarType; auto arrow_numeric_scalar = - std::dynamic_pointer_cast(sum_datum.scalar()); + std::dynamic_pointer_cast((*arrow_sum_datum).scalar()); if (arrow_numeric_scalar->is_valid) { return arrow_numeric_scalar->value; } else { @@ -69,17 +64,12 @@ garrow_numeric_array_compare(GArrowArrayType array, const gchar *tag) { auto arrow_array = garrow_array_get_raw(GARROW_ARRAY(array)); - auto memory_pool = arrow::default_memory_pool(); - arrow::compute::FunctionContext context(memory_pool); - arrow::compute::Datum compared_datum; auto arrow_options = garrow_compare_options_get_raw(options); - auto status = arrow::compute::Compare(&context, - arrow_array, - arrow::compute::Datum(value), - *arrow_options, - &compared_datum); - if (garrow_error_check(error, status, tag)) { - auto arrow_compared_array = compared_datum.make_array(); + auto arrow_compared_datum = arrow::compute::Compare(arrow_array, + arrow::Datum(value), + *arrow_options); + if (garrow::check(error, arrow_compared_datum, tag)) { + auto arrow_compared_array = (*arrow_compared_datum).make_array(); return GARROW_BOOLEAN_ARRAY(garrow_array_new_raw(&arrow_compared_array)); } else { return NULL; @@ -676,39 +666,32 @@ garrow_array_cast(GArrowArray *array, { auto arrow_array = garrow_array_get_raw(array); auto arrow_array_raw = arrow_array.get(); - auto memory_pool = arrow::default_memory_pool(); - arrow::compute::FunctionContext context(memory_pool); auto arrow_target_data_type = garrow_data_type_get_raw(target_data_type); - std::shared_ptr arrow_casted_array; - arrow::Status status; + arrow::Result> arrow_casted_array; if (options) { auto arrow_options = garrow_cast_options_get_raw(options); - status = arrow::compute::Cast(&context, - *arrow_array_raw, - arrow_target_data_type, - *arrow_options, - &arrow_casted_array); + arrow_casted_array = arrow::compute::Cast(*arrow_array_raw, + arrow_target_data_type, + *arrow_options); } else { - arrow::compute::CastOptions arrow_options; - status = arrow::compute::Cast(&context, - *arrow_array_raw, - arrow_target_data_type, - arrow_options, - &arrow_casted_array); + arrow_casted_array = arrow::compute::Cast(*arrow_array_raw, + arrow_target_data_type); } - - if (!status.ok()) { - std::stringstream message; - message << "[array][cast] <"; - message << arrow_array->type()->ToString(); - message << "> -> <"; - message << arrow_target_data_type->ToString(); - message << ">"; - garrow_error_check(error, status, message.str().c_str()); + if (garrow::check(error, + arrow_casted_array, + [&]() { + std::stringstream message; + message << "[array][cast] <"; + message << arrow_array->type()->ToString(); + message << "> -> <"; + message << arrow_target_data_type->ToString(); + message << ">"; + return message.str(); + })) { + return garrow_array_new_raw(&(*arrow_casted_array)); + } else { return NULL; } - - return garrow_array_new_raw(&arrow_casted_array); } /** @@ -726,22 +709,20 @@ garrow_array_unique(GArrowArray *array, GError **error) { auto arrow_array = garrow_array_get_raw(array); - auto memory_pool = arrow::default_memory_pool(); - arrow::compute::FunctionContext context(memory_pool); - std::shared_ptr arrow_unique_array; - auto status = arrow::compute::Unique(&context, - arrow::compute::Datum(arrow_array), - &arrow_unique_array); - if (!status.ok()) { - std::stringstream message; - message << "[array][unique] <"; - message << arrow_array->type()->ToString(); - message << ">"; - garrow_error_check(error, status, message.str().c_str()); + auto arrow_unique_array = arrow::compute::Unique(arrow_array); + if (garrow::check(error, + arrow_unique_array, + [&]() { + std::stringstream message; + message << "[array][unique] <"; + message << arrow_array->type()->ToString(); + message << ">"; + return message.str(); + })) { + return garrow_array_new_raw(&(*arrow_unique_array)); + } else { return NULL; } - - return garrow_array_new_raw(&arrow_unique_array); } /** @@ -760,27 +741,25 @@ garrow_array_dictionary_encode(GArrowArray *array, GError **error) { auto arrow_array = garrow_array_get_raw(array); - auto memory_pool = arrow::default_memory_pool(); - arrow::compute::FunctionContext context(memory_pool); - arrow::compute::Datum dictionary_encoded_datum; - auto status = - arrow::compute::DictionaryEncode(&context, - arrow::compute::Datum(arrow_array), - &dictionary_encoded_datum); - if (!status.ok()) { - std::stringstream message; - message << "[array][dictionary-encode] <"; - message << arrow_array->type()->ToString(); - message << ">"; - garrow_error_check(error, status, message.str().c_str()); + auto arrow_dictionary_encoded_datum = + arrow::compute::DictionaryEncode(arrow_array); + if (garrow::check(error, + arrow_dictionary_encoded_datum, + [&]() { + std::stringstream message; + message << "[array][dictionary-encode] <"; + message << arrow_array->type()->ToString(); + message << ">"; + return message.str(); + })) { + auto arrow_dictionary_encoded_array = + (*arrow_dictionary_encoded_datum).make_array(); + auto dictionary_encoded_array = + garrow_array_new_raw(&arrow_dictionary_encoded_array); + return GARROW_DICTIONARY_ARRAY(dictionary_encoded_array); + } else { return NULL; } - - auto arrow_dictionary_encoded_array = - arrow::MakeArray(dictionary_encoded_datum.array()); - auto dictionary_encoded_array = - garrow_array_new_raw(&arrow_dictionary_encoded_array); - return GARROW_DICTIONARY_ARRAY(dictionary_encoded_array); } /** @@ -801,28 +780,19 @@ garrow_array_count(GArrowArray *array, { auto arrow_array = garrow_array_get_raw(array); auto arrow_array_raw = arrow_array.get(); - auto memory_pool = arrow::default_memory_pool(); - arrow::compute::FunctionContext context(memory_pool); - arrow::compute::Datum counted_datum; - arrow::Status status; + arrow::Result arrow_counted_datum; if (options) { auto arrow_options = garrow_count_options_get_raw(options); - status = arrow::compute::Count(&context, - *arrow_options, - *arrow_array_raw, - &counted_datum); + arrow_counted_datum = + arrow::compute::Count(*arrow_array_raw, *arrow_options); } else { - arrow::compute::CountOptions arrow_options(arrow::compute::CountOptions::COUNT_ALL); - status = arrow::compute::Count(&context, - arrow_options, - *arrow_array_raw, - &counted_datum); + arrow_counted_datum = arrow::compute::Count(*arrow_array_raw); } - - if (garrow_error_check(error, status, "[array][count]")) { + if (garrow::check(error, arrow_counted_datum, "[array][count]")) { using ScalarType = typename arrow::TypeTraits::ScalarType; - auto counted_scalar = std::dynamic_pointer_cast(counted_datum.scalar()); - return counted_scalar->value; + auto arrow_counted_scalar = + std::dynamic_pointer_cast((*arrow_counted_datum).scalar()); + return arrow_counted_scalar->value; } else { return 0; } @@ -844,14 +814,9 @@ garrow_array_count_values(GArrowArray *array, GError **error) { auto arrow_array = garrow_array_get_raw(array); - auto memory_pool = arrow::default_memory_pool(); - arrow::compute::FunctionContext context(memory_pool); - std::shared_ptr arrow_counted_values; - auto status = arrow::compute::ValueCounts(&context, - arrow::compute::Datum(arrow_array), - &arrow_counted_values); - if (garrow_error_check(error, status, "[array][count-values]")) { - return GARROW_STRUCT_ARRAY(garrow_array_new_raw(&arrow_counted_values)); + auto arrow_counted_values = arrow::compute::ValueCounts(arrow_array); + if (garrow::check(error, arrow_counted_values, "[array][count-values]")) { + return GARROW_STRUCT_ARRAY(garrow_array_new_raw(&(*arrow_counted_values))); } else { return NULL; } @@ -874,13 +839,9 @@ garrow_boolean_array_invert(GArrowBooleanArray *array, GError **error) { auto arrow_array = garrow_array_get_raw(GARROW_ARRAY(array)); - auto datum = arrow::compute::Datum(arrow_array); - auto memory_pool = arrow::default_memory_pool(); - arrow::compute::FunctionContext context(memory_pool); - arrow::compute::Datum inverted_datum; - auto status = arrow::compute::Invert(&context, datum, &inverted_datum); - if (garrow_error_check(error, status, "[boolean-array][invert]")) { - auto arrow_inverted_array = inverted_datum.make_array(); + auto arrow_inverted_datum = arrow::compute::Invert(arrow_array); + if (garrow::check(error, arrow_inverted_datum, "[boolean-array][invert]")) { + auto arrow_inverted_array = (*arrow_inverted_datum).make_array(); return GARROW_BOOLEAN_ARRAY(garrow_array_new_raw(&arrow_inverted_array)); } else { return NULL; @@ -905,18 +866,10 @@ garrow_boolean_array_and(GArrowBooleanArray *left, GError **error) { auto arrow_left = garrow_array_get_raw(GARROW_ARRAY(left)); - auto left_datum = arrow::compute::Datum(arrow_left); auto arrow_right = garrow_array_get_raw(GARROW_ARRAY(right)); - auto right_datum = arrow::compute::Datum(arrow_right); - auto memory_pool = arrow::default_memory_pool(); - arrow::compute::FunctionContext context(memory_pool); - arrow::compute::Datum operated_datum; - auto status = arrow::compute::And(&context, - left_datum, - right_datum, - &operated_datum); - if (garrow_error_check(error, status, "[boolean-array][and]")) { - auto arrow_operated_array = operated_datum.make_array(); + auto arrow_operated_datum = arrow::compute::And(arrow_left, arrow_right); + if (garrow::check(error, arrow_operated_datum, "[boolean-array][and]")) { + auto arrow_operated_array = (*arrow_operated_datum).make_array(); return GARROW_BOOLEAN_ARRAY(garrow_array_new_raw(&arrow_operated_array)); } else { return NULL; @@ -941,18 +894,10 @@ garrow_boolean_array_or(GArrowBooleanArray *left, GError **error) { auto arrow_left = garrow_array_get_raw(GARROW_ARRAY(left)); - auto left_datum = arrow::compute::Datum(arrow_left); auto arrow_right = garrow_array_get_raw(GARROW_ARRAY(right)); - auto right_datum = arrow::compute::Datum(arrow_right); - auto memory_pool = arrow::default_memory_pool(); - arrow::compute::FunctionContext context(memory_pool); - arrow::compute::Datum operated_datum; - auto status = arrow::compute::Or(&context, - left_datum, - right_datum, - &operated_datum); - if (garrow_error_check(error, status, "[boolean-array][or]")) { - auto arrow_operated_array = operated_datum.make_array(); + auto arrow_operated_datum = arrow::compute::Or(arrow_left, arrow_right); + if (garrow::check(error, arrow_operated_datum, "[boolean-array][or]")) { + auto arrow_operated_array = (*arrow_operated_datum).make_array(); return GARROW_BOOLEAN_ARRAY(garrow_array_new_raw(&arrow_operated_array)); } else { return NULL; @@ -977,18 +922,10 @@ garrow_boolean_array_xor(GArrowBooleanArray *left, GError **error) { auto arrow_left = garrow_array_get_raw(GARROW_ARRAY(left)); - auto left_datum = arrow::compute::Datum(arrow_left); auto arrow_right = garrow_array_get_raw(GARROW_ARRAY(right)); - auto right_datum = arrow::compute::Datum(arrow_right); - auto memory_pool = arrow::default_memory_pool(); - arrow::compute::FunctionContext context(memory_pool); - arrow::compute::Datum operated_datum; - auto status = arrow::compute::Xor(&context, - left_datum, - right_datum, - &operated_datum); - if (garrow_error_check(error, status, "[boolean-array][xor]")) { - auto arrow_operated_array = operated_datum.make_array(); + auto arrow_operated_datum = arrow::compute::Xor(arrow_left, arrow_right); + if (garrow::check(error, arrow_operated_datum, "[boolean-array][xor]")) { + auto arrow_operated_array = (*arrow_operated_datum).make_array(); return GARROW_BOOLEAN_ARRAY(garrow_array_new_raw(&arrow_operated_array)); } else { return NULL; @@ -1010,14 +947,11 @@ garrow_numeric_array_mean(GArrowNumericArray *array, GError **error) { auto arrow_array = garrow_array_get_raw(GARROW_ARRAY(array)); - auto memory_pool = arrow::default_memory_pool(); - arrow::compute::FunctionContext context(memory_pool); - arrow::compute::Datum mean_datum; - auto status = arrow::compute::Mean(&context, arrow_array, &mean_datum); - if (garrow_error_check(error, status, "[numeric-array][mean]")) { + auto arrow_mean_datum = arrow::compute::Mean(arrow_array); + if (garrow::check(error, arrow_mean_datum, "[numeric-array][mean]")) { using ScalarType = typename arrow::TypeTraits::ScalarType; auto arrow_numeric_scalar = - std::dynamic_pointer_cast(mean_datum.scalar()); + std::dynamic_pointer_cast((*arrow_mean_datum).scalar()); if (arrow_numeric_scalar->is_valid) { return arrow_numeric_scalar->value; } else { @@ -1251,28 +1185,18 @@ garrow_array_take(GArrowArray *array, auto arrow_array_raw = arrow_array.get(); auto arrow_indices = garrow_array_get_raw(indices); auto arrow_indices_raw = arrow_indices.get(); - auto memory_pool = arrow::default_memory_pool(); - arrow::compute::FunctionContext context(memory_pool); - std::shared_ptr taken_array; - arrow::Status status; + arrow::Result> arrow_taken_array; if (options) { auto arrow_options = garrow_take_options_get_raw(options); - status = arrow::compute::Take(&context, - *arrow_array_raw, - *arrow_indices_raw, - *arrow_options, - &taken_array); + arrow_taken_array = arrow::compute::Take(*arrow_array_raw, + *arrow_indices_raw, + *arrow_options); } else { - arrow::compute::TakeOptions arrow_options; - status = arrow::compute::Take(&context, - *arrow_array_raw, - *arrow_indices_raw, - arrow_options, - &taken_array); + arrow_taken_array = arrow::compute::Take(*arrow_array_raw, + *arrow_indices_raw); } - - if (garrow_error_check(error, status, "[array][take]")) { - return garrow_array_new_raw(&taken_array); + if (garrow::check(error, arrow_taken_array, "[array][take]")) { + return garrow_array_new_raw(&(*arrow_taken_array)); } else { return NULL; } @@ -1300,28 +1224,20 @@ garrow_array_take_chunked_array(GArrowArray *array, auto arrow_array_raw = arrow_array.get(); auto arrow_indices = garrow_chunked_array_get_raw(indices); auto arrow_indices_raw = arrow_indices.get(); - auto memory_pool = arrow::default_memory_pool(); - arrow::compute::FunctionContext context(memory_pool); - std::shared_ptr taken_chunked_array; - arrow::Status status; + arrow::Result> arrow_taken_chunked_array; if (options) { auto arrow_options = garrow_take_options_get_raw(options); - status = arrow::compute::Take(&context, - *arrow_array_raw, - *arrow_indices_raw, - *arrow_options, - &taken_chunked_array); + arrow_taken_chunked_array = arrow::compute::Take(*arrow_array_raw, + *arrow_indices_raw, + *arrow_options); } else { - arrow::compute::TakeOptions arrow_options; - status = arrow::compute::Take(&context, - *arrow_array_raw, - *arrow_indices_raw, - arrow_options, - &taken_chunked_array); + arrow_taken_chunked_array = arrow::compute::Take(*arrow_array_raw, + *arrow_indices_raw); } - - if (garrow_error_check(error, status, "[array][take][chunked-array]")) { - return garrow_chunked_array_new_raw(&taken_chunked_array); + if (garrow::check(error, + arrow_taken_chunked_array, + "[array][take][chunked-array]")) { + return garrow_chunked_array_new_raw(&(*arrow_taken_chunked_array)); } else { return NULL; } @@ -1349,28 +1265,18 @@ garrow_table_take(GArrowTable *table, auto arrow_table_raw = arrow_table.get(); auto arrow_indices = garrow_array_get_raw(indices); auto arrow_indices_raw = arrow_indices.get(); - auto memory_pool = arrow::default_memory_pool(); - arrow::compute::FunctionContext context(memory_pool); - std::shared_ptr taken_table; - arrow::Status status; + arrow::Result> arrow_taken_table; if (options) { auto arrow_options = garrow_take_options_get_raw(options); - status = arrow::compute::Take(&context, - *arrow_table_raw, - *arrow_indices_raw, - *arrow_options, - &taken_table); + arrow_taken_table = arrow::compute::Take(*arrow_table_raw, + *arrow_indices_raw, + *arrow_options); } else { - arrow::compute::TakeOptions arrow_options; - status = arrow::compute::Take(&context, - *arrow_table_raw, - *arrow_indices_raw, - arrow_options, - &taken_table); + arrow_taken_table = arrow::compute::Take(*arrow_table_raw, + *arrow_indices_raw); } - - if (garrow_error_check(error, status, "[table][take]")) { - return garrow_table_new_raw(&taken_table); + if (garrow::check(error, arrow_taken_table, "[table][take]")) { + return garrow_table_new_raw(&(*arrow_taken_table)); } else { return NULL; } @@ -1398,28 +1304,18 @@ garrow_table_take_chunked_array(GArrowTable *table, auto arrow_table_raw = arrow_table.get(); auto arrow_indices = garrow_chunked_array_get_raw(indices); auto arrow_indices_raw = arrow_indices.get(); - auto memory_pool = arrow::default_memory_pool(); - arrow::compute::FunctionContext context(memory_pool); - std::shared_ptr taken_table; - arrow::Status status; + arrow::Result> arrow_taken_table; if (options) { auto arrow_options = garrow_take_options_get_raw(options); - status = arrow::compute::Take(&context, - *arrow_table_raw, - *arrow_indices_raw, - *arrow_options, - &taken_table); + arrow_taken_table = arrow::compute::Take(*arrow_table_raw, + *arrow_indices_raw, + *arrow_options); } else { - arrow::compute::TakeOptions arrow_options; - status = arrow::compute::Take(&context, - *arrow_table_raw, - *arrow_indices_raw, - arrow_options, - &taken_table); + arrow_taken_table = arrow::compute::Take(*arrow_table_raw, + *arrow_indices_raw); } - - if (garrow_error_check(error, status, "[table][take][chunked-array]")) { - return garrow_table_new_raw(&taken_table); + if (garrow::check(error, arrow_taken_table, "[table][take][chunked-array]")) { + return garrow_table_new_raw(&(*arrow_taken_table)); } else { return NULL; } @@ -1447,28 +1343,18 @@ garrow_chunked_array_take(GArrowChunkedArray *chunked_array, auto arrow_chunked_array_raw = arrow_chunked_array.get(); auto arrow_indices = garrow_array_get_raw(indices); auto arrow_indices_raw = arrow_indices.get(); - auto memory_pool = arrow::default_memory_pool(); - arrow::compute::FunctionContext context(memory_pool); - std::shared_ptr taken_chunked_array; - arrow::Status status; + arrow::Result> arrow_taken_chunked_array; if (options) { auto arrow_options = garrow_take_options_get_raw(options); - status = arrow::compute::Take(&context, - *arrow_chunked_array_raw, - *arrow_indices_raw, - *arrow_options, - &taken_chunked_array); + arrow_taken_chunked_array = arrow::compute::Take(*arrow_chunked_array_raw, + *arrow_indices_raw, + *arrow_options); } else { - arrow::compute::TakeOptions arrow_options; - status = arrow::compute::Take(&context, - *arrow_chunked_array_raw, - *arrow_indices_raw, - arrow_options, - &taken_chunked_array); + arrow_taken_chunked_array = arrow::compute::Take(*arrow_chunked_array_raw, + *arrow_indices_raw); } - - if (garrow_error_check(error, status, "[chunked-array][take]")) { - return garrow_chunked_array_new_raw(&taken_chunked_array); + if (garrow::check(error, arrow_taken_chunked_array, "[chunked-array][take]")) { + return garrow_chunked_array_new_raw(&(*arrow_taken_chunked_array)); } else { return NULL; } @@ -1496,28 +1382,20 @@ garrow_chunked_array_take_chunked_array(GArrowChunkedArray *chunked_array, auto arrow_chunked_array_raw = arrow_chunked_array.get(); auto arrow_indices = garrow_chunked_array_get_raw(indices); auto arrow_indices_raw = arrow_indices.get(); - auto memory_pool = arrow::default_memory_pool(); - arrow::compute::FunctionContext context(memory_pool); - std::shared_ptr taken_chunked_array; - arrow::Status status; + arrow::Result> arrow_taken_chunked_array; if (options) { auto arrow_options = garrow_take_options_get_raw(options); - status = arrow::compute::Take(&context, - *arrow_chunked_array_raw, - *arrow_indices_raw, - *arrow_options, - &taken_chunked_array); + arrow_taken_chunked_array = arrow::compute::Take(*arrow_chunked_array_raw, + *arrow_indices_raw, + *arrow_options); } else { - arrow::compute::TakeOptions arrow_options; - status = arrow::compute::Take(&context, - *arrow_chunked_array_raw, - *arrow_indices_raw, - arrow_options, - &taken_chunked_array); + arrow_taken_chunked_array = arrow::compute::Take(*arrow_chunked_array_raw, + *arrow_indices_raw); } - - if (garrow_error_check(error, status, "[chunked-array][take][chunked-array]")) { - return garrow_chunked_array_new_raw(&taken_chunked_array); + if (garrow::check(error, + arrow_taken_chunked_array, + "[chunked-array][take][chunked-array]")) { + return garrow_chunked_array_new_raw(&(*arrow_taken_chunked_array)); } else { return NULL; } @@ -1541,33 +1419,23 @@ garrow_record_batch_take(GArrowRecordBatch *record_batch, GArrowTakeOptions *options, GError **error) { - auto arrow_record_batch = - garrow_record_batch_get_raw(record_batch); + auto arrow_record_batch = garrow_record_batch_get_raw(record_batch); auto arrow_record_batch_raw = arrow_record_batch.get(); auto arrow_indices = garrow_array_get_raw(indices); auto arrow_indices_raw = arrow_indices.get(); - auto memory_pool = arrow::default_memory_pool(); - arrow::compute::FunctionContext context(memory_pool); - std::shared_ptr taken_record_batch; - arrow::Status status; + arrow::Result> arrow_taken_record_batch; if (options) { auto arrow_options = garrow_take_options_get_raw(options); - status = arrow::compute::Take(&context, - *arrow_record_batch_raw, - *arrow_indices_raw, - *arrow_options, - &taken_record_batch); + arrow_taken_record_batch = arrow::compute::Take(*arrow_record_batch_raw, + *arrow_indices_raw, + *arrow_options); } else { - arrow::compute::TakeOptions arrow_options; - status = arrow::compute::Take(&context, - *arrow_record_batch_raw, - *arrow_indices_raw, - arrow_options, - &taken_record_batch); + arrow_taken_record_batch = arrow::compute::Take(*arrow_record_batch_raw, + *arrow_indices_raw); } - if (garrow_error_check(error, status, "[record-batch][take]")) { - return garrow_record_batch_new_raw(&taken_record_batch); + if (garrow::check(error, arrow_taken_record_batch, "[record-batch][take]")) { + return garrow_record_batch_new_raw(&(*arrow_taken_record_batch)); } else { return NULL; } @@ -1855,27 +1723,18 @@ garrow_array_filter(GArrowArray *array, { auto arrow_array = garrow_array_get_raw(array); auto arrow_filter = garrow_array_get_raw(GARROW_ARRAY(filter)); - auto memory_pool = arrow::default_memory_pool(); - arrow::compute::FunctionContext context(memory_pool); - arrow::compute::Datum arrow_filtered; - arrow::Status status; + arrow::Result arrow_filtered_datum; if (options) { auto arrow_options = garrow_filter_options_get_raw(options); - status = arrow::compute::Filter(&context, - arrow_array, - arrow_filter, - *arrow_options, - &arrow_filtered); + arrow_filtered_datum = arrow::compute::Filter(arrow_array, + arrow_filter, + *arrow_options); } else { - arrow::compute::FilterOptions arrow_options; - status = arrow::compute::Filter(&context, - arrow_array, - arrow_filter, - arrow_options, - &arrow_filtered); + arrow_filtered_datum = arrow::compute::Filter(arrow_array, + arrow_filter); } - if (garrow_error_check(error, status, "[array][filter]")) { - auto arrow_filtered_array = arrow_filtered.make_array(); + if (garrow::check(error, arrow_filtered_datum, "[array][filter]")) { + auto arrow_filtered_array = (*arrow_filtered_datum).make_array(); return garrow_array_new_raw(&arrow_filtered_array); } else { return NULL; @@ -1900,19 +1759,11 @@ garrow_array_is_in(GArrowArray *left, GError **error) { auto arrow_left = garrow_array_get_raw(left); - auto arrow_left_datum = arrow::compute::Datum(arrow_left); auto arrow_right = garrow_array_get_raw(right); - auto arrow_right_datum = arrow::compute::Datum(arrow_right); - auto memory_pool = arrow::default_memory_pool(); - arrow::compute::FunctionContext context(memory_pool); - arrow::compute::Datum arrow_datum; - auto status = arrow::compute::IsIn(&context, - arrow_left_datum, - arrow_right_datum, - &arrow_datum); - if (garrow_error_check(error, status, "[array][is-in]")) { - auto arrow_array = arrow_datum.make_array(); - return GARROW_BOOLEAN_ARRAY(garrow_array_new_raw(&arrow_array)); + auto arrow_is_in_datum = arrow::compute::IsIn(arrow_left, arrow_right); + if (garrow::check(error, arrow_is_in_datum, "[array][is-in]")) { + auto arrow_is_in_array = (*arrow_is_in_datum).make_array(); + return GARROW_BOOLEAN_ARRAY(garrow_array_new_raw(&arrow_is_in_array)); } else { return NULL; } @@ -1936,19 +1787,13 @@ garrow_array_is_in_chunked_array(GArrowArray *left, GError **error) { auto arrow_left = garrow_array_get_raw(left); - auto arrow_left_datum = arrow::compute::Datum(arrow_left); auto arrow_right = garrow_chunked_array_get_raw(right); - auto arrow_right_datum = arrow::compute::Datum(arrow_right); - auto memory_pool = arrow::default_memory_pool(); - arrow::compute::FunctionContext context(memory_pool); - arrow::compute::Datum arrow_datum; - auto status = arrow::compute::IsIn(&context, - arrow_left_datum, - arrow_right_datum, - &arrow_datum); - if (garrow_error_check(error, status, "[array][is-in][chunked-array]")) { - auto arrow_array = arrow_datum.make_array(); - return GARROW_BOOLEAN_ARRAY(garrow_array_new_raw(&arrow_array)); + auto arrow_is_in_datum = arrow::compute::IsIn(arrow_left, arrow_right); + if (garrow::check(error, + arrow_is_in_datum, + "[array][is-in][chunked-array]")) { + auto arrow_is_in_array = (*arrow_is_in_datum).make_array(); + return GARROW_BOOLEAN_ARRAY(garrow_array_new_raw(&arrow_is_in_array)); } else { return NULL; } @@ -1970,14 +1815,9 @@ garrow_array_sort_to_indices(GArrowArray *array, { auto arrow_array = garrow_array_get_raw(array); auto arrow_array_raw = arrow_array.get(); - auto memory_pool = arrow::default_memory_pool(); - arrow::compute::FunctionContext context(memory_pool); - std::shared_ptr arrow_indices; - auto status = arrow::compute::SortToIndices(&context, - *arrow_array_raw, - &arrow_indices); - if (garrow_error_check(error, status, "[array][sort-to-indices]")) { - return GARROW_UINT64_ARRAY(garrow_array_new_raw(&arrow_indices)); + auto arrow_indices_array = arrow::compute::SortToIndices(*arrow_array_raw); + if (garrow::check(error, arrow_indices_array, "[array][sort-to-indices]")) { + return GARROW_UINT64_ARRAY(garrow_array_new_raw(&(*arrow_indices_array))); } else { return NULL; } @@ -2004,27 +1844,18 @@ garrow_table_filter(GArrowTable *table, { auto arrow_table = garrow_table_get_raw(table); auto arrow_filter = garrow_array_get_raw(GARROW_ARRAY(filter)); - auto memory_pool = arrow::default_memory_pool(); - arrow::compute::FunctionContext context(memory_pool); - arrow::compute::Datum arrow_filtered; - arrow::Status status; + arrow::Result arrow_filtered_datum; if (options) { auto arrow_options = garrow_filter_options_get_raw(options); - status = arrow::compute::Filter(&context, - arrow_table, - arrow_filter, - *arrow_options, - &arrow_filtered); + arrow_filtered_datum = arrow::compute::Filter(arrow_table, + arrow_filter, + *arrow_options); } else { - arrow::compute::FilterOptions arrow_options; - status = arrow::compute::Filter(&context, - arrow_table, - arrow_filter, - arrow_options, - &arrow_filtered); + arrow_filtered_datum = arrow::compute::Filter(arrow_table, + arrow_filter); } - if (garrow_error_check(error, status, "[table][filter]")) { - auto arrow_filtered_table = arrow_filtered.table(); + if (garrow::check(error, arrow_filtered_datum, "[table][filter]")) { + auto arrow_filtered_table = (*arrow_filtered_datum).table(); return garrow_table_new_raw(&arrow_filtered_table); } else { return NULL; @@ -2052,27 +1883,20 @@ garrow_table_filter_chunked_array(GArrowTable *table, { auto arrow_table = garrow_table_get_raw(table); auto arrow_filter = garrow_chunked_array_get_raw(filter); - auto memory_pool = arrow::default_memory_pool(); - arrow::compute::FunctionContext context(memory_pool); - arrow::compute::Datum arrow_filtered; - arrow::Status status; + arrow::Result arrow_filtered_datum; if (options) { auto arrow_options = garrow_filter_options_get_raw(options); - status = arrow::compute::Filter(&context, - arrow_table, - arrow_filter, - *arrow_options, - &arrow_filtered); + arrow_filtered_datum = arrow::compute::Filter(arrow_table, + arrow_filter, + *arrow_options); } else { - arrow::compute::FilterOptions arrow_options; - status = arrow::compute::Filter(&context, - arrow_table, - arrow_filter, - arrow_options, - &arrow_filtered); + arrow_filtered_datum = arrow::compute::Filter(arrow_table, + arrow_filter); } - if (garrow_error_check(error, status, "[table][filter][chunked-array]")) { - auto arrow_filtered_table = arrow_filtered.table(); + if (garrow::check(error, + arrow_filtered_datum, + "[table][filter][chunked-array]")) { + auto arrow_filtered_table = (*arrow_filtered_datum).table(); return garrow_table_new_raw(&arrow_filtered_table); } else { return NULL; @@ -2098,30 +1922,20 @@ garrow_chunked_array_filter(GArrowChunkedArray *chunked_array, GArrowFilterOptions *options, GError **error) { - auto arrow_chunked_array = - garrow_chunked_array_get_raw(chunked_array); + auto arrow_chunked_array = garrow_chunked_array_get_raw(chunked_array); auto arrow_filter = garrow_array_get_raw(GARROW_ARRAY(filter)); - auto memory_pool = arrow::default_memory_pool(); - arrow::compute::FunctionContext context(memory_pool); - arrow::compute::Datum arrow_filtered; - arrow::Status status; + arrow::Result arrow_filtered_datum; if (options) { auto arrow_options = garrow_filter_options_get_raw(options); - status = arrow::compute::Filter(&context, - arrow_chunked_array, - arrow_filter, - *arrow_options, - &arrow_filtered); + arrow_filtered_datum = arrow::compute::Filter(arrow_chunked_array, + arrow_filter, + *arrow_options); } else { - arrow::compute::FilterOptions arrow_options; - status = arrow::compute::Filter(&context, - arrow_chunked_array, - arrow_filter, - arrow_options, - &arrow_filtered); + arrow_filtered_datum = arrow::compute::Filter(arrow_chunked_array, + arrow_filter); } - if (garrow_error_check(error, status, "[chunked-array][filter]")) { - auto arrow_filtered_chunked_array = arrow_filtered.chunked_array(); + if (garrow::check(error, arrow_filtered_datum, "[chunked-array][filter]")) { + auto arrow_filtered_chunked_array = (*arrow_filtered_datum).chunked_array(); return garrow_chunked_array_new_raw(&arrow_filtered_chunked_array); } else { return NULL; @@ -2147,30 +1961,22 @@ garrow_chunked_array_filter_chunked_array(GArrowChunkedArray *chunked_array, GArrowFilterOptions *options, GError **error) { - auto arrow_chunked_array = - garrow_chunked_array_get_raw(chunked_array); + auto arrow_chunked_array = garrow_chunked_array_get_raw(chunked_array); auto arrow_filter = garrow_chunked_array_get_raw(filter); - auto memory_pool = arrow::default_memory_pool(); - arrow::compute::FunctionContext context(memory_pool); - arrow::compute::Datum arrow_filtered; - arrow::Status status; + arrow::Result arrow_filtered_datum; if (options) { auto arrow_options = garrow_filter_options_get_raw(options); - status = arrow::compute::Filter(&context, - arrow_chunked_array, - arrow_filter, - *arrow_options, - &arrow_filtered); + arrow_filtered_datum = arrow::compute::Filter(arrow_chunked_array, + arrow_filter, + *arrow_options); } else { - arrow::compute::FilterOptions arrow_options; - status = arrow::compute::Filter(&context, - arrow_chunked_array, - arrow_filter, - arrow_options, - &arrow_filtered); + arrow_filtered_datum = arrow::compute::Filter(arrow_chunked_array, + arrow_filter); } - if (garrow_error_check(error, status, "[chunked-array][filter][chunked-array]")) { - auto arrow_filtered_chunked_array = arrow_filtered.chunked_array(); + if (garrow::check(error, + arrow_filtered_datum, + "[chunked-array][filter][chunked-array]")) { + auto arrow_filtered_chunked_array = (*arrow_filtered_datum).chunked_array(); return garrow_chunked_array_new_raw(&arrow_filtered_chunked_array); } else { return NULL; @@ -2196,30 +2002,20 @@ garrow_record_batch_filter(GArrowRecordBatch *record_batch, GArrowFilterOptions *options, GError **error) { - auto arrow_record_batch = - garrow_record_batch_get_raw(record_batch); + auto arrow_record_batch = garrow_record_batch_get_raw(record_batch); auto arrow_filter = garrow_array_get_raw(GARROW_ARRAY(filter)); - auto memory_pool = arrow::default_memory_pool(); - arrow::compute::FunctionContext context(memory_pool); - arrow::compute::Datum arrow_filtered; - arrow::Status status; + arrow::Result arrow_filtered_datum; if (options) { auto arrow_options = garrow_filter_options_get_raw(options); - status = arrow::compute::Filter(&context, - arrow_record_batch, - arrow_filter, - *arrow_options, - &arrow_filtered); + arrow_filtered_datum = arrow::compute::Filter(arrow_record_batch, + arrow_filter, + *arrow_options); } else { - arrow::compute::FilterOptions arrow_options; - status = arrow::compute::Filter(&context, - arrow_record_batch, - arrow_filter, - arrow_options, - &arrow_filtered); + arrow_filtered_datum = arrow::compute::Filter(arrow_record_batch, + arrow_filter); } - if (garrow_error_check(error, status, "[record-batch][filter]")) { - auto arrow_filtered_record_batch = arrow_filtered.record_batch(); + if (garrow::check(error, arrow_filtered_datum, "[record-batch][filter]")) { + auto arrow_filtered_record_batch = (*arrow_filtered_datum).record_batch(); return garrow_record_batch_new_raw(&arrow_filtered_record_batch); } else { return NULL; diff --git a/c_glib/arrow-glib/error.cpp b/c_glib/arrow-glib/error.cpp index 211ebefe4b5..b03edefba7d 100644 --- a/c_glib/arrow-glib/error.cpp +++ b/c_glib/arrow-glib/error.cpp @@ -39,8 +39,18 @@ G_BEGIN_DECLS G_DEFINE_QUARK(garrow-error-quark, garrow_error) -static GArrowError -garrow_error_code(const arrow::Status &status) +G_END_DECLS + +gboolean +garrow_error_check(GError **error, + const arrow::Status &status, + const char *context) +{ + return garrow::check(error, status, context); +} + +GArrowError +garrow_error_from_status(const arrow::Status &status) { switch (status.code()) { case arrow::StatusCode::OK: @@ -73,25 +83,34 @@ garrow_error_code(const arrow::Status &status) return GARROW_ERROR_EXECUTION; case arrow::StatusCode::AlreadyExists: return GARROW_ERROR_ALREADY_EXISTS; - default: return GARROW_ERROR_UNKNOWN; } } -G_END_DECLS +arrow::Status +garrow_error_to_status(GError *error, + arrow::StatusCode code, + const char *context) +{ + std::stringstream message; + message << context << ": " << g_quark_to_string(error->domain); + message << "(" << error->code << "): "; + message << error->message; + g_error_free(error); + return arrow::Status(code, message.str()); +} namespace garrow { - gboolean - check(GError **error, - const arrow::Status &status, - const char *context) { + gboolean check(GError **error, + const arrow::Status &status, + const char *context) { if (status.ok()) { return TRUE; } else { g_set_error(error, GARROW_ERROR, - garrow_error_code(status), + garrow_error_from_status(status), "%s: %s", context, status.ToString().c_str()); @@ -99,24 +118,3 @@ namespace garrow { } } } - -gboolean -garrow_error_check(GError **error, - const arrow::Status &status, - const char *context) -{ - return garrow::check(error, status, context); -} - -arrow::Status -garrow_error_to_status(GError *error, - arrow::StatusCode code, - const char *context) -{ - std::stringstream message; - message << context << ": " << g_quark_to_string(error->domain); - message << "(" << error->code << "): "; - message << error->message; - g_error_free(error); - return arrow::Status(code, message.str()); -} diff --git a/c_glib/arrow-glib/error.hpp b/c_glib/arrow-glib/error.hpp index 735c67a60d1..d7ab1515c56 100644 --- a/c_glib/arrow-glib/error.hpp +++ b/c_glib/arrow-glib/error.hpp @@ -23,26 +23,48 @@ #include +gboolean garrow_error_check(GError **error, + const arrow::Status &status, + const char *context); +GArrowError garrow_error_from_status(const arrow::Status &status); +arrow::Status garrow_error_to_status(GError *error, + arrow::StatusCode code, + const char *context); + namespace garrow { gboolean check(GError **error, const arrow::Status &status, const char *context); - template + template gboolean check(GError **error, - const arrow::Result &result, - const char *context) { - if (result.ok()) { + const arrow::Status &status, + CONTEXT_FUNC &&context_func) { + if (status.ok()) { return TRUE; } else { - return check(error, result.status(), context); + std::string context = std::move(context_func()); + g_set_error(error, + GARROW_ERROR, + garrow_error_from_status(status), + "%s: %s", + context.c_str(), + status.ToString().c_str()); + return FALSE; } } -} -gboolean garrow_error_check(GError **error, - const arrow::Status &status, - const char *context); -arrow::Status garrow_error_to_status(GError *error, - arrow::StatusCode code, - const char *context); + template + gboolean check(GError **error, + const arrow::Result &result, + const char *context) { + return check(error, result.status(), context); + } + + template + gboolean check(GError **error, + const arrow::Result &result, + CONTEXT_FUNC &&context_func) { + return check(error, result.status(), context_func); + } +} diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index b06147f2247..30c4c737081 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -121,7 +121,6 @@ endfunction() set(ARROW_SRCS array.cc - builder.cc array/builder_adaptive.cc array/builder_base.cc array/builder_binary.cc @@ -134,8 +133,10 @@ set(ARROW_SRCS array/dict_internal.cc array/diff.cc array/validate.cc + builder.cc buffer.cc compare.cc + datum.cc device.cc extension_type.cc memory_pool.cc @@ -319,24 +320,30 @@ endif() if(ARROW_COMPUTE) list(APPEND ARROW_SRCS - compute/context.cc - compute/kernels/aggregate.cc - compute/kernels/boolean.cc - compute/kernels/cast.cc - compute/kernels/compare.cc - compute/kernels/count.cc - compute/kernels/hash.cc - compute/kernels/filter.cc - compute/kernels/mean.cc - compute/kernels/minmax.cc - compute/kernels/sort_to_indices.cc - compute/kernels/nth_to_indices.cc - compute/kernels/sum.cc - compute/kernels/add.cc - compute/kernels/take.cc - compute/kernels/isin.cc - compute/kernels/match.cc - compute/kernels/util_internal.cc) + compute/api_aggregate.cc + compute/api_scalar.cc + compute/api_vector.cc + compute/cast.cc + compute/exec.cc + compute/function.cc + compute/kernel.cc + compute/registry.cc + compute/kernels/aggregate_basic.cc + compute/kernels/codegen_internal.cc + compute/kernels/scalar_arithmetic.cc + compute/kernels/scalar_boolean.cc + compute/kernels/scalar_cast_boolean.cc + compute/kernels/scalar_cast_internal.cc + compute/kernels/scalar_cast_nested.cc + compute/kernels/scalar_cast_numeric.cc + compute/kernels/scalar_cast_string.cc + compute/kernels/scalar_cast_temporal.cc + compute/kernels/scalar_compare.cc + compute/kernels/scalar_set_lookup.cc + compute/kernels/vector_filter.cc + compute/kernels/vector_hash.cc + compute/kernels/vector_sort.cc + compute/kernels/vector_take.cc) endif() if(ARROW_FILESYSTEM) @@ -524,6 +531,7 @@ endif() add_arrow_test(misc_test SOURCES + datum_test.cc memory_pool_test.cc result_test.cc pretty_print_test.cc diff --git a/cpp/src/arrow/adapters/orc/CMakeLists.txt b/cpp/src/arrow/adapters/orc/CMakeLists.txt index 20501dccf7d..7a3681968fd 100644 --- a/cpp/src/arrow/adapters/orc/CMakeLists.txt +++ b/cpp/src/arrow/adapters/orc/CMakeLists.txt @@ -44,7 +44,7 @@ elseif(NOT MSVC) set(ORC_MIN_TEST_LIBS ${ORC_MIN_TEST_LIBS} pthread ${CMAKE_DL_LIBS}) endif() -set(ORC_STATIC_TEST_LINK_LIBS ${ORC_MIN_TEST_LIBS} ${ARROW_LIBRARIES_FOR_STATIC_TESTS} +set(ORC_STATIC_TEST_LINK_LIBS ${ARROW_LIBRARIES_FOR_STATIC_TESTS} ${ORC_MIN_TEST_LIBS} orc::liborc) add_arrow_test(adapter_test diff --git a/cpp/src/arrow/array/diff_test.cc b/cpp/src/arrow/array/diff_test.cc index 0e9ccc40504..4917d4524d1 100644 --- a/cpp/src/arrow/array/diff_test.cc +++ b/cpp/src/arrow/array/diff_test.cc @@ -33,8 +33,7 @@ #include "arrow/array/diff.h" #include "arrow/buffer.h" #include "arrow/builder.h" -#include "arrow/compute/context.h" -#include "arrow/compute/kernels/filter.h" +#include "arrow/compute/api.h" #include "arrow/status.h" #include "arrow/testing/gtest_common.h" #include "arrow/testing/random.h" @@ -119,20 +118,19 @@ class DiffTest : public ::testing::Test { void BaseAndTargetFromRandomFilter(std::shared_ptr values, double filter_probability) { - compute::Datum out_datum, base_filter, target_filter; + std::shared_ptr base_filter, target_filter; do { base_filter = this->rng_.Boolean(values->length(), filter_probability, 0.0); target_filter = this->rng_.Boolean(values->length(), filter_probability, 0.0); - } while (base_filter.Equals(target_filter)); + } while (base_filter->Equals(target_filter)); - ASSERT_OK(compute::Filter(&ctx_, values, base_filter, {}, &out_datum)); + ASSERT_OK_AND_ASSIGN(Datum out_datum, compute::Filter(values, base_filter)); base_ = out_datum.make_array(); - ASSERT_OK(compute::Filter(&ctx_, values, target_filter, {}, &out_datum)); + ASSERT_OK_AND_ASSIGN(out_datum, compute::Filter(values, target_filter)); target_ = out_datum.make_array(); } - compute::FunctionContext ctx_; random::RandomArrayGenerator rng_; std::shared_ptr edits_; std::shared_ptr base_, target_; @@ -616,7 +614,6 @@ void MakeSameLength(std::shared_ptr* a, std::shared_ptr* b) { } TEST_F(DiffTest, CompareRandomStruct) { - compute::FunctionContext ctx; for (auto null_probability : {0.0, 0.25}) { constexpr auto length = 1 << 10; auto int32_values = this->rng_.Int32(length, 0, 127, null_probability); diff --git a/cpp/src/arrow/compute/CMakeLists.txt b/cpp/src/arrow/compute/CMakeLists.txt index 495a4a3f944..8ee87047a3d 100644 --- a/cpp/src/arrow/compute/CMakeLists.txt +++ b/cpp/src/arrow/compute/CMakeLists.txt @@ -58,7 +58,12 @@ function(ADD_ARROW_COMPUTE_TEST REL_TEST_NAME) ${ARG_UNPARSED_ARGUMENTS}) endfunction() -add_arrow_compute_test(compute_test) -add_arrow_benchmark(compute_benchmark) +add_arrow_compute_test(internals_test + SOURCES + function_test.cc + exec_test.cc + kernel_test.cc + registry_test.cc) +add_arrow_compute_test(exec_test) add_subdirectory(kernels) diff --git a/cpp/src/arrow/compute/README.md b/cpp/src/arrow/compute/README.md new file mode 100644 index 00000000000..80d8918e3d9 --- /dev/null +++ b/cpp/src/arrow/compute/README.md @@ -0,0 +1,58 @@ + + +## Apache Arrow C++ Compute Functions + +This submodule contains analytical functions that process primarily Arrow +columnar data; some functions can process scalar or Arrow-based array +inputs. These are intended for use inside query engines, data frame libraries, +etc. + +Many functions have SQL-like semantics in that they perform elementwise or +scalar operations on whole arrays at a time. Other functions are not SQL-like +and compute results that may be a different length or whose results depend on +the order of the values. + +Some basic terminology: + +* We use the term "function" to refer to particular general operation that may + have many different implementations corresponding to different combinations + of types or function behavior options. +* We call a specific implementation of a function a "kernel". When executing a + function on inputs, we must first select a suitable kernel (kernel selection + is called "dispatching") corresponding to the value types of the inputs +* Functions along with their kernel implementations are collected in a + "function registry". Given a function name and argument types, we can look up + that function and dispatch to a compatible kernel. + +Types of functions + +* Scalar functions: elementwise functions that perform scalar operations in a + vectorized manner. These functions are generally valid for SQL-like + context. These are called "scalar" in that the functions executed consider + each value in an array independently, and the output array or arrays have the + same length as the input arrays. The result for each array cell is generally + independent of its position in the array. +* Vector functions, which produce a result whose output is generally dependent + on the entire contents of the input arrays. These functions **are generally + not valid** for SQL-like processing because the output size may be different + than the input size, and the result may change based on the order of the + values in the array. This includes things like array subselection, sorting, + hashing, and more. +* Scalar aggregate functions of which can be used in a SQL-like context \ No newline at end of file diff --git a/cpp/src/arrow/compute/api.h b/cpp/src/arrow/compute/api.h index 8e60a39a0fd..3fc6e22b4be 100644 --- a/cpp/src/arrow/compute/api.h +++ b/cpp/src/arrow/compute/api.h @@ -15,20 +15,17 @@ // specific language governing permissions and limitations // under the License. -#pragma once +// NOTE: API is EXPERIMENTAL and will change without going through a +// deprecation cycle -#include "arrow/compute/context.h" // IWYU pragma: export -#include "arrow/compute/kernel.h" // IWYU pragma: export +#pragma once -#include "arrow/compute/kernels/boolean.h" // IWYU pragma: export -#include "arrow/compute/kernels/cast.h" // IWYU pragma: export -#include "arrow/compute/kernels/compare.h" // IWYU pragma: export -#include "arrow/compute/kernels/count.h" // IWYU pragma: export -#include "arrow/compute/kernels/filter.h" // IWYU pragma: export -#include "arrow/compute/kernels/hash.h" // IWYU pragma: export -#include "arrow/compute/kernels/isin.h" // IWYU pragma: export -#include "arrow/compute/kernels/mean.h" // IWYU pragma: export -#include "arrow/compute/kernels/nth_to_indices.h" // IWYU pragma: export -#include "arrow/compute/kernels/sort_to_indices.h" // IWYU pragma: export -#include "arrow/compute/kernels/sum.h" // IWYU pragma: export -#include "arrow/compute/kernels/take.h" // IWYU pragma: export +#include "arrow/compute/api_aggregate.h" // IWYU pragma: export +#include "arrow/compute/api_scalar.h" // IWYU pragma: export +#include "arrow/compute/api_vector.h" // IWYU pragma: export +#include "arrow/compute/cast.h" // IWYU pragma: export +#include "arrow/compute/exec.h" // IWYU pragma: export +#include "arrow/compute/function.h" // IWYU pragma: export +#include "arrow/compute/kernel.h" // IWYU pragma: export +#include "arrow/compute/registry.h" // IWYU pragma: export +#include "arrow/datum.h" // IWYU pragma: export diff --git a/cpp/src/arrow/compute/api_aggregate.cc b/cpp/src/arrow/compute/api_aggregate.cc new file mode 100644 index 00000000000..0a41e4c11f0 --- /dev/null +++ b/cpp/src/arrow/compute/api_aggregate.cc @@ -0,0 +1,45 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/compute/api_aggregate.h" + +#include "arrow/compute/exec.h" + +namespace arrow { +namespace compute { + +// ---------------------------------------------------------------------- +// Scalar aggregates + +Result Count(const Datum& value, CountOptions options, ExecContext* ctx) { + return CallFunction(ctx, "count", {value}, &options); +} + +Result Mean(const Datum& value, ExecContext* ctx) { + return CallFunction(ctx, "mean", {value}); +} + +Result Sum(const Datum& value, ExecContext* ctx) { + return CallFunction(ctx, "sum", {value}); +} + +Result MinMax(const Datum& value, const MinMaxOptions& options, ExecContext* ctx) { + return CallFunction(ctx, "minmax", {value}, &options); +} + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/api_aggregate.h b/cpp/src/arrow/compute/api_aggregate.h new file mode 100644 index 00000000000..82a4ebf76b6 --- /dev/null +++ b/cpp/src/arrow/compute/api_aggregate.h @@ -0,0 +1,148 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Eager evaluation convenience APIs for invoking common functions, including +// necessary memory allocations + +#pragma once + +#include "arrow/compute/function.h" +#include "arrow/datum.h" +#include "arrow/result.h" +#include "arrow/util/macros.h" +#include "arrow/util/visibility.h" + +namespace arrow { + +class Array; + +namespace compute { + +class ExecContext; + +// ---------------------------------------------------------------------- +// Aggregate functions + +/// \class CountOptions +/// +/// The user control the Count kernel behavior with this class. By default, the +/// it will count all non-null values. +struct ARROW_EXPORT CountOptions : public FunctionOptions { + enum mode { + // Count all non-null values. + COUNT_ALL = 0, + // Count all null values. + COUNT_NULL, + }; + + explicit CountOptions(enum mode count_mode) : count_mode(count_mode) {} + + static CountOptions Defaults() { return CountOptions(COUNT_ALL); } + + enum mode count_mode = COUNT_ALL; +}; + +/// \brief Count non-null (or null) values in an array. +/// +/// \param[in] options counting options, see CountOptions for more information +/// \param[in] datum to count +/// \param[in] ctx the function execution context, optional +/// \return out resulting datum +/// +/// \since 1.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result Count(const Datum& datum, CountOptions options = CountOptions::Defaults(), + ExecContext* ctx = NULLPTR); + +/// \brief Compute the mean of a numeric array. +/// +/// \param[in] value datum to compute the mean, expecting Array +/// \param[in] ctx the function execution context, optional +/// \return datum of the computed mean as a DoubleScalar +/// +/// \since 1.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result Mean(const Datum& value, ExecContext* ctx = NULLPTR); + +/// \brief Sum values of a numeric array. +/// +/// \param[in] value datum to sum, expecting Array or ChunkedArray +/// \param[in] ctx the function execution context, optional +/// \return datum of the computed sum as a Scalar +/// +/// \since 1.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result Sum(const Datum& value, ExecContext* ctx = NULLPTR); + +/// \class MinMaxOptions +/// +/// The user can control the MinMax kernel behavior with this class. By default, +/// it will skip null if there is a null value present. +struct ARROW_EXPORT MinMaxOptions : public FunctionOptions { + enum mode { + /// skip null values + SKIP = 0, + /// any nulls will result in null output + OUTPUT_NULL + }; + + explicit MinMaxOptions(enum mode null_handling = SKIP) : null_handling(null_handling) {} + + static MinMaxOptions Defaults() { return MinMaxOptions{}; } + + enum mode null_handling = SKIP; +}; + +/// \brief Calculate the min / max of a numeric array +/// +/// This function returns both the min and max as a struct scalar, with type +/// struct, where T is ht einput type +/// +/// \param[in] value input datum, expecting Array or ChunkedArray +/// \param[in] options see MinMaxOptions for more information +/// \param[in] ctx the function execution context, optional +/// \return resulting datum as a struct scalar +/// +/// \since 1.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result MinMax(const Datum& value, + const MinMaxOptions& options = MinMaxOptions::Defaults(), + ExecContext* ctx = NULLPTR); + +/// \brief Calculate the min / max of a numeric array. +/// +/// This function returns both the min and max as a collection. The resulting +/// datum thus consists of two scalar datums: {Datum(min), Datum(max)} +/// +/// \param[in] array input array +/// \param[in] options see MinMaxOptions for more information +/// \param[in] ctx the function execution context, optional +/// \return resulting datum containing a {min, max} collection +/// +/// \since 1.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result MinMax(const Array& array, + const MinMaxOptions& options = MinMaxOptions::Defaults(), + ExecContext* ctx = NULLPTR); + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/api_scalar.cc b/cpp/src/arrow/compute/api_scalar.cc new file mode 100644 index 00000000000..07064395b68 --- /dev/null +++ b/cpp/src/arrow/compute/api_scalar.cc @@ -0,0 +1,113 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/compute/api_scalar.h" + +#include +#include +#include + +#include "arrow/compute/exec.h" + +namespace arrow { +namespace compute { + +#define SCALAR_EAGER_UNARY(NAME, REGISTRY_NAME) \ + Result NAME(const Datum& value, ExecContext* ctx) { \ + return CallFunction(ctx, REGISTRY_NAME, {value}); \ + } + +#define SCALAR_EAGER_BINARY(NAME, REGISTRY_NAME) \ + Result NAME(const Datum& left, const Datum& right, ExecContext* ctx) { \ + return CallFunction(ctx, REGISTRY_NAME, {left, right}); \ + } + +// ---------------------------------------------------------------------- +// Arithmetic + +SCALAR_EAGER_BINARY(Add, "add") + +// ---------------------------------------------------------------------- +// Set-related operations + +static Result ExecSetLookup(const std::string& func_name, const Datum& data, + const Datum& value_set, bool add_nulls_to_hash_table, + ExecContext* ctx) { + if (!value_set.is_arraylike()) { + return Status::Invalid("Set lookup value set must be Array or ChunkedArray"); + } + + if (value_set.length() > 0 && !data.type()->Equals(value_set.type())) { + std::stringstream ss; + ss << "Array type didn't match type of values set: " << data.type()->ToString() + << " vs " << value_set.type()->ToString(); + return Status::Invalid(ss.str()); + } + SetLookupOptions options(value_set, !add_nulls_to_hash_table); + return CallFunction(ctx, func_name, {data}, &options); +} + +Result IsIn(const Datum& values, const Datum& value_set, ExecContext* ctx) { + return ExecSetLookup("isin", values, value_set, + /*add_nulls_to_hash_table=*/false, ctx); +} + +Result Match(const Datum& values, const Datum& value_set, ExecContext* ctx) { + return ExecSetLookup("match", values, value_set, + /*add_nulls_to_hash_table=*/true, ctx); +} + +// ---------------------------------------------------------------------- +// Boolean functions + +SCALAR_EAGER_UNARY(Invert, "invert") +SCALAR_EAGER_BINARY(And, "and") +SCALAR_EAGER_BINARY(KleeneAnd, "and_kleene") +SCALAR_EAGER_BINARY(Or, "or") +SCALAR_EAGER_BINARY(KleeneOr, "or_kleene") +SCALAR_EAGER_BINARY(Xor, "xor") + +// ---------------------------------------------------------------------- + +Result Compare(const Datum& left, const Datum& right, CompareOptions options, + ExecContext* ctx) { + std::string func_name; + switch (options.op) { + case CompareOperator::EQUAL: + func_name = "equal"; + break; + case CompareOperator::NOT_EQUAL: + func_name = "not_equal"; + break; + case CompareOperator::GREATER: + func_name = "greater"; + break; + case CompareOperator::GREATER_EQUAL: + func_name = "greater_equal"; + break; + case CompareOperator::LESS: + func_name = "less"; + break; + case CompareOperator::LESS_EQUAL: + func_name = "less_equal"; + break; + } + return CallFunction(ctx, func_name, {left, right}, &options); +} + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/api_scalar.h b/cpp/src/arrow/compute/api_scalar.h new file mode 100644 index 00000000000..e001a74a067 --- /dev/null +++ b/cpp/src/arrow/compute/api_scalar.h @@ -0,0 +1,208 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Eager evaluation convenience APIs for invoking common functions, including +// necessary memory allocations + +#pragma once + +#include +#include + +#include "arrow/compute/exec.h" // IWYU pragma: keep +#include "arrow/compute/function.h" +#include "arrow/datum.h" +#include "arrow/result.h" +#include "arrow/util/macros.h" +#include "arrow/util/visibility.h" + +namespace arrow { + +class Array; + +namespace compute { + +// ---------------------------------------------------------------------- + +/// \brief Add two values together. Array values must be the same length. If a +/// value is null in either addend, the result is null +/// +/// \param[in] left the first value +/// \param[in] right the second value +/// \param[in] ctx the function execution context, optional +/// \return the elementwise addition of the values +ARROW_EXPORT +Result Add(const Datum& left, const Datum& right, ExecContext* ctx = NULLPTR); + +enum CompareOperator { + EQUAL, + NOT_EQUAL, + GREATER, + GREATER_EQUAL, + LESS, + LESS_EQUAL, +}; + +struct CompareOptions : public FunctionOptions { + explicit CompareOptions(CompareOperator op) : op(op) {} + + enum CompareOperator op; +}; + +/// \brief Compare a numeric array with a scalar. +/// +/// \param[in] left datum to compare, must be an Array +/// \param[in] right datum to compare, must be a Scalar of the same type than +/// left Datum. +/// \param[in] options compare options +/// \param[in] ctx the function execution context, optional +/// \return resulting datum +/// +/// Note on floating point arrays, this uses ieee-754 compare semantics. +/// +/// \since 1.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result Compare(const Datum& left, const Datum& right, + struct CompareOptions options, ExecContext* ctx = NULLPTR); + +/// \brief Invert the values of a boolean datum +/// \param[in] value datum to invert +/// \param[in] ctx the function execution context, optional +/// \return the resulting datum +/// +/// \since 1.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result Invert(const Datum& value, ExecContext* ctx = NULLPTR); + +/// \brief Element-wise AND of two boolean datums which always propagates nulls +/// (null and false is null). +/// +/// \param[in] left left operand (array) +/// \param[in] right right operand (array) +/// \param[in] ctx the function execution context, optional +/// \return the resulting datum +/// +/// \since 1.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result And(const Datum& left, const Datum& right, ExecContext* ctx = NULLPTR); + +/// \brief Element-wise AND of two boolean datums with a Kleene truth table +/// (null and false is false). +/// +/// \param[in] left left operand (array) +/// \param[in] right right operand (array) +/// \param[in] ctx the function execution context, optional +/// \return the resulting datum +/// +/// \since 1.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result KleeneAnd(const Datum& left, const Datum& right, + ExecContext* ctx = NULLPTR); + +/// \brief Element-wise OR of two boolean datums which always propagates nulls +/// (null and true is null). +/// +/// \param[in] left left operand (array) +/// \param[in] right right operand (array) +/// \param[in] ctx the function execution context, optional +/// \return the resulting datum +/// +/// \since 1.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result Or(const Datum& left, const Datum& right, ExecContext* ctx = NULLPTR); + +/// \brief Element-wise OR of two boolean datums with a Kleene truth table +/// (null or true is true). +/// +/// \param[in] left left operand (array) +/// \param[in] right right operand (array) +/// \param[in] ctx the function execution context, optional +/// \return the resulting datum +/// +/// \since 1.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result KleeneOr(const Datum& left, const Datum& right, ExecContext* ctx = NULLPTR); + +/// \brief Element-wise XOR of two boolean datums +/// \param[in] left left operand (array) +/// \param[in] right right operand (array) +/// \param[in] ctx the function execution context, optional +/// \return the resulting datum +/// +/// \since 1.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result Xor(const Datum& left, const Datum& right, ExecContext* ctx = NULLPTR); + +/// For set lookup operations like IsIn, Match +struct ARROW_EXPORT SetLookupOptions : public FunctionOptions { + explicit SetLookupOptions(Datum value_set, bool skip_nulls) + : value_set(std::move(value_set)), skip_nulls(skip_nulls) {} + + Datum value_set; + bool skip_nulls; +}; + +/// \brief IsIn returns true for each element of `values` that is contained in +/// `value_set` +/// +/// If null occurs in left, if null count in right is not 0, +/// it returns true, else returns null. +/// +/// \param[in] values array-like input to look up in value_set +/// \param[in] value_set either Array or ChunkedArray +/// \param[in] ctx the function execution context, optional +/// \return the resulting datum +/// +/// \since 1.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result IsIn(const Datum& values, const Datum& value_set, + ExecContext* ctx = NULLPTR); + +/// \brief Match examines each slot in the values against a value_set array. +/// If the value is not found in value_set, null will be output. +/// If found, the index of occurrence within value_set (ignoring duplicates) +/// will be output. +/// +/// For example given values = [99, 42, 3, null] and +/// value_set = [3, 3, 99], the output will be = [1, null, 0, null] +/// +/// Note: Null in the values is considered to match +/// a null in the value_set array. For example given +/// values = [99, 42, 3, null] and value_set = [3, 99, null], +/// the output will be = [1, null, 0, 2] +/// +/// \param[in] values array-like input +/// \param[in] value_set either Array or ChunkedArray +/// \param[in] ctx the function execution context, optional +/// \return the resulting datum +/// +/// \since 1.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result Match(const Datum& values, const Datum& value_set, + ExecContext* ctx = NULLPTR); + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/api_vector.cc b/cpp/src/arrow/compute/api_vector.cc new file mode 100644 index 00000000000..6b28b02fa21 --- /dev/null +++ b/cpp/src/arrow/compute/api_vector.cc @@ -0,0 +1,226 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/compute/api_vector.h" + +#include +#include +#include + +#include "arrow/array/concatenate.h" +#include "arrow/compute/exec.h" +#include "arrow/compute/kernels/vector_selection_internal.h" +#include "arrow/datum.h" +#include "arrow/result.h" + +namespace arrow { +namespace compute { + +// ---------------------------------------------------------------------- +// Direct exec interface to kernels + +Result> NthToIndices(const Array& values, int64_t n, + ExecContext* ctx) { + PartitionOptions options(/*pivot=*/n); + ARROW_ASSIGN_OR_RAISE( + Datum result, CallFunction(ctx, "partition_indices", {Datum(values)}, &options)); + return result.make_array(); +} + +Result> SortToIndices(const Array& values, ExecContext* ctx) { + ARROW_ASSIGN_OR_RAISE(Datum result, CallFunction(ctx, "sort_indices", {Datum(values)})); + return result.make_array(); +} + +Result Take(const Datum& values, const Datum& indices, const TakeOptions& options, + ExecContext* ctx) { + return CallFunction(ctx, "take", {values, indices}, &options); +} + +Result> Unique(const Datum& value, ExecContext* ctx) { + ARROW_ASSIGN_OR_RAISE(Datum result, CallFunction(ctx, "unique", {value})); + return result.make_array(); +} + +Result DictionaryEncode(const Datum& value, ExecContext* ctx) { + return CallFunction(ctx, "dictionary_encode", {value}); +} + +const char kValuesFieldName[] = "values"; +const char kCountsFieldName[] = "counts"; +const int32_t kValuesFieldIndex = 0; +const int32_t kCountsFieldIndex = 1; + +Result> ValueCounts(const Datum& value, ExecContext* ctx) { + ARROW_ASSIGN_OR_RAISE(Datum result, CallFunction(ctx, "value_counts", {value})); + return result.make_array(); +} + +// ---------------------------------------------------------------------- +// Filter with conveniences to filter RecordBatch, Table + +Result> FilterRecordBatch(const RecordBatch& batch, + const Datum& filter, + FilterOptions options, + ExecContext* ctx) { + if (!filter.is_array()) { + return Status::Invalid("Cannot filter a RecordBatch with a filter of kind ", + filter.kind()); + } + + // TODO: Rewrite this to convert to selection vector and use Take + std::vector> columns(batch.num_columns()); + for (int i = 0; i < batch.num_columns(); ++i) { + ARROW_ASSIGN_OR_RAISE(Datum out, + Filter(batch.column(i)->data(), filter, options, ctx)); + columns[i] = out.make_array(); + } + + int64_t out_length; + if (columns.size() == 0) { + out_length = + internal::FilterOutputSize(options.null_selection_behavior, *filter.make_array()); + } else { + out_length = columns[0]->length(); + } + return RecordBatch::Make(batch.schema(), out_length, columns); +} + +Result> FilterTable(const Table& table, const Datum& filter, + FilterOptions options, ExecContext* ctx) { + auto new_columns = table.columns(); + for (auto& column : new_columns) { + ARROW_ASSIGN_OR_RAISE(Datum out_column, Filter(column, filter, options, ctx)); + column = out_column.chunked_array(); + } + return Table::Make(table.schema(), std::move(new_columns)); +} + +Result Filter(const Datum& values, const Datum& filter, FilterOptions options, + ExecContext* ctx) { + if (values.kind() == Datum::RECORD_BATCH) { + auto values_batch = values.record_batch(); + ARROW_ASSIGN_OR_RAISE(std::shared_ptr out_batch, + FilterRecordBatch(*values_batch, filter, options, ctx)); + return Datum(out_batch); + } else if (values.kind() == Datum::TABLE) { + ARROW_ASSIGN_OR_RAISE(std::shared_ptr out_table, + FilterTable(*values.table(), filter, options, ctx)); + return Datum(out_table); + } else { + return CallFunction(ctx, "filter", {values, filter}, &options); + } +} + +// ---------------------------------------------------------------------- +// Take invocation conveniences + +Result> Take(const Array& values, const Array& indices, + const TakeOptions& options, ExecContext* ctx) { + ARROW_ASSIGN_OR_RAISE(Datum out_datum, + Take(Datum(values.data()), Datum(indices.data()), options, ctx)); + return out_datum.make_array(); +} + +Result> Take(const ChunkedArray& values, + const Array& indices, + const TakeOptions& options, ExecContext* ctx) { + auto num_chunks = values.num_chunks(); + std::vector> new_chunks(1); // Hard-coded 1 for now + std::shared_ptr current_chunk; + + // Case 1: `values` has a single chunk, so just use it + if (num_chunks == 1) { + current_chunk = values.chunk(0); + } else { + // TODO Case 2: See if all `indices` fall in the same chunk and call Array Take on it + // See + // https://github.com/apache/arrow/blob/6f2c9041137001f7a9212f244b51bc004efc29af/r/src/compute.cpp#L123-L151 + // TODO Case 3: If indices are sorted, can slice them and call Array Take + + // Case 4: Else, concatenate chunks and call Array Take + RETURN_NOT_OK(Concatenate(values.chunks(), default_memory_pool(), ¤t_chunk)); + } + // Call Array Take on our single chunk + ARROW_ASSIGN_OR_RAISE(new_chunks[0], Take(*current_chunk, indices, options, ctx)); + return std::make_shared(std::move(new_chunks)); +} + +Result> Take(const ChunkedArray& values, + const ChunkedArray& indices, + const TakeOptions& options, ExecContext* ctx) { + auto num_chunks = indices.num_chunks(); + std::vector> new_chunks(num_chunks); + for (int i = 0; i < num_chunks; i++) { + // Take with that indices chunk + // Note that as currently implemented, this is inefficient because `values` + // will get concatenated on every iteration of this loop + ARROW_ASSIGN_OR_RAISE(std::shared_ptr current_chunk, + Take(values, *indices.chunk(i), options, ctx)); + // Concatenate the result to make a single array for this chunk + RETURN_NOT_OK( + Concatenate(current_chunk->chunks(), default_memory_pool(), &new_chunks[i])); + } + return std::make_shared(std::move(new_chunks)); +} + +Result> Take(const Array& values, + const ChunkedArray& indices, + const TakeOptions& options, ExecContext* ctx) { + auto num_chunks = indices.num_chunks(); + std::vector> new_chunks(num_chunks); + for (int i = 0; i < num_chunks; i++) { + // Take with that indices chunk + ARROW_ASSIGN_OR_RAISE(new_chunks[i], Take(values, *indices.chunk(i), options, ctx)); + } + return std::make_shared(std::move(new_chunks)); +} + +Result> Take(const RecordBatch& batch, const Array& indices, + const TakeOptions& options, ExecContext* ctx) { + auto ncols = batch.num_columns(); + auto nrows = indices.length(); + std::vector> columns(ncols); + for (int j = 0; j < ncols; j++) { + ARROW_ASSIGN_OR_RAISE(columns[j], Take(*batch.column(j), indices, options, ctx)); + } + return RecordBatch::Make(batch.schema(), nrows, columns); +} + +Result> Take(const Table& table, const Array& indices, + const TakeOptions& options, ExecContext* ctx) { + auto ncols = table.num_columns(); + std::vector> columns(ncols); + + for (int j = 0; j < ncols; j++) { + ARROW_ASSIGN_OR_RAISE(columns[j], Take(*table.column(j), indices, options, ctx)); + } + return Table::Make(table.schema(), columns); +} + +Result> Take(const Table& table, const ChunkedArray& indices, + const TakeOptions& options, ExecContext* ctx) { + auto ncols = table.num_columns(); + std::vector> columns(ncols); + for (int j = 0; j < ncols; j++) { + ARROW_ASSIGN_OR_RAISE(columns[j], Take(*table.column(j), indices, options, ctx)); + } + return Table::Make(table.schema(), columns); +} + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/api_vector.h b/cpp/src/arrow/compute/api_vector.h new file mode 100644 index 00000000000..9e8ffacaf0e --- /dev/null +++ b/cpp/src/arrow/compute/api_vector.h @@ -0,0 +1,308 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include + +#include "arrow/compute/function.h" +#include "arrow/datum.h" +#include "arrow/result.h" + +namespace arrow { +namespace compute { + +class ExecContext; + +struct FilterOptions : public FunctionOptions { + /// Configure the action taken when a slot of the selection mask is null + enum NullSelectionBehavior { + /// the corresponding filtered value will be removed in the output + DROP, + /// the corresponding filtered value will be null in the output + EMIT_NULL, + }; + + static FilterOptions Defaults() { return FilterOptions{}; } + + NullSelectionBehavior null_selection_behavior = DROP; +}; + +/// \brief Filter with a boolean selection filter +/// +/// The output will be populated with values from the input at positions +/// where the selection filter is not 0. Nulls in the filter will be handled +/// based on options.null_selection_behavior. +/// +/// For example given values = ["a", "b", "c", null, "e", "f"] and +/// filter = [0, 1, 1, 0, null, 1], the output will be +/// (null_selection_behavior == DROP) = ["b", "c", "f"] +/// (null_selection_behavior == EMIT_NULL) = ["b", "c", null, "f"] +/// +/// \param[in] values array to filter +/// \param[in] filter indicates which values should be filtered out +/// \param[in] options configures null_selection_behavior +/// \param[in] context the function execution context, optional +/// \return the resulting datum +ARROW_EXPORT +Result Filter(const Datum& values, const Datum& filter, + FilterOptions options = FilterOptions::Defaults(), + ExecContext* context = NULLPTR); + +struct ARROW_EXPORT TakeOptions : public FunctionOptions { + static TakeOptions Defaults() { return TakeOptions{}; } +}; + +/// \brief Take from an array of values at indices in another array +/// +/// \param[in] values datum from which to take +/// \param[in] indices which values to take +/// \param[in] options options +/// \param[in] context the function execution context, optional +/// \return the resulting datum +ARROW_EXPORT +Result Take(const Datum& values, const Datum& indices, + const TakeOptions& options = TakeOptions::Defaults(), + ExecContext* context = NULLPTR); + +/// \brief Take from an array of values at indices in another array +/// +/// The output array will be of the same type as the input values +/// array, with elements taken from the values array at the given +/// indices. If an index is null then the taken element will be null. +/// +/// For example given values = ["a", "b", "c", null, "e", "f"] and +/// indices = [2, 1, null, 3], the output will be +/// = [values[2], values[1], null, values[3]] +/// = ["c", "b", null, null] +/// +/// \param[in] values array from which to take +/// \param[in] indices which values to take +/// \param[in] options options +/// \param[in] context the function execution context, optional +/// \return the resulting array +ARROW_EXPORT +Result> Take(const Array& values, const Array& indices, + const TakeOptions& options = TakeOptions::Defaults(), + ExecContext* context = NULLPTR); + +/// \brief Take from a chunked array of values at indices in another array +/// +/// The output chunked array will be of the same type as the input values +/// array, with elements taken from the values array at the given +/// indices. If an index is null then the taken element will be null. +/// +/// For example given values = ["a", "b", "c", null, "e", "f"] and +/// indices = [2, 1, null, 3], the output will be +/// = [values[2], values[1], null, values[3]] +/// = ["c", "b", null, null] +/// +/// \param[in] values chunked array from which to take +/// \param[in] indices which values to take +/// \param[in] options options +/// \param[in] context the function execution context, optional +/// \return the resulting chunked array +/// NOTE: Experimental API +ARROW_EXPORT +Result> Take( + const ChunkedArray& values, const Array& indices, + const TakeOptions& options = TakeOptions::Defaults(), ExecContext* context = NULLPTR); + +/// \brief Take from a chunked array of values at indices in a chunked array +/// +/// The output chunked array will be of the same type as the input values +/// array, with elements taken from the values array at the given +/// indices. If an index is null then the taken element will be null. +/// The chunks in the output array will align with the chunks in the indices. +/// +/// For example given values = ["a", "b", "c", null, "e", "f"] and +/// indices = [2, 1, null, 3], the output will be +/// = [values[2], values[1], null, values[3]] +/// = ["c", "b", null, null] +/// +/// \param[in] values chunked array from which to take +/// \param[in] indices which values to take +/// \param[in] options options +/// \param[in] context the function execution context, optional +/// \return the resulting chunked array +/// NOTE: Experimental API +ARROW_EXPORT +Result> Take( + const ChunkedArray& values, const ChunkedArray& indices, + const TakeOptions& options = TakeOptions::Defaults(), ExecContext* context = NULLPTR); + +/// \brief Take from an array of values at indices in a chunked array +/// +/// The output chunked array will be of the same type as the input values +/// array, with elements taken from the values array at the given +/// indices. If an index is null then the taken element will be null. +/// The chunks in the output array will align with the chunks in the indices. +/// +/// For example given values = ["a", "b", "c", null, "e", "f"] and +/// indices = [2, 1, null, 3], the output will be +/// = [values[2], values[1], null, values[3]] +/// = ["c", "b", null, null] +/// +/// \param[in] values array from which to take +/// \param[in] indices which values to take +/// \param[in] options options +/// \param[in] context the function execution context, optional +/// \return the resulting chunked array +/// NOTE: Experimental API +ARROW_EXPORT +Result> Take( + const Array& values, const ChunkedArray& indices, + const TakeOptions& options = TakeOptions::Defaults(), ExecContext* context = NULLPTR); + +/// \brief Take from a record batch at indices in another array +/// +/// The output batch will have the same schema as the input batch, +/// with rows taken from the columns in the batch at the given +/// indices. If an index is null then the taken element will be null. +/// +/// \param[in] batch record batch from which to take +/// \param[in] indices which values to take +/// \param[in] options options +/// \param[in] context the function execution context, optional +/// \return the resulting record batch +/// NOTE: Experimental API +ARROW_EXPORT +Result> Take( + const RecordBatch& batch, const Array& indices, + const TakeOptions& options = TakeOptions::Defaults(), ExecContext* context = NULLPTR); + +/// \brief Take from a table at indices in an array +/// +/// The output table will have the same schema as the input table, +/// with rows taken from the columns in the table at the given +/// indices. If an index is null then the taken element will be null. +/// +/// \param[in] table table from which to take +/// \param[in] indices which values to take +/// \param[in] options options +/// \param[in] context the function execution context, optional +/// \return the resulting table +/// NOTE: Experimental API +ARROW_EXPORT +Result> Take(const Table& table, const Array& indices, + const TakeOptions& options = TakeOptions::Defaults(), + ExecContext* context = NULLPTR); + +/// \brief Take from a table at indices in a chunked array +/// +/// The output table will have the same schema as the input table, +/// with rows taken from the values array at the given +/// indices. If an index is null then the taken element will be null. +/// +/// \param[in] table table from which to take +/// \param[in] indices which values to take +/// \param[in] options options +/// \param[in] context the function execution context, optional +/// \return the resulting table +/// NOTE: Experimental API +ARROW_EXPORT +Result> Take(const Table& table, const ChunkedArray& indices, + const TakeOptions& options = TakeOptions::Defaults(), + ExecContext* context = NULLPTR); + +struct PartitionOptions : public FunctionOptions { + explicit PartitionOptions(int64_t pivot) : pivot(pivot) {} + int64_t pivot; +}; + +/// \brief Returns indices that partition an array around n-th +/// sorted element. +/// +/// Find index of n-th(0 based) smallest value and perform indirect +/// partition of an array around that element. Output indices[0 ~ n-1] +/// holds values no greater than n-th element, and indices[n+1 ~ end] +/// holds values no less than n-th element. Elements in each partition +/// is not sorted. Nulls will be partitioned to the end of the output. +/// Output is not guaranteed to be stable. +/// +/// \param[in] values array to be partitioned +/// \param[in] n pivot array around sorted n-th element +/// \param[in] ctx the function execution context, optional +/// \return offsets indices that would partition an array +ARROW_EXPORT +Result> NthToIndices(const Array& values, int64_t n, + ExecContext* ctx = NULLPTR); + +/// \brief Returns the indices that would sort an array. +/// +/// Perform an indirect sort of array. The output array will contain +/// indices that would sort an array, which would be the same length +/// as input. Nulls will be stably partitioned to the end of the output. +/// +/// For example given values = [null, 1, 3.3, null, 2, 5.3], the output +/// will be [1, 4, 2, 5, 0, 3] +/// +/// \param[in] values array to sort +/// \param[in] ctx the function execution context, optional +/// \return offsets indices that would sort an array +ARROW_EXPORT +Result> SortToIndices(const Array& values, + ExecContext* ctx = NULLPTR); + +/// \brief Compute unique elements from an array-like object +/// +/// Note if a null occurs in the input it will NOT be included in the output. +/// +/// \param[in] datum array-like input +/// \param[in] ctx the function execution context, optional +/// \return result as Array +/// +/// \since 1.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result> Unique(const Datum& datum, ExecContext* ctx = NULLPTR); + +// Constants for accessing the output of ValueCounts +ARROW_EXPORT extern const char kValuesFieldName[]; +ARROW_EXPORT extern const char kCountsFieldName[]; +ARROW_EXPORT extern const int32_t kValuesFieldIndex; +ARROW_EXPORT extern const int32_t kCountsFieldIndex; +/// \brief Return counts of unique elements from an array-like object. +/// +/// Note that the counts do not include counts for nulls in the array. These can be +/// obtained separately from metadata. +/// +/// For floating point arrays there is no attempt to normalize -0.0, 0.0 and NaN values +/// which can lead to unexpected results if the input Array has these values. +/// +/// \param[in] value array-like input +/// \param[in] ctx the function execution context, optional +/// \return counts An array of structs. +/// +/// \since 1.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result> ValueCounts(const Datum& value, + ExecContext* ctx = NULLPTR); + +/// \brief Dictionary-encode values in an array-like object +/// \param[in] data array-like input +/// \param[in] ctx the function execution context, optional +/// \return result with same shape and type as input +/// +/// \since 1.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result DictionaryEncode(const Datum& data, ExecContext* ctx = NULLPTR); + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/cast.cc b/cpp/src/arrow/compute/cast.cc new file mode 100644 index 00000000000..63b8e509edc --- /dev/null +++ b/cpp/src/arrow/compute/cast.cc @@ -0,0 +1,180 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/compute/cast.h" + +#include +#include +#include +#include +#include +#include + +#include "arrow/compute/cast_internal.h" +#include "arrow/compute/kernel.h" + +namespace arrow { +namespace compute { + +namespace internal { + +std::unordered_map> g_cast_table; +static std::once_flag cast_table_initialized; + +void AddCastFunctions(const std::vector>& funcs) { + for (const auto& func : funcs) { + g_cast_table[static_cast(func->out_type_id())] = func; + } +} + +void InitCastTable() { + AddCastFunctions(GetBooleanCasts()); + AddCastFunctions(GetBinaryLikeCasts()); + AddCastFunctions(GetNestedCasts()); + AddCastFunctions(GetNumericCasts()); + AddCastFunctions(GetTemporalCasts()); +} + +void EnsureInitCastTable() { std::call_once(cast_table_initialized, InitCastTable); } + +} // namespace internal + +struct CastFunction::CastFunctionImpl { + Type::type out_type; + std::unordered_set in_types; +}; + +CastFunction::CastFunction(std::string name, Type::type out_type) + : ScalarFunction(std::move(name), Arity::Unary()) { + impl_.reset(new CastFunctionImpl()); + impl_->out_type = out_type; +} + +CastFunction::~CastFunction() {} + +Type::type CastFunction::out_type_id() const { return impl_->out_type; } + +std::unique_ptr CastInit(KernelContext* ctx, const KernelInitArgs& args) { + auto cast_options = static_cast(args.options); + // Ensure that the requested type to cast to was attached to the options + DCHECK(cast_options->to_type); + return std::unique_ptr(new internal::CastState(*cast_options)); +} + +Status CastFunction::AddKernel(Type::type in_type_id, ScalarKernel kernel) { + // We use the same KernelInit for every cast + kernel.init = CastInit; + RETURN_NOT_OK(ScalarFunction::AddKernel(kernel)); + impl_->in_types.insert(static_cast(in_type_id)); + return Status::OK(); +} + +Status CastFunction::AddKernel(Type::type in_type_id, std::vector in_types, + OutputType out_type, ArrayKernelExec exec, + NullHandling::type null_handling, + MemAllocation::type mem_allocation) { + ScalarKernel kernel; + kernel.signature = KernelSignature::Make(std::move(in_types), std::move(out_type)); + kernel.exec = exec; + kernel.null_handling = null_handling; + kernel.mem_allocation = mem_allocation; + return AddKernel(in_type_id, std::move(kernel)); +} + +bool CastFunction::CanCastTo(const DataType& out_type) const { + return impl_->in_types.find(static_cast(out_type.id())) != impl_->in_types.end(); +} + +Result CastFunction::DispatchExact( + const std::vector& values) const { + const int passed_num_args = static_cast(values.size()); + + // Validate arity + if (passed_num_args != 1) { + return Status::Invalid("Cast sunctions accept 1 argument but passed ", + passed_num_args); + } + std::vector candidate_kernels; + for (const auto& kernel : kernels_) { + if (kernel.signature->MatchesInputs(values)) { + candidate_kernels.push_back(&kernel); + } + } + + if (candidate_kernels.size() == 0) { + return Status::NotImplemented("Function ", this->name(), + " has no kernel matching input type ", + values[0].ToString()); + } else if (candidate_kernels.size() == 1) { + // One match, return it + return candidate_kernels[0]; + } else { + // Now we are in a casting scenario where we may have both a EXACT_TYPE and + // a SAME_TYPE_ID. So we will see if there is an exact match among the + // candidate kernels and if not we will just return the first one + for (auto kernel : candidate_kernels) { + const InputType& arg0 = kernel->signature->in_types()[0]; + if (arg0.kind() == InputType::EXACT_TYPE) { + // Bingo. Return it + return kernel; + } + } + // We didn't find an exact match. So just return some kernel that matches + return candidate_kernels[0]; + } +} + +Result Cast(const Datum& value, std::shared_ptr to_type, + const CastOptions& options, ExecContext* ctx) { + if (value.type()->Equals(*to_type)) { + return value; + } + CastOptions options_with_to_type = options; + options_with_to_type.to_type = to_type; + ARROW_ASSIGN_OR_RAISE(std::shared_ptr cast_func, + GetCastFunction(to_type)); + return cast_func->Execute({Datum(value)}, &options_with_to_type, ctx); +} + +Result> Cast(const Array& value, std::shared_ptr to_type, + const CastOptions& options, ExecContext* ctx) { + ARROW_ASSIGN_OR_RAISE(Datum result, Cast(Datum(value), to_type, options, ctx)); + return result.make_array(); +} + +Result> GetCastFunction( + const std::shared_ptr& to_type) { + internal::EnsureInitCastTable(); + auto it = internal::g_cast_table.find(static_cast(to_type->id())); + if (it == internal::g_cast_table.end()) { + return Status::NotImplemented("No cast function available to cast to ", + to_type->ToString()); + } + return it->second; +} + +bool CanCast(const DataType& from_type, const DataType& to_type) { + // TODO + auto it = internal::g_cast_table.find(static_cast(from_type.id())); + if (it == internal::g_cast_table.end()) { + return false; + } + return it->second->CanCastTo(to_type); +} + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/cast.h b/cpp/src/arrow/compute/cast.h new file mode 100644 index 00000000000..93961a0fd3b --- /dev/null +++ b/cpp/src/arrow/compute/cast.h @@ -0,0 +1,145 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include + +#include "arrow/compute/function.h" +#include "arrow/compute/kernel.h" +#include "arrow/datum.h" +#include "arrow/result.h" +#include "arrow/status.h" +#include "arrow/type.h" +#include "arrow/util/logging.h" +#include "arrow/util/macros.h" +#include "arrow/util/visibility.h" + +namespace arrow { + +class Array; + +namespace compute { + +class ExecContext; + +struct ARROW_EXPORT CastOptions : public FunctionOptions { + CastOptions() + : allow_int_overflow(false), + allow_time_truncate(false), + allow_time_overflow(false), + allow_decimal_truncate(false), + allow_float_truncate(false), + allow_invalid_utf8(false) {} + + explicit CastOptions(bool safe) + : allow_int_overflow(!safe), + allow_time_truncate(!safe), + allow_time_overflow(!safe), + allow_decimal_truncate(!safe), + allow_float_truncate(!safe), + allow_invalid_utf8(!safe) {} + + static CastOptions Safe() { return CastOptions(true); } + + static CastOptions Unsafe() { return CastOptions(false); } + + // Type being casted to. May be passed separate to eager function + // compute::Cast + std::shared_ptr to_type; + + bool allow_int_overflow; + bool allow_time_truncate; + bool allow_time_overflow; + bool allow_decimal_truncate; + bool allow_float_truncate; + // Indicate if conversions from Binary/FixedSizeBinary to string must + // validate the utf8 payload. + bool allow_invalid_utf8; +}; + +// Cast functions are _not_ registered in the FunctionRegistry, though they use +// the same execution machinery +class CastFunction : public ScalarFunction { + public: + CastFunction(std::string name, Type::type out_type); + ~CastFunction(); + + Type::type out_type_id() const; + + Status AddKernel(Type::type in_type_id, std::vector in_types, + OutputType out_type, ArrayKernelExec exec, + NullHandling::type = NullHandling::INTERSECTION, + MemAllocation::type = MemAllocation::PREALLOCATE); + + // Note, this function toggles off memory allocation and sets the init + // function to CastInit + Status AddKernel(Type::type in_type_id, ScalarKernel kernel); + + bool CanCastTo(const DataType& out_type) const; + + Result DispatchExact( + const std::vector& values) const override; + + private: + struct CastFunctionImpl; + std::unique_ptr impl_; +}; + +ARROW_EXPORT +Result> GetCastFunction( + const std::shared_ptr& to_type); + +/// \brief Return true if a cast function is defined +ARROW_EXPORT +bool CanCast(const DataType& from_type, const DataType& to_type); + +// ---------------------------------------------------------------------- +// Convenience invocation APIs for a number of kernels + +/// \brief Cast from one array type to another +/// \param[in] value array to cast +/// \param[in] to_type type to cast to +/// \param[in] options casting options +/// \param[in] context the function execution context, optional +/// \return the resulting array +/// +/// \since 1.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result> Cast(const Array& value, std::shared_ptr to_type, + const CastOptions& options = CastOptions::Safe(), + ExecContext* context = NULLPTR); + +/// \brief Cast from one value to another +/// \param[in] value datum to cast +/// \param[in] to_type type to cast to +/// \param[in] options casting options +/// \param[in] context the function execution context, optional +/// \return the resulting datum +/// +/// \since 1.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result Cast(const Datum& value, std::shared_ptr to_type, + const CastOptions& options = CastOptions::Safe(), + ExecContext* context = NULLPTR); + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/context.cc b/cpp/src/arrow/compute/cast_internal.h similarity index 57% rename from cpp/src/arrow/compute/context.cc rename to cpp/src/arrow/compute/cast_internal.h index dade2464a3d..be64359e4ab 100644 --- a/cpp/src/arrow/compute/context.cc +++ b/cpp/src/arrow/compute/cast_internal.h @@ -15,35 +15,30 @@ // specific language governing permissions and limitations // under the License. -#include "arrow/compute/context.h" +#pragma once #include +#include -#include "arrow/buffer.h" -#include "arrow/result.h" -#include "arrow/util/cpu_info.h" +#include "arrow/compute/cast.h" // IWYU pragma: keep +#include "arrow/compute/kernel.h" // IWYU pragma: keep namespace arrow { namespace compute { +namespace internal { -FunctionContext::FunctionContext(MemoryPool* pool) - : pool_(pool), cpu_info_(internal::CpuInfo::GetInstance()) {} +struct CastState : public KernelState { + explicit CastState(const CastOptions& options) : options(options) {} + CastOptions options; +}; -MemoryPool* FunctionContext::memory_pool() const { return pool_; } - -Status FunctionContext::Allocate(const int64_t nbytes, std::shared_ptr* out) { - return AllocateBuffer(nbytes, pool_).Value(out); -} - -void FunctionContext::SetStatus(const Status& status) { - if (ARROW_PREDICT_FALSE(!status_.ok())) { - return; - } - status_ = status; -} - -/// \brief Clear any error status -void FunctionContext::ResetStatus() { status_ = Status::OK(); } +// See kernels/scalar_cast_*.cc for these +std::vector> GetBooleanCasts(); +std::vector> GetNumericCasts(); +std::vector> GetTemporalCasts(); +std::vector> GetBinaryLikeCasts(); +std::vector> GetNestedCasts(); +} // namespace internal } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/compute_test.cc b/cpp/src/arrow/compute/compute_test.cc deleted file mode 100644 index cd33466a67a..00000000000 --- a/cpp/src/arrow/compute/compute_test.cc +++ /dev/null @@ -1,95 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#include -#include -#include -#include -#include -#include - -#include - -#include "arrow/array.h" -#include "arrow/buffer.h" -#include "arrow/memory_pool.h" -#include "arrow/status.h" -#include "arrow/table.h" -#include "arrow/testing/gtest_common.h" -#include "arrow/testing/gtest_util.h" -#include "arrow/type.h" -#include "arrow/type_traits.h" -#include "arrow/util/decimal.h" - -#include "arrow/compute/context.h" -#include "arrow/compute/kernel.h" -#include "arrow/compute/kernels/util_internal.h" -#include "arrow/compute/test_util.h" - -namespace arrow { -namespace compute { - -// ---------------------------------------------------------------------- -// Datum - -template -void CheckImplicitConstructor(enum Datum::type expected_kind) { - std::shared_ptr value; - Datum datum = value; - ASSERT_EQ(expected_kind, datum.kind()); -} - -TEST(TestDatum, ImplicitConstructors) { - CheckImplicitConstructor(Datum::SCALAR); - - CheckImplicitConstructor(Datum::ARRAY); - - // Instantiate from array subclass - CheckImplicitConstructor(Datum::ARRAY); - - CheckImplicitConstructor(Datum::CHUNKED_ARRAY); - CheckImplicitConstructor(Datum::RECORD_BATCH); - - CheckImplicitConstructor
(Datum::TABLE); -} - -class TestInvokeBinaryKernel : public ComputeFixture, public TestBase {}; - -TEST_F(TestInvokeBinaryKernel, Exceptions) { - MockBinaryKernel kernel; - std::vector outputs; - std::shared_ptr
table; - std::vector values1 = {true, false, true}; - std::vector values2 = {false, true, false}; - - auto type = boolean(); - auto a1 = _MakeArray(type, values1, {}); - auto a2 = _MakeArray(type, values2, {}); - - // Left is not an array-like - ASSERT_RAISES(Invalid, detail::InvokeBinaryArrayKernel(&this->ctx_, &kernel, table, a2, - &outputs)); - // Right is not an array-like - ASSERT_RAISES(Invalid, detail::InvokeBinaryArrayKernel(&this->ctx_, &kernel, a1, table, - &outputs)); - // Different sized inputs - ASSERT_RAISES(Invalid, detail::InvokeBinaryArrayKernel(&this->ctx_, &kernel, a1, - a1->Slice(1), &outputs)); -} - -} // namespace compute -} // namespace arrow diff --git a/cpp/src/arrow/compute/context.h b/cpp/src/arrow/compute/context.h deleted file mode 100644 index dde8b686fc3..00000000000 --- a/cpp/src/arrow/compute/context.h +++ /dev/null @@ -1,79 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#pragma once - -#include -#include - -#include "arrow/memory_pool.h" -#include "arrow/status.h" -#include "arrow/util/macros.h" -#include "arrow/util/visibility.h" - -namespace arrow { - -class Buffer; - -namespace internal { -class CpuInfo; -} // namespace internal - -namespace compute { - -#define ARROW_RETURN_IF_ERROR(ctx) \ - if (ARROW_PREDICT_FALSE(ctx->HasError())) { \ - Status s = ctx->status(); \ - ctx->ResetStatus(); \ - return s; \ - } - -/// \brief Container for variables and options used by function evaluation -class ARROW_EXPORT FunctionContext { - public: - explicit FunctionContext(MemoryPool* pool = default_memory_pool()); - MemoryPool* memory_pool() const; - - /// \brief Allocate buffer from the context's memory pool - Status Allocate(const int64_t nbytes, std::shared_ptr* out); - - /// \brief Indicate that an error has occurred, to be checked by a parent caller - /// \param[in] status a Status instance - /// - /// \note Will not overwrite a prior set Status, so we will have the first - /// error that occurred until FunctionContext::ResetStatus is called - void SetStatus(const Status& status); - - /// \brief Clear any error status - void ResetStatus(); - - /// \brief Return true if an error has occurred - bool HasError() const { return !status_.ok(); } - - /// \brief Return the current status of the context - const Status& status() const { return status_; } - - internal::CpuInfo* cpu_info() const { return cpu_info_; } - - private: - Status status_; - MemoryPool* pool_; - internal::CpuInfo* cpu_info_; -}; - -} // namespace compute -} // namespace arrow diff --git a/cpp/src/arrow/compute/exec.cc b/cpp/src/arrow/compute/exec.cc new file mode 100644 index 00000000000..7c990ea6c9d --- /dev/null +++ b/cpp/src/arrow/compute/exec.cc @@ -0,0 +1,942 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/compute/exec.h" + +#include +#include +#include +#include +#include +#include + +#include "arrow/array.h" +#include "arrow/buffer.h" +#include "arrow/compute/exec_internal.h" +#include "arrow/compute/function.h" +#include "arrow/compute/kernel.h" +#include "arrow/compute/registry.h" +#include "arrow/datum.h" +#include "arrow/scalar.h" +#include "arrow/status.h" +#include "arrow/table.h" +#include "arrow/type.h" +#include "arrow/type_traits.h" +#include "arrow/util/bit_util.h" +#include "arrow/util/checked_cast.h" +#include "arrow/util/cpu_info.h" +#include "arrow/util/logging.h" + +namespace arrow { + +using internal::BitmapAnd; +using internal::checked_cast; +using internal::CopyBitmap; +using internal::CpuInfo; + +class MemoryPool; + +namespace compute { + +namespace { + +Result> AllocateDataBuffer(KernelContext* ctx, int64_t length, + int bit_width) { + if (bit_width == 1) { + return ctx->AllocateBitmap(length); + } else { + ARROW_CHECK_EQ(bit_width % 8, 0) + << "Only bit widths with multiple of 8 are currently supported"; + int64_t buffer_size = length * bit_width / 8; + return ctx->Allocate(buffer_size); + } + return Status::OK(); +} + +bool CanPreallocate(const DataType& type) { + // There are currently cases where NullType is the output type, so we disable + // any preallocation logic when this occurs + return is_fixed_width(type.id()) && type.id() != Type::NA; +} + +Status GetValueDescriptors(const std::vector& args, + std::vector* descrs) { + for (const auto& arg : args) { + descrs->emplace_back(arg.descr()); + } + return Status::OK(); +} + +} // namespace + +namespace detail { + +ExecBatchIterator::ExecBatchIterator(std::vector args, int64_t length, + int64_t max_chunksize) + : args_(std::move(args)), + position_(0), + length_(length), + max_chunksize_(max_chunksize) { + chunk_indexes_.resize(args_.size(), 0); + chunk_positions_.resize(args_.size(), 0); +} + +Result> ExecBatchIterator::Make( + std::vector args, int64_t max_chunksize) { + for (const auto& arg : args) { + if (!(arg.is_arraylike() || arg.is_scalar())) { + return Status::Invalid( + "ExecBatchIterator only works with Scalar, Array, and " + "ChunkedArray arguments"); + } + } + + // If the arguments are all scalars, then the length is 1 + int64_t length = 1; + + bool length_set = false; + for (size_t i = 0; i < args.size(); ++i) { + if (args[i].is_scalar()) { + continue; + } + if (!length_set) { + length = args[i].length(); + length_set = true; + } else { + if (args[i].length() != length) { + return Status::Invalid("Array arguments must all be the same length"); + } + } + } + + max_chunksize = std::min(length, max_chunksize); + + return std::unique_ptr( + new ExecBatchIterator(std::move(args), length, max_chunksize)); +} + +bool ExecBatchIterator::Next(ExecBatch* batch) { + if (position_ == length_) { + return false; + } + + // Determine how large the common contiguous "slice" of all the arguments is + int64_t iteration_size = std::min(length_ - position_, max_chunksize_); + + // If length_ is 0, then this loop will never execute + for (size_t i = 0; i < args_.size() && iteration_size > 0; ++i) { + // If the argument is not a chunked array, it's either a Scalar or Array, + // in which case it doesn't influence the size of this batch. Note that if + // the args are all scalars the batch length is 1 + if (args_[i].kind() != Datum::CHUNKED_ARRAY) { + continue; + } + const ChunkedArray& arg = *args_[i].chunked_array(); + std::shared_ptr current_chunk; + while (true) { + current_chunk = arg.chunk(chunk_indexes_[i]); + if (chunk_positions_[i] == current_chunk->length()) { + // Chunk is zero-length, or was exhausted in the previous iteration + chunk_positions_[i] = 0; + ++chunk_indexes_[i]; + continue; + } + break; + } + iteration_size = + std::min(current_chunk->length() - chunk_positions_[i], iteration_size); + } + + // Now, fill the batch + batch->values.resize(args_.size()); + batch->length = iteration_size; + for (size_t i = 0; i < args_.size(); ++i) { + if (args_[i].is_scalar()) { + batch->values[i] = args_[i].scalar(); + } else if (args_[i].is_array()) { + batch->values[i] = args_[i].array()->Slice(position_, iteration_size); + } else { + const ChunkedArray& carr = *args_[i].chunked_array(); + const auto& chunk = carr.chunk(chunk_indexes_[i]); + batch->values[i] = chunk->data()->Slice(chunk_positions_[i], iteration_size); + chunk_positions_[i] += iteration_size; + } + } + position_ += iteration_size; + DCHECK_LE(position_, length_); + return true; +} + +bool ArrayHasNulls(const ArrayData& data) { + // As discovered in ARROW-8863 (and not only for that reason) + // ArrayData::null_count can -1 even when buffers[0] is nullptr. So we check + // for both cases (nullptr means no nulls, or null_count already computed) + if (data.type->id() == Type::NA) { + return true; + } else if (data.buffers[0] == nullptr) { + return false; + } else { + // Do not count the bits if they haven't been counted already + const int64_t known_null_count = data.null_count.load(); + return known_null_count == kUnknownNullCount || known_null_count > 0; + } +} + +// Null propagation implementation that deals both with preallocated bitmaps +// and maybe-to-be allocated bitmaps +// +// If the bitmap is preallocated, it MUST be populated (since it might be a +// view of a much larger bitmap). If it isn't preallocated, then we have +// more flexibility. +// +// * If the batch has no nulls, then we do nothing +// * If only a single array has nulls, and its offset is a multiple of 8, +// then we can zero-copy the bitmap into the output +// * Otherwise, we allocate the bitmap and populate it +class NullPropagator { + public: + NullPropagator(KernelContext* ctx, const ExecBatch& batch, ArrayData* output) + : ctx_(ctx), batch_(batch), output_(output) { + // At this point, the values in batch_.values must have been validated to + // all be value-like + for (const Datum& val : batch_.values) { + if (val.kind() == Datum::ARRAY) { + if (ArrayHasNulls(*val.array())) { + values_with_nulls_.push_back(&val); + } + } else if (!val.scalar()->is_valid) { + values_with_nulls_.push_back(&val); + } + } + + if (output->buffers[0] != nullptr) { + bitmap_preallocated_ = true; + SetBitmap(output_->buffers[0].get()); + } + } + + void SetBitmap(Buffer* bitmap) { bitmap_ = bitmap->mutable_data(); } + + Status EnsureAllocated() { + if (bitmap_preallocated_) { + return Status::OK(); + } + ARROW_ASSIGN_OR_RAISE(output_->buffers[0], ctx_->AllocateBitmap(output_->length)); + SetBitmap(output_->buffers[0].get()); + return Status::OK(); + } + + Result ShortCircuitIfAllNull() { + // An all-null value (scalar null or all-null array) gives us a short + // circuit opportunity + bool is_all_null = false; + std::shared_ptr all_null_bitmap; + + // Walk all the values with nulls instead of breaking on the first in case + // we find a bitmap that can be reused in the non-preallocated case + for (const Datum* value : values_with_nulls_) { + if (value->type()->id() == Type::NA) { + // No bitmap + is_all_null = true; + } else if (value->kind() == Datum::ARRAY) { + const ArrayData& arr = *value->array(); + if (arr.null_count.load() == arr.length) { + // Pluck the all null bitmap so we can set it in the output if it was + // not pre-allocated + all_null_bitmap = arr.buffers[0]; + is_all_null = true; + } + } else { + // Scalar + is_all_null = true; + } + } + if (!is_all_null) { + return false; + } + + // OK, the output should be all null + output_->null_count = output_->length; + + if (!bitmap_preallocated_ && all_null_bitmap) { + // If we did not pre-allocate memory, and we observed an all-null bitmap, + // then we can zero-copy it into the output + output_->buffers[0] = std::move(all_null_bitmap); + } else { + RETURN_NOT_OK(EnsureAllocated()); + BitUtil::SetBitsTo(bitmap_, output_->offset, output_->length, false); + } + return true; + } + + Status PropagateSingle() { + // One array + const ArrayData& arr = *values_with_nulls_[0]->array(); + const std::shared_ptr& arr_bitmap = arr.buffers[0]; + + // Reuse the null count if it's known + output_->null_count = arr.null_count.load(); + + if (bitmap_preallocated_) { + CopyBitmap(arr_bitmap->data(), arr.offset, arr.length, bitmap_, output_->offset); + } else { + // Two cases when memory was not pre-allocated: + // + // * Offset is zero: we reuse the bitmap as is + // * Offset is nonzero but a multiple of 8: we can slice the bitmap + // * Offset is not a multiple of 8: we must allocate and use CopyBitmap + // + // Keep in mind that output_->offset is not permitted to be nonzero when + // the bitmap is not preallocated, and that precondition is asserted + // higher in the call stack. + if (arr.offset == 0) { + output_->buffers[0] = arr_bitmap; + } else if (arr.offset % 8 == 0) { + output_->buffers[0] = + SliceBuffer(arr_bitmap, arr.offset / 8, BitUtil::BytesForBits(arr.length)); + } else { + RETURN_NOT_OK(EnsureAllocated()); + CopyBitmap(arr_bitmap->data(), arr.offset, arr.length, bitmap_, + /*dst_offset=*/0); + } + } + return Status::OK(); + } + + Status PropagateMultiple() { + // More than one array. We use BitmapAnd to intersect their bitmaps + + // Do not compute the intersection null count until it's needed + RETURN_NOT_OK(EnsureAllocated()); + + auto Accumulate = [&](const ArrayData& left, const ArrayData& right) { + // This is a precondition of reaching this code path + DCHECK(left.buffers[0]); + DCHECK(right.buffers[0]); + BitmapAnd(left.buffers[0]->data(), left.offset, right.buffers[0]->data(), + right.offset, output_->length, output_->offset, + output_->buffers[0]->mutable_data()); + }; + + DCHECK_GT(values_with_nulls_.size(), 1); + + // Seed the output bitmap with the & of the first two bitmaps + Accumulate(*values_with_nulls_[0]->array(), *values_with_nulls_[1]->array()); + + // Accumulate the rest + for (size_t i = 2; i < values_with_nulls_.size(); ++i) { + Accumulate(*output_, *values_with_nulls_[i]->array()); + } + return Status::OK(); + } + + Status Execute() { + bool finished = false; + ARROW_ASSIGN_OR_RAISE(finished, ShortCircuitIfAllNull()); + if (finished) { + return Status::OK(); + } + + // At this point, by construction we know that all of the values in + // values_with_nulls_ are arrays that are not all null. So there are a + // few cases: + // + // * No arrays. This is a no-op w/o preallocation but when the bitmap is + // pre-allocated we have to fill it with 1's + // * One array, whose bitmap can be zero-copied (w/o preallocation, and + // when no byte is split) or copied (split byte or w/ preallocation) + // * More than one array, we must compute the intersection of all the + // bitmaps + // + // BUT, if the output offset is nonzero for some reason, we copy into the + // output unconditionally + + output_->null_count = kUnknownNullCount; + + if (values_with_nulls_.size() == 0) { + // No arrays with nulls case + output_->null_count = 0; + if (bitmap_preallocated_) { + BitUtil::SetBitsTo(bitmap_, output_->offset, output_->length, true); + } + return Status::OK(); + } else if (values_with_nulls_.size() == 1) { + return PropagateSingle(); + } else { + return PropagateMultiple(); + } + } + + private: + KernelContext* ctx_; + const ExecBatch& batch_; + std::vector values_with_nulls_; + ArrayData* output_; + uint8_t* bitmap_; + bool bitmap_preallocated_ = false; +}; + +Status PropagateNulls(KernelContext* ctx, const ExecBatch& batch, ArrayData* output) { + DCHECK_NE(nullptr, output); + DCHECK_GT(output->buffers.size(), 0); + + if (output->type->id() == Type::NA) { + // Null output type is a no-op (rare when this would happen but we at least + // will test for it) + return Status::OK(); + } + + // This function is ONLY able to write into output with non-zero offset + // when the bitmap is preallocated. This could be a DCHECK but returning + // error Status for now for emphasis + if (output->offset != 0 && output->buffers[0] == nullptr) { + return Status::Invalid( + "Can only propagate nulls into pre-allocated memory " + "when the output offset is non-zero"); + } + NullPropagator propagator(ctx, batch, output); + return propagator.Execute(); +} + +std::shared_ptr ToChunkedArray(const std::vector& values, + const std::shared_ptr& type) { + std::vector> arrays; + for (const auto& val : values) { + auto boxed = val.make_array(); + if (boxed->length() == 0) { + // Skip empty chunks + continue; + } + arrays.emplace_back(std::move(boxed)); + } + return std::make_shared(arrays, type); +} + +bool HaveChunkedArray(const std::vector& values) { + for (const auto& value : values) { + if (value.kind() == Datum::CHUNKED_ARRAY) { + return true; + } + } + return false; +} + +Status CheckAllValues(const std::vector& values) { + for (const auto& value : values) { + if (!value.is_value()) { + return Status::Invalid("Tried executing function with non-value type: ", + value.ToString()); + } + } + return Status::OK(); +} + +template +class FunctionExecutorImpl : public FunctionExecutor { + public: + FunctionExecutorImpl(ExecContext* exec_ctx, const FunctionType* func, + const FunctionOptions* options) + : exec_ctx_(exec_ctx), kernel_ctx_(exec_ctx), func_(func), options_(options) {} + + protected: + using KernelType = typename FunctionType::KernelType; + + void Reset() {} + + Status InitState() { + // Some kernels require initialization of an opaque state object + if (kernel_->init) { + KernelInitArgs init_args{kernel_, input_descrs_, options_}; + state_ = kernel_->init(&kernel_ctx_, init_args); + ARROW_CTX_RETURN_IF_ERROR(&kernel_ctx_); + kernel_ctx_.SetState(state_.get()); + } + return Status::OK(); + } + + // This is overridden by the VectorExecutor + virtual Status SetupArgIteration(const std::vector& args) { + ARROW_ASSIGN_OR_RAISE(batch_iterator_, + ExecBatchIterator::Make(args, exec_ctx_->exec_chunksize())); + return Status::OK(); + } + + Status BindArgs(const std::vector& args) { + RETURN_NOT_OK(GetValueDescriptors(args, &input_descrs_)); + ARROW_ASSIGN_OR_RAISE(kernel_, func_->DispatchExact(input_descrs_)); + + // Initialize kernel state, since type resolution may depend on this state + RETURN_NOT_OK(this->InitState()); + + // Resolve the output descriptor for this kernel + ARROW_ASSIGN_OR_RAISE(output_descr_, kernel_->signature->out_type().Resolve( + &kernel_ctx_, input_descrs_)); + + return SetupArgIteration(args); + } + + Result> PrepareOutput(int64_t length) { + auto out = std::make_shared(output_descr_.type, length); + out->buffers.resize(output_num_buffers_); + + if (validity_preallocated_) { + ARROW_ASSIGN_OR_RAISE(out->buffers[0], kernel_ctx_.AllocateBitmap(length)); + } + if (data_preallocated_) { + const auto& fw_type = checked_cast(*out->type); + ARROW_ASSIGN_OR_RAISE( + out->buffers[1], AllocateDataBuffer(&kernel_ctx_, length, fw_type.bit_width())); + } + return out; + } + + ValueDescr output_descr() const override { return output_descr_; } + + // Not all of these members are used for every executor type + + ExecContext* exec_ctx_; + KernelContext kernel_ctx_; + const FunctionType* func_; + const KernelType* kernel_; + std::unique_ptr batch_iterator_; + std::unique_ptr state_; + std::vector input_descrs_; + ValueDescr output_descr_; + const FunctionOptions* options_; + + int output_num_buffers_; + + // If true, then the kernel writes into a preallocated data buffer + bool data_preallocated_ = false; + + // If true, then memory is preallocated for the validity bitmap with the same + // strategy as the data buffer(s). + bool validity_preallocated_ = false; +}; + +class ScalarExecutor : public FunctionExecutorImpl { + public: + using FunctionType = ScalarFunction; + static constexpr Function::Kind function_kind = Function::SCALAR; + using BASE = FunctionExecutorImpl; + using BASE::BASE; + + Status Execute(const std::vector& args, ExecListener* listener) override { + RETURN_NOT_OK(PrepareExecute(args)); + ExecBatch batch; + while (batch_iterator_->Next(&batch)) { + RETURN_NOT_OK(ExecuteBatch(batch, listener)); + } + if (preallocate_contiguous_) { + // If we preallocated one big chunk, since the kernel execution is + // completed, we can now emit it + RETURN_NOT_OK(listener->OnResult(std::move(preallocated_))); + } + return Status::OK(); + } + + Datum WrapResults(const std::vector& inputs, + const std::vector& outputs) override { + if (output_descr_.shape == ValueDescr::SCALAR) { + DCHECK_GT(outputs.size(), 0); + if (outputs.size() == 1) { + // Return as SCALAR + return outputs[0]; + } else { + // Return as COLLECTION + return outputs; + } + } else { + // If execution yielded multiple chunks (because large arrays were split + // based on the ExecContext parameters, then the result is a ChunkedArray + if (HaveChunkedArray(inputs) || outputs.size() > 1) { + return ToChunkedArray(outputs, output_descr_.type); + } else if (outputs.size() == 1) { + // Outputs have just one element + return outputs[0]; + } else { + // XXX: In the case where no outputs are omitted, is returning a 0-length + // array always the correct move? + return MakeArrayOfNull(output_descr_.type, /*length=*/0).ValueOrDie(); + } + } + } + + protected: + Status ExecuteBatch(const ExecBatch& batch, ExecListener* listener) { + Datum out; + RETURN_NOT_OK(PrepareNextOutput(batch, &out)); + + if (kernel_->null_handling == NullHandling::INTERSECTION && + output_descr_.shape == ValueDescr::ARRAY) { + RETURN_NOT_OK(PropagateNulls(&kernel_ctx_, batch, out.mutable_array())); + } + + kernel_->exec(&kernel_ctx_, batch, &out); + ARROW_CTX_RETURN_IF_ERROR(&kernel_ctx_); + if (!preallocate_contiguous_) { + // If we are producing chunked output rather than one big array, then + // emit each chunk as soon as it's available + RETURN_NOT_OK(listener->OnResult(std::move(out))); + } + return Status::OK(); + } + + Status PrepareExecute(const std::vector& args) { + this->Reset(); + RETURN_NOT_OK(this->BindArgs(args)); + + if (output_descr_.shape == ValueDescr::ARRAY) { + // If the executor is configured to produce a single large Array output for + // kernels supporting preallocation, then we do so up front and then + // iterate over slices of that large array. Otherwise, we preallocate prior + // to processing each batch emitted from the ExecBatchIterator + RETURN_NOT_OK(SetupPreallocation(batch_iterator_->length())); + } + return Status::OK(); + } + + // We must accommodate two different modes of execution for preallocated + // execution + // + // * A single large ("contiguous") allocation that we populate with results + // on a chunkwise basis according to the ExecBatchIterator. This permits + // parallelization even if the objective is to obtain a single Array or + // ChunkedArray at the end + // * A standalone buffer preallocation for each chunk emitted from the + // ExecBatchIterator + // + // When data buffer preallocation is not possible (e.g. with BINARY / STRING + // outputs), then contiguous results are only possible if the input is + // contiguous. + + Status PrepareNextOutput(const ExecBatch& batch, Datum* out) { + if (output_descr_.shape == ValueDescr::ARRAY) { + if (preallocate_contiguous_) { + // The output is already fully preallocated + const int64_t batch_start_position = batch_iterator_->position() - batch.length; + + if (batch.length < batch_iterator_->length()) { + // If this is a partial execution, then we write into a slice of + // preallocated_ + // + // XXX: ArrayData::Slice not returning std::shared_ptr is + // a nuisance + out->value = std::make_shared( + preallocated_->Slice(batch_start_position, batch.length)); + } else { + // Otherwise write directly into preallocated_. The main difference + // computationally (versus the Slice approach) is that the null_count + // may not need to be recomputed in the result + out->value = preallocated_; + } + } else { + // We preallocate (maybe) only for the output of processing the current + // batch + ARROW_ASSIGN_OR_RAISE(out->value, PrepareOutput(batch.length)); + } + } else { + // For scalar outputs, we set a null scalar of the correct type to + // communicate the output type to the kernel if needed + // + // XXX: Is there some way to avoid this step? + out->value = MakeNullScalar(output_descr_.type); + } + return Status::OK(); + } + + Status SetupPreallocation(int64_t total_length) { + output_num_buffers_ = static_cast(output_descr_.type->layout().buffers.size()); + + // Decide if we need to preallocate memory for this kernel + data_preallocated_ = ((kernel_->mem_allocation == MemAllocation::PREALLOCATE) && + CanPreallocate(*output_descr_.type)); + validity_preallocated_ = + (kernel_->null_handling != NullHandling::COMPUTED_NO_PREALLOCATE && + kernel_->null_handling != NullHandling::OUTPUT_NOT_NULL); + + // Contiguous preallocation only possible if both the VALIDITY and DATA can + // be preallocated. Otherwise, we must go chunk-by-chunk. Note that when + // the DATA cannot be preallocated, the VALIDITY may still be preallocated + // depending on the NullHandling of the kernel + // + // Some kernels are unable to write into sliced outputs, so we respect the + // kernel's attributes + preallocate_contiguous_ = + (exec_ctx_->preallocate_contiguous() && kernel_->can_write_into_slices && + data_preallocated_ && validity_preallocated_); + if (preallocate_contiguous_) { + DCHECK_EQ(2, output_num_buffers_); + ARROW_ASSIGN_OR_RAISE(preallocated_, PrepareOutput(total_length)); + } + return Status::OK(); + } + + // If true, and the kernel and output type supports preallocation (for both + // the validity and data buffers), then we allocate one big array and then + // iterate through it while executing the kernel in chunks + bool preallocate_contiguous_ = false; + + // For storing a contiguous preallocation per above. Unused otherwise + std::shared_ptr preallocated_; +}; + +Status PackBatchNoChunks(const std::vector& args, ExecBatch* out) { + int64_t length = 0; + for (size_t i = 0; i < args.size(); ++i) { + switch (args[i].kind()) { + case Datum::SCALAR: + case Datum::ARRAY: + length = std::max(args[i].length(), length); + break; + case Datum::CHUNKED_ARRAY: + return Status::Invalid("Kernel does not support chunked array arguments"); + default: + DCHECK(false); + break; + } + } + out->length = length; + out->values = args; + return Status::OK(); +} + +class VectorExecutor : public FunctionExecutorImpl { + public: + using FunctionType = VectorFunction; + static constexpr Function::Kind function_kind = Function::VECTOR; + using BASE = FunctionExecutorImpl; + using BASE::BASE; + + Status Execute(const std::vector& args, ExecListener* listener) override { + RETURN_NOT_OK(PrepareExecute(args)); + ExecBatch batch; + if (kernel_->can_execute_chunkwise) { + while (batch_iterator_->Next(&batch)) { + RETURN_NOT_OK(ExecuteBatch(batch, listener)); + } + } else { + RETURN_NOT_OK(PackBatchNoChunks(args, &batch)); + RETURN_NOT_OK(ExecuteBatch(batch, listener)); + } + return Finalize(listener); + } + + Datum WrapResults(const std::vector& inputs, + const std::vector& outputs) override { + // If execution yielded multiple chunks (because large arrays were split + // based on the ExecContext parameters, then the result is a ChunkedArray + if (kernel_->output_chunked) { + if (HaveChunkedArray(inputs) || outputs.size() > 1) { + return ToChunkedArray(outputs, output_descr_.type); + } else if (outputs.size() == 1) { + // Outputs have just one element + return outputs[0]; + } else { + // XXX: In the case where no outputs are omitted, is returning a 0-length + // array always the correct move? + return MakeArrayOfNull(output_descr_.type, /*length=*/0).ValueOrDie(); + } + } else { + return outputs[0]; + } + } + + protected: + Status ExecuteBatch(const ExecBatch& batch, ExecListener* listener) { + if (batch.length == 0) { + // Skip empty batches. This may only happen when not using + // ExecBatchIterator + return Status::OK(); + } + Datum out; + if (output_descr_.shape == ValueDescr::ARRAY) { + // We preallocate (maybe) only for the output of processing the current + // batch + ARROW_ASSIGN_OR_RAISE(out.value, PrepareOutput(batch.length)); + } + + if (kernel_->null_handling == NullHandling::INTERSECTION && + output_descr_.shape == ValueDescr::ARRAY) { + RETURN_NOT_OK(PropagateNulls(&kernel_ctx_, batch, out.mutable_array())); + } + kernel_->exec(&kernel_ctx_, batch, &out); + ARROW_CTX_RETURN_IF_ERROR(&kernel_ctx_); + if (!kernel_->finalize) { + // If there is no result finalizer (e.g. for hash-based functions, we can + // emit the processed batch right away rather than waiting + RETURN_NOT_OK(listener->OnResult(std::move(out))); + } else { + results_.emplace_back(std::move(out)); + } + return Status::OK(); + } + + Status Finalize(ExecListener* listener) { + if (kernel_->finalize) { + // Intermediate results require post-processing after the execution is + // completed (possibly involving some accumulated state) + kernel_->finalize(&kernel_ctx_, &results_); + ARROW_CTX_RETURN_IF_ERROR(&kernel_ctx_); + for (const auto& result : results_) { + RETURN_NOT_OK(listener->OnResult(result)); + } + } + return Status::OK(); + } + + Status SetupArgIteration(const std::vector& args) override { + if (kernel_->can_execute_chunkwise) { + ARROW_ASSIGN_OR_RAISE(batch_iterator_, + ExecBatchIterator::Make(args, exec_ctx_->exec_chunksize())); + } + return Status::OK(); + } + + Status PrepareExecute(const std::vector& args) { + this->Reset(); + RETURN_NOT_OK(this->BindArgs(args)); + output_num_buffers_ = static_cast(output_descr_.type->layout().buffers.size()); + + // Decide if we need to preallocate memory for this kernel + data_preallocated_ = ((kernel_->mem_allocation == MemAllocation::PREALLOCATE) && + CanPreallocate(*output_descr_.type)); + validity_preallocated_ = + (kernel_->null_handling != NullHandling::COMPUTED_NO_PREALLOCATE && + kernel_->null_handling != NullHandling::OUTPUT_NOT_NULL); + return Status::OK(); + } + + std::vector results_; +}; + +class ScalarAggExecutor : public FunctionExecutorImpl { + public: + using FunctionType = ScalarAggregateFunction; + static constexpr Function::Kind function_kind = Function::SCALAR_AGGREGATE; + using BASE = FunctionExecutorImpl; + using BASE::BASE; + + Status Execute(const std::vector& args, ExecListener* listener) override { + RETURN_NOT_OK(BindArgs(args)); + + ExecBatch batch; + while (batch_iterator_->Next(&batch)) { + // TODO: implement parallelism + if (batch.length > 0) { + RETURN_NOT_OK(Consume(batch)); + } + } + + Datum out; + kernel_->finalize(&kernel_ctx_, &out); + ARROW_CTX_RETURN_IF_ERROR(&kernel_ctx_); + RETURN_NOT_OK(listener->OnResult(std::move(out))); + return Status::OK(); + } + + Datum WrapResults(const std::vector&, + const std::vector& outputs) override { + DCHECK_EQ(1, outputs.size()); + return outputs[0]; + } + + private: + Status Consume(const ExecBatch& batch) { + KernelInitArgs init_args{kernel_, input_descrs_, options_}; + auto batch_state = kernel_->init(&kernel_ctx_, init_args); + ARROW_CTX_RETURN_IF_ERROR(&kernel_ctx_); + + KernelContext batch_ctx(exec_ctx_); + batch_ctx.SetState(batch_state.get()); + + kernel_->consume(&batch_ctx, batch); + ARROW_CTX_RETURN_IF_ERROR(&batch_ctx); + + kernel_->merge(&kernel_ctx_, *batch_state, state_.get()); + ARROW_CTX_RETURN_IF_ERROR(&kernel_ctx_); + return Status::OK(); + } +}; + +template +Result> MakeExecutor(ExecContext* ctx, + const Function* func, + const FunctionOptions* options) { + DCHECK_EQ(ExecutorType::function_kind, func->kind()); + auto typed_func = checked_cast(func); + return std::unique_ptr(new ExecutorType(ctx, typed_func, options)); +} + +Result> FunctionExecutor::Make( + ExecContext* ctx, const Function* func, const FunctionOptions* options) { + switch (func->kind()) { + case Function::SCALAR: + return MakeExecutor(ctx, func, options); + case Function::VECTOR: + return MakeExecutor(ctx, func, options); + case Function::SCALAR_AGGREGATE: + return MakeExecutor(ctx, func, options); + default: + DCHECK(false); + return nullptr; + } +} + +} // namespace detail + +ExecContext::ExecContext(MemoryPool* pool, FunctionRegistry* func_registry) + : pool_(pool) { + this->func_registry_ = func_registry == nullptr ? GetFunctionRegistry() : func_registry; +} + +CpuInfo* ExecContext::cpu_info() const { return CpuInfo::GetInstance(); } + +// ---------------------------------------------------------------------- +// SelectionVector + +SelectionVector::SelectionVector(std::shared_ptr data) + : data_(std::move(data)) { + DCHECK_EQ(Type::INT32, data_->type->id()); + DCHECK_EQ(0, data_->GetNullCount()); + indices_ = data_->GetValues(1); +} + +SelectionVector::SelectionVector(const Array& arr) : SelectionVector(arr.data()) {} + +int32_t SelectionVector::length() const { return static_cast(data_->length); } + +Result> SelectionVector::FromMask( + const BooleanArray& arr) { + return Status::NotImplemented("FromMask"); +} + +Result CallFunction(ExecContext* ctx, const std::string& func_name, + const std::vector& args, + const FunctionOptions* options) { + if (ctx == nullptr) { + ExecContext default_ctx; + return CallFunction(&default_ctx, func_name, args, options); + } + ARROW_ASSIGN_OR_RAISE(std::shared_ptr func, + ctx->func_registry()->GetFunction(func_name)); + return func->Execute(args, options, ctx); +} + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/exec.h b/cpp/src/arrow/compute/exec.h new file mode 100644 index 00000000000..d6ba48db366 --- /dev/null +++ b/cpp/src/arrow/compute/exec.h @@ -0,0 +1,183 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// NOTE: API is EXPERIMENTAL and will change without going through a +// deprecation cycle + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "arrow/datum.h" +#include "arrow/result.h" +#include "arrow/type_fwd.h" +#include "arrow/util/macros.h" +#include "arrow/util/visibility.h" + +namespace arrow { + +class Array; +struct ArrayData; +class MemoryPool; + +namespace internal { + +class CpuInfo; + +} // namespace internal + +namespace compute { + +struct FunctionOptions; +class FunctionRegistry; + +// It seems like 64K might be a good default chunksize to use for execution +// based on the experience of other query processing systems. The current +// default is not to chunk contiguous arrays, though, but this may change in +// the future once parallel execution is implemented +static constexpr int64_t kDefaultExecChunksize = UINT16_MAX; + +/// \brief Context for expression-global variables and options used by +/// function evaluation +class ARROW_EXPORT ExecContext { + public: + // If no function registry passed, the default is used + explicit ExecContext(MemoryPool* pool = default_memory_pool(), + FunctionRegistry* func_registry = NULLPTR); + + MemoryPool* memory_pool() const { return pool_; } + + ::arrow::internal::CpuInfo* cpu_info() const; + + FunctionRegistry* func_registry() const { return func_registry_; } + + // \brief Set maximum length unit of work for kernel execution. Larger inputs + // will be split into smaller chunks, and, if desired, processed in + // parallel. Set to -1 for no limit + void set_exec_chunksize(int64_t chunksize) { exec_chunksize_ = chunksize; } + + // \brief Maximum length unit of work for kernel execution. + int64_t exec_chunksize() const { return exec_chunksize_; } + + /// \brief Set whether to use multiple threads for function execution + void set_use_threads(bool use_threads = true) { use_threads_ = use_threads; } + + /// \brief If true, then utilize multiple threads where relevant for function + /// execution + bool use_threads() const { return use_threads_; } + + // Set the preallocation strategy for kernel execution as it relates to + // chunked execution. For chunked execution, whether via ChunkedArray inputs + // or splitting larger Array arguments into smaller pieces, contiguous + // allocation (if permitted by the kernel) will allocate one large array to + // write output into yielding it to the caller at the end. If this option is + // set to off, then preallocations will be performed independently for each + // chunk of execution + // + // TODO: At some point we might want the limit the size of contiguous + // preallocations (for example, merging small ChunkedArray chunks until + // reaching some desired size) + void set_preallocate_contiguous(bool preallocate) { + preallocate_contiguous_ = preallocate; + } + + bool preallocate_contiguous() const { return preallocate_contiguous_; } + + private: + MemoryPool* pool_; + FunctionRegistry* func_registry_; + int64_t exec_chunksize_ = std::numeric_limits::max(); + bool preallocate_contiguous_ = true; + bool use_threads_ = true; +}; + +// TODO: Consider standardizing on uint16 selection vectors and only use them +// when we can ensure that each value is 64K length or smaller + +/// \brief Container for an array of value selection indices that were +/// materialized from a filter. +/// +/// Columnar query engines (see e.g. [1]) have found that rather than +/// materializing filtered data, the filter can instead be converted to an +/// array of the "on" indices and then "fusing" these indices in operator +/// implementations. This is especially relevant for aggregations but also +/// applies to scalar operations. +/// +/// We are not yet using this so this is mostly a placeholder for now. +/// +/// [1]: http://cidrdb.org/cidr2005/papers/P19.pdf +class ARROW_EXPORT SelectionVector { + public: + explicit SelectionVector(std::shared_ptr data); + + explicit SelectionVector(const Array& arr); + + /// \brief Create SelectionVector from boolean mask + static Result> FromMask(const BooleanArray& arr); + + const int32_t* indices() const { return indices_; } + int32_t length() const; + + private: + std::shared_ptr data_; + const int32_t* indices_; +}; + +/// \brief A unit of work for kernel execution. It contains a collection of +/// Array and Scalar values and an optional SelectionVector indicating that +/// there is an unmaterialized filter that either must be materialized, or (if +/// the kernel supports it) pushed down into the kernel implementation. +struct ExecBatch { + ExecBatch() {} + ExecBatch(std::vector values, int64_t length) + : values(std::move(values)), length(length) {} + + std::vector values; + std::shared_ptr selection_vector; + int64_t length; + + template + inline const Datum& operator[](index_type i) const { + return values[i]; + } + + int num_values() const { return static_cast(values.size()); } + + std::vector GetDescriptors() const { + std::vector result; + for (const auto& value : this->values) { + result.emplace_back(value.descr()); + } + return result; + } +}; + +/// \brief One-shot invoker for all types of functions. Does kernel dispatch, +/// argument checking, iteration of ChunkedArray inputs, and wrapping of +/// outputs +ARROW_EXPORT +Result CallFunction(ExecContext* ctx, const std::string& func_name, + const std::vector& args, + const FunctionOptions* options = NULLPTR); + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/exec_internal.h b/cpp/src/arrow/compute/exec_internal.h new file mode 100644 index 00000000000..507cd1703a8 --- /dev/null +++ b/cpp/src/arrow/compute/exec_internal.h @@ -0,0 +1,137 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include +#include +#include + +#include "arrow/array.h" +#include "arrow/buffer.h" +#include "arrow/compute/exec.h" +#include "arrow/compute/kernel.h" +#include "arrow/status.h" +#include "arrow/util/visibility.h" + +namespace arrow { +namespace compute { + +class Function; + +static constexpr int64_t kDefaultMaxChunksize = std::numeric_limits::max(); + +namespace detail { + +/// \brief Break std::vector into a sequence of ExecBatch for kernel +/// execution +class ARROW_EXPORT ExecBatchIterator { + public: + /// \brief Construct iterator and do basic argument validation + /// + /// \param[in] args the Datum argument, must be all array-like or scalar + /// \param[in] max_chunksize the maximum length of each ExecBatch. Depending + /// on the chunk layout of ChunkedArray. + static Result> Make( + std::vector args, int64_t max_chunksize = kDefaultMaxChunksize); + + /// \brief Compute the next batch. Always returns at least one batch. Return + /// false if the iterator is exhausted + bool Next(ExecBatch* batch); + + int64_t length() const { return length_; } + + int64_t position() const { return position_; } + + int64_t max_chunksize() const { return max_chunksize_; } + + private: + ExecBatchIterator(std::vector args, int64_t length, int64_t max_chunksize); + + std::vector args_; + std::vector chunk_indexes_; + std::vector chunk_positions_; + int64_t position_; + int64_t length_; + int64_t max_chunksize_; +}; + +// "Push" / listener API like IPC reader so that consumers can receive +// processed chunks as soon as they're available. + +class ARROW_EXPORT ExecListener { + public: + virtual ~ExecListener() = default; + + virtual Status OnResult(Datum) { return Status::NotImplemented("OnResult"); } +}; + +class DatumAccumulator : public ExecListener { + public: + DatumAccumulator() {} + + Status OnResult(Datum value) override { + values_.emplace_back(value); + return Status::OK(); + } + + std::vector values() const { return values_; } + + private: + std::vector values_; +}; + +/// \brief Check that each Datum is of a "value" type, which means either +/// SCALAR, ARRAY, or CHUNKED_ARRAY. If there are chunked inputs, then these +/// inputs will be split into non-chunked ExecBatch values for execution +Status CheckAllValues(const std::vector& values); + +class ARROW_EXPORT FunctionExecutor { + public: + virtual ~FunctionExecutor() = default; + + /// XXX: Better configurability for listener + /// Not thread-safe + virtual Status Execute(const std::vector& args, ExecListener* listener) = 0; + + virtual ValueDescr output_descr() const = 0; + + virtual Datum WrapResults(const std::vector& args, + const std::vector& outputs) = 0; + + static Result> Make(ExecContext* ctx, + const Function* func, + const FunctionOptions* options); +}; + +/// \brief Populate validity bitmap with the intersection of the nullity of the +/// arguments. If a preallocated bitmap is not provided, then one will be +/// allocated if needed (in some cases a bitmap can be zero-copied from the +/// arguments). If any Scalar value is null, then the entire validity bitmap +/// will be set to null. +/// +/// \param[in] ctx kernel execution context, for memory allocation etc. +/// \param[in] batch the data batch +/// \param[in] out the output ArrayData, must not be null +ARROW_EXPORT +Status PropagateNulls(KernelContext* ctx, const ExecBatch& batch, ArrayData* out); + +} // namespace detail +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/exec_test.cc b/cpp/src/arrow/compute/exec_test.cc new file mode 100644 index 00000000000..933c260344e --- /dev/null +++ b/cpp/src/arrow/compute/exec_test.cc @@ -0,0 +1,841 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include +#include + +#include + +#include "arrow/testing/gtest_util.h" +#include "arrow/testing/random.h" + +#include "arrow/array.h" +#include "arrow/buffer.h" +#include "arrow/compute/exec.h" +#include "arrow/compute/exec_internal.h" +#include "arrow/compute/function.h" +#include "arrow/compute/kernel.h" +#include "arrow/compute/registry.h" +#include "arrow/memory_pool.h" +#include "arrow/scalar.h" +#include "arrow/status.h" +#include "arrow/table.h" +#include "arrow/type.h" +#include "arrow/util/bit_util.h" +#include "arrow/util/checked_cast.h" +#include "arrow/util/cpu_info.h" +#include "arrow/util/logging.h" + +namespace arrow { + +using internal::checked_cast; + +namespace compute { +namespace detail { + +TEST(ExecContext, BasicWorkings) { + { + ExecContext ctx; + ASSERT_EQ(GetFunctionRegistry(), ctx.func_registry()); + ASSERT_EQ(default_memory_pool(), ctx.memory_pool()); + ASSERT_EQ(std::numeric_limits::max(), ctx.exec_chunksize()); + + ASSERT_TRUE(ctx.use_threads()); + ASSERT_EQ(internal::CpuInfo::GetInstance(), ctx.cpu_info()); + } + + // Now, let's customize all the things + LoggingMemoryPool my_pool(default_memory_pool()); + std::unique_ptr custom_reg = FunctionRegistry::Make(); + ExecContext ctx(&my_pool, custom_reg.get()); + + ASSERT_EQ(custom_reg.get(), ctx.func_registry()); + ASSERT_EQ(&my_pool, ctx.memory_pool()); + + ctx.set_exec_chunksize(1 << 20); + ASSERT_EQ(1 << 20, ctx.exec_chunksize()); + + ctx.set_use_threads(false); + ASSERT_FALSE(ctx.use_threads()); +} + +TEST(SelectionVector, Basics) { + auto indices = ArrayFromJSON(int32(), "[0, 3]"); + auto sel_vector = std::make_shared(*indices); + + ASSERT_EQ(indices->length(), sel_vector->length()); + ASSERT_EQ(3, sel_vector->indices()[1]); +} + +void AssertValidityZeroExtraBits(const ArrayData& arr) { + const Buffer& buf = *arr.buffers[0]; + + const int64_t bit_extent = ((arr.offset + arr.length + 7) / 8) * 8; + for (int64_t i = arr.offset + arr.length; i < bit_extent; ++i) { + EXPECT_FALSE(BitUtil::GetBit(buf.data(), i)) << i; + } +} + +class TestComputeInternals : public ::testing::Test { + public: + void SetUp() { + registry_ = FunctionRegistry::Make(); + rng_.reset(new random::RandomArrayGenerator(/*seed=*/0)); + ResetContexts(); + } + + void ResetContexts() { + exec_ctx_.reset(new ExecContext(default_memory_pool(), registry_.get())); + ctx_.reset(new KernelContext(exec_ctx_.get())); + } + + std::shared_ptr GetUInt8Array(int64_t size, double null_probability = 0.1) { + return rng_->UInt8(size, /*min=*/0, /*max=*/100, null_probability); + } + + std::shared_ptr GetInt32Array(int64_t size, double null_probability = 0.1) { + return rng_->Int32(size, /*min=*/0, /*max=*/1000, null_probability); + } + + std::shared_ptr GetFloat64Array(int64_t size, double null_probability = 0.1) { + return rng_->Float64(size, /*min=*/0, /*max=*/1000, null_probability); + } + + std::shared_ptr GetInt32Chunked(const std::vector& sizes) { + std::vector> chunks; + for (auto size : sizes) { + chunks.push_back(GetInt32Array(size)); + } + return std::make_shared(std::move(chunks)); + } + + protected: + std::unique_ptr exec_ctx_; + std::unique_ptr ctx_; + std::unique_ptr registry_; + std::unique_ptr rng_; +}; + +class TestPropagateNulls : public TestComputeInternals {}; + +TEST_F(TestPropagateNulls, UnknownNullCountWithNullsZeroCopies) { + const int64_t length = 16; + + constexpr uint8_t validity_bitmap[8] = {254, 0, 0, 0, 0, 0, 0, 0}; + auto nulls = std::make_shared(validity_bitmap, 8); + + ArrayData output(boolean(), length, {nullptr, nullptr}); + ArrayData input(boolean(), length, {nulls, nullptr}, kUnknownNullCount); + + ExecBatch batch({input}, length); + ASSERT_OK(PropagateNulls(ctx_.get(), batch, &output)); + ASSERT_EQ(nulls.get(), output.buffers[0].get()); + ASSERT_EQ(kUnknownNullCount, output.null_count); + ASSERT_EQ(9, output.GetNullCount()); +} + +TEST_F(TestPropagateNulls, UnknownNullCountWithoutNulls) { + const int64_t length = 16; + constexpr uint8_t validity_bitmap[8] = {255, 255, 0, 0, 0, 0, 0, 0}; + auto nulls = std::make_shared(validity_bitmap, 8); + + ArrayData output(boolean(), length, {nullptr, nullptr}); + ArrayData input(boolean(), length, {nulls, nullptr}, kUnknownNullCount); + + ExecBatch batch({input}, length); + ASSERT_OK(PropagateNulls(ctx_.get(), batch, &output)); + EXPECT_EQ(-1, output.null_count); + EXPECT_EQ(nulls.get(), output.buffers[0].get()); +} + +TEST_F(TestPropagateNulls, SetAllNulls) { + const int64_t length = 16; + + auto CheckSetAllNull = [&](std::vector values, bool preallocate) { + // Make fresh bitmap with all 1's + uint8_t bitmap_data[2] = {255, 255}; + auto preallocated_mem = std::make_shared(bitmap_data, 2); + + std::vector> buffers(2); + if (preallocate) { + buffers[0] = preallocated_mem; + } + + ArrayData output(boolean(), length, buffers); + + ExecBatch batch(values, length); + ASSERT_OK(PropagateNulls(ctx_.get(), batch, &output)); + + if (preallocate) { + // Ensure that buffer object the same when we pass in preallocated memory + ASSERT_EQ(preallocated_mem.get(), output.buffers[0].get()); + } + ASSERT_NE(nullptr, output.buffers[0]); + uint8_t expected[2] = {0, 0}; + const Buffer& out_buf = *output.buffers[0]; + ASSERT_EQ(0, std::memcmp(out_buf.data(), expected, out_buf.size())); + }; + + // There is a null scalar + std::shared_ptr i32_val = std::make_shared(3); + std::vector vals = {i32_val, MakeNullScalar(boolean())}; + CheckSetAllNull(vals, true); + CheckSetAllNull(vals, false); + + const double true_prob = 0.5; + + vals[0] = rng_->Boolean(length, true_prob); + CheckSetAllNull(vals, true); + CheckSetAllNull(vals, false); + + auto arr_all_nulls = rng_->Boolean(length, true_prob, /*null_probability=*/1); + + // One value is all null + vals = {rng_->Boolean(length, true_prob, /*null_probability=*/0.5), arr_all_nulls}; + CheckSetAllNull(vals, true); + CheckSetAllNull(vals, false); + + // A value is NullType + std::shared_ptr null_arr = std::make_shared(length); + vals = {rng_->Boolean(length, true_prob), null_arr}; + CheckSetAllNull(vals, true); + CheckSetAllNull(vals, false); + + // Other nitty-gritty scenarios + { + // An all-null bitmap is zero-copied over, even though there is a + // null-scalar earlier in the batch + ArrayData output(boolean(), length, {nullptr, nullptr}); + ExecBatch batch({MakeNullScalar(boolean()), arr_all_nulls}, length); + ASSERT_OK(PropagateNulls(ctx_.get(), batch, &output)); + ASSERT_EQ(arr_all_nulls->data()->buffers[0].get(), output.buffers[0].get()); + } +} + +TEST_F(TestPropagateNulls, SingleValueWithNulls) { + // Input offset is non-zero (0 mod 8 and nonzero mod 8 cases) + const int64_t length = 100; + auto arr = rng_->Boolean(length, 0.5, /*null_probability=*/0.5); + + auto CheckSliced = [&](int64_t offset, bool preallocate = false, + int64_t out_offset = 0) { + // Unaligned bitmap, zero copy not possible + auto sliced = arr->Slice(offset); + std::vector vals = {sliced}; + + ArrayData output(boolean(), vals[0].length(), {nullptr, nullptr}); + output.offset = out_offset; + + ExecBatch batch(vals, vals[0].length()); + + std::shared_ptr preallocated_bitmap; + if (preallocate) { + ASSERT_OK_AND_ASSIGN( + preallocated_bitmap, + AllocateBuffer(BitUtil::BytesForBits(sliced->length() + out_offset))); + std::memset(preallocated_bitmap->mutable_data(), 0, preallocated_bitmap->size()); + output.buffers[0] = preallocated_bitmap; + } else { + ASSERT_EQ(0, output.offset); + } + + ASSERT_OK(PropagateNulls(ctx_.get(), batch, &output)); + + if (!preallocate) { + const Buffer* parent_buf = arr->data()->buffers[0].get(); + if (offset == 0) { + // Validity bitmap same, no slice + ASSERT_EQ(parent_buf, output.buffers[0].get()); + } else if (offset % 8 == 0) { + // Validity bitmap sliced + ASSERT_NE(parent_buf, output.buffers[0].get()); + ASSERT_EQ(parent_buf, output.buffers[0]->parent().get()); + } else { + // New memory for offset not 0 mod 8 + ASSERT_NE(parent_buf, output.buffers[0].get()); + ASSERT_EQ(nullptr, output.buffers[0]->parent()); + } + } else { + // preallocated, so check that the validity bitmap is unbothered + ASSERT_EQ(preallocated_bitmap.get(), output.buffers[0].get()); + } + + ASSERT_EQ(arr->Slice(offset)->null_count(), output.GetNullCount()); + + ASSERT_TRUE(internal::BitmapEquals(output.buffers[0]->data(), output.offset, + sliced->null_bitmap_data(), sliced->offset(), + output.length)); + AssertValidityZeroExtraBits(output); + }; + + CheckSliced(8); + CheckSliced(7); + CheckSliced(8, /*preallocated=*/true); + CheckSliced(7, true); + CheckSliced(8, true, /*offset=*/4); + CheckSliced(7, true, 4); +} + +TEST_F(TestPropagateNulls, ZeroCopyWhenZeroNullsOnOneInput) { + const int64_t length = 16; + + constexpr uint8_t validity_bitmap[8] = {254, 0, 0, 0, 0, 0, 0, 0}; + auto nulls = std::make_shared(validity_bitmap, 8); + + ArrayData some_nulls(boolean(), 16, {nulls, nullptr}, /*null_count=*/9); + ArrayData no_nulls(boolean(), length, {nullptr, nullptr}, /*null_count=*/0); + + ArrayData output(boolean(), length, {nullptr, nullptr}); + ExecBatch batch({some_nulls, no_nulls}, length); + ASSERT_OK(PropagateNulls(ctx_.get(), batch, &output)); + ASSERT_EQ(nulls.get(), output.buffers[0].get()); + ASSERT_EQ(9, output.null_count); + + // Flip order of args + output = ArrayData(boolean(), length, {nullptr, nullptr}); + batch.values = {no_nulls, no_nulls, some_nulls}; + ASSERT_OK(PropagateNulls(ctx_.get(), batch, &output)); + ASSERT_EQ(nulls.get(), output.buffers[0].get()); + ASSERT_EQ(9, output.null_count); + + // Check that preallocated memory is not clobbered + uint8_t bitmap_data[2] = {0, 0}; + auto preallocated_mem = std::make_shared(bitmap_data, 2); + output.null_count = kUnknownNullCount; + output.buffers[0] = preallocated_mem; + ASSERT_OK(PropagateNulls(ctx_.get(), batch, &output)); + + ASSERT_EQ(preallocated_mem.get(), output.buffers[0].get()); + ASSERT_EQ(9, output.null_count); + ASSERT_EQ(254, bitmap_data[0]); + ASSERT_EQ(0, bitmap_data[1]); +} + +TEST_F(TestPropagateNulls, IntersectsNulls) { + const int64_t length = 16; + + // 0b01111111 0b11001111 + constexpr uint8_t bitmap1[8] = {127, 207, 0, 0, 0, 0, 0, 0}; + + // 0b11111110 0b01111111 + constexpr uint8_t bitmap2[8] = {254, 127, 0, 0, 0, 0, 0, 0}; + + // 0b11101111 0b11111110 + constexpr uint8_t bitmap3[8] = {239, 254, 0, 0, 0, 0, 0, 0}; + + ArrayData arr1(boolean(), length, {std::make_shared(bitmap1, 8), nullptr}); + ArrayData arr2(boolean(), length, {std::make_shared(bitmap2, 8), nullptr}); + ArrayData arr3(boolean(), length, {std::make_shared(bitmap3, 8), nullptr}); + + auto CheckCase = [&](std::vector values, int64_t ex_null_count, + const uint8_t* ex_bitmap, bool preallocate = false, + int64_t output_offset = 0) { + ExecBatch batch(values, length); + + std::shared_ptr nulls; + if (preallocate) { + // Make the buffer one byte bigger so we can have non-zero offsets + ASSERT_OK_AND_ASSIGN(nulls, AllocateBuffer(3)); + std::memset(nulls->mutable_data(), 0, nulls->size()); + } else { + // non-zero output offset not permitted unless the output memory is + // preallocated + ASSERT_EQ(0, output_offset); + } + ArrayData output(boolean(), length, {nulls, nullptr}); + output.offset = output_offset; + + ASSERT_OK(PropagateNulls(ctx_.get(), batch, &output)); + + // Preallocated memory used + if (preallocate) { + ASSERT_EQ(nulls.get(), output.buffers[0].get()); + } + + EXPECT_EQ(kUnknownNullCount, output.null_count); + EXPECT_EQ(ex_null_count, output.GetNullCount()); + + const auto& out_buffer = *output.buffers[0]; + + ASSERT_TRUE(internal::BitmapEquals(out_buffer.data(), output_offset, ex_bitmap, + /*ex_offset=*/0, length)); + + // Now check that the rest of the bits in out_buffer are still 0 + AssertValidityZeroExtraBits(output); + }; + + // 0b01101110 0b01001110 + uint8_t expected1[2] = {110, 78}; + CheckCase({arr1, arr2, arr3}, 7, expected1); + CheckCase({arr1, arr2, arr3}, 7, expected1, /*preallocate=*/true); + CheckCase({arr1, arr2, arr3}, 7, expected1, /*preallocate=*/true, + /*output_offset=*/4); + + // 0b01111110 0b01001111 + uint8_t expected2[2] = {126, 79}; + CheckCase({arr1, arr2}, 5, expected2); + CheckCase({arr1, arr2}, 5, expected2, /*preallocate=*/true, + /*output_offset=*/4); +} + +TEST_F(TestPropagateNulls, NullOutputTypeNoop) { + // Ensure we leave the buffers alone when the output type is null() + const int64_t length = 100; + ExecBatch batch({rng_->Boolean(100, 0.5, 0.5)}, length); + + ArrayData output(null(), length, {nullptr}); + ASSERT_OK(PropagateNulls(ctx_.get(), batch, &output)); + ASSERT_EQ(nullptr, output.buffers[0]); +} + +// ---------------------------------------------------------------------- +// ExecBatchIterator + +class TestExecBatchIterator : public TestComputeInternals { + public: + void SetupIterator(std::vector args, + int64_t max_chunksize = kDefaultMaxChunksize) { + ASSERT_OK_AND_ASSIGN(iterator_, + ExecBatchIterator::Make(std::move(args), max_chunksize)); + } + void CheckIteration(const std::vector& args, int chunksize, + const std::vector& ex_batch_sizes) { + SetupIterator(args, chunksize); + ExecBatch batch; + int64_t position = 0; + for (size_t i = 0; i < ex_batch_sizes.size(); ++i) { + ASSERT_EQ(position, iterator_->position()); + ASSERT_TRUE(iterator_->Next(&batch)); + ASSERT_EQ(ex_batch_sizes[i], batch.length); + + for (size_t j = 0; j < args.size(); ++j) { + switch (args[j].kind()) { + case Datum::SCALAR: + ASSERT_TRUE(args[j].scalar()->Equals(batch[j].scalar())); + break; + case Datum::ARRAY: + AssertArraysEqual(*args[j].make_array()->Slice(position, batch.length), + *batch[j].make_array()); + break; + case Datum::CHUNKED_ARRAY: { + const ChunkedArray& carr = *args[j].chunked_array(); + if (batch.length == 0) { + ASSERT_EQ(0, carr.length()); + } else { + auto arg_slice = carr.Slice(position, batch.length); + // The sliced ChunkedArrays should only ever be 1 chunk + ASSERT_EQ(1, arg_slice->num_chunks()); + AssertArraysEqual(*arg_slice->chunk(0), *batch[j].make_array()); + } + } break; + default: + break; + } + } + position += ex_batch_sizes[i]; + } + // Ensure that the iterator is exhausted + ASSERT_FALSE(iterator_->Next(&batch)); + + ASSERT_EQ(iterator_->length(), iterator_->position()); + } + + protected: + std::unique_ptr iterator_; +}; + +TEST_F(TestExecBatchIterator, Basics) { + const int64_t length = 100; + + // Simple case with a single chunk + std::vector args = {Datum(GetInt32Array(length)), Datum(GetFloat64Array(length)), + Datum(std::make_shared(3))}; + SetupIterator(args); + + ExecBatch batch; + ASSERT_TRUE(iterator_->Next(&batch)); + ASSERT_EQ(3, batch.values.size()); + ASSERT_EQ(3, batch.num_values()); + ASSERT_EQ(length, batch.length); + + std::vector descrs = batch.GetDescriptors(); + ASSERT_EQ(ValueDescr::Array(int32()), descrs[0]); + ASSERT_EQ(ValueDescr::Array(float64()), descrs[1]); + ASSERT_EQ(ValueDescr::Scalar(int32()), descrs[2]); + + AssertArraysEqual(*args[0].make_array(), *batch[0].make_array()); + AssertArraysEqual(*args[1].make_array(), *batch[1].make_array()); + ASSERT_TRUE(args[2].scalar()->Equals(batch[2].scalar())); + + ASSERT_EQ(length, iterator_->position()); + ASSERT_FALSE(iterator_->Next(&batch)); + + // Split into chunks of size 16 + CheckIteration(args, /*chunksize=*/16, {16, 16, 16, 16, 16, 16, 4}); +} + +TEST_F(TestExecBatchIterator, InputValidation) { + std::vector args = {Datum(GetInt32Array(10)), Datum(GetInt32Array(9))}; + ASSERT_RAISES(Invalid, ExecBatchIterator::Make(args)); + + args = {Datum(GetInt32Array(9)), Datum(GetInt32Array(10))}; + ASSERT_RAISES(Invalid, ExecBatchIterator::Make(args)); + + args = {Datum(GetInt32Array(10))}; + ASSERT_OK_AND_ASSIGN(auto iterator, ExecBatchIterator::Make(args)); + ASSERT_EQ(10, iterator->max_chunksize()); +} + +TEST_F(TestExecBatchIterator, ChunkedArrays) { + std::vector args = {Datum(GetInt32Chunked({0, 20, 10})), + Datum(GetInt32Chunked({15, 15})), Datum(GetInt32Array(30)), + Datum(std::make_shared(5)), + Datum(MakeNullScalar(boolean()))}; + + CheckIteration(args, /*chunksize=*/10, {10, 5, 5, 10}); + CheckIteration(args, /*chunksize=*/20, {15, 5, 10}); + CheckIteration(args, /*chunksize=*/30, {15, 5, 10}); +} + +TEST_F(TestExecBatchIterator, ZeroLengthInputs) { + auto carr = std::shared_ptr(new ChunkedArray({}, int32())); + + auto CheckArgs = [&](const std::vector& args) { + auto iterator = ExecBatchIterator::Make(args).ValueOrDie(); + ExecBatch batch; + ASSERT_FALSE(iterator->Next(&batch)); + }; + + // Zero-length ChunkedArray with zero chunks + std::vector args = {Datum(carr)}; + CheckArgs(args); + + // Zero-length array + args = {Datum(GetInt32Array(0))}; + CheckArgs(args); + + // ChunkedArray with single empty chunk + args = {Datum(GetInt32Chunked({0}))}; + CheckArgs(args); +} + +// ---------------------------------------------------------------------- +// Scalar function execution + +void ExecCopy(KernelContext*, const ExecBatch& batch, Datum* out) { + DCHECK_EQ(1, batch.num_values()); + const auto& type = checked_cast(*batch[0].type()); + int value_size = type.bit_width() / 8; + + const ArrayData& arg0 = *batch[0].array(); + ArrayData* out_arr = out->mutable_array(); + uint8_t* dst = out_arr->buffers[1]->mutable_data() + out_arr->offset * value_size; + const uint8_t* src = arg0.buffers[1]->data() + arg0.offset * value_size; + std::memcpy(dst, src, batch.length * value_size); +} + +void ExecComputedBitmap(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + // Propagate nulls not used. Check that the out bitmap isn't the same already + // as the input bitmap + const ArrayData& arg0 = *batch[0].array(); + ArrayData* out_arr = out->mutable_array(); + + if (internal::CountSetBits(arg0.buffers[0]->data(), arg0.offset, batch.length) > 0) { + // Check that the bitmap has not been already copied over + DCHECK(!internal::BitmapEquals(arg0.buffers[0]->data(), arg0.offset, + out_arr->buffers[0]->data(), out_arr->offset, + batch.length)); + } + internal::CopyBitmap(arg0.buffers[0]->data(), arg0.offset, batch.length, + out_arr->buffers[0]->mutable_data(), out_arr->offset); + ExecCopy(ctx, batch, out); +} + +void ExecNoPreallocatedData(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + // Validity preallocated, but not the data + ArrayData* out_arr = out->mutable_array(); + DCHECK_EQ(0, out_arr->offset); + const auto& type = checked_cast(*batch[0].type()); + int value_size = type.bit_width() / 8; + Status s = (ctx->Allocate(out_arr->length * value_size).Value(&out_arr->buffers[1])); + DCHECK_OK(s); + ExecCopy(ctx, batch, out); +} + +void ExecNoPreallocatedAnything(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + // Neither validity nor data preallocated + ArrayData* out_arr = out->mutable_array(); + DCHECK_EQ(0, out_arr->offset); + Status s = (ctx->AllocateBitmap(out_arr->length).Value(&out_arr->buffers[0])); + DCHECK_OK(s); + const ArrayData& arg0 = *batch[0].array(); + internal::CopyBitmap(arg0.buffers[0]->data(), arg0.offset, batch.length, + out_arr->buffers[0]->mutable_data(), /*offset=*/0); + + // Reuse the kernel that allocates the data + ExecNoPreallocatedData(ctx, batch, out); +} + +struct ExampleOptions : public FunctionOptions { + std::shared_ptr value; + explicit ExampleOptions(std::shared_ptr value) : value(std::move(value)) {} +}; + +struct ExampleState : public KernelState { + std::shared_ptr value; + explicit ExampleState(std::shared_ptr value) : value(std::move(value)) {} +}; + +std::unique_ptr InitStateful(KernelContext*, const KernelInitArgs& args) { + auto func_options = static_cast(args.options); + return std::unique_ptr(new ExampleState{func_options->value}); +} + +void ExecStateful(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + // We take the value from the state and multiply the data in batch[0] with it + ExampleState* state = static_cast(ctx->state()); + int32_t multiplier = checked_cast(*state->value).value; + + const ArrayData& arg0 = *batch[0].array(); + ArrayData* out_arr = out->mutable_array(); + const int32_t* arg0_data = arg0.GetValues(1); + int32_t* dst = out_arr->GetMutableValues(1); + for (int64_t i = 0; i < arg0.length; ++i) { + dst[i] = arg0_data[i] * multiplier; + } +} + +void ExecAddInt32(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + const Int32Scalar& arg0 = batch[0].scalar_as(); + const Int32Scalar& arg1 = batch[1].scalar_as(); + out->value = std::make_shared(arg0.value + arg1.value); +} + +class TestCallScalarFunction : public TestComputeInternals { + public: + void SetUp() { + TestComputeInternals::SetUp(); + + AddCopyFunctions(); + AddNoPreallocateFunctions(); + AddStatefulFunction(); + AddScalarFunction(); + } + + void AddCopyFunctions() { + // This function simply copies memory from the input argument into the + // (preallocated) output + auto func = std::make_shared("copy", 1); + + // Add a few kernels. Our implementation only accepts arrays + ASSERT_OK(func->AddKernel({InputType::Array(uint8())}, uint8(), ExecCopy)); + ASSERT_OK(func->AddKernel({InputType::Array(int32())}, int32(), ExecCopy)); + ASSERT_OK(func->AddKernel({InputType::Array(float64())}, float64(), ExecCopy)); + ASSERT_OK(registry_->AddFunction(func)); + + // A version which doesn't want the executor to call PropagateNulls + auto func2 = std::make_shared("copy_computed_bitmap", 1); + ScalarKernel kernel({InputType::Array(uint8())}, uint8(), ExecComputedBitmap); + kernel.null_handling = NullHandling::COMPUTED_PREALLOCATE; + ASSERT_OK(func2->AddKernel(kernel)); + ASSERT_OK(registry_->AddFunction(func2)); + } + + void AddNoPreallocateFunctions() { + // A function that allocates its own output memory. We have cases for both + // non-preallocated data and non-preallocated validity bitmap + auto f1 = std::make_shared("nopre_data", 1); + auto f2 = std::make_shared("nopre_validity_or_data", 1); + + ScalarKernel kernel({InputType::Array(uint8())}, uint8(), ExecNoPreallocatedData); + kernel.mem_allocation = MemAllocation::NO_PREALLOCATE; + ASSERT_OK(f1->AddKernel(kernel)); + + kernel.exec = ExecNoPreallocatedAnything; + kernel.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE; + ASSERT_OK(f2->AddKernel(kernel)); + + ASSERT_OK(registry_->AddFunction(f1)); + ASSERT_OK(registry_->AddFunction(f2)); + } + + void AddStatefulFunction() { + // This function's behavior depends on a static parameter that is made + // available to the kernel's execution function through its Options object + auto func = std::make_shared("stateful", 1); + + ScalarKernel kernel({InputType::Array(int32())}, int32(), ExecStateful, InitStateful); + ASSERT_OK(func->AddKernel(kernel)); + ASSERT_OK(registry_->AddFunction(func)); + } + + void AddScalarFunction() { + auto func = std::make_shared("scalar_add_int32", 2); + ASSERT_OK(func->AddKernel({InputType::Scalar(int32()), InputType::Scalar(int32())}, + int32(), ExecAddInt32)); + ASSERT_OK(registry_->AddFunction(func)); + } +}; + +TEST_F(TestCallScalarFunction, ArgumentValidation) { + // Copy accepts only a single array argument + Datum d1(GetInt32Array(10)); + + // Too many args + std::vector args = {d1, d1}; + ASSERT_RAISES(Invalid, CallFunction(exec_ctx_.get(), "copy", args)); + + // Too few + args = {}; + ASSERT_RAISES(Invalid, CallFunction(exec_ctx_.get(), "copy", args)); + + // Cannot do scalar + args = {Datum(std::make_shared(5))}; + ASSERT_RAISES(NotImplemented, CallFunction(exec_ctx_.get(), "copy", args)); +} + +TEST_F(TestCallScalarFunction, PreallocationCases) { + double null_prob = 0.2; + + auto arr = GetUInt8Array(1000, null_prob); + + auto CheckFunction = [&](std::string func_name) { + ResetContexts(); + + // The default should be a single array output + { + std::vector args = {Datum(arr)}; + ASSERT_OK_AND_ASSIGN(Datum result, CallFunction(exec_ctx_.get(), func_name, args)); + ASSERT_EQ(Datum::ARRAY, result.kind()); + AssertArraysEqual(*arr, *result.make_array()); + } + + // Set the exec_chunksize to be smaller, so now we have several invocations + // of the kernel, but still the output is onee array + { + std::vector args = {Datum(arr)}; + exec_ctx_->set_exec_chunksize(80); + ASSERT_OK_AND_ASSIGN(Datum result, CallFunction(exec_ctx_.get(), func_name, args)); + AssertArraysEqual(*arr, *result.make_array()); + } + + exec_ctx_->set_exec_chunksize(12); + + // Chunksize not multiple of 8 + { + std::vector args = {Datum(arr)}; + exec_ctx_->set_exec_chunksize(111); + ASSERT_OK_AND_ASSIGN(Datum result, CallFunction(exec_ctx_.get(), func_name, args)); + AssertArraysEqual(*arr, *result.make_array()); + } + + // Input is chunked, output has one big chunk + { + auto carr = std::shared_ptr( + new ChunkedArray({arr->Slice(0, 100), arr->Slice(100)})); + std::vector args = {Datum(carr)}; + ASSERT_OK_AND_ASSIGN(Datum result, CallFunction(exec_ctx_.get(), func_name, args)); + std::shared_ptr actual = result.chunked_array(); + ASSERT_EQ(1, actual->num_chunks()); + AssertChunkedEquivalent(*carr, *actual); + } + + // Preallocate independently for each batch + { + std::vector args = {Datum(arr)}; + exec_ctx_->set_preallocate_contiguous(false); + exec_ctx_->set_exec_chunksize(400); + ASSERT_OK_AND_ASSIGN(Datum result, CallFunction(exec_ctx_.get(), func_name, args)); + ASSERT_EQ(Datum::CHUNKED_ARRAY, result.kind()); + const ChunkedArray& carr = *result.chunked_array(); + ASSERT_EQ(3, carr.num_chunks()); + AssertArraysEqual(*arr->Slice(0, 400), *carr.chunk(0)); + AssertArraysEqual(*arr->Slice(400, 400), *carr.chunk(1)); + AssertArraysEqual(*arr->Slice(800), *carr.chunk(2)); + } + }; + + CheckFunction("copy"); + CheckFunction("copy_computed_bitmap"); +} + +TEST_F(TestCallScalarFunction, BasicNonStandardCases) { + // Test a handful of cases + // + // * Validity bitmap computed by kernel rather than using PropagateNulls + // * Data not pre-allocated + // * Validity bitmap not pre-allocated + + double null_prob = 0.2; + + auto arr = GetUInt8Array(1000, null_prob); + std::vector args = {Datum(arr)}; + + auto CheckFunction = [&](std::string func_name) { + ResetContexts(); + + // The default should be a single array output + { + exec_ctx_->set_exec_chunksize(kDefaultMaxChunksize); + ASSERT_OK_AND_ASSIGN(Datum result, CallFunction(exec_ctx_.get(), func_name, args)); + AssertArraysEqual(*arr, *result.make_array(), true); + } + + // Split execution into 3 chunks + { + exec_ctx_->set_exec_chunksize(400); + ASSERT_OK_AND_ASSIGN(Datum result, CallFunction(exec_ctx_.get(), func_name, args)); + ASSERT_EQ(Datum::CHUNKED_ARRAY, result.kind()); + const ChunkedArray& carr = *result.chunked_array(); + ASSERT_EQ(3, carr.num_chunks()); + AssertArraysEqual(*arr->Slice(0, 400), *carr.chunk(0)); + AssertArraysEqual(*arr->Slice(400, 400), *carr.chunk(1)); + AssertArraysEqual(*arr->Slice(800), *carr.chunk(2)); + } + }; + + CheckFunction("nopre_data"); + CheckFunction("nopre_validity_or_data"); +} + +TEST_F(TestCallScalarFunction, StatefulKernel) { + auto input = ArrayFromJSON(int32(), "[1, 2, 3, null, 5]"); + auto multiplier = std::make_shared(2); + auto expected = ArrayFromJSON(int32(), "[2, 4, 6, null, 10]"); + + ExampleOptions options(multiplier); + std::vector args = {Datum(input)}; + ASSERT_OK_AND_ASSIGN(Datum result, + CallFunction(exec_ctx_.get(), "stateful", args, &options)); + AssertArraysEqual(*expected, *result.make_array()); +} + +TEST_F(TestCallScalarFunction, ScalarFunction) { + std::vector args = {Datum(std::make_shared(5)), + Datum(std::make_shared(7))}; + ASSERT_OK_AND_ASSIGN(Datum result, + CallFunction(exec_ctx_.get(), "scalar_add_int32", args)); + ASSERT_EQ(Datum::SCALAR, result.kind()); + + auto expected = std::make_shared(12); + ASSERT_TRUE(expected->Equals(*result.scalar())); +} + +} // namespace detail +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/function.cc b/cpp/src/arrow/compute/function.cc new file mode 100644 index 00000000000..881cd229b3b --- /dev/null +++ b/cpp/src/arrow/compute/function.cc @@ -0,0 +1,167 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/compute/function.h" + +#include +#include +#include + +#include "arrow/compute/exec.h" +#include "arrow/compute/exec_internal.h" +#include "arrow/datum.h" + +namespace arrow { +namespace compute { + +static Status CheckArity(const std::vector& args, const Arity& arity) { + const int passed_num_args = static_cast(args.size()); + if (arity.is_varargs && passed_num_args < arity.num_args) { + return Status::Invalid("VarArgs function needs at least ", arity.num_args, + " arguments but kernel accepts only ", passed_num_args); + } else if (!arity.is_varargs && passed_num_args != arity.num_args) { + return Status::Invalid("Function accepts ", arity.num_args, + " arguments but kernel accepts ", passed_num_args); + } + return Status::OK(); +} + +template +std::string FormatArgTypes(const std::vector& descrs) { + std::stringstream ss; + ss << "("; + for (size_t i = 0; i < descrs.size(); ++i) { + if (i > 0) { + ss << ", "; + } + ss << descrs[i].ToString(); + } + ss << ")"; + return ss.str(); +} + +template +Result DispatchExactImpl(const Function& func, + const std::vector& kernels, + const std::vector& values) { + const int passed_num_args = static_cast(values.size()); + + // Validate arity + const Arity arity = func.arity(); + if (arity.is_varargs && passed_num_args < arity.num_args) { + return Status::Invalid("VarArgs function needs at least ", arity.num_args, + " arguments but passed only ", passed_num_args); + } else if (!arity.is_varargs && passed_num_args != arity.num_args) { + return Status::Invalid("Function accepts ", arity.num_args, " arguments but passed ", + passed_num_args); + } + for (const auto& kernel : kernels) { + if (kernel.signature->MatchesInputs(values)) { + return &kernel; + } + } + return Status::NotImplemented("Function ", func.name(), + " has no kernel matching input types ", + FormatArgTypes(values)); +} + +Result Function::Execute(const std::vector& args, + const FunctionOptions* options, ExecContext* ctx) const { + if (ctx == nullptr) { + ExecContext default_ctx; + return Execute(args, options, &default_ctx); + } + // type-check Datum arguments here. Really we'd like to avoid this as much as + // possible + RETURN_NOT_OK(detail::CheckAllValues(args)); + ARROW_ASSIGN_OR_RAISE(auto executor, + detail::FunctionExecutor::Make(ctx, this, options)); + auto listener = std::make_shared(); + RETURN_NOT_OK(executor->Execute(args, listener.get())); + return executor->WrapResults(args, listener->values()); +} + +Status ScalarFunction::AddKernel(std::vector in_types, OutputType out_type, + ArrayKernelExec exec, KernelInit init) { + RETURN_NOT_OK(CheckArity(in_types, arity_)); + + if (arity_.is_varargs && in_types.size() != 1) { + return Status::Invalid("VarArgs signatures must have exactly one input type"); + } + auto sig = + KernelSignature::Make(std::move(in_types), std::move(out_type), arity_.is_varargs); + kernels_.emplace_back(std::move(sig), exec, init); + return Status::OK(); +} + +Status ScalarFunction::AddKernel(ScalarKernel kernel) { + RETURN_NOT_OK(CheckArity(kernel.signature->in_types(), arity_)); + if (arity_.is_varargs && !kernel.signature->is_varargs()) { + return Status::Invalid("Function accepts varargs but kernel signature does not"); + } + kernels_.emplace_back(std::move(kernel)); + return Status::OK(); +} + +Result ScalarFunction::DispatchExact( + const std::vector& values) const { + return DispatchExactImpl(*this, kernels_, values); +} + +Status VectorFunction::AddKernel(std::vector in_types, OutputType out_type, + ArrayKernelExec exec, KernelInit init) { + RETURN_NOT_OK(CheckArity(in_types, arity_)); + + if (arity_.is_varargs && in_types.size() != 1) { + return Status::Invalid("VarArgs signatures must have exactly one input type"); + } + auto sig = + KernelSignature::Make(std::move(in_types), std::move(out_type), arity_.is_varargs); + kernels_.emplace_back(std::move(sig), exec, init); + return Status::OK(); +} + +Status VectorFunction::AddKernel(VectorKernel kernel) { + RETURN_NOT_OK(CheckArity(kernel.signature->in_types(), arity_)); + if (arity_.is_varargs && !kernel.signature->is_varargs()) { + return Status::Invalid("Function accepts varargs but kernel signature does not"); + } + kernels_.emplace_back(std::move(kernel)); + return Status::OK(); +} + +Result VectorFunction::DispatchExact( + const std::vector& values) const { + return DispatchExactImpl(*this, kernels_, values); +} + +Status ScalarAggregateFunction::AddKernel(ScalarAggregateKernel kernel) { + RETURN_NOT_OK(CheckArity(kernel.signature->in_types(), arity_)); + if (arity_.is_varargs && !kernel.signature->is_varargs()) { + return Status::Invalid("Function accepts varargs but kernel signature does not"); + } + kernels_.emplace_back(std::move(kernel)); + return Status::OK(); +} + +Result ScalarAggregateFunction::DispatchExact( + const std::vector& values) const { + return DispatchExactImpl(*this, kernels_, values); +} + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/function.h b/cpp/src/arrow/compute/function.h new file mode 100644 index 00000000000..4280235d678 --- /dev/null +++ b/cpp/src/arrow/compute/function.h @@ -0,0 +1,216 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// NOTE: API is EXPERIMENTAL and will change without going through a +// deprecation cycle + +#pragma once + +#include +#include +#include + +#include "arrow/compute/kernel.h" +#include "arrow/result.h" +#include "arrow/status.h" +#include "arrow/util/macros.h" +#include "arrow/util/visibility.h" + +namespace arrow { + +struct Datum; +struct ValueDescr; + +namespace compute { + +class ExecContext; + +struct ARROW_EXPORT FunctionOptions {}; + +/// \brief Contains the number of required arguments for the function +struct ARROW_EXPORT Arity { + static Arity Nullary() { return Arity(0, false); } + static Arity Unary() { return Arity(1, false); } + static Arity Binary() { return Arity(2, false); } + static Arity Ternary() { return Arity(3, false); } + static Arity VarArgs(int min_args = 1) { return Arity(min_args, true); } + + Arity(int num_args, bool is_varargs = false) // NOLINT implicit conversion + : num_args(num_args), is_varargs(is_varargs) {} + + /// The number of required arguments (or the minimum number for varargs + /// functions) + int num_args; + + /// If true, then the num_args is the minimum number of required arguments + bool is_varargs = false; +}; + +/// \brief Base class for function containers that are capable of dispatch to +/// kernel implementations +class ARROW_EXPORT Function { + public: + /// \brief The kind of function, which indicates in what contexts it is + /// valid for use + enum Kind { + /// A function that performs scalar data operations on whole arrays of + /// data. Can generally process Array or Scalar values. The size of the + /// output will be the same as the size (or broadcasted size, in the case + /// of mixing Array and Scalar inputs) of the input. + SCALAR, + + /// A function with array input and output whose behavior depends on the + /// values of the entire arrays passed, rather than the value of each scalar + /// value. + VECTOR, + + /// A function that computes scalar summary statistics from array input. + SCALAR_AGGREGATE + }; + + virtual ~Function() = default; + + /// \brief The name of the kernel. The registry enforces uniqueness of names + const std::string& name() const { return name_; } + + /// \brief The kind of kernel, which indicates in what contexts it is valid + /// for use + Function::Kind kind() const { return kind_; } + + /// \brief Contains the number of arguments the function requires + const Arity& arity() const { return arity_; } + + /// \brief Returns the number of registered kernels for this function + virtual int num_kernels() const = 0; + + /// \brief Convenience for invoking a function with kernel dispatch and + /// memory allocation details taken care of + Result Execute(const std::vector& args, const FunctionOptions* options, + ExecContext* ctx = NULLPTR) const; + + protected: + Function(std::string name, Function::Kind kind, const Arity& arity) + : name_(std::move(name)), kind_(kind), arity_(arity) {} + std::string name_; + Function::Kind kind_; + Arity arity_; +}; + +namespace detail { + +template +class FunctionImpl : public Function { + public: + /// \brief Return pointers to current-available kernels for inspection + std::vector kernels() const { + std::vector result; + for (const auto& kernel : kernels_) { + result.push_back(&kernel); + } + return result; + } + + int num_kernels() const override { return static_cast(kernels_.size()); } + + protected: + FunctionImpl(std::string name, Function::Kind kind, const Arity& arity) + : Function(std::move(name), kind, arity) {} + + std::vector kernels_; +}; + +} // namespace detail + +/// \brief A function that executes elementwise operations on arrays or +/// scalars, and therefore whose results generally do not depend on the order +/// of the values in the arguments. Accepts and returns arrays that are all of +/// the same size. These functions roughly correspond to the functions used in +/// SQL expressions. +class ARROW_EXPORT ScalarFunction : public detail::FunctionImpl { + public: + using KernelType = ScalarKernel; + + ScalarFunction(std::string name, const Arity& arity) + : detail::FunctionImpl(std::move(name), Function::SCALAR, arity) {} + + /// \brief Add a simple kernel (function implementation) with given + /// input/output types, no required state initialization, preallocation for + /// fixed-width types, and default null handling (intersect validity bitmaps + /// of inputs) + Status AddKernel(std::vector in_types, OutputType out_type, + ArrayKernelExec exec, KernelInit init = NULLPTR); + + /// \brief Add a kernel (function implementation). Returns error if fails + /// to match the other parameters of the function + Status AddKernel(ScalarKernel kernel); + + /// \brief Return the first kernel that can execute the function given the + /// exact argument types (without implicit type casts or scalar->array + /// promotions) + /// + /// This function is overridden in CastFunction + virtual Result DispatchExact( + const std::vector& values) const; +}; + +/// \brief A function that executes general array operations that may yield +/// outputs of different sizes or have results that depend on the whole array +/// contents. These functions roughly correspond to the functions found in +/// non-SQL array languages like APL and its derivatives +class ARROW_EXPORT VectorFunction : public detail::FunctionImpl { + public: + using KernelType = VectorKernel; + + VectorFunction(std::string name, const Arity& arity) + : detail::FunctionImpl(std::move(name), Function::VECTOR, arity) {} + + /// \brief Add a simple kernel (function implementation) with given + /// input/output types, no required state initialization, preallocation for + /// fixed-width types, and default null handling (intersect validity bitmaps + /// of inputs) + Status AddKernel(std::vector in_types, OutputType out_type, + ArrayKernelExec exec, KernelInit init = NULLPTR); + + /// \brief Add a kernel (function implementation). Returns error if fails + /// to match the other parameters of the function + Status AddKernel(VectorKernel kernel); + + /// \brief Return the first kernel that can execute the function given the + /// exact argument types (without implicit type casts or scalar->array + /// promotions) + Result DispatchExact(const std::vector& values) const; +}; + +class ARROW_EXPORT ScalarAggregateFunction + : public detail::FunctionImpl { + public: + using KernelType = ScalarAggregateKernel; + + ScalarAggregateFunction(std::string name, const Arity& arity) + : detail::FunctionImpl(std::move(name), + Function::SCALAR_AGGREGATE, arity) {} + + /// \brief Add a kernel (function implementation). Returns error if fails + /// to match the other parameters of the function + Status AddKernel(ScalarAggregateKernel kernel); + + Result DispatchExact( + const std::vector& values) const; +}; + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/function_test.cc b/cpp/src/arrow/compute/function_test.cc new file mode 100644 index 00000000000..0c1d6241ef4 --- /dev/null +++ b/cpp/src/arrow/compute/function_test.cc @@ -0,0 +1,235 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include +#include + +#include + +#include "arrow/compute/function.h" +#include "arrow/compute/kernel.h" +#include "arrow/datum.h" +#include "arrow/status.h" +#include "arrow/testing/gtest_util.h" +#include "arrow/type.h" + +namespace arrow { +namespace compute { + +struct ExecBatch; + +TEST(Arity, Basics) { + auto nullary = Arity::Nullary(); + ASSERT_EQ(0, nullary.num_args); + ASSERT_FALSE(nullary.is_varargs); + + auto unary = Arity::Unary(); + ASSERT_EQ(1, unary.num_args); + + auto binary = Arity::Binary(); + ASSERT_EQ(2, binary.num_args); + + auto ternary = Arity::Ternary(); + ASSERT_EQ(3, ternary.num_args); + + auto varargs = Arity::VarArgs(); + ASSERT_EQ(1, varargs.num_args); + ASSERT_TRUE(varargs.is_varargs); + + auto varargs2 = Arity::VarArgs(2); + ASSERT_EQ(2, varargs2.num_args); + ASSERT_TRUE(varargs2.is_varargs); +} + +TEST(ScalarFunction, Basics) { + ScalarFunction func("scalar_test", 2); + ScalarFunction varargs_func("varargs_test", Arity::VarArgs(1)); + + ASSERT_EQ("scalar_test", func.name()); + ASSERT_EQ(2, func.arity().num_args); + ASSERT_FALSE(func.arity().is_varargs); + ASSERT_EQ(Function::SCALAR, func.kind()); + + ASSERT_EQ("varargs_test", varargs_func.name()); + ASSERT_EQ(1, varargs_func.arity().num_args); + ASSERT_TRUE(varargs_func.arity().is_varargs); + ASSERT_EQ(Function::SCALAR, varargs_func.kind()); +} + +TEST(VectorFunction, Basics) { + VectorFunction func("vector_test", 2); + VectorFunction varargs_func("varargs_test", Arity::VarArgs(1)); + + ASSERT_EQ("vector_test", func.name()); + ASSERT_EQ(2, func.arity().num_args); + ASSERT_FALSE(func.arity().is_varargs); + ASSERT_EQ(Function::VECTOR, func.kind()); + + ASSERT_EQ("varargs_test", varargs_func.name()); + ASSERT_EQ(1, varargs_func.arity().num_args); + ASSERT_TRUE(varargs_func.arity().is_varargs); + ASSERT_EQ(Function::VECTOR, varargs_func.kind()); +} + +auto ExecNYI = [](KernelContext* ctx, const ExecBatch& args, Datum* out) { + ctx->SetStatus(Status::NotImplemented("NYI")); + return; +}; + +template +void CheckAddDispatch(FunctionType* func) { + using KernelType = typename FunctionType::KernelType; + + ASSERT_EQ(0, func->num_kernels()); + ASSERT_EQ(0, func->kernels().size()); + + std::vector in_types1 = {int32(), int32()}; + OutputType out_type1 = int32(); + + ASSERT_OK(func->AddKernel(in_types1, out_type1, ExecNYI)); + ASSERT_OK(func->AddKernel({int32(), int8()}, int32(), ExecNYI)); + + // Duplicate sig is okay + ASSERT_OK(func->AddKernel(in_types1, out_type1, ExecNYI)); + + // Add given a descr + KernelType descr({float64(), float64()}, float64(), ExecNYI); + ASSERT_OK(func->AddKernel(descr)); + + ASSERT_EQ(4, func->num_kernels()); + ASSERT_EQ(4, func->kernels().size()); + + // Try adding some invalid kernels + ASSERT_RAISES(Invalid, func->AddKernel({}, int32(), ExecNYI)); + ASSERT_RAISES(Invalid, func->AddKernel({int32()}, int32(), ExecNYI)); + ASSERT_RAISES(Invalid, func->AddKernel({int8(), int8(), int8()}, int32(), ExecNYI)); + + // Add valid and invalid kernel using kernel struct directly + KernelType valid_kernel({boolean(), boolean()}, boolean(), ExecNYI); + ASSERT_OK(func->AddKernel(valid_kernel)); + + KernelType invalid_kernel({boolean()}, boolean(), ExecNYI); + ASSERT_RAISES(Invalid, func->AddKernel(invalid_kernel)); + + ASSERT_OK_AND_ASSIGN(const KernelType* kernel, func->DispatchExact({int32(), int32()})); + KernelSignature expected_sig(in_types1, out_type1); + ASSERT_TRUE(kernel->signature->Equals(expected_sig)); + + // No kernel available + ASSERT_RAISES(NotImplemented, func->DispatchExact({utf8(), utf8()})); + + // Wrong arity + ASSERT_RAISES(Invalid, func->DispatchExact({})); + ASSERT_RAISES(Invalid, func->DispatchExact({int32(), int32(), int32()})); +} + +TEST(ScalarVectorFunction, DispatchExact) { + ScalarFunction func1("scalar_test", 2); + VectorFunction func2("vector_test", 2); + + CheckAddDispatch(&func1); + CheckAddDispatch(&func2); +} + +TEST(ArrayFunction, VarArgs) { + ScalarFunction va_func("va_test", Arity::VarArgs(1)); + + std::vector va_args = {int8()}; + + ASSERT_OK(va_func.AddKernel(va_args, int8(), ExecNYI)); + + // No input type passed + ASSERT_RAISES(Invalid, va_func.AddKernel({}, int8(), ExecNYI)); + + // VarArgs function expect a single input type + ASSERT_RAISES(Invalid, va_func.AddKernel({int8(), int8()}, int8(), ExecNYI)); + + // Invalid sig + ScalarKernel non_va_kernel(std::make_shared(va_args, int8()), ExecNYI); + ASSERT_RAISES(Invalid, va_func.AddKernel(non_va_kernel)); + + std::vector args = {ValueDescr::Scalar(int8()), int8(), int8()}; + ASSERT_OK_AND_ASSIGN(const ScalarKernel* kernel, va_func.DispatchExact(args)); + ASSERT_TRUE(kernel->signature->MatchesInputs(args)); + + // No dispatch possible because args incompatible + args[2] = int32(); + ASSERT_RAISES(NotImplemented, va_func.DispatchExact(args)); +} + +TEST(ScalarAggregateFunction, Basics) { + ScalarAggregateFunction func("agg_test", 1); + + ASSERT_EQ("agg_test", func.name()); + ASSERT_EQ(1, func.arity().num_args); + ASSERT_FALSE(func.arity().is_varargs); + ASSERT_EQ(Function::SCALAR_AGGREGATE, func.kind()); +} + +std::unique_ptr NoopInit(KernelContext*, const KernelInitArgs&) { + return nullptr; +} + +void NoopConsume(KernelContext*, const ExecBatch&) {} +void NoopMerge(KernelContext*, const KernelState&, KernelState*) {} +void NoopFinalize(KernelContext*, Datum*) {} + +TEST(ScalarAggregateFunction, DispatchExact) { + ScalarAggregateFunction func("agg_test", 1); + + std::vector in_args = {ValueDescr::Array(int8())}; + ScalarAggregateKernel kernel(std::move(in_args), int64(), NoopInit, NoopConsume, + NoopMerge, NoopFinalize); + ASSERT_OK(func.AddKernel(kernel)); + + in_args = {float64()}; + kernel.signature = std::make_shared(in_args, float64()); + ASSERT_OK(func.AddKernel(kernel)); + + ASSERT_EQ(2, func.num_kernels()); + ASSERT_EQ(2, func.kernels().size()); + ASSERT_TRUE(func.kernels()[1]->signature->Equals(*kernel.signature)); + + // Invalid arity + in_args = {}; + kernel.signature = std::make_shared(in_args, float64()); + ASSERT_RAISES(Invalid, func.AddKernel(kernel)); + + in_args = {float32(), float64()}; + kernel.signature = std::make_shared(in_args, float64()); + ASSERT_RAISES(Invalid, func.AddKernel(kernel)); + + std::vector dispatch_args = {ValueDescr::Array(int8())}; + ASSERT_OK_AND_ASSIGN(const ScalarAggregateKernel* selected_kernel, + func.DispatchExact(dispatch_args)); + ASSERT_EQ(func.kernels()[0], selected_kernel); + ASSERT_TRUE(selected_kernel->signature->MatchesInputs(dispatch_args)); + + // We declared that only arrays are accepted + dispatch_args[0] = {ValueDescr::Scalar(int8())}; + ASSERT_RAISES(NotImplemented, func.DispatchExact(dispatch_args)); + + // Didn't qualify the float64() kernel so this actually dispatches (even + // though that may not be what you want) + dispatch_args[0] = {ValueDescr::Scalar(float64())}; + ASSERT_OK_AND_ASSIGN(selected_kernel, func.DispatchExact(dispatch_args)); + ASSERT_TRUE(selected_kernel->signature->MatchesInputs(dispatch_args)); +} + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/kernel.cc b/cpp/src/arrow/compute/kernel.cc new file mode 100644 index 00000000000..d3652131b32 --- /dev/null +++ b/cpp/src/arrow/compute/kernel.cc @@ -0,0 +1,376 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/compute/kernel.h" + +#include +#include +#include +#include + +#include "arrow/buffer.h" +#include "arrow/compute/exec.h" +#include "arrow/result.h" +#include "arrow/util/bit_util.h" +#include "arrow/util/checked_cast.h" +#include "arrow/util/hash_util.h" +#include "arrow/util/logging.h" +#include "arrow/util/macros.h" + +namespace arrow { + +using internal::checked_cast; +using internal::hash_combine; + +static constexpr size_t kHashSeed = 0; + +namespace compute { + +// ---------------------------------------------------------------------- +// KernelContext + +inline void ZeroLastByte(Buffer* buffer) { + *(buffer->mutable_data() + (buffer->size() - 1)) = 0; +} + +Result> KernelContext::Allocate(int64_t nbytes) { + ARROW_ASSIGN_OR_RAISE(std::shared_ptr result, + AllocateBuffer(nbytes, exec_ctx_->memory_pool())); + result->ZeroPadding(); + return result; +} + +Result> KernelContext::AllocateBitmap(int64_t num_bits) { + const int64_t nbytes = BitUtil::BytesForBits(num_bits); + ARROW_ASSIGN_OR_RAISE(std::shared_ptr result, + AllocateBuffer(nbytes, exec_ctx_->memory_pool())); + // Some utility methods access the last byte before it might be + // initialized this makes valgrind/asan unhappy, so we proactively + // zero it. + if (nbytes > 0) { + ZeroLastByte(result.get()); + result->ZeroPadding(); + } + return result; +} + +void KernelContext::SetStatus(const Status& status) { + if (ARROW_PREDICT_FALSE(!status_.ok())) { + return; + } + status_ = status; +} + +/// \brief Clear any error status +void KernelContext::ResetStatus() { status_ = Status::OK(); } + +// ---------------------------------------------------------------------- +// Some basic TypeMatcher implementations + +namespace match { + +class SameTypeIdMatcher : public TypeMatcher { + public: + explicit SameTypeIdMatcher(Type::type accepted_id) : accepted_id_(accepted_id) {} + + bool Matches(const DataType& type) const override { return type.id() == accepted_id_; } + + std::string ToString() const override { + std::stringstream ss; + ss << "Type::" << ::arrow::internal::ToString(accepted_id_); + return ss.str(); + } + + bool Equals(const TypeMatcher& other) const override { + if (this == &other) { + return true; + } + + auto casted = dynamic_cast(&other); + if (casted == nullptr) { + return false; + } + return this->accepted_id_ == casted->accepted_id_; + } + + private: + Type::type accepted_id_; +}; + +std::shared_ptr SameTypeId(Type::type type_id) { + return std::make_shared(type_id); +} + +class TimestampUnitMatcher : public TypeMatcher { + public: + explicit TimestampUnitMatcher(TimeUnit::type accepted_unit) + : accepted_unit_(accepted_unit) {} + + bool Matches(const DataType& type) const override { + if (type.id() != Type::TIMESTAMP) { + return false; + } + const auto& ts_type = checked_cast(type); + return ts_type.unit() == accepted_unit_; + } + + bool Equals(const TypeMatcher& other) const override { + if (this == &other) { + return true; + } + auto casted = dynamic_cast(&other); + if (casted == nullptr) { + return false; + } + return this->accepted_unit_ == casted->accepted_unit_; + } + + std::string ToString() const override { + std::stringstream ss; + ss << "timestamp(" << ::arrow::internal::ToString(accepted_unit_) << ")"; + return ss.str(); + } + + private: + TimeUnit::type accepted_unit_; +}; + +std::shared_ptr TimestampUnit(TimeUnit::type unit) { + return std::make_shared(unit); +} + +} // namespace match + +// ---------------------------------------------------------------------- +// InputType + +size_t InputType::Hash() const { + size_t result = kHashSeed; + hash_combine(result, static_cast(shape_)); + hash_combine(result, static_cast(kind_)); + switch (kind_) { + case InputType::EXACT_TYPE: + hash_combine(result, type_->Hash()); + break; + default: + break; + } + return result; +} + +std::string InputType::ToString() const { + std::stringstream ss; + switch (shape_) { + case ValueDescr::ANY: + ss << "any"; + break; + case ValueDescr::ARRAY: + ss << "array"; + break; + case ValueDescr::SCALAR: + ss << "scalar"; + break; + default: + DCHECK(false); + break; + } + ss << "["; + switch (kind_) { + case InputType::EXACT_TYPE: + ss << type_->ToString(); + break; + case InputType::USE_TYPE_MATCHER: { + ss << type_matcher_->ToString(); + } break; + default: + DCHECK(false); + break; + } + ss << "]"; + return ss.str(); +} + +bool InputType::Equals(const InputType& other) const { + if (this == &other) { + return true; + } + if (kind_ != other.kind_ || shape_ != other.shape_) { + return false; + } + switch (kind_) { + case InputType::EXACT_TYPE: + return type_->Equals(*other.type_); + case InputType::USE_TYPE_MATCHER: + return type_matcher_->Equals(*other.type_matcher_); + default: + return false; + } +} + +bool InputType::Matches(const ValueDescr& descr) const { + if (shape_ != ValueDescr::ANY && descr.shape != shape_) { + return false; + } + switch (kind_) { + case InputType::EXACT_TYPE: + return type_->Equals(*descr.type); + case InputType::USE_TYPE_MATCHER: + return type_matcher_->Matches(*descr.type); + default: + // ANY_TYPE + return true; + } +} + +bool InputType::Matches(const Datum& value) const { return Matches(value.descr()); } + +const std::shared_ptr& InputType::type() const { + DCHECK_EQ(InputType::EXACT_TYPE, kind_); + return type_; +} + +const TypeMatcher& InputType::type_matcher() const { + DCHECK_EQ(InputType::USE_TYPE_MATCHER, kind_); + return *type_matcher_; +} + +// ---------------------------------------------------------------------- +// OutputType + +OutputType::OutputType(ValueDescr descr) : OutputType(descr.type) { + shape_ = descr.shape; +} + +Result OutputType::Resolve(KernelContext* ctx, + const std::vector& args) const { + if (kind_ == OutputType::FIXED) { + ValueDescr::Shape out_shape = shape_; + if (out_shape == ValueDescr::ANY) { + out_shape = GetBroadcastShape(args); + } + return ValueDescr(type_, out_shape); + } else { + return resolver_(ctx, args); + } +} + +const std::shared_ptr& OutputType::type() const { + DCHECK_EQ(FIXED, kind_); + return type_; +} + +const OutputType::Resolver& OutputType::resolver() const { + DCHECK_EQ(COMPUTED, kind_); + return resolver_; +} + +std::string OutputType::ToString() const { + if (kind_ == OutputType::FIXED) { + return type_->ToString(); + } else { + return "computed"; + } +} + +// ---------------------------------------------------------------------- +// KernelSignature + +KernelSignature::KernelSignature(std::vector in_types, OutputType out_type, + bool is_varargs) + : in_types_(std::move(in_types)), + out_type_(std::move(out_type)), + is_varargs_(is_varargs), + hash_code_(0) { + // VarArgs sigs must have only a single input type to use for argument validation + DCHECK(!is_varargs || (is_varargs && (in_types_.size() == 1))); +} + +std::shared_ptr KernelSignature::Make(std::vector in_types, + OutputType out_type, + bool is_varargs) { + return std::make_shared(std::move(in_types), std::move(out_type), + is_varargs); +} + +bool KernelSignature::Equals(const KernelSignature& other) const { + if (is_varargs_ != other.is_varargs_) { + return false; + } + if (in_types_.size() != other.in_types_.size()) { + return false; + } + for (size_t i = 0; i < in_types_.size(); ++i) { + if (!in_types_[i].Equals(other.in_types_[i])) { + return false; + } + } + return true; +} + +bool KernelSignature::MatchesInputs(const std::vector& args) const { + if (is_varargs_) { + for (const auto& arg : args) { + if (!in_types_[0].Matches(arg)) { + return false; + } + } + } else { + if (args.size() != in_types_.size()) { + return false; + } + for (size_t i = 0; i < in_types_.size(); ++i) { + if (!in_types_[i].Matches(args[i])) { + return false; + } + } + } + return true; +} + +size_t KernelSignature::Hash() const { + if (hash_code_ != 0) { + return hash_code_; + } + size_t result = kHashSeed; + for (const auto& in_type : in_types_) { + hash_combine(result, in_type.Hash()); + } + hash_code_ = result; + return result; +} + +std::string KernelSignature::ToString() const { + std::stringstream ss; + + if (is_varargs_) { + ss << "varargs[" << in_types_[0].ToString() << "]"; + } else { + ss << "("; + for (size_t i = 0; i < in_types_.size(); ++i) { + if (i > 0) { + ss << ", "; + } + ss << in_types_[i].ToString(); + } + ss << ")"; + } + ss << " -> " << out_type_.ToString(); + return ss.str(); +} + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/kernel.h b/cpp/src/arrow/compute/kernel.h index 16dca696567..b76c27a7ab8 100644 --- a/cpp/src/arrow/compute/kernel.h +++ b/cpp/src/arrow/compute/kernel.h @@ -15,295 +15,554 @@ // specific language governing permissions and limitations // under the License. +// NOTE: API is EXPERIMENTAL and will change without going through a +// deprecation cycle + #pragma once +#include +#include #include +#include #include #include -#include "arrow/array.h" -#include "arrow/record_batch.h" -#include "arrow/scalar.h" -#include "arrow/table.h" +#include "arrow/compute/exec.h" +#include "arrow/datum.h" +#include "arrow/result.h" +#include "arrow/status.h" +#include "arrow/type.h" #include "arrow/util/macros.h" -#include "arrow/util/memory.h" -#include "arrow/util/variant.h" // IWYU pragma: export #include "arrow/util/visibility.h" namespace arrow { + +class Buffer; +class MemoryPool; + namespace compute { -class FunctionContext; +struct FunctionOptions; -/// \class OpKernel -/// \brief Base class for operator kernels -/// -/// Note to implementors: -/// Operator kernels are intended to be the lowest level of an analytics/compute -/// engine. They will generally not be exposed directly to end-users. Instead -/// they will be wrapped by higher level constructs (e.g. top-level functions -/// or physical execution plan nodes). These higher level constructs are -/// responsible for user input validation and returning the appropriate -/// error Status. -/// -/// Due to this design, implementations of Call (the execution -/// method on subclasses) should use assertions (i.e. DCHECK) to double-check -/// parameter arguments when in higher level components returning an -/// InvalidArgument error might be more appropriate. -/// -class ARROW_EXPORT OpKernel { +/// \brief Base class for opaque kernel-specific state. For example, if there +/// is some kind of initialization required +struct KernelState { + virtual ~KernelState() = default; +}; + +/// \brief Context/state for the execution of a particular kernel +class ARROW_EXPORT KernelContext { public: - virtual ~OpKernel() = default; - /// \brief EXPERIMENTAL The output data type of the kernel - /// \return the output type - virtual std::shared_ptr out_type() const = 0; + explicit KernelContext(ExecContext* exec_ctx) : exec_ctx_(exec_ctx) {} + + /// \brief Allocate buffer from the context's memory pool + Result> Allocate(int64_t nbytes); + + /// \brief Allocate buffer for bitmap from the context's memory pool + Result> AllocateBitmap(int64_t num_bits); + + /// \brief Indicate that an error has occurred, to be checked by a exec caller + /// \param[in] status a Status instance + /// + /// \note Will not overwrite a prior set Status, so we will have the first + /// error that occurred until ExecContext::ResetStatus is called + void SetStatus(const Status& status); + + /// \brief Clear any error status + void ResetStatus(); + + /// \brief Return true if an error has occurred + bool HasError() const { return !status_.ok(); } + + /// \brief Return the current status of the context + const Status& status() const { return status_; } + + // For passing kernel state to + void SetState(KernelState* state) { state_ = state; } + + KernelState* state() { return state_; } + + /// \brief Common state related to function execution + ExecContext* exec_context() { return exec_ctx_; } + + MemoryPool* memory_pool() { return exec_ctx_->memory_pool(); } + + private: + ExecContext* exec_ctx_; + Status status_; + KernelState* state_; }; -struct Datum; -static inline bool CollectionEquals(const std::vector& left, - const std::vector& right); - -// Datums variants may have a length. This special value indicate that the -// current variant does not have a length. -constexpr int64_t kUnknownLength = -1; - -/// \class Datum -/// \brief Variant type for various Arrow C++ data structures -struct ARROW_EXPORT Datum { - enum type { NONE, SCALAR, ARRAY, CHUNKED_ARRAY, RECORD_BATCH, TABLE, COLLECTION }; - - util::variant, std::shared_ptr, - std::shared_ptr, std::shared_ptr, - std::shared_ptr
, std::vector> - value; - - /// \brief Empty datum, to be populated elsewhere - Datum() : value(NULLPTR) {} - - Datum(const std::shared_ptr& value) // NOLINT implicit conversion - : value(value) {} - Datum(const std::shared_ptr& value) // NOLINT implicit conversion - : value(value) {} - - Datum(const std::shared_ptr& value) // NOLINT implicit conversion - : Datum(value ? value->data() : NULLPTR) {} - - Datum(const std::shared_ptr& value) // NOLINT implicit conversion - : value(value) {} - Datum(const std::shared_ptr& value) // NOLINT implicit conversion - : value(value) {} - Datum(const std::shared_ptr
& value) // NOLINT implicit conversion - : value(value) {} - Datum(const std::vector& value) // NOLINT implicit conversion - : value(value) {} - - // Cast from subtypes of Array to Datum - template ::value>> - Datum(const std::shared_ptr& value) // NOLINT implicit conversion - : Datum(std::shared_ptr(value)) {} - - // Convenience constructors - explicit Datum(bool value) : value(std::make_shared(value)) {} - explicit Datum(int8_t value) : value(std::make_shared(value)) {} - explicit Datum(uint8_t value) : value(std::make_shared(value)) {} - explicit Datum(int16_t value) : value(std::make_shared(value)) {} - explicit Datum(uint16_t value) : value(std::make_shared(value)) {} - explicit Datum(int32_t value) : value(std::make_shared(value)) {} - explicit Datum(uint32_t value) : value(std::make_shared(value)) {} - explicit Datum(int64_t value) : value(std::make_shared(value)) {} - explicit Datum(uint64_t value) : value(std::make_shared(value)) {} - explicit Datum(float value) : value(std::make_shared(value)) {} - explicit Datum(double value) : value(std::make_shared(value)) {} - - ~Datum() {} - - Datum(const Datum& other) noexcept { this->value = other.value; } - - Datum& operator=(const Datum& other) noexcept { - value = other.value; - return *this; - } +#define ARROW_CTX_RETURN_IF_ERROR(CTX) \ + do { \ + if (ARROW_PREDICT_FALSE((CTX)->HasError())) { \ + Status s = (CTX)->status(); \ + (CTX)->ResetStatus(); \ + return s; \ + } \ + } while (0) + +/// A standard function taking zero or more Array/Scalar values and returning +/// Array/Scalar output. May be used for SCALAR and VECTOR kernel kinds. Should +/// write into pre-allocated memory except in cases when a builder +/// (e.g. StringBuilder) must be employed +using ArrayKernelExec = std::function; + +/// \brief An abstract type-checking interface to permit customizable +/// validation rules. This is for scenarios where the acceptance is not an +/// exact type instance along with its unit. +struct TypeMatcher { + virtual ~TypeMatcher() = default; + + /// \brief Return true if this matcher accepts the data type + virtual bool Matches(const DataType& type) const = 0; + + /// \brief A human-interpretable string representation of what the type + /// matcher checks for, usable when printing KernelSignature or formatting + /// error messages. + virtual std::string ToString() const = 0; + + virtual bool Equals(const TypeMatcher& other) const = 0; +}; - // Define move constructor and move assignment, for better performance - Datum(Datum&& other) noexcept : value(std::move(other.value)) {} +namespace match { - Datum& operator=(Datum&& other) noexcept { - value = std::move(other.value); - return *this; - } +/// \brief Match any DataType instance having the same DataType::id +ARROW_EXPORT std::shared_ptr SameTypeId(Type::type type_id); - Datum::type kind() const { - switch (this->value.index()) { - case 0: - return Datum::NONE; - case 1: - return Datum::SCALAR; - case 2: - return Datum::ARRAY; - case 3: - return Datum::CHUNKED_ARRAY; - case 4: - return Datum::RECORD_BATCH; - case 5: - return Datum::TABLE; - case 6: - return Datum::COLLECTION; - default: - return Datum::NONE; - } - } +/// \brief Match any TimestampType instance having the same unit, but the time +/// zones can be different +ARROW_EXPORT std::shared_ptr TimestampUnit(TimeUnit::type unit); - std::shared_ptr array() const { - return util::get>(this->value); - } +} // namespace match - std::shared_ptr make_array() const { - return MakeArray(util::get>(this->value)); - } +/// \brief A container to express what kernel argument input types are accepted +class ARROW_EXPORT InputType { + public: + enum Kind { + /// Accept any value type + ANY_TYPE, - std::shared_ptr chunked_array() const { - return util::get>(this->value); - } + /// A fixed arrow::DataType and will only exact match having this exact + /// type (e.g. same TimestampType unit, same decimal scale and precision, + /// or same nested child types + EXACT_TYPE, - std::shared_ptr record_batch() const { - return util::get>(this->value); - } + /// Uses an TypeMatcher implementation to check the type + USE_TYPE_MATCHER + }; - std::shared_ptr
table() const { - return util::get>(this->value); - } + InputType(ValueDescr::Shape shape = ValueDescr::ANY) // NOLINT implicit construction + : kind_(ANY_TYPE), shape_(shape) {} + + InputType(std::shared_ptr type, + ValueDescr::Shape shape = ValueDescr::ANY) // NOLINT implicit construction + : kind_(EXACT_TYPE), shape_(shape), type_(std::move(type)) {} + + InputType(const ValueDescr& descr) // NOLINT implicit construction + : InputType(descr.type, descr.shape) {} - const std::vector collection() const { - return util::get>(this->value); + InputType(std::shared_ptr type_matcher, + ValueDescr::Shape shape = ValueDescr::ANY) + : kind_(USE_TYPE_MATCHER), shape_(shape), type_matcher_(std::move(type_matcher)) {} + + explicit InputType(Type::type type_id, ValueDescr::Shape shape = ValueDescr::ANY) + : InputType(match::SameTypeId(type_id), shape) {} + + InputType(const InputType& other) { CopyInto(other); } + + // Convenience ctors + static InputType Array(std::shared_ptr type) { + return InputType(std::move(type), ValueDescr::ARRAY); } - std::shared_ptr scalar() const { - return util::get>(this->value); + static InputType Scalar(std::shared_ptr type) { + return InputType(std::move(type), ValueDescr::SCALAR); } - bool is_array() const { return this->kind() == Datum::ARRAY; } + static InputType Array(Type::type id) { return InputType(id, ValueDescr::ARRAY); } - bool is_arraylike() const { - return this->kind() == Datum::ARRAY || this->kind() == Datum::CHUNKED_ARRAY; - } + static InputType Scalar(Type::type id) { return InputType(id, ValueDescr::SCALAR); } - bool is_scalar() const { return this->kind() == Datum::SCALAR; } + void operator=(const InputType& other) { CopyInto(other); } - bool is_collection() const { return this->kind() == Datum::COLLECTION; } + InputType(InputType&& other) { MoveInto(std::forward(other)); } - /// \brief The value type of the variant, if any - /// - /// \return nullptr if no type - std::shared_ptr type() const { - if (this->kind() == Datum::ARRAY) { - return util::get>(this->value)->type; - } else if (this->kind() == Datum::CHUNKED_ARRAY) { - return util::get>(this->value)->type(); - } else if (this->kind() == Datum::SCALAR) { - return util::get>(this->value)->type; - } - return NULLPTR; + void operator=(InputType&& other) { MoveInto(std::forward(other)); } + + /// \brief Return true if this type exactly matches another + bool Equals(const InputType& other) const; + + bool operator==(const InputType& other) const { return this->Equals(other); } + + bool operator!=(const InputType& other) const { return !(*this == other); } + + /// \brief Return hash code + size_t Hash() const; + + /// \brief Render a human-readable string representation + std::string ToString() const; + + /// \brief Return true if the value matches this argument kind in type + /// and shape + bool Matches(const Datum& value) const; + + /// \brief Return true if the value descriptor matches this argument kind in + /// type and shape + bool Matches(const ValueDescr& value) const; + + /// \brief The type matching rule that this InputType uses + Kind kind() const { return kind_; } + + ValueDescr::Shape shape() const { return shape_; } + + /// \brief For InputType::EXACT_TYPE, the exact type that this InputType must + /// match. Otherwise this function should not be used + const std::shared_ptr& type() const; + + /// \brief For InputType::, the Type::type that this InputType must + /// match, Otherwise this function should not be used + const TypeMatcher& type_matcher() const; + + private: + void CopyInto(const InputType& other) { + this->kind_ = other.kind_; + this->shape_ = other.shape_; + this->type_ = other.type_; + this->type_matcher_ = other.type_matcher_; } - /// \brief The value length of the variant, if any - /// - /// \return kUnknownLength if no type - int64_t length() const { - if (this->kind() == Datum::ARRAY) { - return util::get>(this->value)->length; - } else if (this->kind() == Datum::CHUNKED_ARRAY) { - return util::get>(this->value)->length(); - } else if (this->kind() == Datum::SCALAR) { - return 1; - } - return kUnknownLength; + void MoveInto(InputType&& other) { + this->kind_ = other.kind_; + this->shape_ = other.shape_; + this->type_ = std::move(other.type_); + this->type_matcher_ = std::move(other.type_matcher_); } - /// \brief The array chunks of the variant, if any - /// - /// \return empty if not arraylike - ArrayVector chunks() const { - if (!this->is_arraylike()) { - return {}; - } - if (this->is_array()) { - return {this->make_array()}; - } - return this->chunked_array()->chunks(); + Kind kind_; + + ValueDescr::Shape shape_ = ValueDescr::ANY; + + // For EXACT_TYPE Kind + std::shared_ptr type_; + + // For USE_TYPE_MATCHER Kind + std::shared_ptr type_matcher_; +}; + +/// \brief Container to capture both exact and input-dependent output types +/// +/// The value shape returned by Resolve will be determined by broadcasting the +/// shapes of the input arguments, otherwise this is handled by the +/// user-defined resolver function +/// +/// * Any ARRAY shape -> output shape is ARRAY +/// * All SCALAR shapes -> output shape is SCALAR +class ARROW_EXPORT OutputType { + public: + /// \brief An enum indicating whether the value type is an invariant fixed + /// value or one that's computed by a kernel-defined resolver function + enum ResolveKind { FIXED, COMPUTED }; + + /// Type resolution function. Given input types and shapes, return output + /// type and shape. This function SHOULD _not_ be used to check for arity, + /// that SHOULD be performed one or more layers above. May make use of kernel + /// state to know what type to output + using Resolver = + std::function(KernelContext*, const std::vector&)>; + + OutputType(std::shared_ptr type) // NOLINT implicit construction + : kind_(FIXED), type_(std::move(type)) {} + + /// For outputting a particular type and shape + OutputType(ValueDescr descr); // NOLINT implicit construction + + explicit OutputType(Resolver resolver) : kind_(COMPUTED), resolver_(resolver) {} + + OutputType(const OutputType& other) { + this->kind_ = other.kind_; + this->shape_ = other.shape_; + this->type_ = other.type_; + this->resolver_ = other.resolver_; } - bool Equals(const Datum& other) const { - if (this->kind() != other.kind()) return false; - - switch (this->kind()) { - case Datum::NONE: - return true; - case Datum::SCALAR: - return internal::SharedPtrEquals(this->scalar(), other.scalar()); - case Datum::ARRAY: - return internal::SharedPtrEquals(this->make_array(), other.make_array()); - case Datum::CHUNKED_ARRAY: - return internal::SharedPtrEquals(this->chunked_array(), other.chunked_array()); - case Datum::RECORD_BATCH: - return internal::SharedPtrEquals(this->record_batch(), other.record_batch()); - case Datum::TABLE: - return internal::SharedPtrEquals(this->table(), other.table()); - case Datum::COLLECTION: - return CollectionEquals(this->collection(), other.collection()); - default: - return false; - } + OutputType(OutputType&& other) { + this->kind_ = other.kind_; + this->type_ = std::move(other.type_); + this->shape_ = other.shape_; + this->resolver_ = other.resolver_; } + + /// \brief Return the shape and type of the expected output value of the + /// kernel given the value descriptors (shapes and types). The resolver may + /// make use of state information kept in the KernelContext + Result Resolve(KernelContext* ctx, + const std::vector& args) const; + + /// \brief The value type for the FIXED kind rule + const std::shared_ptr& type() const; + + /// \brief For use with COMPUTED resolution strategy, the output type depends + /// on the input type. It may be more convenient to invoke this with + /// OutputType::Resolve returned from this method + const Resolver& resolver() const; + + /// \brief Render a human-readable string representation + std::string ToString() const; + + /// \brief Return the kind of type resolution of this output type, whether + /// fixed/invariant or computed by a "user"-defined resolver + ResolveKind kind() const { return kind_; } + + /// \brief If the shape is ANY, then Resolve will compute the shape based on + /// the input arguments + ValueDescr::Shape shape() const { return shape_; } + + private: + ResolveKind kind_; + + // For FIXED resolution + std::shared_ptr type_; + + ValueDescr::Shape shape_ = ValueDescr::ANY; + + // For COMPUTED resolution + Resolver resolver_; }; -/// \class UnaryKernel -/// \brief An array-valued function of a single input argument. +/// \brief Holds the input types and output type of the kernel /// -/// Note to implementors: Try to avoid making kernels that allocate memory if -/// the output size is a deterministic function of the Input Datum's metadata. -/// Instead separate the logic of the kernel and allocations necessary into -/// two different kernels. Some reusable kernels that allocate buffers -/// and delegate computation to another kernel are available in util-internal.h. -class ARROW_EXPORT UnaryKernel : public OpKernel { +/// Varargs functions should pass a single input type to be used to validate +/// the the input types of a function invocation +class ARROW_EXPORT KernelSignature { public: - /// \brief Executes the kernel. - /// - /// \param[in] ctx The function context for the kernel - /// \param[in] input The kernel input data - /// \param[out] out The output of the function. Each implementation of this - /// function might assume different things about the existing contents of out - /// (e.g. which buffers are preallocated). In the future it is expected that - /// there will be a more generic mechanism for understanding the necessary - /// contracts. - virtual Status Call(FunctionContext* ctx, const Datum& input, Datum* out) = 0; + KernelSignature(std::vector in_types, OutputType out_type, + bool is_varargs = false); + + /// \brief Convenience ctor since make_shared can be awkward + static std::shared_ptr Make(std::vector in_types, + OutputType out_type, + bool is_varargs = false); + + /// \brief Return true if the signature if compatible with the list of input + /// value descriptors + bool MatchesInputs(const std::vector& descriptors) const; + + /// \brief Returns true if the input types of each signature are + /// equal. Well-formed functions should have a deterministic output type + /// given input types, but currently it is the responsibility of the + /// developer to ensure this + bool Equals(const KernelSignature& other) const; + + bool operator==(const KernelSignature& other) const { return this->Equals(other); } + + bool operator!=(const KernelSignature& other) const { return !(*this == other); } + + /// \brief Compute a hash code for the signature + size_t Hash() const; + + const std::vector& in_types() const { return in_types_; } + + const OutputType& out_type() const { return out_type_; } + + /// \brief Render a human-readable string representation + std::string ToString() const; + + bool is_varargs() const { return is_varargs_; } + + private: + std::vector in_types_; + OutputType out_type_; + bool is_varargs_; + + // For caching the hash code after it's computed the first time + mutable uint64_t hash_code_; }; -/// \class BinaryKernel -/// \brief An array-valued function of a two input arguments -class ARROW_EXPORT BinaryKernel : public OpKernel { - public: - virtual Status Call(FunctionContext* ctx, const Datum& left, const Datum& right, - Datum* out) = 0; +/// \brief A function may contain multiple variants of a kernel for a given +/// type combination for different SIMD levels. Based on the active system's +/// CPU info or the user's preferences, we can elect to use one over the other. +struct SimdLevel { + enum type { NONE, SSE4_2, AVX, AVX2, AVX512, NEON }; }; -// TODO doxygen 1.8.16 does not like the following code -///@cond INTERNAL +struct NullHandling { + enum type { + /// Compute the output validity bitmap by intersecting the validity bitmaps + /// of the arguments. Kernel does not do anything with the bitmap + INTERSECTION, -static inline bool CollectionEquals(const std::vector& left, - const std::vector& right) { - if (left.size() != right.size()) { - return false; - } + /// Kernel expects a pre-allocated buffer to write the result bitmap into + COMPUTED_PREALLOCATE, - for (size_t i = 0; i < left.size(); i++) { - if (!left[i].Equals(right[i])) { - return false; - } - } - return true; -} + /// Kernel allocates and populates the validity bitmap of the output + COMPUTED_NO_PREALLOCATE, + + /// Output is never null + OUTPUT_NOT_NULL + }; +}; + +struct MemAllocation { + enum type { + // For data types that support pre-allocation (fixed-type), the kernel + // expects to be provided pre-allocated memory to write + // into. Non-fixed-width must always allocate their own memory but perhaps + // not their validity bitmaps. The allocation made for the same length as + // the execution batch, so vector kernels yielding differently sized output + // should not use this + PREALLOCATE, + + // The kernel does its own memory allocation + NO_PREALLOCATE + }; +}; + +struct Kernel; + +struct KernelInitArgs { + const Kernel* kernel; + const std::vector& inputs; + const FunctionOptions* options; +}; + +// Kernel initializer (context, argument descriptors, options) +using KernelInit = + std::function(KernelContext*, const KernelInitArgs&)>; + +/// \brief Base type for kernels. Contains the function signature and +/// optionally the state initialization function, along with some common +/// attributes +struct Kernel { + Kernel() {} -///@endcond + Kernel(std::shared_ptr sig, KernelInit init) + : signature(std::move(sig)), init(init) {} + + Kernel(std::vector in_types, OutputType out_type, KernelInit init) + : Kernel(KernelSignature::Make(std::move(in_types), out_type), init) {} + + std::shared_ptr signature; + + /// \brief Create a new KernelState for invocations of this kernel, e.g. to + /// set up any options or state relevant for execution. May be nullptr + KernelInit init; + + // Does execution benefit from parallelization (splitting large chunks into + // smaller chunks and using multiple threads). Some vector kernels may + // require single-threaded execution. + bool parallelizable = true; + + /// \brief What level of SIMD instruction support in the host CPU is required + /// to use the function + SimdLevel::type simd_level = SimdLevel::NONE; +}; + +/// \brief Descriptor to hold signature and execution function implementations +/// for a particular kernel +struct ArrayKernel : public Kernel { + ArrayKernel() {} + + ArrayKernel(std::shared_ptr sig, ArrayKernelExec exec, + KernelInit init = NULLPTR) + : Kernel(std::move(sig), init), exec(exec) {} + + ArrayKernel(std::vector in_types, OutputType out_type, ArrayKernelExec exec, + KernelInit init = NULLPTR) + : Kernel(std::move(in_types), std::move(out_type), init), exec(exec) {} + + /// \brief Perform a single invocation of this kernel. Depending on the + /// implementation, it may only write into preallocated memory, while in some + /// cases it will allocate its own memory. + ArrayKernelExec exec; + + /// \brief Writing execution results into larger contiguous allocations + /// requires that the kernel be able to write into sliced output + /// ArrayData*. Some kernel implementations may not be able to do this, so + /// setting this to false disables this functionality + bool can_write_into_slices = true; +}; + +struct ScalarKernel : public ArrayKernel { + using ArrayKernel::ArrayKernel; + + // For scalar functions preallocated data and intersecting arg validity + // bitmaps is a reasonable default + NullHandling::type null_handling = NullHandling::INTERSECTION; + MemAllocation::type mem_allocation = MemAllocation::PREALLOCATE; +}; + +// Convert intermediate results into finalized results. Mutates input argument +using VectorFinalize = std::function*)>; + +struct VectorKernel : public ArrayKernel { + VectorKernel() {} + + VectorKernel(std::shared_ptr sig, ArrayKernelExec exec) + : ArrayKernel(std::move(sig), exec) {} + + VectorKernel(std::vector in_types, OutputType out_type, ArrayKernelExec exec, + KernelInit init = NULLPTR, VectorFinalize finalize = NULLPTR) + : ArrayKernel(std::move(in_types), out_type, exec, init), finalize(finalize) {} + + VectorKernel(std::shared_ptr sig, ArrayKernelExec exec, + KernelInit init = NULLPTR, VectorFinalize finalize = NULLPTR) + : ArrayKernel(std::move(sig), exec, init), finalize(finalize) {} + + VectorFinalize finalize; + + /// Since vector kernels generally are implemented rather differently from + /// scalar/elementwise kernels (and they may not even yield arrays of the same + /// size), so we make the developer opt-in to any memory preallocation rather + /// than having to turn it off. + NullHandling::type null_handling = NullHandling::COMPUTED_NO_PREALLOCATE; + MemAllocation::type mem_allocation = MemAllocation::NO_PREALLOCATE; + + /// Some vector kernels can do chunkwise execution using ExecBatchIterator, + /// in some cases accumulating some state. Other kernels (like Take) need to + /// be passed whole arrays and don't work on ChunkedArray inputs + bool can_execute_chunkwise = true; + + /// Some kernels (like unique and value_counts) yield non-chunked output from + /// chunked-array inputs. This option controls how the results are boxed when + /// returned from ExecVectorFunction + /// + /// true -> ChunkedArray + /// false -> Array + /// + /// TODO: Where is a better place to deal with this issue? + bool output_chunked = true; +}; + +using ScalarAggregateConsume = std::function; + +using ScalarAggregateMerge = + std::function; + +// Finalize returns Datum to permit multiple return values +using ScalarAggregateFinalize = std::function; + +struct ScalarAggregateKernel : public Kernel { + ScalarAggregateKernel() {} + + ScalarAggregateKernel(std::shared_ptr sig, KernelInit init, + ScalarAggregateConsume consume, ScalarAggregateMerge merge, + ScalarAggregateFinalize finalize) + : Kernel(std::move(sig), init), + consume(consume), + merge(merge), + finalize(finalize) {} + + ScalarAggregateKernel(std::vector in_types, OutputType out_type, + KernelInit init, ScalarAggregateConsume consume, + ScalarAggregateMerge merge, ScalarAggregateFinalize finalize) + : ScalarAggregateKernel(KernelSignature::Make(std::move(in_types), out_type), init, + consume, merge, finalize) {} + + ScalarAggregateConsume consume; + ScalarAggregateMerge merge; + ScalarAggregateFinalize finalize; +}; } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/kernel_test.cc b/cpp/src/arrow/compute/kernel_test.cc new file mode 100644 index 00000000000..bd5571b2fb5 --- /dev/null +++ b/cpp/src/arrow/compute/kernel_test.cc @@ -0,0 +1,479 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include +#include + +#include + +#include "arrow/compute/kernel.h" +#include "arrow/status.h" +#include "arrow/testing/gtest_util.h" +#include "arrow/type.h" +#include "arrow/util/key_value_metadata.h" + +namespace arrow { +namespace compute { + +// ---------------------------------------------------------------------- +// TypeMatcher + +TEST(TypeMatcher, SameTypeId) { + std::shared_ptr matcher = match::SameTypeId(Type::DECIMAL); + ASSERT_TRUE(matcher->Matches(*decimal(12, 2))); + ASSERT_FALSE(matcher->Matches(*int8())); + + ASSERT_EQ("Type::DECIMAL", matcher->ToString()); + + ASSERT_TRUE(matcher->Equals(*matcher)); + ASSERT_TRUE(matcher->Equals(*match::SameTypeId(Type::DECIMAL))); + ASSERT_FALSE(matcher->Equals(*match::SameTypeId(Type::TIMESTAMP))); +} + +TEST(TypeMatcher, TimestampUnit) { + std::shared_ptr matcher = match::TimestampUnit(TimeUnit::MILLI); + + ASSERT_TRUE(matcher->Matches(*timestamp(TimeUnit::MILLI))); + ASSERT_TRUE(matcher->Matches(*timestamp(TimeUnit::MILLI, "utc"))); + ASSERT_FALSE(matcher->Matches(*timestamp(TimeUnit::SECOND))); + + // Check ToString representation + ASSERT_EQ("timestamp(s)", match::TimestampUnit(TimeUnit::SECOND)->ToString()); + ASSERT_EQ("timestamp(ms)", match::TimestampUnit(TimeUnit::MILLI)->ToString()); + ASSERT_EQ("timestamp(us)", match::TimestampUnit(TimeUnit::MICRO)->ToString()); + ASSERT_EQ("timestamp(ns)", match::TimestampUnit(TimeUnit::NANO)->ToString()); + + // Equals implementation + ASSERT_TRUE(matcher->Equals(*matcher)); + ASSERT_TRUE(matcher->Equals(*match::TimestampUnit(TimeUnit::MILLI))); + ASSERT_FALSE(matcher->Equals(*match::TimestampUnit(TimeUnit::MICRO))); +} + +// ---------------------------------------------------------------------- +// InputType + +TEST(InputType, AnyTypeConstructor) { + // Check the ANY_TYPE ctors + InputType ty; + ASSERT_EQ(InputType::ANY_TYPE, ty.kind()); + ASSERT_EQ(ValueDescr::ANY, ty.shape()); + + ty = InputType(ValueDescr::SCALAR); + ASSERT_EQ(ValueDescr::SCALAR, ty.shape()); + + ty = InputType(ValueDescr::ARRAY); + ASSERT_EQ(ValueDescr::ARRAY, ty.shape()); +} + +TEST(InputType, Constructors) { + // Exact type constructor + InputType ty1(int8()); + ASSERT_EQ(InputType::EXACT_TYPE, ty1.kind()); + ASSERT_EQ(ValueDescr::ANY, ty1.shape()); + AssertTypeEqual(*int8(), *ty1.type()); + + InputType ty1_implicit = int8(); + ASSERT_TRUE(ty1.Equals(ty1_implicit)); + + InputType ty1_array(int8(), ValueDescr::ARRAY); + ASSERT_EQ(ValueDescr::ARRAY, ty1_array.shape()); + + InputType ty1_scalar(int8(), ValueDescr::SCALAR); + ASSERT_EQ(ValueDescr::SCALAR, ty1_scalar.shape()); + + // Same type id constructor + InputType ty2(Type::DECIMAL); + ASSERT_EQ(InputType::USE_TYPE_MATCHER, ty2.kind()); + ASSERT_EQ("any[Type::DECIMAL]", ty2.ToString()); + ASSERT_TRUE(ty2.type_matcher().Matches(*decimal(12, 2))); + ASSERT_FALSE(ty2.type_matcher().Matches(*int16())); + + InputType ty2_array(Type::DECIMAL, ValueDescr::ARRAY); + ASSERT_EQ(ValueDescr::ARRAY, ty2_array.shape()); + + InputType ty2_scalar(Type::DECIMAL, ValueDescr::SCALAR); + ASSERT_EQ(ValueDescr::SCALAR, ty2_scalar.shape()); + + // Implicit construction in a vector + std::vector types = {int8(), InputType(Type::DECIMAL)}; + ASSERT_TRUE(types[0].Equals(ty1)); + ASSERT_TRUE(types[1].Equals(ty2)); + + // Copy constructor + InputType ty3 = ty1; + InputType ty4 = ty2; + ASSERT_TRUE(ty3.Equals(ty1)); + ASSERT_TRUE(ty4.Equals(ty2)); + + // Move constructor + InputType ty5 = std::move(ty3); + InputType ty6 = std::move(ty4); + ASSERT_TRUE(ty5.Equals(ty1)); + ASSERT_TRUE(ty6.Equals(ty2)); + + // ToString + ASSERT_EQ("any[int8]", ty1.ToString()); + ASSERT_EQ("array[int8]", ty1_array.ToString()); + ASSERT_EQ("scalar[int8]", ty1_scalar.ToString()); + + ASSERT_EQ("any[Type::DECIMAL]", ty2.ToString()); + ASSERT_EQ("array[Type::DECIMAL]", ty2_array.ToString()); + ASSERT_EQ("scalar[Type::DECIMAL]", ty2_scalar.ToString()); + + InputType ty7(match::TimestampUnit(TimeUnit::MICRO)); + ASSERT_EQ("any[timestamp(us)]", ty7.ToString()); +} + +TEST(InputType, Equals) { + InputType t1 = int8(); + InputType t2 = int8(); + InputType t3(int8(), ValueDescr::ARRAY); + InputType t3_i32(int32(), ValueDescr::ARRAY); + InputType t3_scalar(int8(), ValueDescr::SCALAR); + InputType t4(int8(), ValueDescr::ARRAY); + InputType t4_i32(int32(), ValueDescr::ARRAY); + + InputType t5(Type::DECIMAL); + InputType t6(Type::DECIMAL); + InputType t7(Type::DECIMAL, ValueDescr::SCALAR); + InputType t7_i32(Type::INT32, ValueDescr::SCALAR); + InputType t8(Type::DECIMAL, ValueDescr::SCALAR); + InputType t8_i32(Type::INT32, ValueDescr::SCALAR); + + ASSERT_TRUE(t1.Equals(t2)); + ASSERT_EQ(t1, t2); + + // ANY vs SCALAR + ASSERT_NE(t1, t3); + + ASSERT_EQ(t3, t4); + + // both ARRAY, but different type + ASSERT_NE(t3, t3_i32); + + // ARRAY vs SCALAR + ASSERT_NE(t3, t3_scalar); + + ASSERT_EQ(t3_i32, t4_i32); + + ASSERT_FALSE(t1.Equals(t5)); + ASSERT_NE(t1, t5); + + ASSERT_EQ(t5, t5); + ASSERT_EQ(t5, t6); + ASSERT_NE(t5, t7); + ASSERT_EQ(t7, t8); + ASSERT_EQ(t7, t8); + ASSERT_NE(t7, t7_i32); + ASSERT_EQ(t7_i32, t8_i32); + + // NOTE: For the time being, we treat int32() and Type::INT32 as being + // different. This could obviously be fixed later to make these equivalent + ASSERT_NE(InputType(int8()), InputType(Type::INT32)); + + // Check that field metadata excluded from equality checks + InputType t9 = list( + field("item", utf8(), /*nullable=*/true, key_value_metadata({"foo"}, {"bar"}))); + InputType t10 = list(field("item", utf8())); + ASSERT_TRUE(t9.Equals(t10)); +} + +TEST(InputType, Hash) { + InputType t0; + InputType t0_scalar(ValueDescr::SCALAR); + InputType t0_array(ValueDescr::ARRAY); + + InputType t1 = int8(); + InputType t2(Type::DECIMAL); + + // These checks try to determine first of all whether Hash always returns the + // same value, and whether the elements of the type are all incorporated into + // the Hash + ASSERT_EQ(t0.Hash(), t0.Hash()); + ASSERT_NE(t0.Hash(), t0_scalar.Hash()); + ASSERT_NE(t0.Hash(), t0_array.Hash()); + ASSERT_NE(t0_scalar.Hash(), t0_array.Hash()); + + ASSERT_EQ(t1.Hash(), t1.Hash()); + ASSERT_EQ(t2.Hash(), t2.Hash()); + + ASSERT_NE(t0.Hash(), t1.Hash()); + ASSERT_NE(t0.Hash(), t2.Hash()); + ASSERT_NE(t1.Hash(), t2.Hash()); +} + +TEST(InputType, Matches) { + InputType ty1 = int8(); + + ASSERT_TRUE(ty1.Matches(ValueDescr::Scalar(int8()))); + ASSERT_TRUE(ty1.Matches(ValueDescr::Array(int8()))); + ASSERT_TRUE(ty1.Matches(ValueDescr::Any(int8()))); + ASSERT_FALSE(ty1.Matches(ValueDescr::Any(int16()))); + + InputType ty2(Type::DECIMAL); + ASSERT_TRUE(ty2.Matches(ValueDescr::Scalar(decimal(12, 2)))); + ASSERT_TRUE(ty2.Matches(ValueDescr::Array(decimal(12, 2)))); + ASSERT_FALSE(ty2.Matches(ValueDescr::Any(float64()))); + + InputType ty3(int64(), ValueDescr::SCALAR); + ASSERT_FALSE(ty3.Matches(ValueDescr::Array(int64()))); + ASSERT_TRUE(ty3.Matches(ValueDescr::Scalar(int64()))); + ASSERT_FALSE(ty3.Matches(ValueDescr::Scalar(int32()))); + ASSERT_FALSE(ty3.Matches(ValueDescr::Any(int64()))); +} + +// ---------------------------------------------------------------------- +// OutputType + +TEST(OutputType, Constructors) { + OutputType ty1 = int8(); + ASSERT_EQ(OutputType::FIXED, ty1.kind()); + AssertTypeEqual(*int8(), *ty1.type()); + + auto DummyResolver = [](KernelContext*, + const std::vector& args) -> Result { + return ValueDescr(int32(), GetBroadcastShape(args)); + }; + OutputType ty2(DummyResolver); + ASSERT_EQ(OutputType::COMPUTED, ty2.kind()); + + ASSERT_OK_AND_ASSIGN(ValueDescr out_descr2, ty2.Resolve(nullptr, {})); + ASSERT_EQ(ValueDescr::Scalar(int32()), out_descr2); + + // Copy constructor + OutputType ty3 = ty1; + ASSERT_EQ(OutputType::FIXED, ty3.kind()); + AssertTypeEqual(*ty1.type(), *ty3.type()); + + OutputType ty4 = ty2; + ASSERT_EQ(OutputType::COMPUTED, ty4.kind()); + ASSERT_OK_AND_ASSIGN(ValueDescr out_descr4, ty4.Resolve(nullptr, {})); + ASSERT_EQ(ValueDescr::Scalar(int32()), out_descr4); + + // Move constructor + OutputType ty5 = std::move(ty1); + ASSERT_EQ(OutputType::FIXED, ty5.kind()); + AssertTypeEqual(*int8(), *ty5.type()); + + OutputType ty6 = std::move(ty4); + ASSERT_EQ(OutputType::COMPUTED, ty6.kind()); + ASSERT_OK_AND_ASSIGN(ValueDescr out_descr6, ty6.Resolve(nullptr, {})); + ASSERT_EQ(ValueDescr::Scalar(int32()), out_descr6); + + // ToString + + // ty1 was copied to ty3 + ASSERT_EQ("int8", ty3.ToString()); + ASSERT_EQ("computed", ty2.ToString()); +} + +TEST(OutputType, Resolve) { + // Check shape promotion rules for FIXED kind + OutputType ty1(int32()); + + ASSERT_OK_AND_ASSIGN(ValueDescr descr, ty1.Resolve(nullptr, {})); + ASSERT_EQ(ValueDescr::Scalar(int32()), descr); + + ASSERT_OK_AND_ASSIGN(descr, + ty1.Resolve(nullptr, {ValueDescr(int8(), ValueDescr::SCALAR)})); + ASSERT_EQ(ValueDescr::Scalar(int32()), descr); + + ASSERT_OK_AND_ASSIGN(descr, + ty1.Resolve(nullptr, {ValueDescr(int8(), ValueDescr::SCALAR), + ValueDescr(int8(), ValueDescr::ARRAY)})); + ASSERT_EQ(ValueDescr::Array(int32()), descr); + + OutputType ty2([](KernelContext*, const std::vector& args) { + return ValueDescr(args[0].type, GetBroadcastShape(args)); + }); + + ASSERT_OK_AND_ASSIGN(descr, ty2.Resolve(nullptr, {ValueDescr::Array(utf8())})); + ASSERT_EQ(ValueDescr::Array(utf8()), descr); + + // Type resolver that returns an error + OutputType ty3( + [](KernelContext* ctx, const std::vector& args) -> Result { + // NB: checking the value types versus the function arity should be + // validated elsewhere, so this is just for illustration purposes + if (args.size() == 0) { + return Status::Invalid("Need at least one argument"); + } + return ValueDescr(args[0]); + }); + ASSERT_RAISES(Invalid, ty3.Resolve(nullptr, {})); +} + +TEST(OutputType, ResolveDescr) { + ValueDescr d1 = ValueDescr::Scalar(int32()); + ValueDescr d2 = ValueDescr::Array(int32()); + + OutputType ty1(d1); + OutputType ty2(d2); + + ASSERT_EQ(ValueDescr::SCALAR, ty1.shape()); + ASSERT_EQ(ValueDescr::ARRAY, ty2.shape()); + + { + ASSERT_OK_AND_ASSIGN(ValueDescr descr, ty1.Resolve(nullptr, {})); + ASSERT_EQ(d1, descr); + } + + { + ASSERT_OK_AND_ASSIGN(ValueDescr descr, ty2.Resolve(nullptr, {})); + ASSERT_EQ(d2, descr); + } +} + +// ---------------------------------------------------------------------- +// KernelSignature + +TEST(KernelSignature, Basics) { + // (any[int8], scalar[decimal]) -> utf8 + std::vector in_types({int8(), InputType(Type::DECIMAL, ValueDescr::SCALAR)}); + OutputType out_type(utf8()); + + KernelSignature sig(in_types, out_type); + ASSERT_EQ(2, sig.in_types().size()); + ASSERT_TRUE(sig.in_types()[0].type()->Equals(*int8())); + ASSERT_TRUE(sig.in_types()[0].Matches(ValueDescr::Scalar(int8()))); + ASSERT_TRUE(sig.in_types()[0].Matches(ValueDescr::Array(int8()))); + + ASSERT_TRUE(sig.in_types()[1].Matches(ValueDescr::Scalar(decimal(12, 2)))); + ASSERT_FALSE(sig.in_types()[1].Matches(ValueDescr::Array(decimal(12, 2)))); +} + +TEST(KernelSignature, Equals) { + KernelSignature sig1({}, utf8()); + KernelSignature sig1_copy({}, utf8()); + KernelSignature sig2({int8()}, utf8()); + + // Output type doesn't matter (for now) + KernelSignature sig3({int8()}, int32()); + + KernelSignature sig4({int8(), int16()}, utf8()); + KernelSignature sig4_copy({int8(), int16()}, utf8()); + KernelSignature sig5({int8(), int16(), int32()}, utf8()); + + // Differ in shape + KernelSignature sig6({ValueDescr::Scalar(int8())}, utf8()); + KernelSignature sig7({ValueDescr::Array(int8())}, utf8()); + + ASSERT_EQ(sig1, sig1); + + ASSERT_EQ(sig2, sig3); + ASSERT_NE(sig3, sig4); + + // Different sig objects, but same sig + ASSERT_EQ(sig1, sig1_copy); + ASSERT_EQ(sig4, sig4_copy); + + // Match first 2 args, but not third + ASSERT_NE(sig4, sig5); + + ASSERT_NE(sig6, sig7); +} + +TEST(KernelSignature, VarArgsEquals) { + KernelSignature sig1({int8()}, utf8(), /*is_varargs=*/true); + KernelSignature sig2({int8()}, utf8(), /*is_varargs=*/true); + KernelSignature sig3({int8()}, utf8()); + + ASSERT_EQ(sig1, sig2); + ASSERT_NE(sig2, sig3); +} + +TEST(KernelSignature, Hash) { + // Some basic tests to ensure that the hashes are deterministic and that all + // input arguments are incorporated + KernelSignature sig1({}, utf8()); + KernelSignature sig2({int8()}, utf8()); + KernelSignature sig3({int8(), int32()}, utf8()); + + ASSERT_EQ(sig1.Hash(), sig1.Hash()); + ASSERT_EQ(sig2.Hash(), sig2.Hash()); + ASSERT_NE(sig1.Hash(), sig2.Hash()); + ASSERT_NE(sig2.Hash(), sig3.Hash()); +} + +TEST(KernelSignature, MatchesInputs) { + // () -> boolean + KernelSignature sig1({}, boolean()); + + ASSERT_TRUE(sig1.MatchesInputs({})); + ASSERT_FALSE(sig1.MatchesInputs({int8()})); + + // (any[int8], any[decimal]) -> boolean + KernelSignature sig2({int8(), InputType(Type::DECIMAL)}, boolean()); + + ASSERT_FALSE(sig2.MatchesInputs({})); + ASSERT_FALSE(sig2.MatchesInputs({int8()})); + ASSERT_TRUE(sig2.MatchesInputs({int8(), decimal(12, 2)})); + ASSERT_TRUE(sig2.MatchesInputs( + {ValueDescr::Scalar(int8()), ValueDescr::Scalar(decimal(12, 2))})); + ASSERT_TRUE( + sig2.MatchesInputs({ValueDescr::Array(int8()), ValueDescr::Array(decimal(12, 2))})); + + // (scalar[int8], array[int32]) -> boolean + KernelSignature sig3({ValueDescr::Scalar(int8()), ValueDescr::Array(int32())}, + boolean()); + + ASSERT_FALSE(sig3.MatchesInputs({})); + + // Unqualified, these are ANY type and do not match because the kernel + // requires a scalar and an array + ASSERT_FALSE(sig3.MatchesInputs({int8(), int32()})); + ASSERT_TRUE( + sig3.MatchesInputs({ValueDescr::Scalar(int8()), ValueDescr::Array(int32())})); + ASSERT_FALSE( + sig3.MatchesInputs({ValueDescr::Array(int8()), ValueDescr::Array(int32())})); +} + +TEST(KernelSignature, VarArgsMatchesInputs) { + KernelSignature sig({int8()}, utf8(), /*is_varargs=*/true); + + std::vector args = {int8()}; + ASSERT_TRUE(sig.MatchesInputs(args)); + args.push_back(ValueDescr::Scalar(int8())); + args.push_back(ValueDescr::Array(int8())); + ASSERT_TRUE(sig.MatchesInputs(args)); + args.push_back(int32()); + ASSERT_FALSE(sig.MatchesInputs(args)); +} + +TEST(KernelSignature, ToString) { + std::vector in_types = {InputType(int8(), ValueDescr::SCALAR), + InputType(Type::DECIMAL, ValueDescr::ARRAY), + InputType(utf8())}; + KernelSignature sig(in_types, utf8()); + ASSERT_EQ("(scalar[int8], array[Type::DECIMAL], any[string]) -> string", + sig.ToString()); + + OutputType out_type([](KernelContext*, const std::vector& args) { + return Status::Invalid("NYI"); + }); + KernelSignature sig2({int8(), InputType(Type::DECIMAL)}, out_type); + ASSERT_EQ("(any[int8], any[Type::DECIMAL]) -> computed", sig2.ToString()); +} + +TEST(KernelSignature, VarArgsToString) { + KernelSignature sig({int8()}, utf8(), /*is_varargs=*/true); + ASSERT_EQ("varargs[any[int8]] -> string", sig.ToString()); +} + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/CMakeLists.txt b/cpp/src/arrow/compute/kernels/CMakeLists.txt index 12ad4d3a958..361e24b7523 100644 --- a/cpp/src/arrow/compute/kernels/CMakeLists.txt +++ b/cpp/src/arrow/compute/kernels/CMakeLists.txt @@ -15,37 +15,39 @@ # specific language governing permissions and limitations # under the License. -arrow_install_all_headers("arrow/compute/kernels") - -add_arrow_compute_test(boolean_test) -add_arrow_compute_test(cast_test) -add_arrow_compute_test(hash_test) -add_arrow_compute_test(isin_test) -add_arrow_compute_test(match_test) -add_arrow_compute_test(sort_to_indices_test) -add_arrow_compute_test(nth_to_indices_test) -add_arrow_compute_test(util_internal_test) -add_arrow_compute_test(add_test) +# ---------------------------------------------------------------------- +# Scalar kernels + +add_arrow_compute_test(scalar_test + SOURCES + scalar_arithmetic_test.cc + scalar_boolean_test.cc + scalar_cast_test.cc + scalar_compare_test.cc + scalar_set_lookup_test.cc) + +add_arrow_benchmark(scalar_compare_benchmark PREFIX "arrow-compute") + +# ---------------------------------------------------------------------- +# Vector kernels + +add_arrow_compute_test(vector_test + SOURCES + vector_filter_test.cc + vector_hash_test.cc + vector_take_test.cc + vector_sort_test.cc) + +add_arrow_benchmark(vector_hash_benchmark PREFIX "arrow-compute") +add_arrow_benchmark(vector_sort_benchmark PREFIX "arrow-compute") +add_arrow_benchmark(vector_partition_benchmark PREFIX "arrow-compute") +add_arrow_benchmark(vector_filter_benchmark PREFIX "arrow-compute") +add_arrow_benchmark(vector_take_benchmark PREFIX "arrow-compute") + +# ---------------------------------------------------------------------- +# Aggregate kernels # Aggregates -add_arrow_compute_test(aggregate_test) - -# Comparison -add_arrow_compute_test(compare_test) - -# Selection -add_arrow_compute_test(take_test) -add_arrow_compute_test(filter_test) - -add_arrow_benchmark(sort_to_indices_benchmark PREFIX "arrow-compute") -add_arrow_benchmark(nth_to_indices_benchmark PREFIX "arrow-compute") -# Aggregates +add_arrow_compute_test(aggregate_test) add_arrow_benchmark(aggregate_benchmark PREFIX "arrow-compute") - -# Comparison -add_arrow_benchmark(compare_benchmark PREFIX "arrow-compute") - -# Selection -add_arrow_benchmark(filter_benchmark PREFIX "arrow-compute") -add_arrow_benchmark(take_benchmark PREFIX "arrow-compute") diff --git a/cpp/src/arrow/compute/kernels/add.cc b/cpp/src/arrow/compute/kernels/add.cc deleted file mode 100644 index 19eb153b5cd..00000000000 --- a/cpp/src/arrow/compute/kernels/add.cc +++ /dev/null @@ -1,131 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#include "arrow/compute/kernels/add.h" -#include "arrow/builder.h" -#include "arrow/compute/context.h" -#include "arrow/type_traits.h" - -namespace arrow { -namespace compute { - -template -class AddKernelImpl : public AddKernel { - private: - using ArrayType = typename TypeTraits::ArrayType; - std::shared_ptr result_type_; - - Status Add(FunctionContext* ctx, const std::shared_ptr& lhs, - const std::shared_ptr& rhs, std::shared_ptr* result) { - NumericBuilder builder; - RETURN_NOT_OK(builder.Reserve(lhs->length())); - for (int i = 0; i < lhs->length(); i++) { - if (lhs->IsNull(i) || rhs->IsNull(i)) { - builder.UnsafeAppendNull(); - } else { - builder.UnsafeAppend(lhs->Value(i) + rhs->Value(i)); - } - } - return builder.Finish(result); - } - - public: - explicit AddKernelImpl(std::shared_ptr result_type) - : result_type_(result_type) {} - - Status Call(FunctionContext* ctx, const Datum& lhs, const Datum& rhs, - Datum* out) override { - if (!lhs.is_array() || !rhs.is_array()) { - return Status::Invalid("AddKernel expects array values"); - } - if (lhs.length() != rhs.length()) { - return Status::Invalid("AddKernel expects arrays with the same length"); - } - auto lhs_array = lhs.make_array(); - auto rhs_array = rhs.make_array(); - std::shared_ptr result; - RETURN_NOT_OK(this->Add(ctx, lhs_array, rhs_array, &result)); - *out = result; - return Status::OK(); - } - - std::shared_ptr out_type() const override { return result_type_; } - - Status Add(FunctionContext* ctx, const std::shared_ptr& lhs, - const std::shared_ptr& rhs, std::shared_ptr* result) override { - auto lhs_array = std::static_pointer_cast(lhs); - auto rhs_array = std::static_pointer_cast(rhs); - return Add(ctx, lhs_array, rhs_array, result); - } -}; - -Status AddKernel::Make(const std::shared_ptr& value_type, - std::unique_ptr* out) { - AddKernel* kernel; - switch (value_type->id()) { - case Type::UINT8: - kernel = new AddKernelImpl(value_type); - break; - case Type::INT8: - kernel = new AddKernelImpl(value_type); - break; - case Type::UINT16: - kernel = new AddKernelImpl(value_type); - break; - case Type::INT16: - kernel = new AddKernelImpl(value_type); - break; - case Type::UINT32: - kernel = new AddKernelImpl(value_type); - break; - case Type::INT32: - kernel = new AddKernelImpl(value_type); - break; - case Type::UINT64: - kernel = new AddKernelImpl(value_type); - break; - case Type::INT64: - kernel = new AddKernelImpl(value_type); - break; - case Type::FLOAT: - kernel = new AddKernelImpl(value_type); - break; - case Type::DOUBLE: - kernel = new AddKernelImpl(value_type); - break; - default: - return Status::NotImplemented("Arithmetic operations on ", *value_type, " arrays"); - } - out->reset(kernel); - return Status::OK(); -} - -Status Add(FunctionContext* ctx, const Array& lhs, const Array& rhs, - std::shared_ptr* result) { - Datum result_datum; - std::unique_ptr kernel; - ARROW_RETURN_IF( - !lhs.type()->Equals(rhs.type()), - Status::Invalid("Array types should be equal to use arithmetic kernels")); - RETURN_NOT_OK(AddKernel::Make(lhs.type(), &kernel)); - RETURN_NOT_OK(kernel->Call(ctx, Datum(lhs.data()), Datum(rhs.data()), &result_datum)); - *result = result_datum.make_array(); - return Status::OK(); -} - -} // namespace compute -} // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/add.h b/cpp/src/arrow/compute/kernels/add.h deleted file mode 100644 index 19991aa4473..00000000000 --- a/cpp/src/arrow/compute/kernels/add.h +++ /dev/null @@ -1,77 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#pragma once - -#include - -#include "arrow/compute/kernel.h" -#include "arrow/status.h" -#include "arrow/util/visibility.h" - -namespace arrow { - -class Array; - -namespace compute { - -class FunctionContext; - -/// \brief Summarizes two arrays. -/// -/// Summarizes two arrays with the same length. -/// The output is an array with same length and type as input. -/// Types of both input arrays should be equal -/// -/// For example given lhs = [1, null, 3], rhs = [4, 5, 6], the output -/// will be [5, null, 7] -/// -/// \param[in] ctx the FunctionContext -/// \param[in] lhs the first array -/// \param[in] rhs the second array -/// \param[out] result the sum of first and second arrays - -ARROW_EXPORT -Status Add(FunctionContext* ctx, const Array& lhs, const Array& rhs, - std::shared_ptr* result); - -/// \brief BinaryKernel implementing Add operation -class ARROW_EXPORT AddKernel : public BinaryKernel { - public: - /// \brief BinaryKernel interface - /// - /// delegates to subclasses via Add() - Status Call(FunctionContext* ctx, const Datum& lhs, const Datum& rhs, - Datum* out) override = 0; - - /// \brief output type of this kernel - std::shared_ptr out_type() const override = 0; - - /// \brief single-array implementation - virtual Status Add(FunctionContext* ctx, const std::shared_ptr& lhs, - const std::shared_ptr& rhs, - std::shared_ptr* result) = 0; - - /// \brief factory for Add - /// - /// \param[in] value_type constructed AddKernel - /// \param[out] out created kernel - static Status Make(const std::shared_ptr& value_type, - std::unique_ptr* out); -}; -} // namespace compute -} // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/aggregate.cc b/cpp/src/arrow/compute/kernels/aggregate.cc deleted file mode 100644 index 90337588615..00000000000 --- a/cpp/src/arrow/compute/kernels/aggregate.cc +++ /dev/null @@ -1,88 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#include - -#include "arrow/compute/context.h" -#include "arrow/compute/kernels/aggregate.h" - -namespace arrow { -namespace compute { - -// Helper class that properly invokes destructor when state goes out of scope. -class ManagedAggregateState { - public: - ManagedAggregateState(std::shared_ptr& desc, - std::shared_ptr&& buffer) - : desc_(desc), state_(buffer) { - desc_->New(state_->mutable_data()); - } - - ~ManagedAggregateState() { desc_->Delete(state_->mutable_data()); } - - void* mutable_data() { return state_->mutable_data(); } - - static std::shared_ptr Make( - std::shared_ptr& desc, MemoryPool* pool) { - auto maybe_buf = AllocateBuffer(desc->Size(), pool); - if (!maybe_buf.ok()) { - return nullptr; - } - return std::make_shared(desc, *std::move(maybe_buf)); - } - - private: - std::shared_ptr desc_; - std::shared_ptr state_; -}; - -Status AggregateUnaryKernel::Call(FunctionContext* ctx, const Datum& input, Datum* out) { - if (!input.is_arraylike()) { - return Status::Invalid("AggregateKernel expects Array or ChunkedArray datum"); - } - auto state = ManagedAggregateState::Make(aggregate_function_, ctx->memory_pool()); - if (!state) { - return Status::OutOfMemory("AggregateState allocation failed"); - } - - if (input.is_array()) { - auto array = input.make_array(); - RETURN_NOT_OK(aggregate_function_->Consume(*array, state->mutable_data())); - } else { - auto chunked_array = input.chunked_array(); - for (int i = 0; i < chunked_array->num_chunks(); i++) { - auto tmp_state = - ManagedAggregateState::Make(aggregate_function_, ctx->memory_pool()); - if (!tmp_state) { - return Status::OutOfMemory("AggregateState allocation failed"); - } - RETURN_NOT_OK(aggregate_function_->Consume(*chunked_array->chunk(i), - tmp_state->mutable_data())); - RETURN_NOT_OK( - aggregate_function_->Merge(tmp_state->mutable_data(), state->mutable_data())); - } - } - - return aggregate_function_->Finalize(state->mutable_data(), out); -} - -std::shared_ptr AggregateUnaryKernel::out_type() const { - return aggregate_function_->out_type(); -} - -} // namespace compute -} // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/aggregate.h b/cpp/src/arrow/compute/kernels/aggregate.h deleted file mode 100644 index f342e31a0b6..00000000000 --- a/cpp/src/arrow/compute/kernels/aggregate.h +++ /dev/null @@ -1,115 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#pragma once - -#include - -#include "arrow/compute/kernel.h" - -namespace arrow { - -class Array; -class Status; - -namespace compute { - -class FunctionContext; -struct Datum; - -/// AggregateFunction is an interface for Aggregates -/// -/// An aggregates transforms an array into single result called a state via the -/// Consume method.. State supports the merge operation via the Merge method. -/// State can be sealed into a final result via the Finalize method. -// -/// State ownership is handled by callers, thus the interface exposes 3 methods -/// for the caller to manage memory: -/// - Size -/// - New (placement new constructor invocation) -/// - Delete (state destructor) -/// -/// Design inspired by ClickHouse aggregate functions. -class AggregateFunction { - public: - /// \brief Consume an array into a state. - virtual Status Consume(const Array& input, void* state) const = 0; - - /// \brief Merge states. - virtual Status Merge(const void* src, void* dst) const = 0; - - /// \brief Convert state into a final result. - virtual Status Finalize(const void* src, Datum* output) const = 0; - - virtual ~AggregateFunction() {} - - virtual std::shared_ptr out_type() const = 0; - - /// State management methods. - virtual int64_t Size() const = 0; - virtual void New(void* ptr) const = 0; - virtual void Delete(void* ptr) const = 0; -}; - -/// AggregateFunction partial implementation for static type state -template -class AggregateFunctionStaticState : public AggregateFunction { - virtual Status Consume(const Array& input, State* state) const = 0; - virtual Status Merge(const State& src, State* dst) const = 0; - virtual Status Finalize(const State& src, Datum* output) const = 0; - - Status Consume(const Array& input, void* state) const final { - return Consume(input, static_cast(state)); - } - - Status Merge(const void* src, void* dst) const final { - return Merge(*static_cast(src), static_cast(dst)); - } - - /// \brief Convert state into a final result. - Status Finalize(const void* src, Datum* output) const final { - return Finalize(*static_cast(src), output); - } - - int64_t Size() const final { return sizeof(State); } - - void New(void* ptr) const final { - // By using placement-new syntax, the constructor of the State is invoked - // in the memory location defined by the caller. This only supports State - // with a parameter-less constructor. - new (ptr) State; - } - - void Delete(void* ptr) const final { static_cast(ptr)->~State(); } -}; - -/// \brief UnaryKernel implemented by an AggregateState -class ARROW_EXPORT AggregateUnaryKernel : public UnaryKernel { - public: - explicit AggregateUnaryKernel(std::shared_ptr& aggregate) - : aggregate_function_(aggregate) {} - - Status Call(FunctionContext* ctx, const Datum& input, Datum* out) override; - - std::shared_ptr out_type() const override; - - private: - std::shared_ptr aggregate_function_; -}; - -} // namespace compute -} // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/aggregate_basic.cc b/cpp/src/arrow/compute/kernels/aggregate_basic.cc new file mode 100644 index 00000000000..14f9be3f93e --- /dev/null +++ b/cpp/src/arrow/compute/kernels/aggregate_basic.cc @@ -0,0 +1,523 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/compute/api_aggregate.h" +#include "arrow/compute/kernels/aggregate_internal.h" +#include "arrow/compute/kernels/common.h" + +namespace arrow { +namespace compute { + +namespace { + +struct ScalarAggregator : public KernelState { + virtual void Consume(KernelContext* ctx, const ExecBatch& batch) = 0; + virtual void MergeFrom(KernelContext* ctx, const KernelState& src) = 0; + virtual void Finalize(KernelContext* ctx, Datum* out) = 0; +}; + +void AggregateConsume(KernelContext* ctx, const ExecBatch& batch) { + checked_cast(ctx->state())->Consume(ctx, batch); +} + +void AggregateMerge(KernelContext* ctx, const KernelState& src, KernelState* dst) { + checked_cast(dst)->MergeFrom(ctx, src); +} + +void AggregateFinalize(KernelContext* ctx, Datum* out) { + checked_cast(ctx->state())->Finalize(ctx, out); +} + +// ---------------------------------------------------------------------- +// Count implementation + +struct CountImpl : public ScalarAggregator { + explicit CountImpl(CountOptions options) + : options(std::move(options)), non_nulls(0), nulls(0) {} + + void Consume(KernelContext*, const ExecBatch& batch) override { + const ArrayData& input = *batch[0].array(); + const int64_t nulls = input.GetNullCount(); + this->nulls += nulls; + this->non_nulls += input.length - nulls; + } + + void MergeFrom(KernelContext*, const KernelState& src) override { + const auto& other_state = checked_cast(src); + this->non_nulls += other_state.non_nulls; + this->nulls += other_state.nulls; + } + + void Finalize(KernelContext* ctx, Datum* out) override { + const auto& state = checked_cast(*ctx->state()); + switch (state.options.count_mode) { + case CountOptions::COUNT_ALL: + *out = Datum(state.non_nulls); + break; + case CountOptions::COUNT_NULL: + *out = Datum(state.nulls); + break; + default: + ctx->SetStatus(Status::Invalid("Unknown CountOptions encountered")); + break; + } + } + + CountOptions options; + int64_t non_nulls = 0; + int64_t nulls = 0; +}; + +std::unique_ptr CountInit(KernelContext*, const KernelInitArgs& args) { + return std::unique_ptr( + new CountImpl(static_cast(*args.options))); +} + +// ---------------------------------------------------------------------- +// Sum implementation + +template ::Type> +struct SumState { + using ThisType = SumState; + using T = typename TypeTraits::CType; + using ArrayType = typename TypeTraits::ArrayType; + + // A small number of elements rounded to the next cacheline. This should + // amount to a maximum of 4 cachelines when dealing with 8 bytes elements. + static constexpr int64_t kTinyThreshold = 32; + static_assert(kTinyThreshold >= (2 * CHAR_BIT) + 1, + "ConsumeSparse requires 3 bytes of null bitmap, and 17 is the" + "required minimum number of bits/elements to cover 3 bytes."); + + ThisType operator+(const ThisType& rhs) const { + return ThisType(this->count + rhs.count, this->sum + rhs.sum); + } + + ThisType& operator+=(const ThisType& rhs) { + this->count += rhs.count; + this->sum += rhs.sum; + + return *this; + } + + public: + void Consume(const Array& input) { + const ArrayType& array = static_cast(input); + if (input.null_count() == 0) { + (*this) += ConsumeDense(array); + } else if (input.length() <= kTinyThreshold) { + // In order to simplify ConsumeSparse implementation (requires at least 3 + // bytes of bitmap data), small arrays are handled differently. + (*this) += ConsumeTiny(array); + } else { + (*this) += ConsumeSparse(array); + } + } + + size_t count = 0; + typename SumType::c_type sum = 0; + + private: + ThisType ConsumeDense(const ArrayType& array) const { + ThisType local; + const auto values = array.raw_values(); + const int64_t length = array.length(); + for (int64_t i = 0; i < length; i++) { + local.sum += values[i]; + } + local.count = length; + return local; + } + + ThisType ConsumeTiny(const ArrayType& array) const { + ThisType local; + + BitmapReader reader(array.null_bitmap_data(), array.offset(), array.length()); + const auto values = array.raw_values(); + for (int64_t i = 0; i < array.length(); i++) { + if (reader.IsSet()) { + local.sum += values[i]; + local.count++; + } + reader.Next(); + } + + return local; + } + + // While this is not branchless, gcc needs this to be in a different function + // for it to generate cmov which ends to be slightly faster than + // multiplication but safe for handling NaN with doubles. + inline T MaskedValue(bool valid, T value) const { return valid ? value : 0; } + + inline ThisType UnrolledSum(uint8_t bits, const T* values) const { + ThisType local; + + if (bits < 0xFF) { + // Some nulls + for (size_t i = 0; i < 8; i++) { + local.sum += MaskedValue(bits & (1U << i), values[i]); + } + local.count += BitUtil::kBytePopcount[bits]; + } else { + // No nulls + for (size_t i = 0; i < 8; i++) { + local.sum += values[i]; + } + local.count += 8; + } + + return local; + } + + ThisType ConsumeSparse(const ArrayType& array) const { + ThisType local; + + // Sliced bitmaps on non-byte positions induce problem with the branchless + // unrolled technique. Thus extra padding is added on both left and right + // side of the slice such that both ends are byte-aligned. The first and + // last bitmap are properly masked to ignore extra values induced by + // padding. + // + // The execution is divided in 3 sections. + // + // 1. Compute the sum of the first masked byte. + // 2. Compute the sum of the middle bytes + // 3. Compute the sum of the last masked byte. + + const int64_t length = array.length(); + const int64_t offset = array.offset(); + + // The number of bytes covering the range, this includes partial bytes. + // This number bounded by `<= (length / 8) + 2`, e.g. a possible extra byte + // on the left, and on the right. + const int64_t covering_bytes = BitUtil::CoveringBytes(offset, length); + DCHECK_GE(covering_bytes, 3); + + // Align values to the first batch of 8 elements. Note that raw_values() is + // already adjusted with the offset, thus we rewind a little to align to + // the closest 8-batch offset. + const auto values = array.raw_values() - (offset % 8); + + // Align bitmap at the first consumable byte. + const auto bitmap = array.null_bitmap_data() + BitUtil::RoundDown(offset, 8) / 8; + + // Consume the first (potentially partial) byte. + const uint8_t first_mask = BitUtil::kTrailingBitmask[offset % 8]; + local += UnrolledSum(bitmap[0] & first_mask, values); + + // Consume the (full) middle bytes. The loop iterates in unit of + // batches of 8 values and 1 byte of bitmap. + for (int64_t i = 1; i < covering_bytes - 1; i++) { + local += UnrolledSum(bitmap[i], &values[i * 8]); + } + + // Consume the last (potentially partial) byte. + const int64_t last_idx = covering_bytes - 1; + const uint8_t last_mask = BitUtil::kPrecedingWrappingBitmask[(offset + length) % 8]; + local += UnrolledSum(bitmap[last_idx] & last_mask, &values[last_idx * 8]); + + return local; + } +}; + +template +struct SumImpl : public ScalarAggregator { + using ArrayType = typename TypeTraits::ArrayType; + using ThisType = SumImpl; + using SumType = typename FindAccumulatorType::Type; + using OutputType = typename TypeTraits::ScalarType; + + void Consume(KernelContext*, const ExecBatch& batch) override { + this->state.Consume(ArrayType(batch[0].array())); + } + + void MergeFrom(KernelContext*, const KernelState& src) override { + const auto& other = checked_cast(src); + this->state += other.state; + } + + void Finalize(KernelContext*, Datum* out) override { + if (state.count == 0) { + out->value = std::make_shared(); + } else { + out->value = MakeScalar(state.sum); + } + } + + SumState state; +}; + +template +struct MeanImpl : public SumImpl { + void Finalize(KernelContext*, Datum* out) override { + const bool is_valid = this->state.count > 0; + const double divisor = static_cast(is_valid ? this->state.count : 1UL); + const double mean = static_cast(this->state.sum) / divisor; + + if (!is_valid) { + out->value = std::make_shared(); + } else { + out->value = std::make_shared(mean); + } + } +}; + +template