diff --git a/cpp/src/arrow/chunked_array.cc b/cpp/src/arrow/chunked_array.cc index 142bd0d8c89..0c954e72e50 100644 --- a/cpp/src/arrow/chunked_array.cc +++ b/cpp/src/arrow/chunked_array.cc @@ -145,6 +145,16 @@ bool ChunkedArray::ApproxEquals(const ChunkedArray& other, .ok(); } +Result> ChunkedArray::GetScalar(int64_t index) const { + for (const auto& chunk : chunks_) { + if (index < chunk->length()) { + return chunk->GetScalar(index); + } + index -= chunk->length(); + } + return Status::Invalid("index out of bounds"); +} + std::shared_ptr ChunkedArray::Slice(int64_t offset, int64_t length) const { ARROW_CHECK_LE(offset, length_) << "Slice offset greater than array length"; bool offset_equals_length = offset == length_; diff --git a/cpp/src/arrow/chunked_array.h b/cpp/src/arrow/chunked_array.h index 86d9b2b51fe..0bf0c66c1ad 100644 --- a/cpp/src/arrow/chunked_array.h +++ b/cpp/src/arrow/chunked_array.h @@ -130,6 +130,9 @@ class ARROW_EXPORT ChunkedArray { const std::shared_ptr& type() const { return type_; } + /// \brief Return a Scalar containing the value of this array at index + Result> GetScalar(int64_t index) const; + /// \brief Determine if two chunked arrays are equal. /// /// Two chunked arrays can be equal only if they have equal datatypes. diff --git a/cpp/src/arrow/chunked_array_test.cc b/cpp/src/arrow/chunked_array_test.cc index c5907549fe4..c41a4c2bd8b 100644 --- a/cpp/src/arrow/chunked_array_test.cc +++ b/cpp/src/arrow/chunked_array_test.cc @@ -22,6 +22,7 @@ #include #include "arrow/chunked_array.h" +#include "arrow/scalar.h" #include "arrow/status.h" #include "arrow/testing/gtest_common.h" #include "arrow/testing/gtest_util.h" @@ -241,4 +242,25 @@ TEST_F(TestChunkedArray, View) { AssertChunkedEqual(*expected, *result); } +TEST_F(TestChunkedArray, GetScalar) { + auto ty = int32(); + ArrayVector chunks{ArrayFromJSON(ty, "[6, 7, null]"), ArrayFromJSON(ty, "[]"), + ArrayFromJSON(ty, "[null]"), ArrayFromJSON(ty, "[3, 4, 5]")}; + ChunkedArray carr(chunks); + + auto check_scalar = [](const ChunkedArray& array, int64_t index, + const Scalar& expected) { + ASSERT_OK_AND_ASSIGN(auto actual, array.GetScalar(index)); + AssertScalarsEqual(expected, *actual, /*verbose=*/true); + }; + + check_scalar(carr, 0, **MakeScalar(ty, 6)); + check_scalar(carr, 2, *MakeNullScalar(ty)); + check_scalar(carr, 3, *MakeNullScalar(ty)); + check_scalar(carr, 4, **MakeScalar(ty, 3)); + check_scalar(carr, 6, **MakeScalar(ty, 5)); + + ASSERT_RAISES(Invalid, carr.GetScalar(7)); +} + } // namespace arrow diff --git a/cpp/src/arrow/compute/api_vector.cc b/cpp/src/arrow/compute/api_vector.cc index 34ee0599c3d..1fc6b787458 100644 --- a/cpp/src/arrow/compute/api_vector.cc +++ b/cpp/src/arrow/compute/api_vector.cc @@ -39,8 +39,11 @@ using internal::checked_cast; using internal::checked_pointer_cast; namespace internal { + using compute::DictionaryEncodeOptions; using compute::FilterOptions; +using compute::NullPlacement; + template <> struct EnumTraits : BasicEnumTraits return ""; } }; +template <> +struct EnumTraits + : BasicEnumTraits { + static std::string name() { return "NullPlacement"; } + static std::string value_name(NullPlacement value) { + switch (value) { + case NullPlacement::AtStart: + return "AtStart"; + case NullPlacement::AtEnd: + return "AtEnd"; + } + return ""; + } +}; + } // namespace internal namespace compute { @@ -106,11 +124,14 @@ static auto kDictionaryEncodeOptionsType = GetFunctionOptionsType(DataMember( "null_encoding_behavior", &DictionaryEncodeOptions::null_encoding_behavior)); static auto kArraySortOptionsType = GetFunctionOptionsType( - DataMember("order", &ArraySortOptions::order)); -static auto kSortOptionsType = - GetFunctionOptionsType(DataMember("sort_keys", &SortOptions::sort_keys)); + DataMember("order", &ArraySortOptions::order), + DataMember("null_placement", &ArraySortOptions::null_placement)); +static auto kSortOptionsType = GetFunctionOptionsType( + DataMember("sort_keys", &SortOptions::sort_keys), + DataMember("null_placement", &SortOptions::null_placement)); static auto kPartitionNthOptionsType = GetFunctionOptionsType( - DataMember("pivot", &PartitionNthOptions::pivot)); + DataMember("pivot", &PartitionNthOptions::pivot), + DataMember("null_placement", &PartitionNthOptions::null_placement)); static auto kSelectKOptionsType = GetFunctionOptionsType( DataMember("k", &SelectKOptions::k), DataMember("sort_keys", &SelectKOptions::sort_keys)); @@ -131,16 +152,22 @@ DictionaryEncodeOptions::DictionaryEncodeOptions(NullEncodingBehavior null_encod null_encoding_behavior(null_encoding) {} constexpr char DictionaryEncodeOptions::kTypeName[]; -ArraySortOptions::ArraySortOptions(SortOrder order) - : FunctionOptions(internal::kArraySortOptionsType), order(order) {} +ArraySortOptions::ArraySortOptions(SortOrder order, NullPlacement null_placement) + : FunctionOptions(internal::kArraySortOptionsType), + order(order), + null_placement(null_placement) {} constexpr char ArraySortOptions::kTypeName[]; -SortOptions::SortOptions(std::vector sort_keys) - : FunctionOptions(internal::kSortOptionsType), sort_keys(std::move(sort_keys)) {} +SortOptions::SortOptions(std::vector sort_keys, NullPlacement null_placement) + : FunctionOptions(internal::kSortOptionsType), + sort_keys(std::move(sort_keys)), + null_placement(null_placement) {} constexpr char SortOptions::kTypeName[]; -PartitionNthOptions::PartitionNthOptions(int64_t pivot) - : FunctionOptions(internal::kPartitionNthOptionsType), pivot(pivot) {} +PartitionNthOptions::PartitionNthOptions(int64_t pivot, NullPlacement null_placement) + : FunctionOptions(internal::kPartitionNthOptionsType), + pivot(pivot), + null_placement(null_placement) {} constexpr char PartitionNthOptions::kTypeName[]; SelectKOptions::SelectKOptions(int64_t k, std::vector sort_keys) @@ -164,6 +191,14 @@ void RegisterVectorOptions(FunctionRegistry* registry) { // ---------------------------------------------------------------------- // Direct exec interface to kernels +Result> NthToIndices(const Array& values, + const PartitionNthOptions& options, + ExecContext* ctx) { + ARROW_ASSIGN_OR_RAISE(Datum result, CallFunction("partition_nth_indices", + {Datum(values)}, &options, ctx)); + return result.make_array(); +} + Result> NthToIndices(const Array& values, int64_t n, ExecContext* ctx) { PartitionNthOptions options(/*pivot=*/n); @@ -185,6 +220,14 @@ Result ReplaceWithMask(const Datum& values, const Datum& mask, return CallFunction("replace_with_mask", {values, mask, replacements}, ctx); } +Result> SortIndices(const Array& values, + const ArraySortOptions& options, + ExecContext* ctx) { + ARROW_ASSIGN_OR_RAISE( + Datum result, CallFunction("array_sort_indices", {Datum(values)}, &options, ctx)); + return result.make_array(); +} + Result> SortIndices(const Array& values, SortOrder order, ExecContext* ctx) { ArraySortOptions options(order); @@ -193,6 +236,15 @@ Result> SortIndices(const Array& values, SortOrder order, return result.make_array(); } +Result> SortIndices(const ChunkedArray& chunked_array, + const ArraySortOptions& array_options, + ExecContext* ctx) { + SortOptions options({SortKey("", array_options.order)}, array_options.null_placement); + ARROW_ASSIGN_OR_RAISE( + Datum result, CallFunction("sort_indices", {Datum(chunked_array)}, &options, ctx)); + return result.make_array(); +} + Result> SortIndices(const ChunkedArray& chunked_array, SortOrder order, ExecContext* ctx) { SortOptions options({SortKey("not-used", order)}); diff --git a/cpp/src/arrow/compute/api_vector.h b/cpp/src/arrow/compute/api_vector.h index c79c8fc9858..a91cf91df06 100644 --- a/cpp/src/arrow/compute/api_vector.h +++ b/cpp/src/arrow/compute/api_vector.h @@ -86,6 +86,15 @@ enum class SortOrder { Descending, }; +enum class NullPlacement { + /// Place nulls and NaNs before any non-null values. + /// NaNs will come after nulls. + AtStart, + /// Place nulls and NaNs after any non-null values. + /// NaNs will come before nulls. + AtEnd, +}; + /// \brief One sort key for PartitionNthIndices (TODO) and SortIndices class ARROW_EXPORT SortKey : public util::EqualityComparable { public: @@ -106,22 +115,28 @@ class ARROW_EXPORT SortKey : public util::EqualityComparable { class ARROW_EXPORT ArraySortOptions : public FunctionOptions { public: - explicit ArraySortOptions(SortOrder order = SortOrder::Ascending); + explicit ArraySortOptions(SortOrder order = SortOrder::Ascending, + NullPlacement null_placement = NullPlacement::AtEnd); constexpr static char const kTypeName[] = "ArraySortOptions"; static ArraySortOptions Defaults() { return ArraySortOptions(); } /// Sorting order SortOrder order; + /// Whether nulls and NaNs are placed at the start or at the end + NullPlacement null_placement; }; class ARROW_EXPORT SortOptions : public FunctionOptions { public: - explicit SortOptions(std::vector sort_keys = {}); + explicit SortOptions(std::vector sort_keys = {}, + NullPlacement null_placement = NullPlacement::AtEnd); constexpr static char const kTypeName[] = "SortOptions"; static SortOptions Defaults() { return SortOptions(); } /// Column key(s) to order by and how to order by these sort keys. std::vector sort_keys; + /// Whether nulls and NaNs are placed at the start or at the end + NullPlacement null_placement; }; /// \brief SelectK options @@ -162,12 +177,15 @@ class ARROW_EXPORT SelectKOptions : public FunctionOptions { /// \brief Partitioning options for NthToIndices class ARROW_EXPORT PartitionNthOptions : public FunctionOptions { public: - explicit PartitionNthOptions(int64_t pivot); + explicit PartitionNthOptions(int64_t pivot, + NullPlacement null_placement = NullPlacement::AtEnd); PartitionNthOptions() : PartitionNthOptions(0) {} constexpr static char const kTypeName[] = "PartitionNthOptions"; /// The index into the equivalent sorted array of the partition pivot element. int64_t pivot; + /// Whether nulls and NaNs are partitioned at the start or at the end + NullPlacement null_placement; }; /// @} @@ -273,8 +291,7 @@ Result DropNull(const Datum& values, ExecContext* ctx = NULLPTR); ARROW_EXPORT Result> DropNull(const Array& values, ExecContext* ctx = NULLPTR); -/// \brief Returns indices that partition an array around n-th -/// sorted element. +/// \brief Return indices that partition an array around n-th sorted element. /// /// Find index of n-th(0 based) smallest value and perform indirect /// partition of an array around that element. Output indices[0 ~ n-1] @@ -291,14 +308,27 @@ ARROW_EXPORT Result> NthToIndices(const Array& values, int64_t n, ExecContext* ctx = NULLPTR); -/// \brief Returns the indices that would select the first `k` elements of the array in -/// the specified order. +/// \brief Return indices that partition an array around n-th sorted element. +/// +/// This overload takes a PartitionNthOptions specifiying the pivot index +/// and the null handling. +/// +/// \param[in] values array to be partitioned +/// \param[in] options options including pivot index and null handling +/// \param[in] ctx the function execution context, optional +/// \return offsets indices that would partition an array +ARROW_EXPORT +Result> NthToIndices(const Array& values, + const PartitionNthOptions& options, + ExecContext* ctx = NULLPTR); + +/// \brief Return indices that would select the first `k` elements. /// -// Perform an indirect sort of the datum, keeping only the first `k` elements. The output -// array will contain indices such that the item indicated by the k-th index will be in -// the position it would be if the datum were sorted by `options.sort_keys`. However, -// indices of null values will not be part of the output. The sort is not guaranteed to be -// stable. +/// Perform an indirect sort of the datum, keeping only the first `k` elements. The output +/// array will contain indices such that the item indicated by the k-th index will be in +/// the position it would be if the datum were sorted by `options.sort_keys`. However, +/// indices of null values will not be part of the output. The sort is not guaranteed to +/// be stable. /// /// \param[in] datum datum to be partitioned /// \param[in] options options @@ -309,8 +339,7 @@ Result> SelectKUnstable(const Datum& datum, const SelectKOptions& options, ExecContext* ctx = NULLPTR); -/// \brief Returns the indices that would sort an array in the -/// specified order. +/// \brief Return the indices that would sort an array. /// /// Perform an indirect sort of array. The output array will contain /// indices that would sort an array, which would be the same length @@ -330,8 +359,21 @@ Result> SortIndices(const Array& array, SortOrder order = SortOrder::Ascending, ExecContext* ctx = NULLPTR); -/// \brief Returns the indices that would sort a chunked array in the -/// specified order. +/// \brief Return the indices that would sort an array. +/// +/// This overload takes a ArraySortOptions specifiying the sort order +/// and the null handling. +/// +/// \param[in] array array to sort +/// \param[in] options options including sort order and null handling +/// \param[in] ctx the function execution context, optional +/// \return offsets indices that would sort an array +ARROW_EXPORT +Result> SortIndices(const Array& array, + const ArraySortOptions& options, + ExecContext* ctx = NULLPTR); + +/// \brief Return the indices that would sort a chunked array. /// /// Perform an indirect sort of chunked array. The output array will /// contain indices that would sort a chunked array, which would be @@ -351,14 +393,28 @@ Result> SortIndices(const ChunkedArray& chunked_array, SortOrder order = SortOrder::Ascending, ExecContext* ctx = NULLPTR); -/// \brief Returns the indices that would sort an input in the +/// \brief Return the indices that would sort a chunked array. +/// +/// This overload takes a ArraySortOptions specifiying the sort order +/// and the null handling. +/// +/// \param[in] chunked_array chunked array to sort +/// \param[in] options options including sort order and null handling +/// \param[in] ctx the function execution context, optional +/// \return offsets indices that would sort an array +ARROW_EXPORT +Result> SortIndices(const ChunkedArray& chunked_array, + const ArraySortOptions& options, + ExecContext* ctx = NULLPTR); + +/// \brief Return the indices that would sort an input in the /// specified order. Input is one of array, chunked array record batch /// or table. /// /// Perform an indirect sort of input. The output array will contain /// indices that would sort an input, which would be the same length -/// as input. Nulls will be stably partitioned to the end of the -/// output regardless of order. +/// as input. Nulls will be stably partitioned to the start or to the end +/// of the output depending on SortOrder::null_placement. /// /// For example given input (table) = { /// "column1": [[null, 1], [ 3, null, 2, 1]], diff --git a/cpp/src/arrow/compute/exec/plan_test.cc b/cpp/src/arrow/compute/exec/plan_test.cc index 2845bb0698c..1f88e0185d1 100644 --- a/cpp/src/arrow/compute/exec/plan_test.cc +++ b/cpp/src/arrow/compute/exec/plan_test.cc @@ -281,7 +281,7 @@ GroupByNode{"aggregate", inputs=[groupby: "project"], outputs=["filter"], keys=[ hash_count(multiply(i32, 2), {mode=NON_NULL}), ]} FilterNode{"filter", inputs=[target: "aggregate"], outputs=["order_by_sink"], filter=(sum(multiply(i32, 2)) > 10)} -OrderBySinkNode{"order_by_sink", inputs=[collected: "filter"], by={sort_keys=[sum(multiply(i32, 2)) ASC]}} +OrderBySinkNode{"order_by_sink", inputs=[collected: "filter"], by={sort_keys=[sum(multiply(i32, 2)) ASC], null_placement=AtEnd}} )a"); ASSERT_OK_AND_ASSIGN(plan, ExecPlan::Make()); diff --git a/cpp/src/arrow/compute/kernels/vector_sort.cc b/cpp/src/arrow/compute/kernels/vector_sort.cc index a2bbc30ed42..f8be3b12a87 100644 --- a/cpp/src/arrow/compute/kernels/vector_sort.cc +++ b/cpp/src/arrow/compute/kernels/vector_sort.cc @@ -222,99 +222,181 @@ struct NullTraits> { static constexpr bool has_null_like_values = true; }; -// Move nulls (not null-like values) to end of array. Return where null starts. +struct NullPartitionResult { + uint64_t* non_nulls_begin; + uint64_t* non_nulls_end; + uint64_t* nulls_begin; + uint64_t* nulls_end; + + uint64_t* overall_begin() const { return std::min(nulls_begin, non_nulls_begin); } + + uint64_t* overall_end() const { return std::max(nulls_end, non_nulls_end); } + + int64_t non_null_count() const { return non_nulls_end - non_nulls_begin; } + + int64_t null_count() const { return nulls_end - nulls_begin; } + + static NullPartitionResult NoNulls(uint64_t* indices_begin, uint64_t* indices_end, + NullPlacement null_placement) { + if (null_placement == NullPlacement::AtStart) { + return {indices_begin, indices_end, indices_begin, indices_begin}; + } else { + return {indices_begin, indices_end, indices_end, indices_end}; + } + } + + static NullPartitionResult NullsAtEnd(uint64_t* indices_begin, uint64_t* indices_end, + uint64_t* midpoint) { + DCHECK_GE(midpoint, indices_begin); + DCHECK_LE(midpoint, indices_end); + return {indices_begin, midpoint, midpoint, indices_end}; + } + + static NullPartitionResult NullsAtStart(uint64_t* indices_begin, uint64_t* indices_end, + uint64_t* midpoint) { + DCHECK_GE(midpoint, indices_begin); + DCHECK_LE(midpoint, indices_end); + return {midpoint, indices_end, indices_begin, midpoint}; + } +}; + +// Move nulls (not null-like values) to end of array. // // `offset` is used when this is called on a chunk of a chunked array template -uint64_t* PartitionNullsOnly(uint64_t* indices_begin, uint64_t* indices_end, - const Array& values, int64_t offset) { +NullPartitionResult PartitionNullsOnly(uint64_t* indices_begin, uint64_t* indices_end, + const Array& values, int64_t offset, + NullPlacement null_placement) { if (values.null_count() == 0) { - return indices_end; + return NullPartitionResult::NoNulls(indices_begin, indices_end, null_placement); } Partitioner partitioner; - return partitioner(indices_begin, indices_end, [&values, &offset](uint64_t ind) { - return !values.IsNull(ind - offset); - }); + if (null_placement == NullPlacement::AtStart) { + auto nulls_end = partitioner( + indices_begin, indices_end, + [&values, &offset](uint64_t ind) { return values.IsNull(ind - offset); }); + return NullPartitionResult::NullsAtStart(indices_begin, indices_end, nulls_end); + } else { + auto nulls_begin = partitioner( + indices_begin, indices_end, + [&values, &offset](uint64_t ind) { return !values.IsNull(ind - offset); }); + return NullPartitionResult::NullsAtEnd(indices_begin, indices_end, nulls_begin); + } } // For chunked array. template -uint64_t* PartitionNullsOnly(uint64_t* indices_begin, uint64_t* indices_end, - const std::vector& arrays, - int64_t null_count) { +NullPartitionResult PartitionNullsOnly(uint64_t* indices_begin, uint64_t* indices_end, + const ChunkedArrayResolver& resolver, + int64_t null_count, NullPlacement null_placement) { if (null_count == 0) { - return indices_end; + return NullPartitionResult::NoNulls(indices_begin, indices_end, null_placement); } - ChunkedArrayResolver resolver(arrays); Partitioner partitioner; - return partitioner(indices_begin, indices_end, [&](uint64_t ind) { - const auto chunk = resolver.Resolve(ind); - return !chunk.IsNull(); - }); + if (null_placement == NullPlacement::AtStart) { + auto nulls_end = partitioner(indices_begin, indices_end, [&](uint64_t ind) { + const auto chunk = resolver.Resolve(ind); + return chunk.IsNull(); + }); + return NullPartitionResult::NullsAtStart(indices_begin, indices_end, nulls_end); + } else { + auto nulls_begin = partitioner(indices_begin, indices_end, [&](uint64_t ind) { + const auto chunk = resolver.Resolve(ind); + return !chunk.IsNull(); + }); + return NullPartitionResult::NullsAtEnd(indices_begin, indices_end, nulls_begin); + } } -// Move non-null null-like values to end of array. Return where null-like starts. +// Move non-null null-like values to end of array. // // `offset` is used when this is called on a chunk of a chunked array template -enable_if_t::value, uint64_t*> +enable_if_t::value, NullPartitionResult> PartitionNullLikes(uint64_t* indices_begin, uint64_t* indices_end, - const ArrayType& values, int64_t offset) { - return indices_end; + const ArrayType& values, int64_t offset, + NullPlacement null_placement) { + return NullPartitionResult::NoNulls(indices_begin, indices_end, null_placement); } -// For chunked array. template -enable_if_t::value, uint64_t*> +enable_if_t::value, NullPartitionResult> PartitionNullLikes(uint64_t* indices_begin, uint64_t* indices_end, - const std::vector& arrays, int64_t null_count) { - return indices_end; + const ChunkedArrayResolver& resolver, NullPlacement null_placement) { + return NullPartitionResult::NoNulls(indices_begin, indices_end, null_placement); } template -enable_if_t::value, uint64_t*> +enable_if_t::value, NullPartitionResult> PartitionNullLikes(uint64_t* indices_begin, uint64_t* indices_end, - const ArrayType& values, int64_t offset) { + const ArrayType& values, int64_t offset, + NullPlacement null_placement) { Partitioner partitioner; - return partitioner(indices_begin, indices_end, [&values, &offset](uint64_t ind) { - return !std::isnan(values.GetView(ind - offset)); - }); + if (null_placement == NullPlacement::AtStart) { + auto null_likes_end = + partitioner(indices_begin, indices_end, [&values, &offset](uint64_t ind) { + return std::isnan(values.GetView(ind - offset)); + }); + return NullPartitionResult::NullsAtStart(indices_begin, indices_end, null_likes_end); + } else { + auto null_likes_begin = + partitioner(indices_begin, indices_end, [&values, &offset](uint64_t ind) { + return !std::isnan(values.GetView(ind - offset)); + }); + return NullPartitionResult::NullsAtEnd(indices_begin, indices_end, null_likes_begin); + } } template -enable_if_t::value, uint64_t*> +enable_if_t::value, NullPartitionResult> PartitionNullLikes(uint64_t* indices_begin, uint64_t* indices_end, - const std::vector& arrays, int64_t null_count) { + const ChunkedArrayResolver& resolver, NullPlacement null_placement) { Partitioner partitioner; - ChunkedArrayResolver resolver(arrays); - return partitioner(indices_begin, indices_end, [&](uint64_t ind) { - const auto chunk = resolver.Resolve(ind); - return !std::isnan(chunk.Value()); - }); + if (null_placement == NullPlacement::AtStart) { + auto null_likes_end = partitioner(indices_begin, indices_end, [&](uint64_t ind) { + const auto chunk = resolver.Resolve(ind); + return std::isnan(chunk.Value()); + }); + return NullPartitionResult::NullsAtStart(indices_begin, indices_end, null_likes_end); + } else { + auto null_likes_begin = partitioner(indices_begin, indices_end, [&](uint64_t ind) { + const auto chunk = resolver.Resolve(ind); + return !std::isnan(chunk.Value()); + }); + return NullPartitionResult::NullsAtEnd(indices_begin, indices_end, null_likes_begin); + } } -// Move nulls to end of array. Return where null starts. +// Move nulls to end of array. // // `offset` is used when this is called on a chunk of a chunked array template -uint64_t* PartitionNulls(uint64_t* indices_begin, uint64_t* indices_end, - const ArrayType& values, int64_t offset) { - // Partition nulls at end, and null-like values just before - uint64_t* nulls_begin = - PartitionNullsOnly(indices_begin, indices_end, values, offset); - return PartitionNullLikes(indices_begin, nulls_begin, values, - offset); +NullPartitionResult PartitionNulls(uint64_t* indices_begin, uint64_t* indices_end, + const ArrayType& values, int64_t offset, + NullPlacement null_placement) { + // Partition nulls at start (resp. end), and null-like values just before (resp. after) + NullPartitionResult p = PartitionNullsOnly(indices_begin, indices_end, + values, offset, null_placement); + NullPartitionResult q = PartitionNullLikes( + p.non_nulls_begin, p.non_nulls_end, values, offset, null_placement); + return NullPartitionResult{q.non_nulls_begin, q.non_nulls_end, + std::min(q.nulls_begin, p.nulls_begin), + std::max(q.nulls_end, p.nulls_end)}; } // For chunked array. template -uint64_t* PartitionNulls(uint64_t* indices_begin, uint64_t* indices_end, - const std::vector& arrays, int64_t null_count) { - // Partition nulls at end, and null-like values just before - uint64_t* nulls_begin = - PartitionNullsOnly(indices_begin, indices_end, arrays, null_count); - return PartitionNullLikes(indices_begin, nulls_begin, arrays, - null_count); +NullPartitionResult PartitionNulls(uint64_t* indices_begin, uint64_t* indices_end, + const ChunkedArrayResolver& resolver, + int64_t null_count, NullPlacement null_placement) { + // Partition nulls at start (resp. end), and null-like values just before (resp. after) + NullPartitionResult p = PartitionNullsOnly( + indices_begin, indices_end, resolver, null_count, null_placement); + NullPartitionResult q = PartitionNullLikes( + p.non_nulls_begin, p.non_nulls_end, resolver, null_placement); + return NullPartitionResult{q.non_nulls_begin, q.non_nulls_end, + std::min(q.nulls_begin, p.nulls_begin), + std::max(q.nulls_end, p.nulls_end)}; } // ---------------------------------------------------------------------- @@ -333,10 +415,11 @@ struct PartitionNthToIndices { if (ctx->state() == nullptr) { return Status::Invalid("NthToIndices requires PartitionNthOptions"); } + const auto& options = PartitionNthToIndicesState::Get(ctx); ArrayType arr(batch[0].array()); - int64_t pivot = PartitionNthToIndicesState::Get(ctx).pivot; + const int64_t pivot = options.pivot; if (pivot > arr.length()) { return Status::IndexError("NthToIndices index out of bound"); } @@ -347,11 +430,11 @@ struct PartitionNthToIndices { if (pivot == arr.length()) { return Status::OK(); } - auto nulls_begin = - PartitionNulls(out_begin, out_end, arr, 0); + const auto p = PartitionNulls( + out_begin, out_end, arr, 0, options.null_placement); auto nth_begin = out_begin + pivot; - if (nth_begin < nulls_begin) { - std::nth_element(out_begin, nth_begin, nulls_begin, + if (nth_begin >= p.non_nulls_begin && nth_begin < p.non_nulls_end) { + std::nth_element(p.non_nulls_begin, nth_begin, p.non_nulls_end, [&arr](uint64_t left, uint64_t right) { const auto lval = GetView::LogicalValue(arr.GetView(left)); const auto rval = GetView::LogicalValue(arr.GetView(right)); @@ -399,23 +482,24 @@ class ArrayCompareSorter { using GetView = GetViewType; public: - // Returns where null starts. - // // `offset` is used when this is called on a chunk of a chunked array - uint64_t* Sort(uint64_t* indices_begin, uint64_t* indices_end, const ArrayType& values, - int64_t offset, const ArraySortOptions& options) { - auto nulls_begin = PartitionNulls( - indices_begin, indices_end, values, offset); + NullPartitionResult Sort(uint64_t* indices_begin, uint64_t* indices_end, + const ArrayType& values, int64_t offset, + const ArraySortOptions& options) { + const auto p = PartitionNulls( + indices_begin, indices_end, values, offset, options.null_placement); if (options.order == SortOrder::Ascending) { std::stable_sort( - indices_begin, nulls_begin, [&values, &offset](uint64_t left, uint64_t right) { + p.non_nulls_begin, p.non_nulls_end, + [&values, &offset](uint64_t left, uint64_t right) { const auto lhs = GetView::LogicalValue(values.GetView(left - offset)); const auto rhs = GetView::LogicalValue(values.GetView(right - offset)); return lhs < rhs; }); } else { std::stable_sort( - indices_begin, nulls_begin, [&values, &offset](uint64_t left, uint64_t right) { + p.non_nulls_begin, p.non_nulls_end, + [&values, &offset](uint64_t left, uint64_t right) { const auto lhs = GetView::LogicalValue(values.GetView(left - offset)); const auto rhs = GetView::LogicalValue(values.GetView(right - offset)); // We don't use 'left > right' here to reduce required operator. @@ -423,7 +507,7 @@ class ArrayCompareSorter { return rhs < lhs; }); } - return nulls_begin; + return p; } }; @@ -443,9 +527,9 @@ class ArrayCountSorter { value_range_ = static_cast(max - min) + 1; } - // Returns where null starts. - uint64_t* Sort(uint64_t* indices_begin, uint64_t* indices_end, const ArrayType& values, - int64_t offset, const ArraySortOptions& options) { + NullPartitionResult Sort(uint64_t* indices_begin, uint64_t* indices_end, + const ArrayType& values, int64_t offset, + const ArraySortOptions& options) const { // 32bit counter performs much better than 64bit one if (values.length() < (1LL << 32)) { return SortInternal(indices_begin, indices_end, values, offset, options); @@ -458,45 +542,65 @@ class ArrayCountSorter { c_type min_{0}; uint32_t value_range_{0}; - // Returns where null starts. - // // `offset` is used when this is called on a chunk of a chunked array template - uint64_t* SortInternal(uint64_t* indices_begin, uint64_t* indices_end, - const ArrayType& values, int64_t offset, - const ArraySortOptions& options) { + NullPartitionResult SortInternal(uint64_t* indices_begin, uint64_t* indices_end, + const ArrayType& values, int64_t offset, + const ArraySortOptions& options) const { const uint32_t value_range = value_range_; - // first slot reserved for prefix sum - std::vector counts(1 + value_range); + // first and last slot reserved for prefix sum (depending on sort order) + std::vector counts(2 + value_range); + NullPartitionResult p; if (options.order == SortOrder::Ascending) { - VisitRawValuesInline( - values, [&](c_type v) { ++counts[v - min_ + 1]; }, []() {}); + // counts will be increasing, starting with 0 and ending with (length - null_count) + CountValues(values, &counts[1]); for (uint32_t i = 1; i <= value_range; ++i) { counts[i] += counts[i - 1]; } - auto null_position = counts[value_range]; - auto nulls_begin = indices_begin + null_position; - int64_t index = offset; - VisitRawValuesInline( - values, [&](c_type v) { indices_begin[counts[v - min_]++] = index++; }, - [&]() { indices_begin[null_position++] = index++; }); - return nulls_begin; + + if (options.null_placement == NullPlacement::AtStart) { + p = NullPartitionResult::NullsAtStart(indices_begin, indices_end, + indices_end - counts[value_range]); + } else { + p = NullPartitionResult::NullsAtEnd(indices_begin, indices_end, + indices_begin + counts[value_range]); + } + EmitIndices(p, values, offset, &counts[0]); } else { - VisitRawValuesInline( - values, [&](c_type v) { ++counts[v - min_]; }, []() {}); + // counts will be decreasing, starting with (length - null_count) and ending with 0 + CountValues(values, &counts[0]); for (uint32_t i = value_range; i >= 1; --i) { counts[i - 1] += counts[i]; } - auto null_position = counts[0]; - auto nulls_begin = indices_begin + null_position; - int64_t index = offset; - VisitRawValuesInline( - values, [&](c_type v) { indices_begin[counts[v - min_ + 1]++] = index++; }, - [&]() { indices_begin[null_position++] = index++; }); - return nulls_begin; + + if (options.null_placement == NullPlacement::AtStart) { + p = NullPartitionResult::NullsAtStart(indices_begin, indices_end, + indices_end - counts[0]); + } else { + p = NullPartitionResult::NullsAtEnd(indices_begin, indices_end, + indices_begin + counts[0]); + } + EmitIndices(p, values, offset, &counts[1]); } + return p; + } + + template + void CountValues(const ArrayType& values, CounterType* counts) const { + VisitRawValuesInline( + values, [&](c_type v) { ++counts[v - min_]; }, []() {}); + } + + template + void EmitIndices(const NullPartitionResult& p, const ArrayType& values, int64_t offset, + CounterType* counts) const { + int64_t index = offset; + CounterType count_nulls = 0; + VisitRawValuesInline( + values, [&](c_type v) { p.non_nulls_begin[counts[v - min_]++] = index++; }, + [&]() { p.nulls_begin[count_nulls++] = index++; }); } }; @@ -507,20 +611,24 @@ class ArrayCountSorter { public: ArrayCountSorter() = default; - // Returns where null starts. // `offset` is used when this is called on a chunk of a chunked array - uint64_t* Sort(uint64_t* indices_begin, uint64_t* indices_end, - const BooleanArray& values, int64_t offset, - const ArraySortOptions& options) { - std::array counts{0, 0}; + NullPartitionResult Sort(uint64_t* indices_begin, uint64_t* indices_end, + const BooleanArray& values, int64_t offset, + const ArraySortOptions& options) { + std::array counts{0, 0, 0}; // false, true, null const int64_t nulls = values.null_count(); const int64_t ones = values.true_count(); const int64_t zeros = values.length() - ones - nulls; - int64_t null_position = values.length() - nulls; - int64_t index = offset; - const auto nulls_begin = indices_begin + null_position; + NullPartitionResult p; + if (options.null_placement == NullPlacement::AtStart) { + p = NullPartitionResult::NullsAtStart(indices_begin, indices_end, + indices_begin + nulls); + } else { + p = NullPartitionResult::NullsAtEnd(indices_begin, indices_end, + indices_end - nulls); + } if (options.order == SortOrder::Ascending) { // ones start after zeros @@ -529,10 +637,12 @@ class ArrayCountSorter { // zeros start after ones counts[0] = ones; } + + int64_t index = offset; VisitRawValuesInline( - values, [&](bool v) { indices_begin[counts[v]++] = index++; }, - [&]() { indices_begin[null_position++] = index++; }); - return nulls_begin; + values, [&](bool v) { p.non_nulls_begin[counts[v]++] = index++; }, + [&]() { p.nulls_begin[counts[2]++] = index++; }); + return p; } }; @@ -545,11 +655,10 @@ class ArrayCountOrCompareSorter { using c_type = typename ArrowType::c_type; public: - // Returns where null starts. - // // `offset` is used when this is called on a chunk of a chunked array - uint64_t* Sort(uint64_t* indices_begin, uint64_t* indices_end, const ArrayType& values, - int64_t offset, const ArraySortOptions& options) { + NullPartitionResult Sort(uint64_t* indices_begin, uint64_t* indices_end, + const ArrayType& values, int64_t offset, + const ArraySortOptions& options) { if (values.length() >= countsort_min_len_ && values.length() > values.null_count()) { c_type min, max; std::tie(min, max) = GetMinMax(*values.data()); @@ -692,29 +801,31 @@ class ChunkedArrayCompareSorter { using ArrayType = typename TypeTraits::ArrayType; public: - // Returns where null starts. - uint64_t* Sort(uint64_t* indices_begin, uint64_t* indices_end, - const std::vector& arrays, int64_t null_count, - const ArraySortOptions& options) { - auto nulls_begin = PartitionNulls( - indices_begin, indices_end, arrays, null_count); + NullPartitionResult Sort(uint64_t* indices_begin, uint64_t* indices_end, + const std::vector& arrays, int64_t null_count, + const ArraySortOptions& options) { + const auto p = PartitionNulls( + indices_begin, indices_end, ChunkedArrayResolver(arrays), null_count, + options.null_placement); ChunkedArrayResolver resolver(arrays); if (options.order == SortOrder::Ascending) { - std::stable_sort(indices_begin, nulls_begin, [&](uint64_t left, uint64_t right) { - const auto chunk_left = resolver.Resolve(left); - const auto chunk_right = resolver.Resolve(right); - return chunk_left.Value() < chunk_right.Value(); - }); + std::stable_sort(p.non_nulls_begin, p.non_nulls_end, + [&](uint64_t left, uint64_t right) { + const auto chunk_left = resolver.Resolve(left); + const auto chunk_right = resolver.Resolve(right); + return chunk_left.Value() < chunk_right.Value(); + }); } else { - std::stable_sort(indices_begin, nulls_begin, [&](uint64_t left, uint64_t right) { - const auto chunk_left = resolver.Resolve(left); - const auto chunk_right = resolver.Resolve(right); - // We don't use 'left > right' here to reduce required operator. - // If we use 'right < left' here, '<' is only required. - return chunk_right.Value() < chunk_left.Value(); - }); + std::stable_sort(p.non_nulls_begin, p.non_nulls_end, + [&](uint64_t left, uint64_t right) { + const auto chunk_left = resolver.Resolve(left); + const auto chunk_right = resolver.Resolve(right); + // We don't use 'left > right' here to reduce required operator. + // If we use 'right < left' here, '<' is only required. + return chunk_right.Value() < chunk_left.Value(); + }); } - return nulls_begin; + return p; } }; @@ -727,7 +838,7 @@ class ChunkedArraySorter : public TypeVisitor { public: ChunkedArraySorter(ExecContext* ctx, uint64_t* indices_begin, uint64_t* indices_end, const ChunkedArray& chunked_array, const SortOrder order, - bool can_use_array_sorter = true) + const NullPlacement null_placement, bool can_use_array_sorter = true) : TypeVisitor(), indices_begin_(indices_begin), indices_end_(indices_end), @@ -735,6 +846,7 @@ class ChunkedArraySorter : public TypeVisitor { physical_type_(GetPhysicalType(chunked_array.type())), physical_chunks_(GetPhysicalChunks(chunked_array_, physical_type_)), order_(order), + null_placement_(null_placement), can_use_array_sorter_(can_use_array_sorter), ctx_(ctx) {} @@ -751,22 +863,18 @@ class ChunkedArraySorter : public TypeVisitor { template Status SortInternal() { using ArrayType = typename TypeTraits::ArrayType; - ArraySortOptions options(order_); + ArraySortOptions options(order_, null_placement_); const auto num_chunks = chunked_array_.num_chunks(); if (num_chunks == 0) { return Status::OK(); } const auto arrays = GetArrayPointers(physical_chunks_); + if (can_use_array_sorter_) { // Sort each chunk independently and merge to sorted indices. // This is a serial implementation. ArraySorter sorter; - struct SortedChunk { - int64_t begin_offset; - int64_t end_offset; - int64_t nulls_offset; - }; - std::vector sorted(num_chunks); + std::vector sorted(num_chunks); // First sort all individual chunks int64_t begin_offset = 0; @@ -776,10 +884,9 @@ class ChunkedArraySorter : public TypeVisitor { const auto array = checked_cast(arrays[i]); end_offset += array->length(); null_count += array->null_count(); - uint64_t* nulls_begin = + sorted[i] = sorter.impl.Sort(indices_begin_ + begin_offset, indices_begin_ + end_offset, *array, begin_offset, options); - sorted[i] = {begin_offset, end_offset, nulls_begin - indices_begin_}; begin_offset = end_offset; } DCHECK_EQ(end_offset, indices_end_ - indices_begin_); @@ -801,17 +908,10 @@ class ChunkedArraySorter : public TypeVisitor { while (it < sorted.end() - 1) { const auto& left = *it++; const auto& right = *it++; - DCHECK_EQ(left.end_offset, right.begin_offset); - DCHECK_GE(left.nulls_offset, left.begin_offset); - DCHECK_LE(left.nulls_offset, left.end_offset); - DCHECK_GE(right.nulls_offset, right.begin_offset); - DCHECK_LE(right.nulls_offset, right.end_offset); - uint64_t* nulls_begin = Merge( - indices_begin_ + left.begin_offset, indices_begin_ + left.end_offset, - indices_begin_ + right.end_offset, indices_begin_ + left.nulls_offset, - indices_begin_ + right.nulls_offset, arrays, null_count, order_, - temp_indices); - *out_it++ = {left.begin_offset, right.end_offset, nulls_begin - indices_begin_}; + DCHECK_EQ(left.overall_end(), right.overall_begin()); + const auto merged = + Merge(left, right, arrays, null_count, temp_indices); + *out_it++ = merged; } if (it < sorted.end()) { *out_it++ = *it++; @@ -819,10 +919,10 @@ class ChunkedArraySorter : public TypeVisitor { sorted.erase(out_it, sorted.end()); } DCHECK_EQ(sorted.size(), 1); - DCHECK_EQ(sorted[0].begin_offset, 0); - DCHECK_EQ(sorted[0].end_offset, chunked_array_.length()); + DCHECK_EQ(sorted[0].overall_begin(), indices_begin_); + DCHECK_EQ(sorted[0].overall_end(), indices_end_); // Note that "nulls" can also include NaNs, hence the >= check - DCHECK_GE(chunked_array_.length() - sorted[0].nulls_offset, null_count); + DCHECK_GE(sorted[0].null_count(), null_count); } else { // Sort the chunked array directory. ChunkedArrayCompareSorter sorter; @@ -832,50 +932,106 @@ class ChunkedArraySorter : public TypeVisitor { return Status::OK(); } - // Merges two sorted indices arrays and returns where nulls starts. - // Where nulls starts is used when the next merge to detect the - // sorted indices locations. + // Merge two adjacent sorted indices arrays template - uint64_t* Merge(uint64_t* indices_begin, uint64_t* indices_middle, - uint64_t* indices_end, uint64_t* left_nulls_begin, - uint64_t* right_nulls_begin, const std::vector& arrays, - int64_t null_count, const SortOrder order, uint64_t* temp_indices) { + NullPartitionResult Merge(const NullPartitionResult& left, + const NullPartitionResult& right, + const std::vector& arrays, int64_t null_count, + uint64_t* temp_indices) { + if (null_placement_ == NullPlacement::AtStart) { + return MergeNullsAtStart(left, right, arrays, null_count, temp_indices); + } else { + return MergeNullsAtEnd(left, right, arrays, null_count, temp_indices); + } + } + + template + NullPartitionResult MergeNullsAtStart(const NullPartitionResult& left, + const NullPartitionResult& right, + const std::vector& arrays, + int64_t null_count, uint64_t* temp_indices) { + // Input layout: + // [left nulls .... left non-nulls .... right nulls .... right non-nulls] + DCHECK_EQ(left.nulls_end, left.non_nulls_begin); + DCHECK_EQ(left.non_nulls_end, right.nulls_begin); + DCHECK_EQ(right.nulls_end, right.non_nulls_begin); + + // Mutate the input, stably, to obtain the following layout: + // [left nulls .... right nulls .... left non-nulls .... right non-nulls] + std::rotate(left.non_nulls_begin, right.nulls_begin, right.nulls_end); + + const auto p = NullPartitionResult::NullsAtStart( + left.nulls_begin, right.non_nulls_end, + left.nulls_begin + left.null_count() + right.null_count()); + + // If the type has null-like values (such as NaN), ensure those plus regular + // nulls are partitioned in the right order. Note this assumes that all + // null-like values (e.g. NaN) are ordered equally. + if (NullTraits::has_null_like_values) { + PartitionNullsOnly(p.nulls_begin, p.nulls_end, + ChunkedArrayResolver(arrays), null_count, + null_placement_); + } + + // Merge the non-null values into temp area + DCHECK_EQ(right.non_nulls_begin - p.non_nulls_begin, left.non_null_count()); + DCHECK_EQ(p.non_nulls_end - right.non_nulls_begin, right.non_null_count()); + MergeNonNulls(p.non_nulls_begin, right.non_nulls_begin, p.non_nulls_end, + arrays, temp_indices); + return p; + } + + template + NullPartitionResult MergeNullsAtEnd(const NullPartitionResult& left, + const NullPartitionResult& right, + const std::vector& arrays, + int64_t null_count, uint64_t* temp_indices) { // Input layout: // [left non-nulls .... left nulls .... right non-nulls .... right nulls] - // ^ ^ ^ ^ - // | | | | - // indices_begin left_nulls_begin indices_middle right_nulls_begin - auto left_num_non_nulls = left_nulls_begin - indices_begin; - auto right_num_non_nulls = right_nulls_begin - indices_middle; + DCHECK_EQ(left.non_nulls_end, left.nulls_begin); + DCHECK_EQ(left.nulls_end, right.non_nulls_begin); + DCHECK_EQ(right.non_nulls_end, right.nulls_begin); // Mutate the input, stably, to obtain the following layout: // [left non-nulls .... right non-nulls .... left nulls .... right nulls] - // ^ ^ ^ ^ - // | | | | - // indices_begin indices_middle nulls_begin right_nulls_begin - std::rotate(left_nulls_begin, indices_middle, right_nulls_begin); - auto nulls_begin = indices_begin + left_num_non_nulls + right_num_non_nulls; + std::rotate(left.nulls_begin, right.non_nulls_begin, right.non_nulls_end); + + const auto p = NullPartitionResult::NullsAtEnd( + left.non_nulls_begin, right.nulls_end, + left.non_nulls_begin + left.non_null_count() + right.non_null_count()); + // If the type has null-like values (such as NaN), ensure those plus regular // nulls are partitioned in the right order. Note this assumes that all // null-like values (e.g. NaN) are ordered equally. if (NullTraits::has_null_like_values) { - PartitionNullsOnly(nulls_begin, indices_end, arrays, null_count); + PartitionNullsOnly(p.nulls_begin, p.nulls_end, + ChunkedArrayResolver(arrays), null_count, + null_placement_); } // Merge the non-null values into temp area - indices_middle = indices_begin + left_num_non_nulls; - indices_end = indices_middle + right_num_non_nulls; + DCHECK_EQ(left.non_nulls_end - p.non_nulls_begin, left.non_null_count()); + DCHECK_EQ(p.non_nulls_end - left.non_nulls_end, right.non_null_count()); + MergeNonNulls(p.non_nulls_begin, left.non_nulls_end, p.non_nulls_end, + arrays, temp_indices); + return p; + } + + template + void MergeNonNulls(uint64_t* range_begin, uint64_t* range_middle, uint64_t* range_end, + const std::vector& arrays, uint64_t* temp_indices) { const ChunkedArrayResolver left_resolver(arrays); const ChunkedArrayResolver right_resolver(arrays); - if (order == SortOrder::Ascending) { - std::merge(indices_begin, indices_middle, indices_middle, indices_end, temp_indices, + + if (order_ == SortOrder::Ascending) { + std::merge(range_begin, range_middle, range_middle, range_end, temp_indices, [&](uint64_t left, uint64_t right) { const auto chunk_left = left_resolver.Resolve(left); const auto chunk_right = right_resolver.Resolve(right); return chunk_left.Value() < chunk_right.Value(); }); } else { - std::merge(indices_begin, indices_middle, indices_middle, indices_end, temp_indices, + std::merge(range_begin, range_middle, range_middle, range_end, temp_indices, [&](uint64_t left, uint64_t right) { const auto chunk_left = left_resolver.Resolve(left); const auto chunk_right = right_resolver.Resolve(right); @@ -886,8 +1042,7 @@ class ChunkedArraySorter : public TypeVisitor { }); } // Copy back temp area into main buffer - std::copy(temp_indices, temp_indices + (nulls_begin - indices_begin), indices_begin); - return nulls_begin; + std::copy(temp_indices, temp_indices + (range_end - range_begin), range_begin); } uint64_t* indices_begin_; @@ -896,6 +1051,7 @@ class ChunkedArraySorter : public TypeVisitor { const std::shared_ptr physical_type_; const ArrayVector physical_chunks_; const SortOrder order_; + const NullPlacement null_placement_; const bool can_use_array_sorter_; ExecContext* ctx_; }; @@ -949,43 +1105,45 @@ class ConcreteRecordBatchColumnSorter : public RecordBatchColumnSorter { using ArrayType = typename TypeTraits::ArrayType; ConcreteRecordBatchColumnSorter(std::shared_ptr array, SortOrder order, + NullPlacement null_placement, RecordBatchColumnSorter* next_column = nullptr) : RecordBatchColumnSorter(next_column), owned_array_(std::move(array)), array_(checked_cast(*owned_array_)), order_(order), + null_placement_(null_placement), null_count_(array_.null_count()) {} void SortRange(uint64_t* indices_begin, uint64_t* indices_end) { using GetView = GetViewType; constexpr int64_t offset = 0; - uint64_t* nulls_begin; + NullPartitionResult p; if (null_count_ == 0) { - nulls_begin = indices_end; + p = NullPartitionResult::NoNulls(indices_begin, indices_end, null_placement_); } else { // NOTE that null_count_ is merely an upper bound on the number of nulls // in this particular range. - nulls_begin = PartitionNullsOnly(indices_begin, indices_end, - array_, offset); - DCHECK_LE(indices_end - nulls_begin, null_count_); + p = PartitionNullsOnly(indices_begin, indices_end, array_, + offset, null_placement_); + DCHECK_LE(p.nulls_end - p.nulls_begin, null_count_); } - uint64_t* null_likes_begin = PartitionNullLikes( - indices_begin, nulls_begin, array_, offset); + const NullPartitionResult q = PartitionNullLikes( + p.non_nulls_begin, p.non_nulls_end, array_, offset, null_placement_); // TODO This is roughly the same as ArrayCompareSorter. // Also, we would like to use a counting sort if possible. This requires // a counting sort compatible with indirect indexing. if (order_ == SortOrder::Ascending) { std::stable_sort( - indices_begin, null_likes_begin, [&](uint64_t left, uint64_t right) { + q.non_nulls_begin, q.non_nulls_end, [&](uint64_t left, uint64_t right) { const auto lhs = GetView::LogicalValue(array_.GetView(left - offset)); const auto rhs = GetView::LogicalValue(array_.GetView(right - offset)); return lhs < rhs; }); } else { std::stable_sort( - indices_begin, null_likes_begin, [&](uint64_t left, uint64_t right) { + q.non_nulls_begin, q.non_nulls_end, [&](uint64_t left, uint64_t right) { // We don't use 'left > right' here to reduce required operator. // If we use 'right < left' here, '<' is only required. const auto lhs = GetView::LogicalValue(array_.GetView(left - offset)); @@ -997,9 +1155,9 @@ class ConcreteRecordBatchColumnSorter : public RecordBatchColumnSorter { if (next_column_ != nullptr) { // Visit all ranges of equal values in this column and sort them on // the next column. - SortNextColumn(null_likes_begin, nulls_begin); - SortNextColumn(nulls_begin, indices_end); - VisitConstantRanges(array_, indices_begin, null_likes_begin, + SortNextColumn(q.nulls_begin, q.nulls_end); + SortNextColumn(p.nulls_begin, p.nulls_end); + VisitConstantRanges(array_, q.non_nulls_begin, q.non_nulls_end, [&](uint64_t* range_start, uint64_t* range_end) { SortNextColumn(range_start, range_end); }); @@ -1017,6 +1175,7 @@ class ConcreteRecordBatchColumnSorter : public RecordBatchColumnSorter { const std::shared_ptr owned_array_; const ArrayType& array_; const SortOrder order_; + const NullPlacement null_placement_; const int64_t null_count_; }; @@ -1038,7 +1197,7 @@ class RadixRecordBatchSorter { std::vector> column_sorts(sort_keys.size()); RecordBatchColumnSorter* next_column = nullptr; for (int64_t i = static_cast(sort_keys.size() - 1); i >= 0; --i) { - ColumnSortFactory factory(sort_keys[i], next_column); + ColumnSortFactory factory(sort_keys[i], options_, next_column); ARROW_ASSIGN_OR_RAISE(column_sorts[i], factory.MakeColumnSort()); next_column = column_sorts[i].get(); } @@ -1055,11 +1214,12 @@ class RadixRecordBatchSorter { }; struct ColumnSortFactory { - ColumnSortFactory(const ResolvedSortKey& sort_key, + ColumnSortFactory(const ResolvedSortKey& sort_key, const SortOptions& options, RecordBatchColumnSorter* next_column) : physical_type(GetPhysicalType(sort_key.array->type())), array(GetPhysicalArray(*sort_key.array, physical_type)), order(sort_key.order), + null_placement(options.null_placement), next_column(next_column) {} Result> MakeColumnSort() { @@ -1082,13 +1242,15 @@ class RadixRecordBatchSorter { template Status VisitGeneric(const Type&) { - result.reset(new ConcreteRecordBatchColumnSorter(array, order, next_column)); + result.reset(new ConcreteRecordBatchColumnSorter(array, order, null_placement, + next_column)); return Status::OK(); } std::shared_ptr physical_type; std::shared_ptr array; SortOrder order; + NullPlacement null_placement; RecordBatchColumnSorter* next_column; std::unique_ptr result; }; @@ -1118,8 +1280,9 @@ class RadixRecordBatchSorter { template class MultipleKeyComparator { public: - explicit MultipleKeyComparator(const std::vector& sort_keys) - : sort_keys_(sort_keys) {} + MultipleKeyComparator(const std::vector& sort_keys, + NullPlacement null_placement) + : sort_keys_(sort_keys), null_placement_(null_placement) {} Status status() const { return status_; } @@ -1192,14 +1355,14 @@ class MultipleKeyComparator { const auto chunk_left = sort_key.template GetChunk(current_left_); const auto chunk_right = sort_key.template GetChunk(current_right_); if (sort_key.null_count > 0) { - auto is_null_left = chunk_left.IsNull(); - auto is_null_right = chunk_right.IsNull(); + const bool is_null_left = chunk_left.IsNull(); + const bool is_null_right = chunk_right.IsNull(); if (is_null_left && is_null_right) { return 0; } else if (is_null_left) { - return 1; + return null_placement_ == NullPlacement::AtStart ? -1 : 1; } else if (is_null_right) { - return -1; + return null_placement_ == NullPlacement::AtStart ? 1 : -1; } } return CompareTypeValue(chunk_left, chunk_right, order); @@ -1235,14 +1398,14 @@ class MultipleKeyComparator { const SortOrder order) { const auto left = chunk_left.Value(); const auto right = chunk_right.Value(); - auto is_nan_left = std::isnan(left); - auto is_nan_right = std::isnan(right); + const bool is_nan_left = std::isnan(left); + const bool is_nan_right = std::isnan(right); if (is_nan_left && is_nan_right) { return 0; } else if (is_nan_left) { - return 1; + return null_placement_ == NullPlacement::AtStart ? -1 : 1; } else if (is_nan_right) { - return -1; + return null_placement_ == NullPlacement::AtStart ? 1 : -1; } int32_t compared; if (left == right) { @@ -1259,6 +1422,7 @@ class MultipleKeyComparator { } const std::vector& sort_keys_; + const NullPlacement null_placement_; Status status_; int64_t current_left_; int64_t current_right_; @@ -1271,7 +1435,7 @@ class MultipleKeyRecordBatchSorter : public TypeVisitor { public: // Preprocessed sort key. struct ResolvedSortKey { - ResolvedSortKey(const std::shared_ptr& array, const SortOrder order) + ResolvedSortKey(const std::shared_ptr& array, SortOrder order) : type(GetPhysicalType(array->type())), owned_array(GetPhysicalArray(*array, type)), array(*owned_array), @@ -1299,7 +1463,8 @@ class MultipleKeyRecordBatchSorter : public TypeVisitor { : indices_begin_(indices_begin), indices_end_(indices_end), sort_keys_(ResolveSortKeys(batch, options.sort_keys, &status_)), - comparator_(sort_keys_) {} + null_placement_(options.null_placement), + comparator_(sort_keys_, null_placement_) {} // This is optimized for the first sort key. The first sort key sort // is processed in this class. The second and following sort keys @@ -1334,106 +1499,74 @@ class MultipleKeyRecordBatchSorter : public TypeVisitor { template Status SortInternal() { using ArrayType = typename TypeTraits::ArrayType; + using GetView = GetViewType; auto& comparator = comparator_; const auto& first_sort_key = sort_keys_[0]; const ArrayType& array = checked_cast(first_sort_key.array); - auto nulls_begin = indices_end_; - nulls_begin = PartitionNullsInternal(first_sort_key); + const auto p = PartitionNullsInternal(first_sort_key); + // Sort first-key non-nulls - std::stable_sort(indices_begin_, nulls_begin, [&](uint64_t left, uint64_t right) { - // Both values are never null nor NaN - // (otherwise they've been partitioned away above). - const auto value_left = array.GetView(left); - const auto value_right = array.GetView(right); - if (value_left != value_right) { - bool compared = value_left < value_right; - if (first_sort_key.order == SortOrder::Ascending) { - return compared; - } else { - return !compared; - } - } - // If the left value equals to the right value, - // we need to compare the second and following - // sort keys. - return comparator.Compare(left, right, 1); - }); + std::stable_sort( + p.non_nulls_begin, p.non_nulls_end, [&](uint64_t left, uint64_t right) { + // Both values are never null nor NaN + // (otherwise they've been partitioned away above). + const auto value_left = GetView::LogicalValue(array.GetView(left)); + const auto value_right = GetView::LogicalValue(array.GetView(right)); + if (value_left != value_right) { + bool compared = value_left < value_right; + if (first_sort_key.order == SortOrder::Ascending) { + return compared; + } else { + return !compared; + } + } + // If the left value equals to the right value, + // we need to compare the second and following + // sort keys. + return comparator.Compare(left, right, 1); + }); return comparator_.status(); } - // Behaves like PatitionNulls() but this supports multiple sort keys. - // - // For non-float types. + // Behaves like PartitionNulls() but this supports multiple sort keys. template - enable_if_t::value, uint64_t*> PartitionNullsInternal( - const ResolvedSortKey& first_sort_key) { + NullPartitionResult PartitionNullsInternal(const ResolvedSortKey& first_sort_key) { using ArrayType = typename TypeTraits::ArrayType; - if (first_sort_key.null_count == 0) { - return indices_end_; - } const ArrayType& array = checked_cast(first_sort_key.array); - StablePartitioner partitioner; - auto nulls_begin = partitioner(indices_begin_, indices_end_, - [&](uint64_t index) { return !array.IsNull(index); }); - // Sort all nulls by second and following sort keys - // TODO: could we instead run an independent sort from the second key on - // this slice? - if (nulls_begin != indices_end_) { - auto& comparator = comparator_; - std::stable_sort(nulls_begin, indices_end_, - [&comparator](uint64_t left, uint64_t right) { - return comparator.Compare(left, right, 1); - }); - } - return nulls_begin; - } - // Behaves like PatitionNulls() but this supports multiple sort keys. - // - // For float types. - template - enable_if_t::value, uint64_t*> PartitionNullsInternal( - const ResolvedSortKey& first_sort_key) { - using ArrayType = typename TypeTraits::ArrayType; - const ArrayType& array = checked_cast(first_sort_key.array); - StablePartitioner partitioner; - uint64_t* nulls_begin; - if (first_sort_key.null_count == 0) { - nulls_begin = indices_end_; - } else { - nulls_begin = partitioner(indices_begin_, indices_end_, - [&](uint64_t index) { return !array.IsNull(index); }); - } - uint64_t* nans_and_nulls_begin = - partitioner(indices_begin_, nulls_begin, - [&](uint64_t index) { return !std::isnan(array.GetView(index)); }); + const auto p = PartitionNullsOnly(indices_begin_, indices_end_, + array, 0, null_placement_); + const auto q = PartitionNullLikes( + p.non_nulls_begin, p.non_nulls_end, array, 0, null_placement_); + auto& comparator = comparator_; - if (nans_and_nulls_begin != nulls_begin) { + if (q.nulls_begin != q.nulls_end) { // Sort all NaNs by the second and following sort keys. // TODO: could we instead run an independent sort from the second key on // this slice? - std::stable_sort(nans_and_nulls_begin, nulls_begin, + std::stable_sort(q.nulls_begin, q.nulls_end, [&comparator](uint64_t left, uint64_t right) { return comparator.Compare(left, right, 1); }); } - if (nulls_begin != indices_end_) { + if (p.nulls_begin != p.nulls_end) { // Sort all nulls by the second and following sort keys. // TODO: could we instead run an independent sort from the second key on // this slice? - std::stable_sort(nulls_begin, indices_end_, + std::stable_sort(p.nulls_begin, p.nulls_end, [&comparator](uint64_t left, uint64_t right) { return comparator.Compare(left, right, 1); }); } - return nans_and_nulls_begin; + return q; } uint64_t* indices_begin_; uint64_t* indices_end_; Status status_; std::vector sort_keys_; + NullPlacement null_placement_; Comparator comparator_; }; @@ -1457,7 +1590,8 @@ class TableRadixSorter { // existing indices. const auto can_use_array_sorter = (i == 0); ChunkedArraySorter sorter(ctx, indices_begin, indices_end, *chunked_array.get(), - sort_key.order, can_use_array_sorter); + sort_key.order, options.null_placement, + can_use_array_sorter); ARROW_RETURN_NOT_OK(sorter.Sort()); } return Status::OK(); @@ -1506,7 +1640,8 @@ class MultipleKeyTableSorter : public TypeVisitor { : indices_begin_(indices_begin), indices_end_(indices_end), sort_keys_(ResolveSortKeys(table, options.sort_keys, &status_)), - comparator_(sort_keys_) {} + null_placement_(options.null_placement), + comparator_(sort_keys_, null_placement_) {} // This is optimized for the first sort key. The first sort key sort // is processed in this class. The second and following sort keys @@ -1545,93 +1680,64 @@ class MultipleKeyTableSorter : public TypeVisitor { auto& comparator = comparator_; const auto& first_sort_key = sort_keys_[0]; - auto nulls_begin = indices_end_; - nulls_begin = PartitionNullsInternal(first_sort_key); - std::stable_sort(indices_begin_, nulls_begin, [&](uint64_t left, uint64_t right) { - // Both values are never null nor NaN. - auto chunk_left = first_sort_key.GetChunk(left); - auto chunk_right = first_sort_key.GetChunk(right); - auto value_left = chunk_left.Value(); - auto value_right = chunk_right.Value(); - if (value_left == value_right) { - // If the left value equals to the right value, - // we need to compare the second and following - // sort keys. - return comparator.Compare(left, right, 1); - } else { - auto compared = value_left < value_right; - if (first_sort_key.order == SortOrder::Ascending) { - return compared; - } else { - return !compared; - } - } - }); + const auto p = PartitionNullsInternal(first_sort_key); + + std::stable_sort(p.non_nulls_begin, p.non_nulls_end, + [&](uint64_t left, uint64_t right) { + // Both values are never null nor NaN. + auto chunk_left = first_sort_key.GetChunk(left); + auto chunk_right = first_sort_key.GetChunk(right); + auto value_left = chunk_left.Value(); + auto value_right = chunk_right.Value(); + if (value_left == value_right) { + // If the left value equals to the right value, + // we need to compare the second and following + // sort keys. + return comparator.Compare(left, right, 1); + } else { + auto compared = value_left < value_right; + if (first_sort_key.order == SortOrder::Ascending) { + return compared; + } else { + return !compared; + } + } + }); return comparator_.status(); } // Behaves like PatitionNulls() but this supports multiple sort keys. // - // For non-float types. template - enable_if_t::value, uint64_t*> PartitionNullsInternal( - const ResolvedSortKey& first_sort_key) { + NullPartitionResult PartitionNullsInternal(const ResolvedSortKey& first_sort_key) { using ArrayType = typename TypeTraits::ArrayType; - if (first_sort_key.null_count == 0) { - return indices_end_; - } - StablePartitioner partitioner; - auto nulls_begin = - partitioner(indices_begin_, indices_end_, [&first_sort_key](uint64_t index) { - const auto chunk = first_sort_key.GetChunk(index); - return !chunk.IsNull(); - }); - DCHECK_EQ(indices_end_ - nulls_begin, first_sort_key.null_count); - auto& comparator = comparator_; - std::stable_sort(nulls_begin, indices_end_, [&](uint64_t left, uint64_t right) { - return comparator.Compare(left, right, 1); - }); - return nulls_begin; - } - // Behaves like PatitionNulls() but this supports multiple sort keys. - // - // For float types. - template - enable_if_t::value, uint64_t*> PartitionNullsInternal( - const ResolvedSortKey& first_sort_key) { - using ArrayType = typename TypeTraits::ArrayType; - StablePartitioner partitioner; - uint64_t* nulls_begin; - if (first_sort_key.null_count == 0) { - nulls_begin = indices_end_; - } else { - nulls_begin = partitioner(indices_begin_, indices_end_, [&](uint64_t index) { - const auto chunk = first_sort_key.GetChunk(index); - return !chunk.IsNull(); - }); - } - DCHECK_EQ(indices_end_ - nulls_begin, first_sort_key.null_count); - uint64_t* nans_begin = partitioner(indices_begin_, nulls_begin, [&](uint64_t index) { - const auto chunk = first_sort_key.GetChunk(index); - return !std::isnan(chunk.Value()); - }); + const auto p = PartitionNullsOnly( + indices_begin_, indices_end_, first_sort_key.resolver, first_sort_key.null_count, + null_placement_); + DCHECK_EQ(p.nulls_end - p.nulls_begin, first_sort_key.null_count); + + const auto q = PartitionNullLikes( + p.non_nulls_begin, p.non_nulls_end, first_sort_key.resolver, null_placement_); + auto& comparator = comparator_; // Sort all NaNs by the second and following sort keys. - std::stable_sort(nans_begin, nulls_begin, [&](uint64_t left, uint64_t right) { + std::stable_sort(q.nulls_begin, q.nulls_end, [&](uint64_t left, uint64_t right) { return comparator.Compare(left, right, 1); }); // Sort all nulls by the second and following sort keys. - std::stable_sort(nulls_begin, indices_end_, [&](uint64_t left, uint64_t right) { + std::stable_sort(p.nulls_begin, p.nulls_end, [&](uint64_t left, uint64_t right) { return comparator.Compare(left, right, 1); }); - return nans_begin; + + return q; } uint64_t* indices_begin_; uint64_t* indices_end_; Status status_; std::vector sort_keys_; + NullPlacement null_placement_; Comparator comparator_; }; @@ -1643,10 +1749,12 @@ const auto kDefaultSortOptions = SortOptions::Defaults(); const FunctionDoc sort_indices_doc( "Return the indices that would sort an array, record batch or table", ("This function computes an array of indices that define a stable sort\n" - "of the input array, record batch or table. Null values are considered\n" - "greater than any other value and are therefore sorted at the end of the\n" - "input. For floating-point types, NaNs are considered greater than any\n" - "other non-null value, but smaller than null values."), + "of the input array, record batch or table. By default, nNull values are\n" + "considered greater than any other value and are therefore sorted at the\n" + "end of the input. For floating-point types, NaNs are considered greater\n" + "than any other non-null value, but smaller than null values.\n" + "\n" + "The handling of nulls and NaNs can be changed in SortOptions."), {"input"}, "SortOptions"); class SortIndicesMetaFunction : public MetaFunction { @@ -1688,7 +1796,7 @@ class SortIndicesMetaFunction : public MetaFunction { if (!options.sort_keys.empty()) { order = options.sort_keys[0].order; } - ArraySortOptions array_options(order); + ArraySortOptions array_options(order, options.null_placement); return CallFunction("array_sort_indices", {values}, &array_options, ctx); } @@ -1711,7 +1819,8 @@ class SortIndicesMetaFunction : public MetaFunction { auto out_end = out_begin + length; std::iota(out_begin, out_end, 0); - ChunkedArraySorter sorter(ctx, out_begin, out_end, chunked_array, order); + ChunkedArraySorter sorter(ctx, out_begin, out_end, chunked_array, order, + options.null_placement); ARROW_RETURN_NOT_OK(sorter.Sort()); return Datum(out); } @@ -1885,8 +1994,11 @@ class ArraySelecter : public TypeVisitor { if (k_ > arr.length()) { k_ = arr.length(); } - auto end_iter = PartitionNulls(indices_begin, - indices_end, arr, 0); + + const auto p = PartitionNulls( + indices_begin, indices_end, arr, 0, NullPlacement::AtEnd); + const auto end_iter = p.non_nulls_end; + auto kth_begin = std::min(indices_begin + k_, end_iter); SelectKComparator comparator; @@ -1997,8 +2109,10 @@ class ChunkedArraySelecter : public TypeVisitor { uint64_t* indices_end = indices_begin + indices.size(); std::iota(indices_begin, indices_end, 0); - auto end_iter = PartitionNulls( - indices_begin, indices_end, arr, 0); + const auto p = PartitionNulls( + indices_begin, indices_end, arr, 0, NullPlacement::AtEnd); + const auto end_iter = p.non_nulls_end; + auto kth_begin = std::min(indices_begin + k_, end_iter); uint64_t* iter = indices_begin; for (; iter != kth_begin && heap.size() < static_cast(k_); ++iter) { @@ -2055,7 +2169,7 @@ class RecordBatchSelecter : public TypeVisitor { k_(options.k), output_(output), sort_keys_(ResolveSortKeys(record_batch, options.sort_keys)), - comparator_(sort_keys_) {} + comparator_(sort_keys_, NullPlacement::AtEnd) {} Status Run() { return sort_keys_[0].type->Accept(this); } @@ -2115,8 +2229,10 @@ class RecordBatchSelecter : public TypeVisitor { uint64_t* indices_end = indices_begin + indices.size(); std::iota(indices_begin, indices_end, 0); - auto end_iter = PartitionNulls(indices_begin, - indices_end, arr, 0); + const auto p = PartitionNulls( + indices_begin, indices_end, arr, 0, NullPlacement::AtEnd); + const auto end_iter = p.non_nulls_end; + auto kth_begin = std::min(indices_begin + k_, end_iter); HeapContainer heap(indices_begin, kth_begin, cmp); @@ -2163,7 +2279,7 @@ class TableSelecter : public TypeVisitor { k_(options.k), output_(output), sort_keys_(ResolveSortKeys(table, options.sort_keys)), - comparator_(sort_keys_) {} + comparator_(sort_keys_, NullPlacement::AtEnd) {} Status Run() { return sort_keys_[0].type->Accept(this); } @@ -2188,65 +2304,33 @@ class TableSelecter : public TypeVisitor { return resolved; } - // Behaves like PatitionNulls() but this supports multiple sort keys. - // - // For non-float types. + // Behaves like PartitionNulls() but this supports multiple sort keys. template - enable_if_t::value, uint64_t*> PartitionNullsInternal( - uint64_t* indices_begin, uint64_t* indices_end, - const ResolvedSortKey& first_sort_key) { + NullPartitionResult PartitionNullsInternal(uint64_t* indices_begin, + uint64_t* indices_end, + const ResolvedSortKey& first_sort_key) { using ArrayType = typename TypeTraits::ArrayType; - if (first_sort_key.null_count == 0) { - return indices_end; - } - StablePartitioner partitioner; - auto nulls_begin = - partitioner(indices_begin, indices_end, [&first_sort_key](uint64_t index) { - const auto chunk = - first_sort_key.GetChunk(static_cast(index)); - return !chunk.IsNull(); - }); - DCHECK_EQ(indices_end - nulls_begin, first_sort_key.null_count); - auto& comparator = comparator_; - std::stable_sort(nulls_begin, indices_end, [&](uint64_t left, uint64_t right) { - return comparator.Compare(left, right, 1); - }); - return nulls_begin; - } - // Behaves like PatitionNulls() but this supports multiple sort keys. - // - // For float types. - template - enable_if_t::value, uint64_t*> PartitionNullsInternal( - uint64_t* indices_begin, uint64_t* indices_end, - const ResolvedSortKey& first_sort_key) { - using ArrayType = typename TypeTraits::ArrayType; - StablePartitioner partitioner; - uint64_t* nulls_begin; - if (first_sort_key.null_count == 0) { - nulls_begin = indices_end; - } else { - nulls_begin = partitioner(indices_begin, indices_end, [&](uint64_t index) { - const auto chunk = first_sort_key.GetChunk(index); - return !chunk.IsNull(); - }); - } - DCHECK_EQ(indices_end - nulls_begin, first_sort_key.null_count); - uint64_t* nans_begin = partitioner(indices_begin, nulls_begin, [&](uint64_t index) { - const auto chunk = first_sort_key.GetChunk(index); - return !std::isnan(chunk.Value()); - }); + const auto p = PartitionNullsOnly( + indices_begin, indices_end, first_sort_key.resolver, first_sort_key.null_count, + NullPlacement::AtEnd); + DCHECK_EQ(p.nulls_end - p.nulls_begin, first_sort_key.null_count); + + const auto q = PartitionNullLikes( + p.non_nulls_begin, p.non_nulls_end, first_sort_key.resolver, + NullPlacement::AtEnd); + auto& comparator = comparator_; // Sort all NaNs by the second and following sort keys. - std::stable_sort(nans_begin, nulls_begin, [&](uint64_t left, uint64_t right) { + std::stable_sort(q.nulls_begin, q.nulls_end, [&](uint64_t left, uint64_t right) { return comparator.Compare(left, right, 1); }); // Sort all nulls by the second and following sort keys. - std::stable_sort(nulls_begin, indices_end, [&](uint64_t left, uint64_t right) { + std::stable_sort(p.nulls_begin, p.nulls_end, [&](uint64_t left, uint64_t right) { return comparator.Compare(left, right, 1); }); - return nans_begin; + + return q; } template @@ -2282,8 +2366,9 @@ class TableSelecter : public TypeVisitor { uint64_t* indices_end = indices_begin + indices.size(); std::iota(indices_begin, indices_end, 0); - auto end_iter = + const auto p = this->PartitionNullsInternal(indices_begin, indices_end, first_sort_key); + const auto end_iter = p.non_nulls_end; auto kth_begin = std::min(indices_begin + k_, end_iter); HeapContainer heap(indices_begin, kth_begin, cmp); @@ -2405,10 +2490,12 @@ const auto kDefaultArraySortOptions = ArraySortOptions::Defaults(); const FunctionDoc array_sort_indices_doc( "Return the indices that would sort an array", ("This function computes an array of indices that define a stable sort\n" - "of the input array. Null values are considered greater than any\n" - "other value and are therefore sorted at the end of the array.\n" + "of the input array. By default, Null values are considered greater\n" + "than any other value and are therefore sorted at the end of the array.\n" "For floating-point types, NaNs are considered greater than any\n" - "other non-null value, but smaller than null values."), + "other non-null value, but smaller than null values.\n" + "\n" + "The handling of nulls and NaNs can be changed in ArraySortOptions."), {"array"}, "ArraySortOptions"); const FunctionDoc partition_nth_indices_doc( @@ -2420,12 +2507,13 @@ const FunctionDoc partition_nth_indices_doc( "of the input in sorted order, and all indices before the `N`'th point\n" "to elements in the input less or equal to elements at or after the `N`'th.\n" "\n" - "Null values are considered greater than any other value and are\n" - "therefore partitioned towards the end of the array.\n" + "By default, null values are considered greater than any other value\n" + "and are therefore partitioned towards the end of the array.\n" "For floating-point types, NaNs are considered greater than any\n" "other non-null value, but smaller than null values.\n" "\n" - "The pivot index `N` must be given in PartitionNthOptions."), + "The pivot index `N` must be given in PartitionNthOptions.\n" + "The handling of nulls and NaNs can also be changed in PartitionNthOptions."), {"array"}, "PartitionNthOptions"); } // namespace diff --git a/cpp/src/arrow/compute/kernels/vector_sort_test.cc b/cpp/src/arrow/compute/kernels/vector_sort_test.cc index 54ac6e47485..d2206d4ebf0 100644 --- a/cpp/src/arrow/compute/kernels/vector_sort_test.cc +++ b/cpp/src/arrow/compute/kernels/vector_sort_test.cc @@ -15,9 +15,12 @@ // specific language governing permissions and limitations // under the License. +#include #include #include #include +#include +#include #include #include @@ -25,12 +28,14 @@ #include "arrow/array/concatenate.h" #include "arrow/compute/api_vector.h" #include "arrow/compute/kernels/test_util.h" +#include "arrow/result.h" #include "arrow/table.h" #include "arrow/testing/gtest_common.h" #include "arrow/testing/gtest_util.h" #include "arrow/testing/random.h" #include "arrow/testing/util.h" #include "arrow/type_traits.h" +#include "arrow/util/logging.h" namespace arrow { @@ -38,6 +43,20 @@ using internal::checked_cast; using internal::checked_pointer_cast; namespace compute { + +std::vector AllOrders() { + return {SortOrder::Ascending, SortOrder::Descending}; +} + +std::vector AllNullPlacements() { + return {NullPlacement::AtEnd, NullPlacement::AtStart}; +} + +std::ostream& operator<<(std::ostream& os, NullPlacement null_placement) { + os << (null_placement == NullPlacement::AtEnd ? "AtEnd" : "AtStart"); + return os; +} + // ---------------------------------------------------------------------- // Tests for NthToIndices @@ -56,59 +75,89 @@ Decimal256 GetLogicalValue(const Decimal256Array& array, uint64_t index) { } template -class NthComparator { - public: - bool operator()(const ArrayType& array, uint64_t lhs, uint64_t rhs) { - if (array.IsNull(rhs)) return true; - if (array.IsNull(lhs)) return false; - const auto lval = GetLogicalValue(array, lhs); - const auto rval = GetLogicalValue(array, rhs); - if (is_floating_type::value) { - // NaNs ordered after non-NaNs - if (rval != rval) return true; - if (lval != lval) return false; - } - return lval <= rval; +struct ThreeWayComparator { + SortOrder order; + NullPlacement null_placement; + + int operator()(const ArrayType& array, uint64_t lhs, uint64_t rhs) const { + return (*this)(array, array, lhs, rhs); } -}; -template -class SortComparator { - public: - bool operator()(const ArrayType& array, SortOrder order, uint64_t lhs, uint64_t rhs) { - if (array.IsNull(rhs) && array.IsNull(lhs)) return lhs < rhs; - if (array.IsNull(rhs)) return true; - if (array.IsNull(lhs)) return false; - const auto lval = GetLogicalValue(array, lhs); - const auto rval = GetLogicalValue(array, rhs); + // Return -1 if L < R, 0 if L == R, 1 if L > R + int operator()(const ArrayType& left, const ArrayType& right, uint64_t lhs, + uint64_t rhs) const { + const bool lhs_is_null = left.IsNull(lhs); + const bool rhs_is_null = right.IsNull(rhs); + if (lhs_is_null && rhs_is_null) return 0; + if (lhs_is_null) { + return null_placement == NullPlacement::AtStart ? -1 : 1; + } + if (rhs_is_null) { + return null_placement == NullPlacement::AtStart ? 1 : -1; + } + const auto lval = GetLogicalValue(left, lhs); + const auto rval = GetLogicalValue(right, rhs); if (is_floating_type::value) { const bool lhs_isnan = lval != lval; const bool rhs_isnan = rval != rval; - if (lhs_isnan && rhs_isnan) return lhs < rhs; - if (rhs_isnan) return true; - if (lhs_isnan) return false; + if (lhs_isnan && rhs_isnan) return 0; + if (lhs_isnan) { + return null_placement == NullPlacement::AtStart ? -1 : 1; + } + if (rhs_isnan) { + return null_placement == NullPlacement::AtStart ? 1 : -1; + } } - if (lval == rval) return lhs < rhs; - if (order == SortOrder::Ascending) { - return lval < rval; + if (lval == rval) return 0; + if (lval < rval) { + return order == SortOrder::Ascending ? -1 : 1; } else { - return lval > rval; + return order == SortOrder::Ascending ? 1 : -1; } } }; +template +struct NthComparator { + ThreeWayComparator three_way; + + explicit NthComparator(NullPlacement null_placement) + : three_way({SortOrder::Ascending, null_placement}) {} + + // Return true iff L <= R + bool operator()(const ArrayType& array, uint64_t lhs, uint64_t rhs) const { + // lhs <= rhs + return three_way(array, lhs, rhs) <= 0; + } +}; + +template +struct SortComparator { + ThreeWayComparator three_way; + + explicit SortComparator(SortOrder order, NullPlacement null_placement) + : three_way({order, null_placement}) {} + + bool operator()(const ArrayType& array, uint64_t lhs, uint64_t rhs) const { + const int r = three_way(array, lhs, rhs); + if (r != 0) return r < 0; + return lhs < rhs; + } +}; + template class TestNthToIndicesBase : public TestBase { using ArrayType = typename TypeTraits::ArrayType; protected: - void Validate(const ArrayType& array, int n, UInt64Array& offsets) { + void Validate(const ArrayType& array, int n, NullPlacement null_placement, + UInt64Array& offsets) { if (n >= array.length()) { for (int i = 0; i < array.length(); ++i) { - ASSERT_TRUE(offsets.Value(i) == (uint64_t)i); + ASSERT_TRUE(offsets.Value(i) == static_cast(i)); } } else { - NthComparator compare; + NthComparator compare{null_placement}; uint64_t nth = offsets.Value(n); for (int i = 0; i < n; ++i) { @@ -122,15 +171,24 @@ class TestNthToIndicesBase : public TestBase { } } - void AssertNthToIndicesArray(const std::shared_ptr values, int n) { - ASSERT_OK_AND_ASSIGN(std::shared_ptr offsets, NthToIndices(*values, n)); + void AssertNthToIndicesArray(const std::shared_ptr& values, int n, + NullPlacement null_placement) { + ARROW_SCOPED_TRACE("n = ", n, ", null_placement = ", null_placement); + ASSERT_OK_AND_ASSIGN(std::shared_ptr offsets, + NthToIndices(*values, PartitionNthOptions(n, null_placement))); // null_count field should have been initialized to 0, for convenience ASSERT_EQ(offsets->data()->null_count, 0); ValidateOutput(*offsets); - Validate(*checked_pointer_cast(values), n, + Validate(*checked_pointer_cast(values), n, null_placement, *checked_pointer_cast(offsets)); } + void AssertNthToIndicesArray(const std::shared_ptr& values, int n) { + for (auto null_placement : AllNullPlacements()) { + AssertNthToIndicesArray(values, n, null_placement); + } + } + void AssertNthToIndicesJson(const std::string& values, int n) { AssertNthToIndicesArray(ArrayFromJSON(GetType(), values), n); } @@ -192,6 +250,12 @@ TYPED_TEST(TestNthToIndicesForReal, Real) { this->AssertNthToIndicesJson("[null, 2, NaN, 3, 1]", 4); this->AssertNthToIndicesJson("[NaN, 2, null, 3, 1]", 3); this->AssertNthToIndicesJson("[NaN, 2, null, 3, 1]", 4); + + this->AssertNthToIndicesJson("[NaN, 2, NaN, 3, 1]", 0); + this->AssertNthToIndicesJson("[NaN, 2, NaN, 3, 1]", 1); + this->AssertNthToIndicesJson("[NaN, 2, NaN, 3, 1]", 2); + this->AssertNthToIndicesJson("[NaN, 2, NaN, 3, 1]", 3); + this->AssertNthToIndicesJson("[NaN, 2, NaN, 3, 1]", 4); } TYPED_TEST(TestNthToIndicesForIntegral, Integral) { @@ -263,8 +327,10 @@ TYPED_TEST(TestNthToIndicesRandom, RandomValues) { template void AssertSortIndices(const std::shared_ptr& input, SortOrder order, + NullPlacement null_placement, const std::shared_ptr& expected) { - ASSERT_OK_AND_ASSIGN(auto actual, SortIndices(*input, order)); + ArraySortOptions options(order, null_placement); + ASSERT_OK_AND_ASSIGN(auto actual, SortIndices(*input, options)); ValidateOutput(*actual); AssertArraysEqual(*expected, *actual, /*verbose=*/true); } @@ -277,11 +343,22 @@ void AssertSortIndices(const std::shared_ptr& input, const SortOptions& optio AssertArraysEqual(*expected, *actual, /*verbose=*/true); } -// `Options` may be both SortOptions or SortOrder -template -void AssertSortIndices(const std::shared_ptr& input, Options&& options, +template +void AssertSortIndices(const std::shared_ptr& input, const SortOptions& options, const std::string& expected) { - AssertSortIndices(input, std::forward(options), + AssertSortIndices(input, options, ArrayFromJSON(uint64(), expected)); +} + +template +void AssertSortIndices(const std::shared_ptr& input, SortOrder order, + NullPlacement null_placement, const std::string& expected) { + AssertSortIndices(input, order, null_placement, ArrayFromJSON(uint64(), expected)); +} + +void AssertSortIndices(const std::shared_ptr& type, const std::string& values, + SortOrder order, NullPlacement null_placement, + const std::string& expected) { + AssertSortIndices(ArrayFromJSON(type, values), order, null_placement, ArrayFromJSON(uint64(), expected)); } @@ -290,14 +367,14 @@ class TestArraySortIndicesBase : public TestBase { virtual std::shared_ptr type() = 0; virtual void AssertSortIndices(const std::string& values, SortOrder order, + NullPlacement null_placement, const std::string& expected) { - auto type = this->type(); - arrow::compute::AssertSortIndices(ArrayFromJSON(type, values), order, - ArrayFromJSON(uint64(), expected)); + arrow::compute::AssertSortIndices(this->type(), values, order, null_placement, + expected); } virtual void AssertSortIndices(const std::string& values, const std::string& expected) { - AssertSortIndices(values, SortOrder::Ascending, expected); + AssertSortIndices(values, SortOrder::Ascending, NullPlacement::AtEnd, expected); } }; @@ -338,104 +415,194 @@ class TestArraySortIndicesForFixedSizeBinary : public TestArraySortIndicesBase { }; TYPED_TEST(TestArraySortIndicesForReal, SortReal) { - this->AssertSortIndices("[]", "[]"); - - this->AssertSortIndices("[3.4, 2.6, 6.3]", "[1, 0, 2]"); - this->AssertSortIndices("[1.1, 2.4, 3.5, 4.3, 5.1, 6.8, 7.3]", "[0, 1, 2, 3, 4, 5, 6]"); - this->AssertSortIndices("[7, 6, 5, 4, 3, 2, 1]", "[6, 5, 4, 3, 2, 1, 0]"); - this->AssertSortIndices("[10.4, 12, 4.2, 50, 50.3, 32, 11]", "[2, 0, 6, 1, 5, 3, 4]"); + for (auto null_placement : AllNullPlacements()) { + for (auto order : AllOrders()) { + this->AssertSortIndices("[]", order, null_placement, "[]"); + this->AssertSortIndices("[null, null]", order, null_placement, "[0, 1]"); + } + this->AssertSortIndices("[3.4, 2.6, 6.3]", SortOrder::Ascending, null_placement, + "[1, 0, 2]"); + this->AssertSortIndices("[1.1, 2.4, 3.5, 4.3, 5.1, 6.8, 7.3]", SortOrder::Ascending, + null_placement, "[0, 1, 2, 3, 4, 5, 6]"); + this->AssertSortIndices("[7, 6, 5, 4, 3, 2, 1]", SortOrder::Ascending, null_placement, + "[6, 5, 4, 3, 2, 1, 0]"); + this->AssertSortIndices("[10.4, 12, 4.2, 50, 50.3, 32, 11]", SortOrder::Ascending, + null_placement, "[2, 0, 6, 1, 5, 3, 4]"); + } this->AssertSortIndices("[null, 1, 3.3, null, 2, 5.3]", SortOrder::Ascending, - "[1, 4, 2, 5, 0, 3]"); + NullPlacement::AtEnd, "[1, 4, 2, 5, 0, 3]"); + this->AssertSortIndices("[null, 1, 3.3, null, 2, 5.3]", SortOrder::Ascending, + NullPlacement::AtStart, "[0, 3, 1, 4, 2, 5]"); this->AssertSortIndices("[null, 1, 3.3, null, 2, 5.3]", SortOrder::Descending, - "[5, 2, 4, 1, 0, 3]"); + NullPlacement::AtEnd, "[5, 2, 4, 1, 0, 3]"); + this->AssertSortIndices("[null, 1, 3.3, null, 2, 5.3]", SortOrder::Descending, + NullPlacement::AtStart, "[0, 3, 5, 2, 4, 1]"); this->AssertSortIndices("[3, 4, NaN, 1, 2, null]", SortOrder::Ascending, - "[3, 4, 0, 1, 2, 5]"); + NullPlacement::AtEnd, "[3, 4, 0, 1, 2, 5]"); + this->AssertSortIndices("[3, 4, NaN, 1, 2, null]", SortOrder::Ascending, + NullPlacement::AtStart, "[5, 2, 3, 4, 0, 1]"); + this->AssertSortIndices("[3, 4, NaN, 1, 2, null]", SortOrder::Descending, + NullPlacement::AtEnd, "[1, 0, 4, 3, 2, 5]"); this->AssertSortIndices("[3, 4, NaN, 1, 2, null]", SortOrder::Descending, - "[1, 0, 4, 3, 2, 5]"); - this->AssertSortIndices("[NaN, 2, NaN, 3, 1]", SortOrder::Ascending, "[4, 1, 3, 0, 2]"); + NullPlacement::AtStart, "[5, 2, 1, 0, 4, 3]"); + + this->AssertSortIndices("[NaN, 2, NaN, 3, 1]", SortOrder::Ascending, + NullPlacement::AtEnd, "[4, 1, 3, 0, 2]"); + this->AssertSortIndices("[NaN, 2, NaN, 3, 1]", SortOrder::Ascending, + NullPlacement::AtStart, "[0, 2, 4, 1, 3]"); + this->AssertSortIndices("[NaN, 2, NaN, 3, 1]", SortOrder::Descending, + NullPlacement::AtEnd, "[3, 1, 4, 0, 2]"); this->AssertSortIndices("[NaN, 2, NaN, 3, 1]", SortOrder::Descending, - "[3, 1, 4, 0, 2]"); - this->AssertSortIndices("[null, NaN, NaN, null]", SortOrder::Ascending, "[1, 2, 0, 3]"); + NullPlacement::AtStart, "[0, 2, 3, 1, 4]"); + + this->AssertSortIndices("[null, NaN, NaN, null]", SortOrder::Ascending, + NullPlacement::AtEnd, "[1, 2, 0, 3]"); + this->AssertSortIndices("[null, NaN, NaN, null]", SortOrder::Ascending, + NullPlacement::AtStart, "[0, 3, 1, 2]"); + this->AssertSortIndices("[null, NaN, NaN, null]", SortOrder::Descending, + NullPlacement::AtEnd, "[1, 2, 0, 3]"); this->AssertSortIndices("[null, NaN, NaN, null]", SortOrder::Descending, - "[1, 2, 0, 3]"); + NullPlacement::AtStart, "[0, 3, 1, 2]"); } TYPED_TEST(TestArraySortIndicesForIntegral, SortIntegral) { - this->AssertSortIndices("[]", "[]"); - - this->AssertSortIndices("[3, 2, 6]", "[1, 0, 2]"); - this->AssertSortIndices("[1, 2, 3, 4, 5, 6, 7]", "[0, 1, 2, 3, 4, 5, 6]"); - this->AssertSortIndices("[7, 6, 5, 4, 3, 2, 1]", "[6, 5, 4, 3, 2, 1, 0]"); - - this->AssertSortIndices("[10, 12, 4, 50, 50, 32, 11]", SortOrder::Ascending, - "[2, 0, 6, 1, 5, 3, 4]"); - this->AssertSortIndices("[10, 12, 4, 50, 50, 32, 11]", SortOrder::Descending, - "[3, 4, 5, 1, 6, 0, 2]"); + for (auto null_placement : AllNullPlacements()) { + for (auto order : AllOrders()) { + this->AssertSortIndices("[]", order, null_placement, "[]"); + this->AssertSortIndices("[null, null]", order, null_placement, "[0, 1]"); + } + this->AssertSortIndices("[1, 2, 3, 4, 5, 6, 7]", SortOrder::Ascending, null_placement, + "[0, 1, 2, 3, 4, 5, 6]"); + this->AssertSortIndices("[7, 6, 5, 4, 3, 2, 1]", SortOrder::Ascending, null_placement, + "[6, 5, 4, 3, 2, 1, 0]"); + + this->AssertSortIndices("[10, 12, 4, 50, 50, 32, 11]", SortOrder::Ascending, + null_placement, "[2, 0, 6, 1, 5, 3, 4]"); + this->AssertSortIndices("[10, 12, 4, 50, 50, 32, 11]", SortOrder::Descending, + null_placement, "[3, 4, 5, 1, 6, 0, 2]"); + } + // Values with a small range (use a counting sort) this->AssertSortIndices("[null, 1, 3, null, 2, 5]", SortOrder::Ascending, - "[1, 4, 2, 5, 0, 3]"); + NullPlacement::AtEnd, "[1, 4, 2, 5, 0, 3]"); + this->AssertSortIndices("[null, 1, 3, null, 2, 5]", SortOrder::Ascending, + NullPlacement::AtStart, "[0, 3, 1, 4, 2, 5]"); this->AssertSortIndices("[null, 1, 3, null, 2, 5]", SortOrder::Descending, - "[5, 2, 4, 1, 0, 3]"); + NullPlacement::AtEnd, "[5, 2, 4, 1, 0, 3]"); + this->AssertSortIndices("[null, 1, 3, null, 2, 5]", SortOrder::Descending, + NullPlacement::AtStart, "[0, 3, 5, 2, 4, 1]"); } TYPED_TEST(TestArraySortIndicesForBool, SortBool) { - this->AssertSortIndices("[]", "[]"); - - this->AssertSortIndices("[true, true, false]", "[2, 0, 1]"); - this->AssertSortIndices("[false, false, false, true, true, true, true]", - "[0, 1, 2, 3, 4, 5, 6]"); - this->AssertSortIndices("[true, true, true, true, false, false, false]", - "[4, 5, 6, 0, 1, 2, 3]"); - - this->AssertSortIndices("[false, true, false, true, true, false, false]", - SortOrder::Ascending, "[0, 2, 5, 6, 1, 3, 4]"); - this->AssertSortIndices("[false, true, false, true, true, false, false]", - SortOrder::Descending, "[1, 3, 4, 0, 2, 5, 6]"); + for (auto null_placement : AllNullPlacements()) { + for (auto order : AllOrders()) { + this->AssertSortIndices("[]", order, null_placement, "[]"); + this->AssertSortIndices("[null, null]", order, null_placement, "[0, 1]"); + } + this->AssertSortIndices("[true, true, false]", SortOrder::Ascending, null_placement, + "[2, 0, 1]"); + this->AssertSortIndices("[false, false, false, true, true, true, true]", + SortOrder::Ascending, null_placement, + "[0, 1, 2, 3, 4, 5, 6]"); + this->AssertSortIndices("[true, true, true, true, false, false, false]", + SortOrder::Ascending, null_placement, + "[4, 5, 6, 0, 1, 2, 3]"); + + this->AssertSortIndices("[false, true, false, true, true, false, false]", + SortOrder::Ascending, null_placement, + "[0, 2, 5, 6, 1, 3, 4]"); + this->AssertSortIndices("[false, true, false, true, true, false, false]", + SortOrder::Descending, null_placement, + "[1, 3, 4, 0, 2, 5, 6]"); + } this->AssertSortIndices("[null, true, false, null, false, true]", SortOrder::Ascending, - "[2, 4, 1, 5, 0, 3]"); + NullPlacement::AtEnd, "[2, 4, 1, 5, 0, 3]"); + this->AssertSortIndices("[null, true, false, null, false, true]", SortOrder::Ascending, + NullPlacement::AtStart, "[0, 3, 2, 4, 1, 5]"); this->AssertSortIndices("[null, true, false, null, false, true]", SortOrder::Descending, - "[1, 5, 2, 4, 0, 3]"); + NullPlacement::AtEnd, "[1, 5, 2, 4, 0, 3]"); + this->AssertSortIndices("[null, true, false, null, false, true]", SortOrder::Descending, + NullPlacement::AtStart, "[0, 3, 1, 5, 2, 4]"); } TYPED_TEST(TestArraySortIndicesForTemporal, SortTemporal) { - this->AssertSortIndices("[]", "[]"); - - this->AssertSortIndices("[3, 2, 6]", "[1, 0, 2]"); - this->AssertSortIndices("[1, 2, 3, 4, 5, 6, 7]", "[0, 1, 2, 3, 4, 5, 6]"); - this->AssertSortIndices("[7, 6, 5, 4, 3, 2, 1]", "[6, 5, 4, 3, 2, 1, 0]"); - - this->AssertSortIndices("[10, 12, 4, 50, 50, 32, 11]", SortOrder::Ascending, - "[2, 0, 6, 1, 5, 3, 4]"); - this->AssertSortIndices("[10, 12, 4, 50, 50, 32, 11]", SortOrder::Descending, - "[3, 4, 5, 1, 6, 0, 2]"); + for (auto null_placement : AllNullPlacements()) { + for (auto order : AllOrders()) { + this->AssertSortIndices("[]", order, null_placement, "[]"); + this->AssertSortIndices("[null, null]", order, null_placement, "[0, 1]"); + } + this->AssertSortIndices("[3, 2, 6]", SortOrder::Ascending, null_placement, + "[1, 0, 2]"); + this->AssertSortIndices("[1, 2, 3, 4, 5, 6, 7]", SortOrder::Ascending, null_placement, + "[0, 1, 2, 3, 4, 5, 6]"); + this->AssertSortIndices("[7, 6, 5, 4, 3, 2, 1]", SortOrder::Ascending, null_placement, + "[6, 5, 4, 3, 2, 1, 0]"); + + this->AssertSortIndices("[10, 12, 4, 50, 50, 32, 11]", SortOrder::Ascending, + null_placement, "[2, 0, 6, 1, 5, 3, 4]"); + this->AssertSortIndices("[10, 12, 4, 50, 50, 32, 11]", SortOrder::Descending, + null_placement, "[3, 4, 5, 1, 6, 0, 2]"); + } this->AssertSortIndices("[null, 1, 3, null, 2, 5]", SortOrder::Ascending, - "[1, 4, 2, 5, 0, 3]"); + NullPlacement::AtEnd, "[1, 4, 2, 5, 0, 3]"); + this->AssertSortIndices("[null, 1, 3, null, 2, 5]", SortOrder::Ascending, + NullPlacement::AtStart, "[0, 3, 1, 4, 2, 5]"); this->AssertSortIndices("[null, 1, 3, null, 2, 5]", SortOrder::Descending, - "[5, 2, 4, 1, 0, 3]"); + NullPlacement::AtEnd, "[5, 2, 4, 1, 0, 3]"); + this->AssertSortIndices("[null, 1, 3, null, 2, 5]", SortOrder::Descending, + NullPlacement::AtStart, "[0, 3, 5, 2, 4, 1]"); } TYPED_TEST(TestArraySortIndicesForStrings, SortStrings) { - this->AssertSortIndices("[]", "[]"); - - this->AssertSortIndices(R"(["a", "b", "c"])", "[0, 1, 2]"); - this->AssertSortIndices(R"(["foo", "bar", "baz"])", "[1,2,0]"); - this->AssertSortIndices(R"(["testing", "sort", "for", "strings"])", "[2, 1, 3, 0]"); + for (auto null_placement : AllNullPlacements()) { + for (auto order : AllOrders()) { + this->AssertSortIndices("[]", order, null_placement, "[]"); + this->AssertSortIndices("[null, null]", order, null_placement, "[0, 1]"); + } + this->AssertSortIndices(R"(["a", "b", "c"])", SortOrder::Ascending, null_placement, + "[0, 1, 2]"); + this->AssertSortIndices(R"(["foo", "bar", "baz"])", SortOrder::Ascending, + null_placement, "[1, 2, 0]"); + this->AssertSortIndices(R"(["testing", "sort", "for", "strings"])", + SortOrder::Ascending, null_placement, "[2, 1, 3, 0]"); + } - this->AssertSortIndices(R"(["c", "b", "a", "b"])", SortOrder::Ascending, - "[2, 1, 3, 0]"); - this->AssertSortIndices(R"(["c", "b", "a", "b"])", SortOrder::Descending, - "[0, 1, 3, 2]"); + const char* input = R"([null, "c", "b", null, "a", "b"])"; + this->AssertSortIndices(input, SortOrder::Ascending, NullPlacement::AtEnd, + "[4, 2, 5, 1, 0, 3]"); + this->AssertSortIndices(input, SortOrder::Ascending, NullPlacement::AtStart, + "[0, 3, 4, 2, 5, 1]"); + this->AssertSortIndices(input, SortOrder::Descending, NullPlacement::AtEnd, + "[1, 2, 5, 4, 0, 3]"); + this->AssertSortIndices(input, SortOrder::Descending, NullPlacement::AtStart, + "[0, 3, 1, 2, 5, 4]"); } TEST_F(TestArraySortIndicesForFixedSizeBinary, SortFixedSizeBinary) { - this->AssertSortIndices("[]", "[]"); + for (auto null_placement : AllNullPlacements()) { + for (auto order : AllOrders()) { + this->AssertSortIndices("[]", order, null_placement, "[]"); + this->AssertSortIndices("[null, null]", order, null_placement, "[0, 1]"); + } + this->AssertSortIndices(R"(["def", "abc", "ghi"])", SortOrder::Ascending, + null_placement, "[1, 0, 2]"); + this->AssertSortIndices(R"(["def", "abc", "ghi"])", SortOrder::Descending, + null_placement, "[2, 0, 1]"); + } - this->AssertSortIndices(R"(["def", "abc", "ghi"])", "[1, 0, 2]"); - this->AssertSortIndices(R"(["def", "abc", "ghi"])", SortOrder::Descending, "[2, 0, 1]"); + const char* input = R"([null, "ccc", "bbb", null, "aaa", "bbb"])"; + this->AssertSortIndices(input, SortOrder::Ascending, NullPlacement::AtEnd, + "[4, 2, 5, 1, 0, 3]"); + this->AssertSortIndices(input, SortOrder::Ascending, NullPlacement::AtStart, + "[0, 3, 4, 2, 5, 1]"); + this->AssertSortIndices(input, SortOrder::Descending, NullPlacement::AtEnd, + "[1, 2, 5, 4, 0, 3]"); + this->AssertSortIndices(input, SortOrder::Descending, NullPlacement::AtStart, + "[0, 3, 1, 2, 5, 4]"); } template @@ -447,13 +614,46 @@ class TestArraySortIndicesForInt8 : public TestArraySortIndices {}; TYPED_TEST_SUITE(TestArraySortIndicesForInt8, Int8Type); TYPED_TEST(TestArraySortIndicesForUInt8, SortUInt8) { - this->AssertSortIndices("[255, null, 0, 255, 10, null, 128, 0]", + const char* input = "[255, null, 0, 255, 10, null, 128, 0]"; + this->AssertSortIndices(input, SortOrder::Ascending, NullPlacement::AtEnd, "[2, 7, 4, 6, 0, 3, 1, 5]"); + this->AssertSortIndices(input, SortOrder::Ascending, NullPlacement::AtStart, + "[1, 5, 2, 7, 4, 6, 0, 3]"); + this->AssertSortIndices(input, SortOrder::Descending, NullPlacement::AtEnd, + "[0, 3, 6, 4, 2, 7, 1, 5]"); + this->AssertSortIndices(input, SortOrder::Descending, NullPlacement::AtStart, + "[1, 5, 0, 3, 6, 4, 2, 7]"); } TYPED_TEST(TestArraySortIndicesForInt8, SortInt8) { - this->AssertSortIndices("[null, 10, 127, 0, -128, -128, null]", - "[4, 5, 3, 1, 2, 0, 6]"); + const char* input = "[127, null, -128, 127, 0, null, 10, -128]"; + this->AssertSortIndices(input, SortOrder::Ascending, NullPlacement::AtEnd, + "[2, 7, 4, 6, 0, 3, 1, 5]"); + this->AssertSortIndices(input, SortOrder::Ascending, NullPlacement::AtStart, + "[1, 5, 2, 7, 4, 6, 0, 3]"); + this->AssertSortIndices(input, SortOrder::Descending, NullPlacement::AtEnd, + "[0, 3, 6, 4, 2, 7, 1, 5]"); + this->AssertSortIndices(input, SortOrder::Descending, NullPlacement::AtStart, + "[1, 5, 0, 3, 6, 4, 2, 7]"); +} + +template +class TestArraySortIndicesForInt64 : public TestArraySortIndices {}; +TYPED_TEST_SUITE(TestArraySortIndicesForInt64, Int64Type); + +TYPED_TEST(TestArraySortIndicesForInt64, SortInt64) { + // Values with a large range (use a comparison-based sort) + const char* input = + "[null, -2000000000000000, 3000000000000000," + " null, -1000000000000000, 5000000000000000]"; + this->AssertSortIndices(input, SortOrder::Ascending, NullPlacement::AtEnd, + "[1, 4, 2, 5, 0, 3]"); + this->AssertSortIndices(input, SortOrder::Ascending, NullPlacement::AtStart, + "[0, 3, 1, 4, 2, 5]"); + this->AssertSortIndices(input, SortOrder::Descending, NullPlacement::AtEnd, + "[5, 2, 4, 1, 0, 3]"); + this->AssertSortIndices(input, SortOrder::Descending, NullPlacement::AtStart, + "[0, 3, 5, 2, 4, 1]"); } template @@ -464,10 +664,15 @@ class TestArraySortIndicesForDecimal : public TestArraySortIndicesBase { TYPED_TEST_SUITE(TestArraySortIndicesForDecimal, DecimalArrowTypes); TYPED_TEST(TestArraySortIndicesForDecimal, DecimalSortTestTypes) { - this->AssertSortIndices(R"(["123.45", null, "-123.45", "456.78", "-456.78"])", - "[4, 2, 0, 3, 1]"); - this->AssertSortIndices(R"(["123.45", null, "-123.45", "456.78", "-456.78"])", - SortOrder::Descending, "[3, 0, 2, 4, 1]"); + const char* input = R"(["123.45", null, "-123.45", "456.78", "-456.78", null])"; + this->AssertSortIndices(input, SortOrder::Ascending, NullPlacement::AtEnd, + "[4, 2, 0, 3, 1, 5]"); + this->AssertSortIndices(input, SortOrder::Ascending, NullPlacement::AtStart, + "[1, 5, 4, 2, 0, 3]"); + this->AssertSortIndices(input, SortOrder::Descending, NullPlacement::AtEnd, + "[3, 0, 2, 4, 1, 5]"); + this->AssertSortIndices(input, SortOrder::Descending, NullPlacement::AtStart, + "[1, 5, 3, 0, 2, 4]"); } TEST(TestArraySortIndices, TemporalTypeParameters) { @@ -482,28 +687,31 @@ TEST(TestArraySortIndices, TemporalTypeParameters) { types.push_back(time32(TimeUnit::MILLI)); types.push_back(time32(TimeUnit::SECOND)); for (const auto& ty : types) { - AssertSortIndices(ArrayFromJSON(ty, "[]"), SortOrder::Ascending, - ArrayFromJSON(uint64(), "[]")); - - AssertSortIndices(ArrayFromJSON(ty, "[3, 2, 6]"), SortOrder::Ascending, - ArrayFromJSON(uint64(), "[1, 0, 2]")); - AssertSortIndices(ArrayFromJSON(ty, "[1, 2, 3, 4, 5, 6, 7]"), SortOrder::Ascending, - ArrayFromJSON(uint64(), "[0, 1, 2, 3, 4, 5, 6]")); - AssertSortIndices(ArrayFromJSON(ty, "[7, 6, 5, 4, 3, 2, 1]"), SortOrder::Ascending, - ArrayFromJSON(uint64(), "[6, 5, 4, 3, 2, 1, 0]")); - - AssertSortIndices(ArrayFromJSON(ty, "[10, 12, 4, 50, 50, 32, 11]"), - SortOrder::Ascending, - ArrayFromJSON(uint64(), "[2, 0, 6, 1, 5, 3, 4]")); - AssertSortIndices(ArrayFromJSON(ty, "[10, 12, 4, 50, 50, 32, 11]"), - SortOrder::Descending, - ArrayFromJSON(uint64(), "[3, 4, 5, 1, 6, 0, 2]")); - - AssertSortIndices(ArrayFromJSON(ty, "[null, 1, 3, null, 2, 5]"), SortOrder::Ascending, - ArrayFromJSON(uint64(), "[1, 4, 2, 5, 0, 3]")); - AssertSortIndices(ArrayFromJSON(ty, "[null, 1, 3, null, 2, 5]"), - SortOrder::Descending, - ArrayFromJSON(uint64(), "[5, 2, 4, 1, 0, 3]")); + for (auto null_placement : AllNullPlacements()) { + for (auto order : AllOrders()) { + AssertSortIndices(ty, "[]", order, null_placement, "[]"); + AssertSortIndices(ty, "[null, null]", order, null_placement, "[0, 1]"); + } + AssertSortIndices(ty, "[3, 2, 6]", SortOrder::Ascending, null_placement, + "[1, 0, 2]"); + AssertSortIndices(ty, "[1, 2, 3, 4, 5, 6, 7]", SortOrder::Ascending, null_placement, + "[0, 1, 2, 3, 4, 5, 6]"); + AssertSortIndices(ty, "[7, 6, 5, 4, 3, 2, 1]", SortOrder::Ascending, null_placement, + "[6, 5, 4, 3, 2, 1, 0]"); + + AssertSortIndices(ty, "[10, 12, 4, 50, 50, 32, 11]", SortOrder::Ascending, + null_placement, "[2, 0, 6, 1, 5, 3, 4]"); + AssertSortIndices(ty, "[10, 12, 4, 50, 50, 32, 11]", SortOrder::Descending, + null_placement, "[3, 4, 5, 1, 6, 0, 2]"); + } + AssertSortIndices(ty, "[null, 1, 3, null, 2, 5]", SortOrder::Ascending, + NullPlacement::AtEnd, "[1, 4, 2, 5, 0, 3]"); + AssertSortIndices(ty, "[null, 1, 3, null, 2, 5]", SortOrder::Ascending, + NullPlacement::AtStart, "[0, 3, 1, 4, 2, 5]"); + AssertSortIndices(ty, "[null, 1, 3, null, 2, 5]", SortOrder::Descending, + NullPlacement::AtEnd, "[5, 2, 4, 1, 0, 3]"); + AssertSortIndices(ty, "[null, 1, 3, null, 2, 5]", SortOrder::Descending, + NullPlacement::AtStart, "[0, 3, 5, 2, 4, 1]"); } } @@ -522,13 +730,14 @@ using SortIndicesableTypes = Decimal128Type, BooleanType>; template -void ValidateSorted(const ArrayType& array, UInt64Array& offsets, SortOrder order) { +void ValidateSorted(const ArrayType& array, UInt64Array& offsets, SortOrder order, + NullPlacement null_placement) { ValidateOutput(array); - SortComparator compare; + SortComparator compare{order, null_placement}; for (int i = 1; i < array.length(); i++) { uint64_t lhs = offsets.Value(i - 1); uint64_t rhs = offsets.Value(i); - ASSERT_TRUE(compare(array, order, lhs, rhs)); + ASSERT_TRUE(compare(array, lhs, rhs)); } } @@ -542,11 +751,16 @@ TYPED_TEST(TestArraySortIndicesRandom, SortRandomValues) { int length = 100; for (int test = 0; test < times; test++) { for (auto null_probability : {0.0, 0.1, 0.5, 1.0}) { - for (auto order : {SortOrder::Ascending, SortOrder::Descending}) { - auto array = rand.Generate(length, null_probability); - ASSERT_OK_AND_ASSIGN(std::shared_ptr offsets, SortIndices(*array, order)); - ValidateSorted(*checked_pointer_cast(array), - *checked_pointer_cast(offsets), order); + auto array = rand.Generate(length, null_probability); + for (auto order : AllOrders()) { + for (auto null_placement : AllNullPlacements()) { + ArraySortOptions options(order, null_placement); + ASSERT_OK_AND_ASSIGN(std::shared_ptr offsets, + SortIndices(*array, options)); + ValidateSorted(*checked_pointer_cast(array), + *checked_pointer_cast(offsets), order, + null_placement); + } } } } @@ -566,11 +780,16 @@ TYPED_TEST(TestArraySortIndicesRandomCount, SortRandomValuesCount) { int range = 2000; for (int test = 0; test < times; test++) { for (auto null_probability : {0.0, 0.1, 0.5, 1.0}) { - for (auto order : {SortOrder::Ascending, SortOrder::Descending}) { - auto array = rand.Generate(length, range, null_probability); - ASSERT_OK_AND_ASSIGN(std::shared_ptr offsets, SortIndices(*array, order)); - ValidateSorted(*checked_pointer_cast(array), - *checked_pointer_cast(offsets), order); + auto array = rand.Generate(length, range, null_probability); + for (auto order : AllOrders()) { + for (auto null_placement : AllNullPlacements()) { + ArraySortOptions options(order, null_placement); + ASSERT_OK_AND_ASSIGN(std::shared_ptr offsets, + SortIndices(*array, options)); + ValidateSorted(*checked_pointer_cast(array), + *checked_pointer_cast(offsets), order, + null_placement); + } } } } @@ -587,11 +806,16 @@ TYPED_TEST(TestArraySortIndicesRandomCompare, SortRandomValuesCompare) { int length = 100; for (int test = 0; test < times; test++) { for (auto null_probability : {0.0, 0.1, 0.5, 1.0}) { - for (auto order : {SortOrder::Ascending, SortOrder::Descending}) { - auto array = rand.Generate(length, null_probability); - ASSERT_OK_AND_ASSIGN(std::shared_ptr offsets, SortIndices(*array, order)); - ValidateSorted(*checked_pointer_cast(array), - *checked_pointer_cast(offsets), order); + auto array = rand.Generate(length, null_probability); + for (auto order : AllOrders()) { + for (auto null_placement : AllNullPlacements()) { + ArraySortOptions options(order, null_placement); + ASSERT_OK_AND_ASSIGN(std::shared_ptr offsets, + SortIndices(*array, options)); + ValidateSorted(*checked_pointer_cast(array), + *checked_pointer_cast(offsets), order, + null_placement); + } } } } @@ -606,8 +830,14 @@ TEST_F(TestChunkedArraySortIndices, Null) { "[3, null, 2]", "[1]", }); - AssertSortIndices(chunked_array, SortOrder::Ascending, "[1, 5, 4, 2, 0, 3]"); - AssertSortIndices(chunked_array, SortOrder::Descending, "[2, 4, 1, 5, 0, 3]"); + AssertSortIndices(chunked_array, SortOrder::Ascending, NullPlacement::AtEnd, + "[1, 5, 4, 2, 0, 3]"); + AssertSortIndices(chunked_array, SortOrder::Ascending, NullPlacement::AtStart, + "[0, 3, 1, 5, 4, 2]"); + AssertSortIndices(chunked_array, SortOrder::Descending, NullPlacement::AtEnd, + "[2, 4, 1, 5, 0, 3]"); + AssertSortIndices(chunked_array, SortOrder::Descending, NullPlacement::AtStart, + "[0, 3, 2, 4, 1, 5]"); } TEST_F(TestChunkedArraySortIndices, NaN) { @@ -616,8 +846,14 @@ TEST_F(TestChunkedArraySortIndices, NaN) { "[3, null, NaN]", "[NaN, 1]", }); - AssertSortIndices(chunked_array, SortOrder::Ascending, "[1, 6, 2, 4, 5, 0, 3]"); - AssertSortIndices(chunked_array, SortOrder::Descending, "[2, 1, 6, 4, 5, 0, 3]"); + AssertSortIndices(chunked_array, SortOrder::Ascending, NullPlacement::AtEnd, + "[1, 6, 2, 4, 5, 0, 3]"); + AssertSortIndices(chunked_array, SortOrder::Ascending, NullPlacement::AtStart, + "[0, 3, 4, 5, 1, 6, 2]"); + AssertSortIndices(chunked_array, SortOrder::Descending, NullPlacement::AtEnd, + "[2, 1, 6, 4, 5, 0, 3]"); + AssertSortIndices(chunked_array, SortOrder::Descending, NullPlacement::AtStart, + "[0, 3, 4, 5, 2, 1, 6]"); } // Tests for temporal types @@ -635,8 +871,12 @@ TYPED_TEST(TestChunkedArraySortIndicesForTemporal, NoNull) { "[3, 2, 1]", "[5, 0]", }); - AssertSortIndices(chunked_array, SortOrder::Ascending, "[0, 6, 1, 4, 3, 2, 5]"); - AssertSortIndices(chunked_array, SortOrder::Descending, "[5, 2, 3, 1, 4, 0, 6]"); + for (auto null_placement : AllNullPlacements()) { + AssertSortIndices(chunked_array, SortOrder::Ascending, null_placement, + "[0, 6, 1, 4, 3, 2, 5]"); + AssertSortIndices(chunked_array, SortOrder::Descending, null_placement, + "[5, 2, 3, 1, 4, 0, 6]"); + } } // Tests for decimal types @@ -651,8 +891,14 @@ TYPED_TEST(TestChunkedArraySortIndicesForDecimal, Basics) { auto type = this->GetType(); auto chunked_array = ChunkedArrayFromJSON( type, {R"(["123.45", "-123.45"])", R"([null, "456.78"])", R"(["-456.78", null])"}); - AssertSortIndices(chunked_array, SortOrder::Ascending, "[4, 1, 0, 3, 2, 5]"); - AssertSortIndices(chunked_array, SortOrder::Descending, "[3, 0, 1, 4, 2, 5]"); + AssertSortIndices(chunked_array, SortOrder::Ascending, NullPlacement::AtEnd, + "[4, 1, 0, 3, 2, 5]"); + AssertSortIndices(chunked_array, SortOrder::Ascending, NullPlacement::AtStart, + "[2, 5, 4, 1, 0, 3]"); + AssertSortIndices(chunked_array, SortOrder::Descending, NullPlacement::AtEnd, + "[3, 0, 1, 4, 2, 5]"); + AssertSortIndices(chunked_array, SortOrder::Descending, NullPlacement::AtStart, + "[2, 5, 3, 0, 1, 4]"); } // Base class for testing against random chunked array. @@ -665,21 +911,26 @@ class TestChunkedArrayRandomBase : public TestBase { // All tests uses this. void TestSortIndices(int length) { using ArrayType = typename TypeTraits::ArrayType; - // We can use INSTANTIATE_TEST_SUITE_P() instead of using fors in a test. + for (auto null_probability : {0.0, 0.1, 0.5, 0.9, 1.0}) { - for (auto order : {SortOrder::Ascending, SortOrder::Descending}) { - for (auto num_chunks : {1, 2, 5, 10, 40}) { - std::vector> arrays; - for (int i = 0; i < num_chunks; ++i) { - auto array = this->GenerateArray(length / num_chunks, null_probability); - arrays.push_back(array); + for (auto num_chunks : {1, 2, 5, 10, 40}) { + std::vector> arrays; + for (int i = 0; i < num_chunks; ++i) { + auto array = this->GenerateArray(length / num_chunks, null_probability); + arrays.push_back(array); + } + ASSERT_OK_AND_ASSIGN(auto chunked_array, ChunkedArray::Make(arrays)); + // Concatenate chunks to use existing ValidateSorted() for array. + ASSERT_OK_AND_ASSIGN(auto concatenated_array, Concatenate(arrays)); + + for (auto order : AllOrders()) { + for (auto null_placement : AllNullPlacements()) { + ArraySortOptions options(order, null_placement); + ASSERT_OK_AND_ASSIGN(auto offsets, SortIndices(*chunked_array, options)); + ValidateSorted( + *checked_pointer_cast(concatenated_array), + *checked_pointer_cast(offsets), order, null_placement); } - ASSERT_OK_AND_ASSIGN(auto chunked_array, ChunkedArray::Make(arrays)); - ASSERT_OK_AND_ASSIGN(auto offsets, SortIndices(*chunked_array, order)); - // Concatenates chunks to use existing ValidateSorted() for array. - ASSERT_OK_AND_ASSIGN(auto concatenated_array, Concatenate(arrays)); - ValidateSorted(*checked_pointer_cast(concatenated_array), - *checked_pointer_cast(offsets), order); } } } @@ -739,9 +990,6 @@ TEST_F(TestRecordBatchSortIndices, NoNull) { {field("a", uint8())}, {field("b", uint32())}, }); - SortOptions options( - {SortKey("a", SortOrder::Ascending), SortKey("b", SortOrder::Descending)}); - auto batch = RecordBatchFromJSON(schema, R"([{"a": 3, "b": 5}, {"a": 1, "b": 3}, @@ -751,7 +999,14 @@ TEST_F(TestRecordBatchSortIndices, NoNull) { {"a": 1, "b": 5}, {"a": 1, "b": 3} ])"); - AssertSortIndices(batch, options, "[3, 5, 1, 6, 4, 0, 2]"); + + for (auto null_placement : AllNullPlacements()) { + SortOptions options( + {SortKey("a", SortOrder::Ascending), SortKey("b", SortOrder::Descending)}, + null_placement); + + AssertSortIndices(batch, options, "[3, 5, 1, 6, 4, 0, 2]"); + } } TEST_F(TestRecordBatchSortIndices, Null) { @@ -759,18 +1014,22 @@ TEST_F(TestRecordBatchSortIndices, Null) { {field("a", uint8())}, {field("b", uint32())}, }); - SortOptions options( - {SortKey("a", SortOrder::Ascending), SortKey("b", SortOrder::Descending)}); - auto batch = RecordBatchFromJSON(schema, R"([{"a": null, "b": 5}, {"a": 1, "b": 3}, {"a": 3, "b": null}, {"a": null, "b": null}, {"a": 2, "b": 5}, - {"a": 1, "b": 5} + {"a": 1, "b": 5}, + {"a": 3, "b": 5} ])"); - AssertSortIndices(batch, options, "[5, 1, 4, 2, 0, 3]"); + const std::vector sort_keys{SortKey("a", SortOrder::Ascending), + SortKey("b", SortOrder::Descending)}; + + SortOptions options(sort_keys, NullPlacement::AtEnd); + AssertSortIndices(batch, options, "[5, 1, 4, 6, 2, 0, 3]"); + options.null_placement = NullPlacement::AtStart; + AssertSortIndices(batch, options, "[3, 0, 5, 1, 4, 2, 6]"); } TEST_F(TestRecordBatchSortIndices, NaN) { @@ -778,9 +1037,6 @@ TEST_F(TestRecordBatchSortIndices, NaN) { {field("a", float32())}, {field("b", float64())}, }); - SortOptions options( - {SortKey("a", SortOrder::Ascending), SortKey("b", SortOrder::Descending)}); - auto batch = RecordBatchFromJSON(schema, R"([{"a": 3, "b": 5}, {"a": 1, "b": NaN}, @@ -791,7 +1047,13 @@ TEST_F(TestRecordBatchSortIndices, NaN) { {"a": NaN, "b": 5}, {"a": 1, "b": 5} ])"); + const std::vector sort_keys{SortKey("a", SortOrder::Ascending), + SortKey("b", SortOrder::Descending)}; + + SortOptions options(sort_keys, NullPlacement::AtEnd); AssertSortIndices(batch, options, "[3, 7, 1, 0, 2, 4, 6, 5]"); + options.null_placement = NullPlacement::AtStart; + AssertSortIndices(batch, options, "[5, 4, 6, 3, 1, 7, 0, 2]"); } TEST_F(TestRecordBatchSortIndices, NaNAndNull) { @@ -799,9 +1061,6 @@ TEST_F(TestRecordBatchSortIndices, NaNAndNull) { {field("a", float32())}, {field("b", float64())}, }); - SortOptions options( - {SortKey("a", SortOrder::Ascending), SortKey("b", SortOrder::Descending)}); - auto batch = RecordBatchFromJSON(schema, R"([{"a": null, "b": 5}, {"a": 1, "b": 3}, @@ -812,7 +1071,13 @@ TEST_F(TestRecordBatchSortIndices, NaNAndNull) { {"a": NaN, "b": 5}, {"a": 1, "b": 5} ])"); + const std::vector sort_keys{SortKey("a", SortOrder::Ascending), + SortKey("b", SortOrder::Descending)}; + + SortOptions options(sort_keys, NullPlacement::AtEnd); AssertSortIndices(batch, options, "[7, 1, 2, 6, 5, 4, 0, 3]"); + options.null_placement = NullPlacement::AtStart; + AssertSortIndices(batch, options, "[3, 0, 4, 5, 6, 7, 1, 2]"); } TEST_F(TestRecordBatchSortIndices, Boolean) { @@ -820,9 +1085,6 @@ TEST_F(TestRecordBatchSortIndices, Boolean) { {field("a", boolean())}, {field("b", boolean())}, }); - SortOptions options( - {SortKey("a", SortOrder::Ascending), SortKey("b", SortOrder::Descending)}); - auto batch = RecordBatchFromJSON(schema, R"([{"a": true, "b": null}, {"a": false, "b": null}, @@ -833,7 +1095,13 @@ TEST_F(TestRecordBatchSortIndices, Boolean) { {"a": false, "b": null}, {"a": null, "b": true} ])"); + const std::vector sort_keys{SortKey("a", SortOrder::Ascending), + SortKey("b", SortOrder::Descending)}; + + SortOptions options(sort_keys, NullPlacement::AtEnd); AssertSortIndices(batch, options, "[3, 1, 6, 2, 4, 0, 7, 5]"); + options.null_placement = NullPlacement::AtStart; + AssertSortIndices(batch, options, "[7, 5, 1, 6, 3, 0, 2, 4]"); } TEST_F(TestRecordBatchSortIndices, MoreTypes) { @@ -842,10 +1110,6 @@ TEST_F(TestRecordBatchSortIndices, MoreTypes) { {field("b", large_utf8())}, {field("c", fixed_size_binary(3))}, }); - SortOptions options({SortKey("a", SortOrder::Ascending), - SortKey("b", SortOrder::Descending), - SortKey("c", SortOrder::Ascending)}); - auto batch = RecordBatchFromJSON(schema, R"([{"a": 3, "b": "05", "c": "aaa"}, {"a": 1, "b": "031", "c": "bbb"}, @@ -854,7 +1118,14 @@ TEST_F(TestRecordBatchSortIndices, MoreTypes) { {"a": 2, "b": "05", "c": "aaa"}, {"a": 1, "b": "05", "c": "bbb"} ])"); - AssertSortIndices(batch, options, "[3, 5, 1, 4, 0, 2]"); + const std::vector sort_keys{SortKey("a", SortOrder::Ascending), + SortKey("b", SortOrder::Descending), + SortKey("c", SortOrder::Ascending)}; + + for (auto null_placement : AllNullPlacements()) { + SortOptions options(sort_keys, null_placement); + AssertSortIndices(batch, options, "[3, 5, 1, 4, 0, 2]"); + } } TEST_F(TestRecordBatchSortIndices, Decimal) { @@ -862,9 +1133,6 @@ TEST_F(TestRecordBatchSortIndices, Decimal) { {field("a", decimal128(3, 1))}, {field("b", decimal256(4, 2))}, }); - SortOptions options( - {SortKey("a", SortOrder::Ascending), SortKey("b", SortOrder::Descending)}); - auto batch = RecordBatchFromJSON(schema, R"([{"a": "12.3", "b": "12.34"}, {"a": "45.6", "b": "12.34"}, @@ -872,7 +1140,13 @@ TEST_F(TestRecordBatchSortIndices, Decimal) { {"a": "-12.3", "b": null}, {"a": "-12.3", "b": "-45.67"} ])"); + const std::vector sort_keys{SortKey("a", SortOrder::Ascending), + SortKey("b", SortOrder::Descending)}; + + SortOptions options(sort_keys, NullPlacement::AtEnd); AssertSortIndices(batch, options, "[4, 3, 0, 2, 1]"); + options.null_placement = NullPlacement::AtStart; + AssertSortIndices(batch, options, "[3, 4, 0, 2, 1]"); } // Test basic cases for table. @@ -883,8 +1157,8 @@ TEST_F(TestTableSortIndices, Null) { {field("a", uint8())}, {field("b", uint32())}, }); - SortOptions options( - {SortKey("a", SortOrder::Ascending), SortKey("b", SortOrder::Descending)}); + const std::vector sort_keys{SortKey("a", SortOrder::Ascending), + SortKey("b", SortOrder::Descending)}; std::shared_ptr table; table = TableFromJSON(schema, {R"([{"a": null, "b": 5}, @@ -892,9 +1166,13 @@ TEST_F(TestTableSortIndices, Null) { {"a": 3, "b": null}, {"a": null, "b": null}, {"a": 2, "b": 5}, - {"a": 1, "b": 5} + {"a": 1, "b": 5}, + {"a": 3, "b": 5} ])"}); - AssertSortIndices(table, options, "[5, 1, 4, 2, 0, 3]"); + SortOptions options(sort_keys, NullPlacement::AtEnd); + AssertSortIndices(table, options, "[5, 1, 4, 6, 2, 0, 3]"); + options.null_placement = NullPlacement::AtStart; + AssertSortIndices(table, options, "[3, 0, 5, 1, 4, 2, 6]"); // Same data, several chunks table = TableFromJSON(schema, {R"([{"a": null, "b": 5}, @@ -903,9 +1181,13 @@ TEST_F(TestTableSortIndices, Null) { ])", R"([{"a": null, "b": null}, {"a": 2, "b": 5}, - {"a": 1, "b": 5} + {"a": 1, "b": 5}, + {"a": 3, "b": 5} ])"}); - AssertSortIndices(table, options, "[5, 1, 4, 2, 0, 3]"); + options.null_placement = NullPlacement::AtEnd; + AssertSortIndices(table, options, "[5, 1, 4, 6, 2, 0, 3]"); + options.null_placement = NullPlacement::AtStart; + AssertSortIndices(table, options, "[3, 0, 5, 1, 4, 2, 6]"); } TEST_F(TestTableSortIndices, NaN) { @@ -913,9 +1195,10 @@ TEST_F(TestTableSortIndices, NaN) { {field("a", float32())}, {field("b", float64())}, }); - SortOptions options( - {SortKey("a", SortOrder::Ascending), SortKey("b", SortOrder::Descending)}); + const std::vector sort_keys{SortKey("a", SortOrder::Ascending), + SortKey("b", SortOrder::Descending)}; std::shared_ptr
table; + table = TableFromJSON(schema, {R"([{"a": 3, "b": 5}, {"a": 1, "b": NaN}, {"a": 3, "b": 4}, @@ -925,7 +1208,10 @@ TEST_F(TestTableSortIndices, NaN) { {"a": NaN, "b": 5}, {"a": 1, "b": 5} ])"}); + SortOptions options(sort_keys, NullPlacement::AtEnd); AssertSortIndices(table, options, "[3, 7, 1, 0, 2, 4, 6, 5]"); + options.null_placement = NullPlacement::AtStart; + AssertSortIndices(table, options, "[5, 4, 6, 3, 1, 7, 0, 2]"); // Same data, several chunks table = TableFromJSON(schema, {R"([{"a": 3, "b": 5}, @@ -938,7 +1224,10 @@ TEST_F(TestTableSortIndices, NaN) { {"a": NaN, "b": 5}, {"a": 1, "b": 5} ])"}); + options.null_placement = NullPlacement::AtEnd; AssertSortIndices(table, options, "[3, 7, 1, 0, 2, 4, 6, 5]"); + options.null_placement = NullPlacement::AtStart; + AssertSortIndices(table, options, "[5, 4, 6, 3, 1, 7, 0, 2]"); } TEST_F(TestTableSortIndices, NaNAndNull) { @@ -946,9 +1235,10 @@ TEST_F(TestTableSortIndices, NaNAndNull) { {field("a", float32())}, {field("b", float64())}, }); - SortOptions options( - {SortKey("a", SortOrder::Ascending), SortKey("b", SortOrder::Descending)}); + const std::vector sort_keys{SortKey("a", SortOrder::Ascending), + SortKey("b", SortOrder::Descending)}; std::shared_ptr
table; + table = TableFromJSON(schema, {R"([{"a": null, "b": 5}, {"a": 1, "b": 3}, {"a": 3, "b": null}, @@ -958,7 +1248,10 @@ TEST_F(TestTableSortIndices, NaNAndNull) { {"a": NaN, "b": 5}, {"a": 1, "b": 5} ])"}); + SortOptions options(sort_keys, NullPlacement::AtEnd); AssertSortIndices(table, options, "[7, 1, 2, 6, 5, 4, 0, 3]"); + options.null_placement = NullPlacement::AtStart; + AssertSortIndices(table, options, "[3, 0, 4, 5, 6, 7, 1, 2]"); // Same data, several chunks table = TableFromJSON(schema, {R"([{"a": null, "b": 5}, @@ -971,6 +1264,7 @@ TEST_F(TestTableSortIndices, NaNAndNull) { {"a": NaN, "b": 5}, {"a": 1, "b": 5} ])"}); + options.null_placement = NullPlacement::AtEnd; AssertSortIndices(table, options, "[7, 1, 2, 6, 5, 4, 0, 3]"); } @@ -979,18 +1273,23 @@ TEST_F(TestTableSortIndices, Boolean) { {field("a", boolean())}, {field("b", boolean())}, }); - SortOptions options( - {SortKey("a", SortOrder::Ascending), SortKey("b", SortOrder::Descending)}); + const std::vector sort_keys{SortKey("a", SortOrder::Ascending), + SortKey("b", SortOrder::Descending)}; + auto table = TableFromJSON(schema, {R"([{"a": true, "b": null}, - {"a": false, "b": null}, - {"a": true, "b": true}, - {"a": false, "b": true}])", + {"a": false, "b": null}, + {"a": true, "b": true}, + {"a": false, "b": true} + ])", R"([{"a": true, "b": false}, - {"a": null, "b": false}, - {"a": false, "b": null}, - {"a": null, "b": true} - ])"}); + {"a": null, "b": false}, + {"a": false, "b": null}, + {"a": null, "b": true} + ])"}); + SortOptions options(sort_keys, NullPlacement::AtEnd); AssertSortIndices(table, options, "[3, 1, 6, 2, 4, 0, 7, 5]"); + options.null_placement = NullPlacement::AtStart; + AssertSortIndices(table, options, "[7, 5, 1, 6, 3, 0, 2, 4]"); } TEST_F(TestTableSortIndices, BinaryLike) { @@ -998,8 +1297,9 @@ TEST_F(TestTableSortIndices, BinaryLike) { {field("a", large_utf8())}, {field("b", fixed_size_binary(3))}, }); - SortOptions options( - {SortKey("a", SortOrder::Descending), SortKey("b", SortOrder::Ascending)}); + const std::vector sort_keys{SortKey("a", SortOrder::Descending), + SortKey("b", SortOrder::Ascending)}; + auto table = TableFromJSON(schema, {R"([{"a": "one", "b": null}, {"a": "two", "b": "aaa"}, {"a": "three", "b": "bbb"}, @@ -1010,7 +1310,10 @@ TEST_F(TestTableSortIndices, BinaryLike) { {"a": "three", "b": "bbb"}, {"a": "four", "b": "aaa"} ])"}); + SortOptions options(sort_keys, NullPlacement::AtEnd); AssertSortIndices(table, options, "[1, 5, 2, 6, 4, 0, 7, 3]"); + options.null_placement = NullPlacement::AtStart; + AssertSortIndices(table, options, "[1, 5, 2, 6, 0, 4, 7, 3]"); } TEST_F(TestTableSortIndices, Decimal) { @@ -1018,8 +1321,8 @@ TEST_F(TestTableSortIndices, Decimal) { {field("a", decimal128(3, 1))}, {field("b", decimal256(4, 2))}, }); - SortOptions options( - {SortKey("a", SortOrder::Ascending), SortKey("b", SortOrder::Descending)}); + const std::vector sort_keys{SortKey("a", SortOrder::Ascending), + SortKey("b", SortOrder::Descending)}; auto table = TableFromJSON(schema, {R"([{"a": "12.3", "b": "12.34"}, {"a": "45.6", "b": "12.34"}, @@ -1028,7 +1331,10 @@ TEST_F(TestTableSortIndices, Decimal) { R"([{"a": "-12.3", "b": null}, {"a": "-12.3", "b": "-45.67"} ])"}); + SortOptions options(sort_keys, NullPlacement::AtEnd); AssertSortIndices(table, options, "[4, 3, 0, 2, 1]"); + options.null_placement = NullPlacement::AtStart; + AssertSortIndices(table, options, "[3, 4, 0, 2, 1]"); } // Tests for temporal types @@ -1041,72 +1347,53 @@ TYPED_TEST_SUITE(TestTableSortIndicesForTemporal, TemporalArrowTypes); TYPED_TEST(TestTableSortIndicesForTemporal, NoNull) { auto type = this->GetType(); + const std::vector sort_keys{SortKey("a", SortOrder::Ascending), + SortKey("b", SortOrder::Descending)}; auto table = TableFromJSON(schema({ {field("a", type)}, {field("b", type)}, }), - {"[" - "{\"a\": 0, \"b\": 5}," - "{\"a\": 1, \"b\": 3}," - "{\"a\": 3, \"b\": 0}," - "{\"a\": 2, \"b\": 1}," - "{\"a\": 1, \"b\": 3}," - "{\"a\": 5, \"b\": 0}," - "{\"a\": 0, \"b\": 4}," - "{\"a\": 1, \"b\": 2}" - "]"}); - SortOptions options( - {SortKey("a", SortOrder::Ascending), SortKey("b", SortOrder::Descending)}); - AssertSortIndices(table, options, "[0, 6, 1, 4, 7, 3, 2, 5]"); + {R"([{"a": 0, "b": 5}, + {"a": 1, "b": 3}, + {"a": 3, "b": 0}, + {"a": 2, "b": 1}, + {"a": 1, "b": 3}, + {"a": 5, "b": 0}, + {"a": 0, "b": 4}, + {"a": 1, "b": 2} + ])"}); + for (auto null_placement : AllNullPlacements()) { + SortOptions options(sort_keys, null_placement); + AssertSortIndices(table, options, "[0, 6, 1, 4, 7, 3, 2, 5]"); + } } // For random table tests. -using RandomParam = std::tuple; +using RandomParam = std::tuple; + class TestTableSortIndicesRandom : public testing::TestWithParam { - // Compares two records in the same table. - class Comparator : public TypeVisitor { + // Compares two records in a column + class ColumnComparator : public TypeVisitor { public: - Comparator(const Table& table, const SortOptions& options) : options_(options) { - for (const auto& sort_key : options_.sort_keys) { - sort_columns_.emplace_back(table.GetColumnByName(sort_key.name).get(), - sort_key.order); - } - } + ColumnComparator(SortOrder order, NullPlacement null_placement) + : order_(order), null_placement_(null_placement) {} - // Returns true if the left record is less or equals to the right - // record, false otherwise. - // - // This supports null and NaN. - bool operator()(uint64_t lhs, uint64_t rhs) { + int operator()(const Array& left, const Array& right, uint64_t lhs, uint64_t rhs) { + left_ = &left; + right_ = &right; lhs_ = lhs; rhs_ = rhs; - for (const auto& pair : sort_columns_) { - const auto& chunked_array = *pair.first; - lhs_array_ = FindTargetArray(chunked_array, lhs, &lhs_index_); - rhs_array_ = FindTargetArray(chunked_array, rhs, &rhs_index_); - if (rhs_array_->IsNull(rhs_index_) && lhs_array_->IsNull(lhs_index_)) continue; - if (rhs_array_->IsNull(rhs_index_)) return true; - if (lhs_array_->IsNull(lhs_index_)) return false; - status_ = lhs_array_->type()->Accept(this); - if (compared_ == 0) continue; - // If either value is NaN, it must sort after the other regardless of order - if (pair.second == SortOrder::Ascending || lhs_isnan_ || rhs_isnan_) { - return compared_ < 0; - } else { - return compared_ > 0; - } - } - return lhs < rhs; + ARROW_CHECK_OK(left.type()->Accept(this)); + return compared_; } - Status status() const { return status_; } - #define VISIT(TYPE) \ Status Visit(const TYPE##Type& type) override { \ compared_ = CompareType(); \ return Status::OK(); \ } + VISIT(Boolean) VISIT(Int8) VISIT(Int16) VISIT(Int32) @@ -1118,12 +1405,57 @@ class TestTableSortIndicesRandom : public testing::TestWithParam { VISIT(Float) VISIT(Double) VISIT(String) + VISIT(LargeString) VISIT(Decimal128) + VISIT(Decimal256) #undef VISIT - private: - // Finds the target chunk and index in the target chunk from an + template + int CompareType() { + using ArrayType = typename TypeTraits::ArrayType; + ThreeWayComparator three_way{order_, null_placement_}; + return three_way(checked_cast(*left_), + checked_cast(*right_), lhs_, rhs_); + } + + const SortOrder order_; + const NullPlacement null_placement_; + const Array* left_; + const Array* right_; + uint64_t lhs_; + uint64_t rhs_; + int compared_; + }; + + // Compares two records in the same table. + class Comparator { + public: + Comparator(const Table& table, const SortOptions& options) : options_(options) { + for (const auto& sort_key : options_.sort_keys) { + sort_columns_.emplace_back(table.GetColumnByName(sort_key.name).get(), + sort_key.order); + } + } + + // Return true if the left record is less or equals to the right record, + // false otherwise. + bool operator()(uint64_t lhs, uint64_t rhs) { + for (const auto& pair : sort_columns_) { + ColumnComparator comparator(pair.second, options_.null_placement); + const auto& chunked_array = *pair.first; + int64_t lhs_index, rhs_index; + const Array* lhs_array = FindTargetArray(chunked_array, lhs, &lhs_index); + const Array* rhs_array = FindTargetArray(chunked_array, rhs, &rhs_index); + int compared = comparator(*lhs_array, *rhs_array, lhs_index, rhs_index); + if (compared != 0) { + return compared < 0; + } + } + return lhs < rhs; + } + + // Find the target chunk and index in the target chunk from an // index in chunked array. const Array* FindTargetArray(const ChunkedArray& chunked_array, int64_t i, int64_t* chunk_index) { @@ -1138,117 +1470,127 @@ class TestTableSortIndicesRandom : public testing::TestWithParam { return nullptr; } - // Compares two values in the same chunked array. Values are never - // null but may be NaN. - // - // Returns true if the left value is less or equals to the right - // value, false otherwise. - template - int CompareType() { - using ArrayType = typename TypeTraits::ArrayType; - auto lhs_value = - GetLogicalValue(checked_cast(*lhs_array_), lhs_index_); - auto rhs_value = - GetLogicalValue(checked_cast(*rhs_array_), rhs_index_); - if (is_floating_type::value) { - lhs_isnan_ = lhs_value != lhs_value; - rhs_isnan_ = rhs_value != rhs_value; - if (lhs_isnan_ && rhs_isnan_) return 0; - // NaN is considered greater than non-NaN - if (rhs_isnan_) return -1; - if (lhs_isnan_) return 1; - } else { - lhs_isnan_ = rhs_isnan_ = false; - } - if (lhs_value == rhs_value) { - return 0; - } else if (lhs_value > rhs_value) { - return 1; - } else { - return -1; - } - } - const SortOptions& options_; std::vector> sort_columns_; - int64_t lhs_; - const Array* lhs_array_; - int64_t lhs_index_; - int64_t rhs_; - const Array* rhs_array_; - int64_t rhs_index_; - bool lhs_isnan_, rhs_isnan_; - int compared_; - Status status_; }; public: - // Validates the sorted indexes are really sorted. + // Validates the sorted indices are really sorted. void Validate(const Table& table, const SortOptions& options, UInt64Array& offsets) { ValidateOutput(offsets); Comparator comparator{table, options}; for (int i = 1; i < table.num_rows(); i++) { uint64_t lhs = offsets.Value(i - 1); uint64_t rhs = offsets.Value(i); - ASSERT_OK(comparator.status()); - ASSERT_TRUE(comparator(lhs, rhs)) << "lhs = " << lhs << ", rhs = " << rhs; + if (!comparator(lhs, rhs)) { + std::stringstream ss; + ss << "Rows not ordered at consecutive sort indices:"; + ss << "\nFirst row (index = " << lhs << "): "; + PrintRow(table, lhs, &ss); + ss << "\nSecond row (index = " << rhs << "): "; + PrintRow(table, rhs, &ss); + FAIL() << ss.str(); + } + } + } + + void PrintRow(const Table& table, uint64_t index, std::ostream* os) { + *os << "{"; + const auto& columns = table.columns(); + for (size_t i = 0; i < columns.size(); ++i) { + if (i != 0) { + *os << ", "; + } + ASSERT_OK_AND_ASSIGN(auto scal, columns[i]->GetScalar(index)); + *os << scal->ToString(); } + *os << "}"; } }; TEST_P(TestTableSortIndicesRandom, Sort) { const auto first_sort_key_name = std::get<0>(GetParam()); - const auto null_probability = std::get<1>(GetParam()); + const auto n_sort_keys = std::get<1>(GetParam()); + const auto null_probability = std::get<2>(GetParam()); + const auto nan_probability = (1.0 - null_probability) / 4; const auto seed = 0x61549225; + ARROW_SCOPED_TRACE("n_sort_keys = ", n_sort_keys); + ARROW_SCOPED_TRACE("null_probability = ", null_probability); + + ::arrow::random::RandomArrayGenerator rng(seed); + + // Of these, "uint8", "boolean" and "string" should have many duplicates const FieldVector fields = { - {field("uint8", uint8())}, {field("uint16", uint16())}, - {field("uint32", uint32())}, {field("uint64", uint64())}, - {field("int8", int8())}, {field("int16", int16())}, - {field("int32", int32())}, {field("int64", int64())}, - {field("float", float32())}, {field("double", float64())}, - {field("string", utf8())}, {field("decimal128", decimal128(18, 3))}, + {field("uint8", uint8())}, + {field("int16", int16())}, + {field("int32", int32())}, + {field("uint64", uint64())}, + {field("float", float32())}, + {field("boolean", boolean())}, + {field("string", utf8())}, + {field("large_string", large_utf8())}, + {field("decimal128", decimal128(25, 3))}, + {field("decimal256", decimal256(42, 6))}, }; const auto length = 200; ArrayVector columns = { - Random(seed).Generate(length, null_probability), - Random(seed).Generate(length, 0.0), - Random(seed).Generate(length, null_probability), - Random(seed).Generate(length, 0.0), - Random(seed).Generate(length, 0.0), - Random(seed).Generate(length, null_probability), - Random(seed).Generate(length, 0.0), - Random(seed).Generate(length, null_probability), - Random(seed).Generate(length, null_probability, 1 - null_probability), - Random(seed).Generate(length, 0.0, null_probability), - Random(seed).Generate(length, null_probability), - Random(seed, fields[11]->type()).Generate(length, null_probability), + rng.UInt8(length, 0, 10, null_probability), + rng.Int16(length, -1000, 12000, /*null_probability=*/0.0), + rng.Int32(length, -123456789, 987654321, null_probability), + rng.UInt64(length, 1, 1234567890123456789ULL, /*null_probability=*/0.0), + rng.Float32(length, -1.0f, 1.0f, null_probability, nan_probability), + rng.Boolean(length, /*true_probability=*/0.3, null_probability), + rng.StringWithRepeats(length, /*unique=*/length / 10, /*min_length=*/5, + /*max_length=*/15, null_probability), + rng.LargeString(length, /*min_length=*/5, /*max_length=*/15, + /*null_probability=*/0.0), + rng.Decimal128(fields[8]->type(), length, null_probability), + rng.Decimal256(fields[9]->type(), length, /*null_probability=*/0.0), }; const auto table = Table::Make(schema(fields), columns, length); - // Generate random sort keys + // Generate random sort keys, making sure no column is included twice std::default_random_engine engine(seed); std::uniform_int_distribution<> distribution(0); - const auto n_sort_keys = 7; + + auto generate_order = [&]() { + return (distribution(engine) & 1) ? SortOrder::Ascending : SortOrder::Descending; + }; + std::vector sort_keys; - const auto first_sort_key_order = - (distribution(engine) % 2) == 0 ? SortOrder::Ascending : SortOrder::Descending; - sort_keys.emplace_back(first_sort_key_name, first_sort_key_order); - for (int i = 1; i < n_sort_keys; ++i) { - const auto& field = *fields[distribution(engine) % fields.size()]; - const auto order = - (distribution(engine) % 2) == 0 ? SortOrder::Ascending : SortOrder::Descending; - sort_keys.emplace_back(field.name(), order); + sort_keys.reserve(fields.size()); + for (const auto& field : fields) { + if (field->name() != first_sort_key_name) { + sort_keys.emplace_back(field->name(), generate_order()); + } + } + std::shuffle(sort_keys.begin(), sort_keys.end(), engine); + sort_keys.emplace(sort_keys.begin(), first_sort_key_name, generate_order()); + sort_keys.erase(sort_keys.begin() + n_sort_keys, sort_keys.end()); + ASSERT_EQ(sort_keys.size(), n_sort_keys); + + std::stringstream ss; + for (const auto& sort_key : sort_keys) { + ss << sort_key.name << (sort_key.order == SortOrder::Ascending ? " ASC" : " DESC"); + ss << ", "; } + ARROW_SCOPED_TRACE("sort_keys = ", ss.str()); + SortOptions options(sort_keys); // Test with different table chunkings for (const int64_t num_chunks : {1, 2, 20}) { + ARROW_SCOPED_TRACE("Table sorting: num_chunks = ", num_chunks); TableBatchReader reader(*table); reader.set_chunksize((length + num_chunks - 1) / num_chunks); ASSERT_OK_AND_ASSIGN(auto chunked_table, Table::FromRecordBatchReader(&reader)); - ASSERT_OK_AND_ASSIGN(auto offsets, SortIndices(Datum(*chunked_table), options)); - Validate(*table, options, *checked_pointer_cast(offsets)); + for (auto null_placement : AllNullPlacements()) { + ARROW_SCOPED_TRACE("null_placement = ", null_placement); + options.null_placement = null_placement; + ASSERT_OK_AND_ASSIGN(auto offsets, SortIndices(Datum(*chunked_table), options)); + Validate(*table, options, *checked_pointer_cast(offsets)); + } } // Also validate RecordBatch sorting @@ -1256,22 +1598,33 @@ TEST_P(TestTableSortIndicesRandom, Sort) { RecordBatchVector batches; ASSERT_OK(reader.ReadAll(&batches)); ASSERT_EQ(batches.size(), 1); - ASSERT_OK_AND_ASSIGN(auto offsets, SortIndices(Datum(*batches[0]), options)); - Validate(*table, options, *checked_pointer_cast(offsets)); + ARROW_SCOPED_TRACE("Record batch sorting"); + for (auto null_placement : AllNullPlacements()) { + ARROW_SCOPED_TRACE("null_placement = ", null_placement); + options.null_placement = null_placement; + ASSERT_OK_AND_ASSIGN(auto offsets, SortIndices(Datum(*batches[0]), options)); + Validate(*table, options, *checked_pointer_cast(offsets)); + } } -static const auto first_sort_keys = - testing::Values("uint8", "uint16", "uint32", "uint64", "int8", "int16", "int32", - "int64", "float", "double", "string", "decimal128"); +// Some first keys will have duplicates, others not +static const auto first_sort_keys = testing::Values("uint8", "int16", "uint64", "float", + "boolean", "string", "decimal128"); + +// Different numbers of sort keys may trigger different algorithms +static const auto num_sort_keys = testing::Values(1, 3, 7, 9); INSTANTIATE_TEST_SUITE_P(NoNull, TestTableSortIndicesRandom, - testing::Combine(first_sort_keys, testing::Values(0.0))); + testing::Combine(first_sort_keys, num_sort_keys, + testing::Values(0.0))); -INSTANTIATE_TEST_SUITE_P(MayNull, TestTableSortIndicesRandom, - testing::Combine(first_sort_keys, testing::Values(0.1, 0.5))); +INSTANTIATE_TEST_SUITE_P(SomeNulls, TestTableSortIndicesRandom, + testing::Combine(first_sort_keys, num_sort_keys, + testing::Values(0.1, 0.5))); INSTANTIATE_TEST_SUITE_P(AllNull, TestTableSortIndicesRandom, - testing::Combine(first_sort_keys, testing::Values(1.0))); + testing::Combine(first_sort_keys, num_sort_keys, + testing::Values(1.0))); } // namespace compute } // namespace arrow diff --git a/docs/source/cpp/compute.rst b/docs/source/cpp/compute.rst index b10c7a120b2..bd115c917d8 100644 --- a/docs/source/cpp/compute.rst +++ b/docs/source/cpp/compute.rst @@ -1448,21 +1448,26 @@ These functions select and return a subset of their input. Sorts and partitions ~~~~~~~~~~~~~~~~~~~~ -In these functions, nulls are considered greater than any other value -(they will be sorted or partitioned at the end of the array). -Floating-point NaN values are considered greater than any other non-null -value, but smaller than nulls. +By default, in these functions, nulls are considered greater than any other value +(they will be sorted or partitioned at the end of the array). Floating-point +NaN values are considered greater than any other non-null value, but smaller +than nulls. This behaviour can be changed using the ``null_placement`` setting +in the respective option classes. + +.. note:: + Binary- and String-like inputs are ordered lexicographically as bytestrings, + even for String types. +-----------------------+------------+---------------------------------------------------------+-------------------+--------------------------------+----------------+ | Function name | Arity | Input types | Output type | Options class | Notes | +=======================+============+=========================================================+===================+================================+================+ -| partition_nth_indices | Unary | Boolean, Numeric, Temporal, Binary- and String-like | UInt64 | :struct:`PartitionNthOptions` | \(1) \(3) | +| partition_nth_indices | Unary | Boolean, Numeric, Temporal, Binary- and String-like | UInt64 | :struct:`PartitionNthOptions` | \(1) | +-----------------------+------------+---------------------------------------------------------+-------------------+--------------------------------+----------------+ -| array_sort_indices | Unary | Boolean, Numeric, Temporal, Binary- and String-like | UInt64 | :struct:`ArraySortOptions` | \(2) \(4) \(3) | +| array_sort_indices | Unary | Boolean, Numeric, Temporal, Binary- and String-like | UInt64 | :struct:`ArraySortOptions` | \(2) \(3) | +-----------------------+------------+---------------------------------------------------------+-------------------+--------------------------------+----------------+ -| select_k_unstable | Unary | Boolean, Numeric, Temporal, Binary- and String-like | UInt64 | :struct:`SelectKOptions` | \(5) \(6) \(3) | +| select_k_unstable | Unary | Boolean, Numeric, Temporal, Binary- and String-like | UInt64 | :struct:`SelectKOptions` | \(4) \(5) | +-----------------------+------------+---------------------------------------------------------+-------------------+--------------------------------+----------------+ -| sort_indices | Unary | Boolean, Numeric, Temporal, Binary- and String-like | UInt64 | :struct:`SortOptions` | \(2) \(5) \(3) | +| sort_indices | Unary | Boolean, Numeric, Temporal, Binary- and String-like | UInt64 | :struct:`SortOptions` | \(2) \(4) | +-----------------------+------------+---------------------------------------------------------+-------------------+--------------------------------+----------------+ * \(1) The output is an array of indices into the input array, that define @@ -1475,16 +1480,13 @@ value, but smaller than nulls. * \(2) The output is an array of indices into the input, that define a stable sort of the input. -* \(3) Input values are ordered lexicographically as bytestrings (even - for String arrays). - -* \(4) The input must be an array. The default order is ascending. +* \(3) The input must be an array. The default order is ascending. -* \(5) The input can be an array, chunked array, record batch or +* \(4) The input can be an array, chunked array, record batch or table. If the input is a record batch or table, one or more sort keys must be specified. -* \(6) The output is an array of indices into the input, that define a +* \(5) The output is an array of indices into the input, that define a non-stable sort of the input. .. _cpp-compute-vector-structural-transforms: diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index 33f3cbca685..44afd4ba0bf 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -897,16 +897,6 @@ class TakeOptions(_TakeOptions): self._set_options(boundscheck) -cdef class _PartitionNthOptions(FunctionOptions): - def _set_options(self, pivot): - self.wrapped.reset(new CPartitionNthOptions(pivot)) - - -class PartitionNthOptions(_PartitionNthOptions): - def __init__(self, pivot): - self._set_options(pivot) - - cdef class _MakeStructOptions(FunctionOptions): def _set_options(self, field_names, field_nullability, field_metadata): cdef: @@ -1140,29 +1130,50 @@ cdef CSortOrder unwrap_sort_order(order) except *: _raise_invalid_function_option(order, "sort order") +cdef CNullPlacement unwrap_null_placement(null_placement) except *: + if null_placement == "at_start": + return CNullPlacement_AtStart + elif null_placement == "at_end": + return CNullPlacement_AtEnd + _raise_invalid_function_option(null_placement, "null placement") + + +cdef class _PartitionNthOptions(FunctionOptions): + def _set_options(self, pivot, null_placement): + self.wrapped.reset(new CPartitionNthOptions( + pivot, unwrap_null_placement(null_placement))) + + +class PartitionNthOptions(_PartitionNthOptions): + def __init__(self, pivot, *, null_placement="at_end"): + self._set_options(pivot, null_placement) + + cdef class _ArraySortOptions(FunctionOptions): - def _set_options(self, order): - self.wrapped.reset(new CArraySortOptions(unwrap_sort_order(order))) + def _set_options(self, order, null_placement): + self.wrapped.reset(new CArraySortOptions( + unwrap_sort_order(order), unwrap_null_placement(null_placement))) class ArraySortOptions(_ArraySortOptions): - def __init__(self, order="ascending"): - self._set_options(order) + def __init__(self, order="ascending", *, null_placement="at_end"): + self._set_options(order, null_placement) cdef class _SortOptions(FunctionOptions): - def _set_options(self, sort_keys): + def _set_options(self, sort_keys, null_placement): cdef vector[CSortKey] c_sort_keys for name, order in sort_keys: c_sort_keys.push_back( CSortKey(tobytes(name), unwrap_sort_order(order)) ) - self.wrapped.reset(new CSortOptions(c_sort_keys)) + self.wrapped.reset(new CSortOptions( + c_sort_keys, unwrap_null_placement(null_placement))) class SortOptions(_SortOptions): - def __init__(self, sort_keys): - self._set_options(sort_keys) + def __init__(self, sort_keys, *, null_placement="at_end"): + self._set_options(sort_keys, null_placement) cdef class _SelectKOptions(FunctionOptions): diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index f45b07a2e43..7cb740d708f 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -2069,11 +2069,6 @@ cdef extern from "arrow/compute/api.h" namespace "arrow::compute" nogil: CIndexOptions(shared_ptr[CScalar] value) shared_ptr[CScalar] value - cdef cppclass CPartitionNthOptions \ - "arrow::compute::PartitionNthOptions"(CFunctionOptions): - CPartitionNthOptions(int64_t pivot) - int64_t pivot - cdef cppclass CMakeStructOptions \ "arrow::compute::MakeStructOptions"(CFunctionOptions): CMakeStructOptions(vector[c_string] n, @@ -2090,10 +2085,23 @@ cdef extern from "arrow/compute/api.h" namespace "arrow::compute" nogil: CSortOrder_Descending \ "arrow::compute::SortOrder::Descending" + ctypedef enum CNullPlacement" arrow::compute::NullPlacement": + CNullPlacement_AtStart \ + "arrow::compute::NullPlacement::AtStart" + CNullPlacement_AtEnd \ + "arrow::compute::NullPlacement::AtEnd" + + cdef cppclass CPartitionNthOptions \ + "arrow::compute::PartitionNthOptions"(CFunctionOptions): + CPartitionNthOptions(int64_t pivot, CNullPlacement) + int64_t pivot + CNullPlacement null_placement + cdef cppclass CArraySortOptions \ "arrow::compute::ArraySortOptions"(CFunctionOptions): - CArraySortOptions(CSortOrder order) + CArraySortOptions(CSortOrder, CNullPlacement) CSortOrder order + CNullPlacement null_placement cdef cppclass CSortKey" arrow::compute::SortKey": CSortKey(c_string name, CSortOrder order) @@ -2102,8 +2110,9 @@ cdef extern from "arrow/compute/api.h" namespace "arrow::compute" nogil: cdef cppclass CSortOptions \ "arrow::compute::SortOptions"(CFunctionOptions): - CSortOptions(vector[CSortKey] sort_keys) + CSortOptions(vector[CSortKey] sort_keys, CNullPlacement) vector[CSortKey] sort_keys + CNullPlacement null_placement cdef cppclass CSelectKOptions \ "arrow::compute::SelectKOptions"(CFunctionOptions): diff --git a/python/pyarrow/tests/test_compute.py b/python/pyarrow/tests/test_compute.py index a4538046724..75501e11192 100644 --- a/python/pyarrow/tests/test_compute.py +++ b/python/pyarrow/tests/test_compute.py @@ -129,7 +129,7 @@ def test_option_class_equality(): pc.ModeOptions(), pc.NullOptions(), pc.PadOptions(5), - pc.PartitionNthOptions(1), + pc.PartitionNthOptions(1, null_placement="at_start"), pc.QuantileOptions(), pc.ReplaceSliceOptions(0, 1, "a"), pc.ReplaceSubstringOptions("a", "b"), @@ -139,7 +139,7 @@ def test_option_class_equality(): pc.SelectKOptions(0, sort_keys=[("b", "ascending")]), pc.SetLookupOptions(pa.array([1])), pc.SliceOptions(0, 1, 1), - pc.SortOptions([("dummy", "descending")]), + pc.SortOptions([("dummy", "descending")], null_placement="at_start"), pc.SplitOptions(), pc.SplitPatternOptions("pattern"), pc.StrftimeOptions(), @@ -174,7 +174,8 @@ def test_option_class_equality(): assert option1 != option2 assert repr(pc.IndexOptions(pa.scalar(1))) == "IndexOptions(value=int64:1)" - assert repr(pc.ArraySortOptions()) == "ArraySortOptions(order=Ascending)" + assert repr(pc.ArraySortOptions()) == \ + "ArraySortOptions(order=Ascending, null_placement=AtEnd)" def test_list_functions(): @@ -1847,17 +1848,44 @@ def test_index(): assert arr.index(1, start=1, end=2).as_py() == -1 +def check_partition_nth(data, indices, pivot, null_placement): + indices = indices.to_pylist() + assert len(indices) == len(data) + assert sorted(indices) == list(range(len(data))) + until_pivot = [data[indices[i]] for i in range(pivot)] + after_pivot = [data[indices[i]] for i in range(pivot, len(data))] + p = data[indices[pivot]] + if p is None: + if null_placement == "at_start": + assert all(v is None for v in until_pivot) + else: + assert all(v is None for v in after_pivot) + else: + if null_placement == "at_start": + assert all(v is None or v <= p for v in until_pivot) + assert all(v >= p for v in after_pivot) + else: + assert all(v <= p for v in until_pivot) + assert all(v is None or v >= p for v in after_pivot) + + def test_partition_nth(): data = list(range(100, 140)) random.shuffle(data) pivot = 10 - indices = pc.partition_nth_indices(data, pivot=pivot).to_pylist() - assert len(indices) == len(data) - assert sorted(indices) == list(range(len(data))) - assert all(data[indices[i]] <= data[indices[pivot]] - for i in range(pivot)) - assert all(data[indices[i]] >= data[indices[pivot]] - for i in range(pivot, len(data))) + indices = pc.partition_nth_indices(data, pivot=pivot) + check_partition_nth(data, indices, pivot, "at_end") + + +def test_partition_nth_null_placement(): + data = list(range(10)) + [None] * 10 + random.shuffle(data) + + for pivot in (0, 7, 13, 19): + for null_placement in ("at_start", "at_end"): + indices = pc.partition_nth_indices(data, pivot=pivot, + null_placement=null_placement) + check_partition_nth(data, indices, pivot, null_placement) def test_select_k_array(): @@ -1949,6 +1977,9 @@ def test_array_sort_indices(): assert result.to_pylist() == [3, 0, 1, 2] result = pc.array_sort_indices(arr, order="descending") assert result.to_pylist() == [1, 0, 3, 2] + result = pc.array_sort_indices(arr, order="descending", + null_placement="at_start") + assert result.to_pylist() == [2, 1, 0, 3] with pytest.raises(ValueError, match="not a valid sort order"): pc.array_sort_indices(arr, order="nonscending") @@ -1962,22 +1993,39 @@ def test_sort_indices_array(): assert result.to_pylist() == [3, 0, 1, 2] result = pc.sort_indices(arr, sort_keys=[("dummy", "descending")]) assert result.to_pylist() == [1, 0, 3, 2] + result = pc.sort_indices(arr, sort_keys=[("dummy", "descending")], + null_placement="at_start") + assert result.to_pylist() == [2, 1, 0, 3] + # Using SortOptions result = pc.sort_indices( arr, options=pc.SortOptions(sort_keys=[("dummy", "descending")]) ) assert result.to_pylist() == [1, 0, 3, 2] + result = pc.sort_indices( + arr, options=pc.SortOptions(sort_keys=[("dummy", "descending")], + null_placement="at_start") + ) + assert result.to_pylist() == [2, 1, 0, 3] def test_sort_indices_table(): - table = pa.table({"a": [1, 1, 0], "b": [1, 0, 1]}) + table = pa.table({"a": [1, 1, None, 0], "b": [1, 0, 0, 1]}) result = pc.sort_indices(table, sort_keys=[("a", "ascending")]) - assert result.to_pylist() == [2, 0, 1] + assert result.to_pylist() == [3, 0, 1, 2] + result = pc.sort_indices(table, sort_keys=[("a", "ascending")], + null_placement="at_start") + assert result.to_pylist() == [2, 3, 0, 1] result = pc.sort_indices( - table, sort_keys=[("a", "ascending"), ("b", "ascending")] + table, sort_keys=[("a", "descending"), ("b", "ascending")] + ) + assert result.to_pylist() == [1, 0, 3, 2] + result = pc.sort_indices( + table, sort_keys=[("a", "descending"), ("b", "ascending")], + null_placement="at_start" ) - assert result.to_pylist() == [2, 1, 0] + assert result.to_pylist() == [2, 1, 0, 3] with pytest.raises(ValueError, match="Must specify one or more sort keys"): pc.sort_indices(table)