diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index e1e409d0a7d..3517d0cf041 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -410,6 +410,8 @@ if(ARROW_COMPUTE) compute/exec/tpch_node.cc compute/exec/union_node.cc compute/exec/util.cc + compute/exec/window_functions/merge_tree.cc + compute/exec/window_functions/window_rank.cc compute/function.cc compute/function_internal.cc compute/kernel.cc diff --git a/cpp/src/arrow/compute/exec/CMakeLists.txt b/cpp/src/arrow/compute/exec/CMakeLists.txt index 4ce73359d0f..d83258a9722 100644 --- a/cpp/src/arrow/compute/exec/CMakeLists.txt +++ b/cpp/src/arrow/compute/exec/CMakeLists.txt @@ -45,6 +45,11 @@ add_arrow_compute_test(util_test SOURCES util_test.cc task_util_test.cc) +add_arrow_compute_test(window_functions_test + PREFIX + "arrow-compute" + SOURCES + window_functions/window_test.cc) add_arrow_benchmark(expression_benchmark PREFIX "arrow-compute") diff --git a/cpp/src/arrow/compute/exec/util.h b/cpp/src/arrow/compute/exec/util.h index 8f6b3de2e1f..0c332d21774 100644 --- a/cpp/src/arrow/compute/exec/util.h +++ b/cpp/src/arrow/compute/exec/util.h @@ -167,6 +167,20 @@ class TempVectorHolder { uint32_t num_elements_; }; +#define TEMP_VECTOR(type, name) \ + auto name##_buf = arrow::util::TempVectorHolder( \ + temp_vector_stack, arrow::util::MiniBatch::kMiniBatchLength); \ + auto name = name##_buf.mutable_data(); + +#define BEGIN_MINI_BATCH_FOR(batch_begin, batch_length, num_rows) \ + for (int64_t batch_begin = 0; batch_begin < num_rows; \ + batch_begin += arrow::util::MiniBatch::kMiniBatchLength) { \ + int64_t batch_length = \ + std::min(static_cast(num_rows) - batch_begin, \ + static_cast(arrow::util::MiniBatch::kMiniBatchLength)); + +#define END_MINI_BATCH_FOR } + class bit_util { public: static void bits_to_indexes(int bit_to_search, int64_t hardware_flags, @@ -365,13 +379,14 @@ struct ARROW_EXPORT TableSinkNodeConsumer : public SinkNodeConsumer { /// Modify an Expression with pre-order and post-order visitation. /// `pre` will be invoked on each Expression. `pre` will visit Calls before their /// arguments, `post_call` will visit Calls (and no other Expressions) after their -/// arguments. Visitors should return the Identical expression to indicate no change; this -/// will prevent unnecessary construction in the common case where a modification is not -/// possible/necessary/... +/// arguments. Visitors should return the Identical expression to indicate no change; +/// this will prevent unnecessary construction in the common case where a modification +/// is not possible/necessary/... /// -/// If an argument was modified, `post_call` visits a reconstructed Call with the modified -/// arguments but also receives a pointer to the unmodified Expression as a second -/// argument. If no arguments were modified the unmodified Expression* will be nullptr. +/// If an argument was modified, `post_call` visits a reconstructed Call with the +/// modified arguments but also receives a pointer to the unmodified Expression as a +/// second argument. If no arguments were modified the unmodified Expression* will be +/// nullptr. template Result ModifyExpression(Expression expr, const PreVisit& pre, const PostVisitCall& post_call) { @@ -409,5 +424,38 @@ Result ModifyExpression(Expression expr, const PreVisit& pre, return post_call(std::move(expr), NULLPTR); } +struct ThreadContext { + int64_t thread_index; + util::TempVectorStack* temp_vector_stack; + int64_t hardware_flags; +}; + +struct ParallelForStream { + using TaskCallback = std::function; + + void InsertParallelFor(int64_t num_tasks, TaskCallback task_callback) { + parallel_fors_.push_back(std::make_pair(num_tasks, task_callback)); + } + + void InsertTaskSingle(TaskCallback task_callback) { + parallel_fors_.push_back(std::make_pair(static_cast(1), task_callback)); + } + + // If any of the tasks returns an error status then all the remaining parallel + // fors in the stream will not be executed and the first error status within + // the failing parallel for loop step will be returned. + // + Status RunOnSingleThread(ThreadContext& thread_context) { + for (size_t i = 0; i < parallel_fors_.size(); ++i) { + for (int64_t j = 0; j < parallel_fors_[i].first; ++j) { + ARROW_RETURN_NOT_OK(parallel_fors_[i].second(j, thread_context)); + } + } + return Status::OK(); + } + + std::vector> parallel_fors_; +}; + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/exec/window_functions/bit_vector_navigator.h b/cpp/src/arrow/compute/exec/window_functions/bit_vector_navigator.h new file mode 100644 index 00000000000..f5013e19e65 --- /dev/null +++ b/cpp/src/arrow/compute/exec/window_functions/bit_vector_navigator.h @@ -0,0 +1,513 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include "arrow/compute/exec/util.h" +#include "arrow/util/bit_util.h" + +namespace arrow { +namespace compute { + +// Storage for a bit vector to be used with BitVectorNavigator and its variants. +// +// Supports weaved bit vectors. +// +class BitVectorWithCountsBase { + template + friend class BitVectorNavigatorImp; + + public: + BitVectorWithCountsBase() : num_children_(0), num_bits_per_child_(0) {} + + void Resize(int64_t num_bits_per_child, int64_t num_children = 1) { + ARROW_DCHECK(num_children > 0 && num_bits_per_child > 0); + num_children_ = num_children; + num_bits_per_child_ = num_bits_per_child; + int64_t num_words = + bit_util::CeilDiv(num_bits_per_child, kBitsPerWord) * num_children; + bits_.resize(num_words); + mid_counts_.resize(num_words); + int64_t num_blocks = + bit_util::CeilDiv(num_bits_per_child, kBitsPerBlock) * num_children; + top_counts_.resize(num_blocks); + } + + void ClearBits() { memset(bits_.data(), 0, bits_.size() * sizeof(bits_[0])); } + + // Word is 64 adjacent bits + // + static constexpr int64_t kBitsPerWord = 64; + // Block is 65536 adjacent bits + // (that means that 16-bit counters can be used within the block) + // +#ifndef NDEBUG + static constexpr int kLogBitsPerBlock = 7; +#else + static constexpr int kLogBitsPerBlock = 16; +#endif + static constexpr int64_t kBitsPerBlock = 1LL << kLogBitsPerBlock; + + protected: + int64_t num_children_; + int64_t num_bits_per_child_; + // TODO: Replace vectors with ResizableBuffers. Return error status from + // Resize on out-of-memory. + // + std::vector bits_; + std::vector top_counts_; + std::vector mid_counts_; +}; + +template +class BitVectorNavigatorImp { + public: + BitVectorNavigatorImp() : container_(NULLPTR) {} + + BitVectorNavigatorImp(BitVectorWithCountsBase* container, int64_t child_index) + : container_(container), child_index_(child_index) {} + + int64_t block_count() const { + return bit_util::CeilDiv(container_->num_bits_per_child_, + BitVectorWithCountsBase::kBitsPerBlock); + } + + int64_t word_count() const { + return bit_util::CeilDiv(container_->num_bits_per_child_, + BitVectorWithCountsBase::kBitsPerWord); + } + + int64_t bit_count() const { return container_->num_bits_per_child_; } + + int64_t pop_count() const { + int64_t last_block = block_count() - 1; + int64_t last_word = word_count() - 1; + int num_bits_last_word = + static_cast((bit_count() - 1) % BitVectorWithCountsBase::kBitsPerWord + 1); + uint64_t last_word_mask = ~0ULL >> (64 - num_bits_last_word); + return container_->top_counts_[apply_stride_and_offset(last_block)] + + container_->mid_counts_[apply_stride_and_offset(last_word)] + + ARROW_POPCOUNT64(container_->bits_[apply_stride_and_offset(last_word)] & + last_word_mask); + } + + const uint8_t* GetBytes() const { + return reinterpret_cast(container_->bits_.data()); + } + + void BuildMidCounts(int64_t block_index) { + ARROW_DCHECK(block_index >= 0 && + block_index < static_cast(container_->mid_counts_.size())); + constexpr int64_t words_per_block = + BitVectorWithCountsBase::kBitsPerBlock / BitVectorWithCountsBase::kBitsPerWord; + int64_t word_begin = block_index * words_per_block; + int64_t word_end = std::min(word_count(), word_begin + words_per_block); + + const uint64_t* words = container_->bits_.data(); + uint16_t* counters = container_->mid_counts_.data(); + + uint16_t count = 0; + for (int64_t word_index = word_begin; word_index < word_end; ++word_index) { + counters[apply_stride_and_offset(word_index)] = count; + count += static_cast( + ARROW_POPCOUNT64(words[apply_stride_and_offset(word_index)])); + } + } + + void BuildTopCounts(int64_t block_index_begin, int64_t block_index_end, + int64_t initial_count = 0) { + const uint64_t* words = container_->bits_.data(); + int64_t* counters = container_->top_counts_.data(); + const uint16_t* mid_counters = container_->mid_counts_.data(); + + int64_t count = initial_count; + + for (int64_t block_index = block_index_begin; block_index < block_index_end - 1; + ++block_index) { + counters[apply_stride_and_offset(block_index)] = count; + + constexpr int64_t words_per_block = + BitVectorWithCountsBase::kBitsPerBlock / BitVectorWithCountsBase::kBitsPerWord; + + int64_t word_begin = block_index * words_per_block; + int64_t word_end = std::min(word_count(), word_begin + words_per_block); + + count += mid_counters[apply_stride_and_offset(word_end - 1)]; + count += ARROW_POPCOUNT64(words[apply_stride_and_offset(word_end - 1)]); + } + counters[apply_stride_and_offset(block_index_end - 1)] = count; + } + + // Position of the nth bit set (input argument zero corresponds to the first + // bit set). + // + int64_t Select(int64_t rank) const { + if (rank < 0) { + return BeforeFirstBit(); + } + if (rank >= pop_count()) { + return AfterLastBit(); + } + + constexpr int64_t bits_per_block = BitVectorWithCountsBase::kBitsPerBlock; + constexpr int64_t bits_per_word = BitVectorWithCountsBase::kBitsPerWord; + constexpr int64_t words_per_block = bits_per_block / bits_per_word; + const int64_t* top_counters = container_->top_counts_.data(); + const uint16_t* mid_counters = container_->mid_counts_.data(); + const uint64_t* words = container_->bits_.data(); + + // Binary search in top level counters. + // + // Equivalent of std::upper_bound() - 1, but not using iterators. + // + int64_t begin = 0; + int64_t end = block_count(); + while (end - begin > 1) { + int64_t middle = (begin + end) / 2; + int reject_left_half = + (rank >= top_counters[apply_stride_and_offset(middle)]) ? 1 : 0; + begin = begin + (middle - begin) * reject_left_half; + end = middle + (end - middle) * reject_left_half; + } + + int64_t block_index = begin; + rank -= top_counters[apply_stride_and_offset(begin)]; + + // Continue with binary search in intermediate level counters of the + // selected block. + // + begin = block_index * words_per_block; + end = std::min(word_count(), begin + words_per_block); + while (end - begin > 1) { + int64_t middle = (begin + end) / 2; + int reject_left_half = + (rank >= mid_counters[apply_stride_and_offset(middle)]) ? 1 : 0; + begin = begin + (middle - begin) * reject_left_half; + end = middle + (end - middle) * reject_left_half; + } + + int64_t word_index = begin; + rank -= mid_counters[apply_stride_and_offset(begin)]; + + // Continue with binary search in the selected word. + // + uint64_t word = words[apply_stride_and_offset(word_index)]; + int pop_count_prefix = 0; + int bit_count_prefix = 0; + const uint64_t masks[6] = {0xFFFFFFFFULL, 0xFFFFULL, 0xFFULL, 0xFULL, 0x3ULL, 0x1ULL}; + int bit_count_left_half = 32; + for (int i = 0; i < 6; ++i) { + int pop_count_left_half = + static_cast(ARROW_POPCOUNT64((word >> bit_count_prefix) & masks[i])); + int reject_left_half = (rank >= pop_count_prefix + pop_count_left_half) ? 1 : 0; + pop_count_prefix += reject_left_half * pop_count_left_half; + bit_count_prefix += reject_left_half * bit_count_left_half; + bit_count_left_half /= 2; + } + + return word_index * bits_per_word + bit_count_prefix; + } + + void Select(int64_t rank_begin, int64_t rank_end, int64_t* selects, + const ThreadContext& thread_ctx) const { + ARROW_DCHECK(rank_begin <= rank_end); + + // For ranks out of the range represented in the bit vector return + // BeforeFirstBit() or AfterLastBit(). + // + if (rank_begin < 0) { + int64_t num_ranks_to_skip = + std::min(rank_end, static_cast(0)) - rank_begin; + for (int64_t i = 0LL; i < num_ranks_to_skip; ++i) { + selects[i] = BeforeFirstBit(); + } + selects += num_ranks_to_skip; + rank_begin += num_ranks_to_skip; + } + + int64_t rank_max = pop_count() - 1; + if (rank_end > rank_max + 1) { + int64_t num_ranks_to_skip = rank_end - std::max(rank_begin, rank_max + 1); + for (int64_t i = 0LL; i < num_ranks_to_skip; ++i) { + selects[rank_end - num_ranks_to_skip + i] = AfterLastBit(); + } + rank_end -= num_ranks_to_skip; + } + + // If there are no more ranks left then we are done. + // + if (rank_begin == rank_end) { + return; + } + + auto temp_vector_stack = thread_ctx.temp_vector_stack; // For TEMP_VECTOR + TEMP_VECTOR(uint16_t, ids); + int num_ids; + TEMP_VECTOR(uint64_t, temp_words); + + int64_t select_begin = Select(rank_begin); + int64_t select_end = Select(rank_end - 1) + 1; + + constexpr int64_t bits_per_word = BitVectorWithCountsBase::kBitsPerWord; + const uint64_t* words = container_->bits_.data(); + + // Split processing into mini batches, in order to use small buffers on + // the stack (and in CPU cache) for intermediate vectors. + // + BEGIN_MINI_BATCH_FOR(batch_begin, batch_length, select_end - select_begin) + + int64_t bit_begin = select_begin + batch_begin; + int64_t word_begin = bit_begin / bits_per_word; + int64_t word_end = + (select_begin + batch_begin + batch_length - 1) / bits_per_word + 1; + + // Copy words from interleaved bit vector to the temporary buffer that will + // have them in a contiguous block of memory. + // + for (int64_t word_index = word_begin; word_index < word_end; ++word_index) { + temp_words[word_index - word_begin] = words[apply_stride_and_offset(word_index)]; + } + + // Find positions of all bits set in current mini-batch of bits + // + util::bit_util::bits_to_indexes( + /*bit_to_search=*/1, thread_ctx.hardware_flags, static_cast(batch_length), + reinterpret_cast(temp_words), &num_ids, ids, + static_cast(bit_begin % bits_per_word)); + + // Output positions of bits set. + // + for (int i = 0; i < num_ids; ++i) { + selects[i] = bit_begin + ids[i]; + } + selects += num_ids; + + END_MINI_BATCH_FOR + } + + template + int64_t RankImp(int64_t bit_index) const { + const int64_t* top_counters = container_->top_counts_.data(); + const uint16_t* mid_counters = container_->mid_counts_.data(); + const uint64_t* words = container_->bits_.data(); + constexpr int64_t bits_per_block = BitVectorWithCountsBase::kBitsPerBlock; + constexpr int64_t bits_per_word = BitVectorWithCountsBase::kBitsPerWord; + uint64_t bit_mask = INCLUSIVE_RANK + ? (~0ULL >> (bits_per_word - 1 - (bit_index % bits_per_word))) + : ((1ULL << (bit_index % bits_per_word)) - 1ULL); + return top_counters[apply_stride_and_offset(bit_index / bits_per_block)] + + mid_counters[apply_stride_and_offset(bit_index / bits_per_word)] + + ARROW_POPCOUNT64(words[apply_stride_and_offset(bit_index / bits_per_word)] & + bit_mask); + } + + // Number of bits in the range [0, bit_index - 1] that are set. + // + int64_t Rank(int64_t bit_index) const { + return RankImp(bit_index); + } + + void Rank(int64_t bit_index_begin, int64_t bit_index_end, int64_t* ranks) const { + const uint64_t* words = container_->bits_.data(); + constexpr int64_t bits_per_word = BitVectorWithCountsBase::kBitsPerWord; + + int64_t rank = Rank(bit_index_begin); + uint64_t word = words[apply_stride_and_offset(bit_index_begin / bits_per_word)]; + for (int64_t bit_index = bit_index_begin; bit_index < bit_index_end; ++bit_index) { + if (bit_index % bits_per_word == 0) { + word = words[apply_stride_and_offset(bit_index / bits_per_word)]; + } + ranks[bit_index - bit_index_begin] = rank; + rank += (word >> (bit_index % bits_per_word)) & 1; + } + } + + // Number of bits in the range [0, bit_index] that are set. + // + int64_t RankNext(int64_t bit_index) const { + return RankImp(bit_index); + } + + uint64_t GetBit(int64_t bit_index) const { + constexpr int64_t bits_per_word = BitVectorWithCountsBase::kBitsPerWord; + return (GetWord(bit_index / bits_per_word) >> (bit_index % bits_per_word)) & 1ULL; + } + + uint64_t GetWord(int64_t word_index) const { + const uint64_t* words = container_->bits_.data(); + return words[apply_stride_and_offset(word_index)]; + } + + void SetBit(int64_t bit_index) { + constexpr int64_t bits_per_word = BitVectorWithCountsBase::kBitsPerWord; + int64_t word_index = bit_index / bits_per_word; + SetWord(word_index, GetWord(word_index) | (1ULL << (bit_index % bits_per_word))); + } + + void SetWord(int64_t word_index, uint64_t word_value) { + uint64_t* words = container_->bits_.data(); + words[apply_stride_and_offset(word_index)] = word_value; + } + + // Constants returned from select query when the rank is outside of the + // range of ranks represented in the bit vector. + // + int64_t BeforeFirstBit() const { return -1LL; } + + int64_t AfterLastBit() const { return bit_count(); } + + // Populate bit vector and counters marking the first position in each group + // of ties for the sequence of values. + // + template + void MarkTieBegins(int64_t length, const T* sorted) { + container_->Resize(length); + + // We start from position 1, in order to not check (i==0) condition inside + // the loop. First position always starts a new group. + // + uint64_t word = 1ULL; + for (int64_t i = 1; i < length; ++i) { + uint64_t bit_value = (sorted[i - 1] != sorted[i]) ? 1ULL : 0ULL; + word |= bit_value << (i & 63); + if ((i & 63) == 63) { + SetWord(i / 64, word); + word = 0ULL; + } + } + if (length % 64 > 0) { + SetWord(length / 64, word); + } + + // Generate population counters for the bit vector. + // + for (int64_t block_index = 0; block_index < block_count(); ++block_index) { + BuildMidCounts(block_index); + } + BuildTopCounts(0, block_count()); + } + + void DebugPrintCountersToFile(FILE* fout) const { + int64_t num_words = bit_util::CeilDiv(container_->num_bits_per_child_, + BitVectorWithCountsBase::kBitsPerWord); + int64_t num_blocks = bit_util::CeilDiv(container_->num_bits_per_child_, + BitVectorWithCountsBase::kBitsPerBlock); + fprintf(fout, "\nmid_counts: "); + for (int64_t word_index = 0; word_index < num_words; ++word_index) { + fprintf( + fout, "%d ", + static_cast(container_->mid_counts_[apply_stride_and_offset(word_index)])); + } + fprintf(fout, "\ntop_counts: "); + for (int64_t block_index = 0; block_index < num_blocks; ++block_index) { + fprintf(fout, "%d ", + static_cast( + container_->top_counts_[apply_stride_and_offset(block_index)])); + } + } + + private: + int64_t apply_stride_and_offset(int64_t index) const { + if (SINGLE_CHILD_BIT_VECTOR) { + return index; + } + int64_t stride = container_->num_children_; + int64_t offset = child_index_; + return index * stride + offset; + } + + BitVectorWithCountsBase* container_; + int64_t child_index_; +}; + +using BitVectorNavigator = BitVectorNavigatorImp; +using BitWeaverNavigator = BitVectorNavigatorImp; + +class BitVectorWithCounts : public BitVectorWithCountsBase { + public: + BitVectorNavigator GetNavigator() { + ARROW_DCHECK(num_children_ == 1); + return BitVectorNavigator(this, 0); + } + BitWeaverNavigator GetChildNavigator(int64_t child_index) { + ARROW_DCHECK(child_index >= 0 && child_index < num_children_); + return BitWeaverNavigator(this, child_index); + } +}; + +class BitMatrixWithCounts { + public: + ~BitMatrixWithCounts() { + for (size_t i = 0; i < bands_.size(); ++i) { + if (bands_[i]) { + delete bands_[i]; + } + } + } + + BitMatrixWithCounts() : band_size_(0), bit_count_(0), num_rows_allocated_(0) {} + + void Init(int band_size, int64_t bit_count) { + ARROW_DCHECK(band_size > 0 && bit_count > 0); + ARROW_DCHECK(band_size_ == 0); + band_size_ = band_size; + bit_count_ = bit_count; + num_rows_allocated_ = 0; + } + + void AddRow(int row_index) { + // Make a room in a lookup table for row with this index if needed. + // + int row_index_end = static_cast(row_navigators_.size()); + if (row_index >= row_index_end) { + row_navigators_.resize(row_index + 1); + } + + // Check if we need to allocate a new band. + // + int num_bands = static_cast(bands_.size()); + if (num_rows_allocated_ == num_bands * band_size_) { + bands_.push_back(new BitVectorWithCountsBase()); + bands_.back()->Resize(bit_count_, band_size_); + } + + // Initialize BitWeaverNavigator for that row. + // + row_navigators_[row_index] = + BitWeaverNavigator(bands_[num_rows_allocated_ / band_size_], + static_cast(num_rows_allocated_ % band_size_)); + + ++num_rows_allocated_; + } + + BitWeaverNavigator& GetMutableRow(int row_index) { return row_navigators_[row_index]; } + + const BitWeaverNavigator& GetRow(int row_index) const { + return row_navigators_[row_index]; + } + + private: + int band_size_; + int64_t bit_count_; + int num_rows_allocated_; + std::vector bands_; + std::vector row_navigators_; +}; + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/exec/window_functions/merge_tree.cc b/cpp/src/arrow/compute/exec/window_functions/merge_tree.cc new file mode 100644 index 00000000000..d8e2ca1b572 --- /dev/null +++ b/cpp/src/arrow/compute/exec/window_functions/merge_tree.cc @@ -0,0 +1,1085 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/compute/exec/window_functions/merge_tree.h" + +namespace arrow { +namespace compute { + +bool MergeTree::IsPermutation(int64_t length, const int64_t* values) { + std::vector present(length, false); + for (int64_t i = 0; i < length; ++i) { + auto value = values[i]; + if (value < 0LL || value >= length || present[value]) { + return false; + } + present[value] = true; + } + return true; +} + +int64_t MergeTree::NodeBegin(int level, int64_t pos) const { + return pos & ~((1LL << level) - 1); +} + +int64_t MergeTree::NodeEnd(int level, int64_t pos) const { + return std::min(NodeBegin(level, pos) + (static_cast(1) << level), length_); +} + +void MergeTree::CascadeBegin(int from_level, int64_t begin, int64_t* lbegin, + int64_t* rbegin) const { + ARROW_DCHECK(begin >= 0 && begin < length_); + ARROW_DCHECK(from_level >= 1); + auto& split_bits = bit_matrix_.GetRow(from_level); + auto node_begin = NodeBegin(from_level, begin); + auto node_begin_plus_whole = node_begin + (1LL << from_level); + auto node_begin_plus_half = node_begin + (1LL << (from_level - 1)); + int64_t node_popcnt = split_bits.Rank(begin) - node_begin / 2; + *rbegin = node_begin_plus_half + node_popcnt; + *lbegin = begin - node_popcnt; + *lbegin = + (*lbegin == node_begin_plus_half || *lbegin == length_) ? kEmptyRange : *lbegin; + *rbegin = + (*rbegin == node_begin_plus_whole || *rbegin == length_) ? kEmptyRange : *rbegin; +} + +void MergeTree::CascadeEnd(int from_level, int64_t end, int64_t* lend, + int64_t* rend) const { + ARROW_DCHECK(end > 0 && end <= length_); + ARROW_DCHECK(from_level >= 1); + auto& split_bits = bit_matrix_.GetRow(from_level); + auto node_begin = NodeBegin(from_level, end - 1); + auto node_begin_plus_half = node_begin + (1LL << (from_level - 1)); + int64_t node_popcnt = split_bits.RankNext(end - 1) - node_begin / 2; + *rend = node_begin_plus_half + node_popcnt; + *lend = end - node_popcnt; + *rend = (*rend == node_begin_plus_half) ? kEmptyRange : *rend; + *lend = (*lend == node_begin) ? kEmptyRange : *lend; +} + +int64_t MergeTree::CascadePos(int from_level, int64_t pos) const { + ARROW_DCHECK(pos >= 0 && pos < length_); + ARROW_DCHECK(from_level >= 1); + auto& split_bits = bit_matrix_.GetRow(from_level); + auto node_begin = NodeBegin(from_level, pos); + auto node_begin_plus_half = node_begin + (1LL << (from_level - 1)); + int64_t node_popcnt = split_bits.Rank(pos) - node_begin / 2; + return split_bits.GetBit(pos) ? node_begin_plus_half + node_popcnt : pos - node_popcnt; +} + +MergeTree::NodeSubsetType MergeTree::NodeIntersect(int level, int64_t pos, int64_t begin, + int64_t end) { + auto node_begin = NodeBegin(level, pos); + auto node_end = NodeEnd(level, pos); + return (node_begin >= begin && node_end <= end) ? NodeSubsetType::FULL + : (node_begin < end && node_end > begin) ? NodeSubsetType::PARTIAL + : NodeSubsetType::EMPTY; +} + +template +void MergeTree::SplitSubsetImp(const BitWeaverNavigator& split_bits, int source_level, + const T* source_level_vector, T* target_level_vector, + int64_t read_begin, int64_t read_end, + int64_t write_begin_bit0, int64_t write_begin_bit1, + ThreadContext& thread_ctx) { + ARROW_DCHECK(source_level >= 1); + + if (read_end == read_begin) { + return; + } + + int64_t write_begin[2]; + write_begin[0] = write_begin_bit0; + write_begin[1] = write_begin_bit1; + int64_t write_offset[2]; + write_offset[0] = write_offset[1] = 0; + int target_level = source_level - 1; + int64_t target_node_mask = (1LL << target_level) - 1LL; + if (MULTIPLE_SOURCE_NODES) { + // In case of processing multiple input nodes, + // we must align write_begin to the target level node boundary, + // so that the target node index calculation inside the main loop behaves + // correctly. + // + write_offset[0] = write_begin[0] & target_node_mask; + write_offset[1] = write_begin[1] & target_node_mask; + write_begin[0] &= ~target_node_mask; + write_begin[1] &= ~target_node_mask; + } + + uint64_t split_bits_batch[util::MiniBatch::kMiniBatchLength / 64 + 1]; + int num_ids_batch; + auto temp_vector_stack = thread_ctx.temp_vector_stack; + TEMP_VECTOR(uint16_t, ids_batch); + + // Split processing into mini batches, in order to use small buffers on + // the stack (and in CPU cache) for intermediate vectors. + // + BEGIN_MINI_BATCH_FOR(batch_begin, batch_length, read_end - read_begin) + + // Copy bit vector words related to the current batch on the stack. + // + // Bit vector words from multiple levels are interleaved in memory, that + // is why we make a copy here to form a contiguous block. + // + int64_t word_index_base = (read_begin + batch_begin) / 64; + for (int64_t word_index = word_index_base; + word_index <= (read_begin + (batch_begin + batch_length) - 1) / 64; ++word_index) { + split_bits_batch[word_index - word_index_base] = split_bits.GetWord(word_index); + } + + for (int bit = 0; bit <= 1; ++bit) { + // Convert bits to lists of bit indices for each bit value. + // + util::bit_util::bits_to_indexes( + bit, thread_ctx.hardware_flags, static_cast(batch_length), + reinterpret_cast(split_bits_batch), &num_ids_batch, ids_batch, + /*bit_offset=*/(read_begin + batch_begin) % 64); + + // For each bit index on the list, calculate position in the input array + // and position in the output array, then make a copy of the value. + // + for (int64_t i = 0; i < num_ids_batch; ++i) { + int64_t read_pos = read_begin + batch_begin + ids_batch[i]; + int64_t write_pos = write_offset[bit] + i; + if (MULTIPLE_SOURCE_NODES) { + // We may need to jump from one target node to the next in case of + // processing multiple source nodes. + // Update write position accordingly + // + write_pos = write_pos + (write_pos & ~target_node_mask); + } + write_pos += write_begin[bit]; + target_level_vector[write_pos] = source_level_vector[read_pos]; + } + + // Advance the write cursor for current bit value (bit 0 or 1). + // + write_offset[bit] += num_ids_batch; + } + + END_MINI_BATCH_FOR +} + +template +void MergeTree::SplitSubset(int source_level, const T* source_level_vector, + T* target_level_vector, int64_t read_begin, int64_t read_end, + ThreadContext& thread_ctx) { + auto& split_bits = bit_matrix_.GetRow(source_level); + int64_t source_node_length = (1LL << source_level); + bool single_node = (read_end - read_begin) <= source_node_length; + + // Calculate initial output positions for bits 0 and bits 1 respectively + // and call a helper function to do the remaining processing. + // + int64_t source_node_begin = NodeBegin(source_level, read_begin); + int64_t target_node_length = (1LL << (source_level - 1)); + int64_t write_begin[2]; + write_begin[1] = split_bits.Rank(read_begin); + write_begin[0] = read_begin - write_begin[1]; + write_begin[0] += source_node_begin / 2; + write_begin[1] += source_node_begin / 2 + target_node_length; + + if (single_node) { + // The case when the entire input subset is contained within a single + // node in the source level. + // + SplitSubsetImp(split_bits, source_level, source_level_vector, + target_level_vector, read_begin, read_end, write_begin[0], + write_begin[1], thread_ctx); + } else { + SplitSubsetImp(split_bits, source_level, source_level_vector, + target_level_vector, read_begin, read_end, write_begin[0], + write_begin[1], thread_ctx); + } +} + +void MergeTree::SetMorselLoglen(int morsel_loglen) { morsel_loglen_ = morsel_loglen; } + +uint64_t MergeTree::GetWordUnaligned(const BitWeaverNavigator& source, int64_t bit_index, + int num_bits) { + ARROW_DCHECK(num_bits > 0 && num_bits <= 64); + int64_t word_index = bit_index / 64; + int64_t word_offset = bit_index % 64; + uint64_t word = source.GetWord(word_index) >> word_offset; + if (word_offset + num_bits > 64) { + word |= source.GetWord(word_index + 1) << (64 - word_offset); + } + word &= (~0ULL >> (64 - num_bits)); + return word; +} + +void MergeTree::UpdateWord(BitWeaverNavigator& target, int64_t bit_index, int num_bits, + uint64_t bits) { + ARROW_DCHECK(num_bits > 0 && num_bits <= 64); + ARROW_DCHECK(bit_index % 64 + num_bits <= 64); + int64_t word_index = bit_index / 64; + int64_t word_offset = bit_index % 64; + uint64_t mask = (~0ULL >> (64 - num_bits)) << word_offset; + bits = ((bits << word_offset) & mask); + target.SetWord(word_index, (target.GetWord(word_index) & ~mask) | bits); +} + +void MergeTree::BitMemcpy(const BitWeaverNavigator& source, BitWeaverNavigator& target, + int64_t source_begin, int64_t source_end, + int64_t target_begin) { + int64_t num_bits = source_end - source_begin; + if (num_bits == 0) { + return; + } + + int64_t target_end = target_begin + num_bits; + int64_t target_word_begin = target_begin / 64; + int64_t target_word_end = (target_end - 1) / 64 + 1; + int64_t target_offset = target_begin % 64; + + // Process the first and the last target word. + // + if (target_word_end - target_word_begin == 1) { + // There is only one output word + // + uint64_t input = GetWordUnaligned(source, source_begin, static_cast(num_bits)); + UpdateWord(target, target_begin, static_cast(num_bits), input); + return; + } else { + // First output word + // + int num_bits_first_word = static_cast(64 - target_offset); + uint64_t input = GetWordUnaligned(source, source_begin, num_bits_first_word); + UpdateWord(target, target_begin, num_bits_first_word, input); + + // Last output word + // + int num_bits_last_word = (target_end % 64 == 0) ? 64 : (target_end % 64); + input = GetWordUnaligned(source, source_end - num_bits_last_word, num_bits_last_word); + UpdateWord(target, target_end - num_bits_last_word, num_bits_last_word, input); + } + + // Index of source word containing the last bit that needs to be copied to + // the first target word. + // + int64_t source_word_begin = + (source_begin + (target_word_begin * 64 + 63) - target_begin) / 64; + + // The case of aligned bit sequences + // + if (target_offset == (source_begin % 64)) { + for (int64_t target_word = target_word_begin + 1; target_word < target_word_end - 1; + ++target_word) { + int64_t source_word = source_word_begin + (target_word - target_word_begin); + target.SetWord(target_word, source.GetWord(source_word)); + } + return; + } + + int64_t first_unprocessed_source_bit = source_begin + (64 - target_offset); + + // Number of bits from a single input word carried from one output word to + // the next + // + int num_carry_bits = 64 - first_unprocessed_source_bit % 64; + ARROW_DCHECK(num_carry_bits > 0 && num_carry_bits < 64); + + // Carried bits + // + uint64_t carry = GetWordUnaligned(source, first_unprocessed_source_bit, num_carry_bits); + + // Process target words between the first and the last. + // + for (int64_t target_word = target_word_begin + 1; target_word < target_word_end - 1; + ++target_word) { + int64_t source_word = source_word_begin + (target_word - target_word_begin); + uint64_t input = source.GetWord(source_word); + uint64_t output = carry | (input << num_carry_bits); + target.SetWord(target_word, output); + carry = input >> (64 - num_carry_bits); + } +} + +void MergeTree::GetChildrenBoundaries(const BitWeaverNavigator& split_bits, + int64_t num_source_nodes, + int64_t* source_node_begins, + int64_t* target_node_begins) { + for (int64_t source_node_index = 0; source_node_index < num_source_nodes; + ++source_node_index) { + int64_t node_begin = source_node_begins[source_node_index]; + int64_t node_end = source_node_begins[source_node_index + 1]; + target_node_begins[2 * source_node_index + 0] = node_begin; + if (node_begin == node_end) { + target_node_begins[2 * source_node_index + 1] = node_begin; + } else { + int64_t num_bits_1 = + split_bits.RankNext(node_end - 1) - split_bits.Rank(node_begin); + int64_t num_bits_0 = (node_end - node_begin) - num_bits_1; + target_node_begins[2 * source_node_index + 1] = node_begin + num_bits_0; + } + } + int64_t num_target_nodes = 2 * num_source_nodes; + target_node_begins[num_target_nodes] = source_node_begins[num_source_nodes]; +} + +void MergeTree::BuildUpperSliceMorsel(int level_begin, int64_t* permutation_of_X, + int64_t* temp_permutation_of_X, + int64_t morsel_index, ThreadContext& thread_ctx) { + int64_t morsel_length = 1LL << morsel_loglen_; + int64_t morsel_begin = morsel_index * morsel_length; + int64_t morsel_end = std::min(length_, morsel_begin + morsel_length); + + ARROW_DCHECK((morsel_begin & (BitVectorWithCounts::kBitsPerBlock - 1)) == 0); + ARROW_DCHECK((morsel_end & (BitVectorWithCounts::kBitsPerBlock - 1)) == 0 || + morsel_end == length_); + ARROW_DCHECK(morsel_end > morsel_begin); + + int level_end = morsel_loglen_; + ARROW_DCHECK(level_begin > level_end); + + std::vector node_begins[2]; + // Begin level may have multiple nodes but the morsel is contained in + // just one. + // + node_begins[0].resize(2); + node_begins[0][0] = morsel_begin; + node_begins[0][1] = morsel_end; + + for (int level = level_begin; level > level_end; --level) { + // Setup pointers to ping-pong buffers (for permutation of X). + // + int64_t* source_Xs; + int64_t* target_Xs; + if ((level_begin - level) % 2 == 0) { + source_Xs = permutation_of_X; + target_Xs = temp_permutation_of_X; + } else { + source_Xs = temp_permutation_of_X; + target_Xs = permutation_of_X; + } + + // Fill the bit vector + // + for (int64_t word_index = morsel_begin / 64; + word_index < bit_util::CeilDiv(morsel_end, 64); ++word_index) { + uint64_t word = 0; + int num_bits = (word_index == (morsel_end / 64)) ? (morsel_end % 64) : 64; + for (int i = 0; i < num_bits; ++i) { + int64_t X = source_Xs[word_index * 64 + i]; + uint64_t bit = ((X >> (level - 1)) & 1ULL); + word |= (bit << i); + } + bit_matrix_upper_slices_.GetMutableRow(level).SetWord(word_index, word); + } + + // Fill the population counters + // + int64_t block_index_begin = + (morsel_begin >> BitVectorWithCountsBase::kLogBitsPerBlock); + int64_t block_index_end = + ((morsel_end - 1) >> BitVectorWithCountsBase::kLogBitsPerBlock) + 1; + for (int64_t block_index = block_index_begin; block_index < block_index_end; + ++block_index) { + bit_matrix_upper_slices_.GetMutableRow(level).BuildMidCounts(block_index); + } + bit_matrix_upper_slices_.GetMutableRow(level).BuildTopCounts(block_index_begin, + block_index_end); + + // Setup pointers to ping-pong buffers (for node boundaries from previous + // and current level). + // + int64_t num_source_nodes = (1LL << (level_begin - level)); + int64_t num_target_nodes = 2 * num_source_nodes; + int64_t* source_node_begins; + int64_t* target_node_begins; + if ((level_begin - level) % 2 == 0) { + source_node_begins = node_begins[0].data(); + node_begins[1].resize(num_target_nodes + 1); + target_node_begins = node_begins[1].data(); + } else { + source_node_begins = node_begins[1].data(); + node_begins[0].resize(num_target_nodes + 1); + target_node_begins = node_begins[0].data(); + } + + // Compute boundaries of the children nodes (cummulative sum of children + // sizes). + // + GetChildrenBoundaries(bit_matrix_upper_slices_.GetRow(level), num_source_nodes, + source_node_begins, target_node_begins); + + // Split vector of Xs, one parent node at a time. + // Each parent node gets split into two children nodes. + // Parent and child nodes can have arbitrary sizes, including zero. + // + for (int64_t source_node_index = 0; source_node_index < num_source_nodes; + ++source_node_index) { + SplitSubsetImp( + bit_matrix_upper_slices_.GetRow(level), level, source_Xs, target_Xs, + source_node_begins[source_node_index], + source_node_begins[source_node_index + 1], + target_node_begins[2 * source_node_index + 0], + target_node_begins[2 * source_node_index + 1], thread_ctx); + } + } +} + +void MergeTree::CombineUpperSlicesMorsel(int level_begin, int64_t output_morsel, + int64_t* input_permutation_of_X, + int64_t* output_permutation_of_X, + ThreadContext& thread_ctx) { + int level_end = morsel_loglen_; + ARROW_DCHECK(level_begin > level_end); + + int64_t morsel_length = 1LL << morsel_loglen_; + int64_t output_morsel_begin = output_morsel * morsel_length; + int64_t output_morsel_end = std::min(length_, output_morsel_begin + morsel_length); + + int64_t begin_level_node_length = (1LL << level_begin); + + // Copy bits for begin level bit vector. + // + ARROW_DCHECK(output_morsel_begin % 64 == 0); + for (int64_t word_index = output_morsel_begin / 64; + word_index <= (output_morsel_end - 1) / 64; ++word_index) { + bit_matrix_.GetMutableRow(level_begin) + .SetWord(word_index, + bit_matrix_upper_slices_.GetRow(level_begin).GetWord(word_index)); + } + + // For each node of the top level + // (every input morsel is contained in one such node): + // + for (int64_t begin_level_node = 0; + begin_level_node < bit_util::CeilDiv(length_, begin_level_node_length); + ++begin_level_node) { + int64_t begin_level_node_begin = begin_level_node * begin_level_node_length; + int64_t begin_level_node_end = + std::min(length_, begin_level_node_begin + begin_level_node_length); + + int64_t num_input_morsels = + bit_util::CeilDiv(begin_level_node_end - begin_level_node_begin, morsel_length); + + std::vector slice_node_begins[2]; + for (int64_t input_morsel = 0; input_morsel < num_input_morsels; ++input_morsel) { + slice_node_begins[0].push_back(begin_level_node_begin + + input_morsel * morsel_length); + } + slice_node_begins[0].push_back(begin_level_node_end); + + for (int level = level_begin - 1; level >= level_end; --level) { + std::vector* parent_node_begins; + std::vector* child_node_begins; + if ((level_begin - level) % 2 == 1) { + parent_node_begins = &slice_node_begins[0]; + child_node_begins = &slice_node_begins[1]; + } else { + parent_node_begins = &slice_node_begins[1]; + child_node_begins = &slice_node_begins[0]; + } + child_node_begins->resize((parent_node_begins->size() - 1) * 2 + 1); + + GetChildrenBoundaries(bit_matrix_upper_slices_.GetRow(level + 1), + static_cast(parent_node_begins->size()) - 1, + parent_node_begins->data(), child_node_begins->data()); + + // Scan all output nodes and all input nodes for each of them. + // + // Filter to the subset of input-output node pairs that cross the output + // morsel boundary. + // + int64_t num_output_nodes = (1LL << (level_begin - level)); + for (int64_t output_node = 0; output_node < num_output_nodes; ++output_node) { + int64_t output_node_length = 1LL << level; + int64_t output_begin = begin_level_node_begin + output_node * output_node_length; + for (int64_t input_morsel = 0; input_morsel < num_input_morsels; ++input_morsel) { + // Boundaries of the input node for a given input morsel and a given + // output node. + // + int64_t input_begin = + (*child_node_begins)[input_morsel * num_output_nodes + output_node]; + int64_t input_end = + (*child_node_begins)[input_morsel * num_output_nodes + output_node + 1]; + int64_t input_length = input_end - input_begin; + if (output_morsel_end > output_begin && + output_morsel_begin < output_begin + input_length) { + // Clamp the copy request to have the output range within the output + // morsel. + // + int64_t target_begin = std::max(output_morsel_begin, output_begin); + int64_t target_end = std::min(output_morsel_end, output_begin + input_length); + + if (level == level_end) { + // Reorder chunks of vector of X for level_end. + // + memcpy(output_permutation_of_X + target_begin, + input_permutation_of_X + input_begin + (target_begin - output_begin), + (target_end - target_begin) * sizeof(input_permutation_of_X[0])); + } else { + // Reorder bits in the split bit vector for all levels above + // level_end. + // + BitMemcpy(bit_matrix_upper_slices_.GetRow(level), + bit_matrix_.GetMutableRow(level), + input_begin + (target_begin - output_begin), + input_begin + (target_end - output_begin), target_begin); + } + } + + // Advance write cursor + // + output_begin += input_length; + } + } + } + } + + // Fill the mid level population counters for bit vectors. + // + // Top level population counters will get initialized in a single-threaded + // section at the end of the build process. + // + ARROW_DCHECK(output_morsel_begin % (BitVectorWithCounts::kBitsPerBlock) == 0); + int64_t block_index_begin = (output_morsel_begin / BitVectorWithCounts::kBitsPerBlock); + int64_t block_index_end = + ((output_morsel_end - 1) / BitVectorWithCounts::kBitsPerBlock) + 1; + + for (int level = level_begin; level > level_end; --level) { + for (int64_t block_index = block_index_begin; block_index < block_index_end; + ++block_index) { + bit_matrix_.GetMutableRow(level).BuildMidCounts(block_index); + } + } +} + +void MergeTree::BuildLower(int level_begin, int64_t morsel_index, + int64_t* begin_permutation_of_X, + int64_t* temp_permutation_of_X, ThreadContext& thread_ctx) { + int64_t morsel_length = 1LL << morsel_loglen_; + int64_t morsel_begin = morsel_index * morsel_length; + int64_t morsel_end = std::min(length_, morsel_begin + morsel_length); + int64_t begin_level_node_length = 1LL << level_begin; + ARROW_DCHECK(morsel_begin % begin_level_node_length == 0 && + (morsel_end % begin_level_node_length == 0 || morsel_end == length_)); + + int64_t* permutation_of_X[2]; + permutation_of_X[0] = begin_permutation_of_X; + permutation_of_X[1] = temp_permutation_of_X; + + for (int level = level_begin; level > 0; --level) { + int selector = (level_begin - level) % 2; + const int64_t* input_X = permutation_of_X[selector]; + int64_t* output_X = permutation_of_X[1 - selector]; + + // Populate bit vector for current level based on (level - 1) bits of X in + // the input vector. + // + ARROW_DCHECK(morsel_begin % 64 == 0); + uint64_t word = 0ULL; + for (int64_t i = morsel_begin; i < morsel_end; ++i) { + word |= ((input_X[i] >> (level - 1)) & 1ULL) << (i % 64); + if (i % 64 == 63) { + bit_matrix_.GetMutableRow(level).SetWord(i / 64, word); + word = 0ULL; + } + } + if (morsel_end % 64 > 0) { + bit_matrix_.GetMutableRow(level).SetWord(morsel_end / 64, word); + } + + // Fille population counters for bit vector. + // + constexpr int64_t block_size = BitVectorWithCounts::kBitsPerBlock; + int64_t block_index_begin = morsel_begin / block_size; + int64_t block_index_end = (morsel_end - 1) / block_size + 1; + for (int64_t block_index = block_index_begin; block_index < block_index_end; + ++block_index) { + bit_matrix_.GetMutableRow(level).BuildMidCounts(block_index); + } + bit_matrix_.GetMutableRow(level).BuildTopCounts(block_index_begin, block_index_end, + morsel_begin / 2); + + // Split X based on the generated bit vector. + // + SplitSubset(level, input_X, output_X, morsel_begin, morsel_end, thread_ctx); + } +} + +Status MergeTree::Build(int64_t length, int level_begin, int64_t* permutation_of_X, + ParallelForStream& parallel_fors) { + morsel_loglen_ = kMinMorselLoglen; + length_ = length; + temp_permutation_of_X_.resize(length); + + // Allocate matrix bits. + // + int upper_slices_level_end = morsel_loglen_; + int num_upper_levels = std::max(0, level_begin - upper_slices_level_end); + bit_matrix_.Init(kBitMatrixBandSize, length); + for (int level = 1; level <= level_begin; ++level) { + bit_matrix_.AddRow(level); + } + bit_matrix_upper_slices_.Init(kBitMatrixBandSize, length); + for (int level = upper_slices_level_end + 1; level <= level_begin; ++level) { + bit_matrix_upper_slices_.AddRow(level); + } + + int64_t num_morsels = bit_util::CeilDiv(length_, 1LL << morsel_loglen_); + + // Upper slices of merge tree are generated for levels for which the size of + // each node is greater than a single morsel. + // + // If there are such level, then add parallel for loops that create upper + // slices and then combine them. + // + if (num_upper_levels > 0) { + parallel_fors.InsertParallelFor( + num_morsels, + [this, level_begin, permutation_of_X](int64_t morsel_index, + ThreadContext& thread_context) -> Status { + BuildUpperSliceMorsel(level_begin, permutation_of_X, + temp_permutation_of_X_.data(), morsel_index, + thread_context); + return Status::OK(); + }); + parallel_fors.InsertParallelFor( + num_morsels, + [this, level_begin, num_upper_levels, permutation_of_X]( + int64_t morsel_index, ThreadContext& thread_context) -> Status { + CombineUpperSlicesMorsel( + level_begin, morsel_index, + (num_upper_levels % 2 == 0) ? permutation_of_X + : temp_permutation_of_X_.data(), + (num_upper_levels % 2 == 0) ? temp_permutation_of_X_.data() + : permutation_of_X, + thread_context); + return Status::OK(); + }); + } + parallel_fors.InsertParallelFor( + num_morsels, + [this, level_begin, num_upper_levels, upper_slices_level_end, permutation_of_X]( + int64_t morsel_index, ThreadContext& thread_context) -> Status { + BuildLower(std::min(level_begin, upper_slices_level_end), morsel_index, + (num_upper_levels > 0 && (num_upper_levels % 2 == 0)) + ? temp_permutation_of_X_.data() + : permutation_of_X, + (num_upper_levels > 0 && (num_upper_levels % 2 == 0)) + ? permutation_of_X + : temp_permutation_of_X_.data(), + thread_context); + return Status::OK(); + }); + parallel_fors.InsertTaskSingle( + [this, level_begin](int64_t morsel_index, ThreadContext& thread_context) -> Status { + // Fill the top level population counters for upper level bit vectors. + // + int level_end = morsel_loglen_; + int64_t num_blocks = + bit_util::CeilDiv(length_, BitVectorWithCountsBase::kBitsPerBlock); + for (int level = level_begin; level > level_end; --level) { + bit_matrix_.GetMutableRow(level).BuildTopCounts(0, num_blocks, 0); + } + + // Release the pair of temporary vectors representing permutation of + // X. + // + std::vector().swap(temp_permutation_of_X_); + + return Status::OK(); + }); + + return Status::OK(); +} + +void MergeTree::BoxQuery(const BoxQueryRequest& queries, ThreadContext& thread_ctx) { + auto temp_vector_stack = thread_ctx.temp_vector_stack; // For TEMP_VECTOR + TEMP_VECTOR(int64_t, partial_results0); + TEMP_VECTOR(int64_t, partial_results1); + TEMP_VECTOR(int64_t, y_ends_copy); + + int64_t child_cursors[5]; + child_cursors[4] = kEmptyRange; + + // Split processing into mini batches, in order to use small buffers on + // the stack (and in CPU cache) for intermediate vectors. + // + BEGIN_MINI_BATCH_FOR(batch_begin, batch_length, queries.num_queries) + + // Preserve initial state, that is the upper bound on y coordinate. + // It will be overwritten for each range of the frame, during tree traversal. + // + if (queries.num_x_ranges > 1) { + for (int64_t i = 0; i < batch_length; ++i) { + y_ends_copy[i] = queries.states[batch_begin + i].ends[0]; + } + } + + for (int x_range_index = 0; x_range_index < queries.num_x_ranges; ++x_range_index) { + const int64_t* xbegins = queries.xbegins[x_range_index]; + const int64_t* xends = queries.xends[x_range_index]; + + // Restore the initial state for ranges after the first one. + // Every range during its processing overwrites it. + // + if (x_range_index > 0) { + for (int64_t i = 0; i < batch_length; ++i) { + queries.states[batch_begin + i].ends[0] = y_ends_copy[i]; + queries.states[batch_begin + i].ends[1] = MergeTree::kEmptyRange; + } + } + + if (queries.level_begin == num_levels() - 1 && num_levels() == 1) { + // Check if the entire top level node is in X range + // + for (int i = 0; i < batch_length; ++i) { + partial_results0[i] = partial_results1[i] = kEmptyRange; + } + for (int64_t query_index = batch_begin; query_index < batch_begin + batch_length; + ++query_index) { + auto& state = queries.states[query_index]; + ARROW_DCHECK(state.ends[1] == kEmptyRange); + int64_t xbegin = xbegins[query_index]; + int64_t xend = xends[query_index]; + if (state.ends[0] != kEmptyRange) { + if (NodeIntersect(num_levels() - 1, state.ends[0] - 1, xbegin, xend) == + NodeSubsetType::FULL) { + partial_results0[query_index - batch_begin] = state.ends[0]; + } + } + } + queries.report_results_callback_(num_levels() - 1, batch_begin, + batch_begin + batch_length, partial_results0, + partial_results1, thread_ctx); + } + + for (int level = queries.level_begin; level > queries.level_end; --level) { + for (int64_t query_index = batch_begin; query_index < batch_begin + batch_length; + ++query_index) { + auto& state = queries.states[query_index]; + int64_t xbegin = xbegins[query_index]; + int64_t xend = xends[query_index]; + + // Predication: kEmptyRange is replaced with special constants, + // which are always a valid input, in order to avoid conditional + // branches. + // + // We will later correct values returned by called functions for + // kEmptyRange inputs. + // + constexpr int64_t kCascadeReplacement = static_cast(1); + constexpr int64_t kIntersectReplacement = static_cast(0); + + // Use fractional cascading to traverse one level down the tree + // + for (int i = 0; i < 2; ++i) { + CascadeEnd(level, + state.ends[i] == kEmptyRange ? kCascadeReplacement : state.ends[i], + &child_cursors[2 * i + 0], &child_cursors[2 * i + 1]); + } + + // For each child node check: + // a) if it should be rejected (outside of specified range of X), + // b) if it should be included in the reported results (fully inside + // of specified range of X). + // + int node_intersects_flags = 0; + int node_inside_flags = 0; + for (int i = 0; i < 4; ++i) { + child_cursors[i] = + state.ends[i / 2] == kEmptyRange ? kEmptyRange : child_cursors[i]; + auto intersection = + NodeIntersect(level - 1, + child_cursors[i] == kEmptyRange ? kIntersectReplacement + : child_cursors[i] - 1, + xbegin, xend); + intersection = + child_cursors[i] == kEmptyRange ? NodeSubsetType::EMPTY : intersection; + node_intersects_flags |= (intersection == NodeSubsetType::PARTIAL ? 1 : 0) << i; + node_inside_flags |= (intersection == NodeSubsetType::FULL ? 1 : 0) << i; + } + + // We shouldn't have more than two bits set in each intersection bit + // masks. + // + ARROW_DCHECK(ARROW_POPCOUNT64(node_intersects_flags) <= 2); + ARROW_DCHECK(ARROW_POPCOUNT64(node_inside_flags) <= 2); + + // Shuffle generated child node cursors based on X range + // intersection results. + // + static constexpr uint8_t kNil = 4; + uint8_t source_shuffle_index[16][2] = { + {kNil, kNil}, {0, kNil}, {1, kNil}, {0, 1}, + {2, kNil}, {0, 2}, {1, 2}, {kNil, kNil}, + {3, kNil}, {0, 3}, {1, 3}, {kNil, kNil}, + {2, 3}, {kNil, kNil}, {kNil, kNil}, {kNil, kNil}}; + state.ends[0] = child_cursors[source_shuffle_index[node_intersects_flags][0]]; + state.ends[1] = child_cursors[source_shuffle_index[node_intersects_flags][1]]; + partial_results0[query_index - batch_begin] = + child_cursors[source_shuffle_index[node_inside_flags][0]]; + partial_results1[query_index - batch_begin] = + child_cursors[source_shuffle_index[node_inside_flags][1]]; + } + + // Report partial query results. + // + queries.report_results_callback_(level - 1, batch_begin, batch_begin + batch_length, + partial_results0, partial_results1, thread_ctx); + } + } + + END_MINI_BATCH_FOR +} + +void MergeTree::BoxCountQuery(int64_t num_queries, int num_x_ranges_per_query, + const int64_t** x_begins, const int64_t** x_ends, + const int64_t* y_ends, int64_t* results, + ThreadContext& thread_context) { + // Callback function that updates the final count based on node prefixes + // representing subsets that satisfy query constraints. + // + // There is one callback call per batch per level. + // + auto callback = [results](int level, int64_t batch_begin, int64_t batch_end, + const int64_t* partial_results0, + const int64_t* partial_results1, + ThreadContext& thread_context) { + // Mask used to separate node offset from the offset within the node at + // the current level. + // + int64_t mask = (1LL << level) - 1LL; + for (int64_t query_index = batch_begin; query_index < batch_end; ++query_index) { + int64_t partial_result0 = partial_results0[query_index]; + int64_t partial_result1 = partial_results1[query_index]; + + // We may have between 0 and 2 node prefixes that satisfy query + // constraints for each query. + // + // To find out their number we need to check if each of the two reported + // indices is equal to kEmptyRange. + // + // For a valid node prefix, the index reported represents the position + // one after the last element in the node prefix. + // + if (partial_result0 != kEmptyRange) { + results[query_index] += ((partial_result0 - 1) & mask) + 1; + } + if (partial_result1 != kEmptyRange) { + results[query_index] += ((partial_result1 - 1) & mask) + 1; + } + } + }; + + auto temp_vector_stack = thread_context.temp_vector_stack; + TEMP_VECTOR(MergeTree::BoxQueryState, states); + + BEGIN_MINI_BATCH_FOR(batch_begin, batch_length, num_queries) + + // Populate BoxQueryRequest structure. + // + MergeTree::BoxQueryRequest request; + request.report_results_callback_ = callback; + request.num_queries = batch_length; + request.num_x_ranges = num_x_ranges_per_query; + for (int range_index = 0; range_index < num_x_ranges_per_query; ++range_index) { + request.xbegins[range_index] = x_begins[range_index] + batch_begin; + request.xends[range_index] = x_ends[range_index] + batch_begin; + } + request.level_begin = num_levels() - 1; + request.level_end = 0; + request.states = states; + for (int64_t i = 0; i < num_queries; ++i) { + int64_t y_end = y_ends[batch_begin + i]; + states[i].ends[0] = (y_end == 0) ? MergeTree::kEmptyRange : y_end; + states[i].ends[1] = MergeTree::kEmptyRange; + } + + BoxQuery(request, thread_context); + + END_MINI_BATCH_FOR +} + +bool MergeTree::NOutOfBounds(const NthQueryRequest& queries, int64_t query_index) { + int64_t num_elements = 0; + for (int y_range_index = 0; y_range_index < queries.num_y_ranges; ++y_range_index) { + int64_t ybegin = queries.ybegins[y_range_index][query_index]; + int64_t yend = queries.yends[y_range_index][query_index]; + num_elements += yend - ybegin; + } + int64_t N = queries.states[query_index].pos; + return N < 0 || N >= num_elements; +} + +void MergeTree::NthQuery(const NthQueryRequest& queries, ThreadContext& thread_ctx) { + ARROW_DCHECK(queries.num_y_ranges >= 1 && queries.num_y_ranges <= 3); + + auto temp_vector_stack = thread_ctx.temp_vector_stack; // For TEMP_VECTOR + TEMP_VECTOR(int64_t, pos); + TEMP_VECTOR(int64_t, ybegins0); + TEMP_VECTOR(int64_t, yends0); + TEMP_VECTOR(int64_t, ybegins1); + TEMP_VECTOR(int64_t, yends1); + TEMP_VECTOR(int64_t, ybegins2); + TEMP_VECTOR(int64_t, yends2); + int64_t* ybegins[3]; + int64_t* yends[3]; + ybegins[0] = ybegins0; + ybegins[1] = ybegins1; + ybegins[2] = ybegins2; + yends[0] = yends0; + yends[1] = yends1; + yends[2] = yends2; + + // Split processing into mini batches, in order to use small buffers on + // the stack (and in CPU cache) for intermediate vectors. + // + BEGIN_MINI_BATCH_FOR(batch_begin, batch_length, queries.num_queries) + + // Filter out queries with N out of bounds. + // + int64_t num_batch_queries = 0; + for (int64_t batch_query_index = 0; batch_query_index < batch_length; + ++batch_query_index) { + int64_t query_index = batch_begin + batch_query_index; + for (int y_range_index = 0; y_range_index < queries.num_y_ranges; ++y_range_index) { + int64_t ybegin = queries.ybegins[y_range_index][query_index]; + int64_t yend = queries.yends[y_range_index][query_index]; + // Set range boundaries to kEmptyRange for all empty ranges in + // queries. + // + ybegin = (yend == ybegin) ? kEmptyRange : ybegin; + yend = (yend == ybegin) ? kEmptyRange : yend; + ybegins[y_range_index][num_batch_queries] = ybegin; + yends[y_range_index][num_batch_queries] = yend; + } + pos[num_batch_queries] = queries.states[query_index].pos; + num_batch_queries += NOutOfBounds(queries, query_index) ? 0 : 1; + } + + for (int level = num_levels() - 1; level > 0; --level) { + // For all batch queries post filtering + // + for (int64_t batch_query_index = 0; batch_query_index < num_batch_queries; + ++batch_query_index) { + // Predication: kEmptyRange is replaced with special constants, which + // are always a valid input, in order to avoid conditional branches. + // + // We will later correct values returned by called functions for + // kEmptyRange inputs. + // + constexpr int64_t kBeginReplacement = static_cast(0); + constexpr int64_t kEndReplacement = static_cast(1); + int64_t ybegin[6]; + int64_t yend[6]; + int64_t num_elements_in_left = 0; + for (int y_range_index = 0; y_range_index < queries.num_y_ranges; ++y_range_index) { + int64_t ybegin_parent = ybegins[y_range_index][batch_query_index]; + int64_t yend_parent = yends[y_range_index][batch_query_index]; + + // Use fractional cascading to map range of elements in parent node + // to corresponding ranges of elements in two child nodes. + // + CascadeBegin(level, + ybegin_parent == kEmptyRange ? kBeginReplacement : ybegin_parent, + &ybegin[y_range_index], &ybegin[3 + y_range_index]); + CascadeEnd(level, yend_parent == kEmptyRange ? kEndReplacement : yend_parent, + ¥d[y_range_index], ¥d[y_range_index + 3]); + + // Check if any of the resulting ranges in child nodes is empty and + // update boundaries accordingly. + // + bool empty_parent = ybegin_parent == kEmptyRange || yend_parent == kEmptyRange; + for (int i = 0; i < 2; ++i) { + int child_range_index = y_range_index + 3 * i; + bool empty_range = ybegin[child_range_index] == kEmptyRange || + yend[child_range_index] == kEmptyRange || + ybegin[child_range_index] == yend[child_range_index]; + ybegin[child_range_index] = + (empty_parent || empty_range) ? kEmptyRange : ybegin[child_range_index]; + yend[child_range_index] = + (empty_parent || empty_range) ? kEmptyRange : yend[child_range_index]; + } + + // Update the number of elements in all ranges in left child. + // + num_elements_in_left += yend[y_range_index] - ybegin[y_range_index]; + } + + // Decide whether to traverse down to the left or to the right child. + // + int64_t N = pos[batch_query_index] - NodeBegin(level, pos[batch_query_index]); + int child_index = N < num_elements_in_left ? 0 : 1; + + // Update range boundaries for the selected child node. + // + for (int y_range_index = 0; y_range_index < queries.num_y_ranges; ++y_range_index) { + ybegins[y_range_index][batch_query_index] = + ybegin[y_range_index + 3 * child_index]; + yends[y_range_index][batch_query_index] = ybegin[y_range_index + 3 * child_index]; + } + + // Update node index and N for the selected child node. + // + int64_t child_node_length = 1LL << (level - 1); + pos[batch_query_index] += + child_node_length * child_index - num_elements_in_left * child_index; + } + } + + // Expand results of filtered batch queries to update the array of all + // query results and fill the remaining query results in this batch with + // kOutOfBounds constant. + // + num_batch_queries = 0; + for (int64_t batch_query_index = 0; batch_query_index < batch_length; + ++batch_query_index) { + int64_t query_index = batch_begin + batch_query_index; + int valid_query = NOutOfBounds(queries, query_index) ? 0 : 1; + queries.states[query_index].pos = valid_query ? pos[num_batch_queries] : kOutOfBounds; + num_batch_queries += valid_query; + } + + END_MINI_BATCH_FOR +} + +void MergeTree::DebugPrintToFile(const char* filename) const { + FILE* fout; +#if defined(_MSC_VER) && _MSC_VER >= 1400 + fopen_s(&fout, filename, "wt"); +#else + fout = fopen(filename, "wt"); +#endif + if (!fout) { + return; + } + + for (int level = num_levels() - 1; level > 0; --level) { + for (int64_t i = 0; i < length_; ++i) { + fprintf(fout, "%s", bit_matrix_.GetRow(level).GetBit(i) ? "1" : "0"); + } + fprintf(fout, "\n"); + } + + fprintf(fout, "\n"); + + for (int level = num_levels() - 1; level > 0; --level) { + auto bits = bit_matrix_.GetRow(level); + bits.DebugPrintCountersToFile(fout); + } + + fclose(fout); +} + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/exec/window_functions/merge_tree.h b/cpp/src/arrow/compute/exec/window_functions/merge_tree.h new file mode 100644 index 00000000000..3cec34fca5c --- /dev/null +++ b/cpp/src/arrow/compute/exec/window_functions/merge_tree.h @@ -0,0 +1,306 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include "arrow/compute/exec/util.h" +#include "arrow/compute/exec/window_functions/bit_vector_navigator.h" +#include "arrow/compute/exec/window_functions/window_frame.h" +#include "arrow/util/bit_util.h" + +namespace arrow { +namespace compute { + +// Represents a fixed set of 2D points with attributes X and Y. +// Values of each attribute across points are unique integers in the range +// [0, N - 1] for N points. +// Supports two kinds of queries: +// a) Nth element +// b) Box count / box filter +// +// Nth element query: filter points using range predicate on Y, return the nth +// smallest X within the remaining points. +// +// Box count query: filter points using range predicate on X and less than +// predicate on Y, count and return the number of remaining points. +// +class MergeTree { + public: + // Constant used in description of boundaries of the ranges of node elements + // to indicate an empty range. + // + static constexpr int64_t kEmptyRange = -1; + + // Constant returned from nth element query when the result is outside of the + // input range of elements. + // + static constexpr int64_t kOutOfBounds = -1; + + int num_levels() const { return bit_util::Log2(length_) + 1; } + + Status Build(int64_t length, int level_begin, int64_t* permutation_of_X, + ParallelForStream& parallel_fors); + + // Internal state of a single box count / box filter query preserved between + // visiting different levels of the merge tree. + // + struct BoxQueryState { + // End positions for ranges of elements sorted on Y belonging to up + // to two nodes from a single level that are active for this box query. + // + // There may be between 0 and 2 nodes represented in this state. + // If it is less than 2 we mark the remaining elements in the ends array + // with the kEmptyRange constant. + // + int64_t ends[2]; + }; + + // Input and mutable state for a series of box queries + // + struct BoxQueryRequest { + // Callback for reporting partial query results for a batch of queries and a + // single level. + // + // The arguments are: + // - tree level, + // - range of query indices (begin and end), + // - two arrays with one element per query in a batch containing two + // cursors. Each cursor represents a prefix of elements (sorted on Y) inside + // a single node from the specified level that satisfy the query. Each + // cursor can be set to kEmptyRange constant, which indicates empty result + // set. + // + using BoxQueryCallback = std::function; + BoxQueryCallback report_results_callback_; + // Number of queries + // + int64_t num_queries; + // The predicate on X can represent a union of multiple ranges, + // but all queries need to use exactly the same number of ranges. + // + int num_x_ranges; + // Range predicates on X. + // + // Since every query can use multiple ranges it is an array of arrays. + // + // Beginnings and ends of corresponding ranges are stored in separate arrays + // of arrays. + // + const int64_t* xbegins[WindowFrames::kMaxRangesInFrame]; + const int64_t* xends[WindowFrames::kMaxRangesInFrame]; + // Range of tree levels to traverse. + // + // If the range does not represent the entire tree, then only part of + // the tree will be processed, starting from the query states provided in + // the array below. The array of query states will be updated afterwards, + // allowing subsequent call to continue processing for the remaining tree + // levels. + // + int level_begin; + int level_end; + // Query state is a pair of cursors pointing to two locations in two nodes + // in a single (level_begin) level of the tree. A cursor can be seen as a + // prefix of elements (sorted on Y) that belongs to a single node. The + // number of cursors may be less than 2, in which case one or two cursors + // are set to the kEmptyRange constant. + // + // Initially the first cursor should be set to exclusive upper bound on Y + // (kEmptyRange if 0) and the second cursor to kEmptyRange. + // + // If we split query processing into multiple steps (level_end > 0), then + // the state will be updated. + // + BoxQueryState* states; + }; + + void BoxQuery(const BoxQueryRequest& queries, ThreadContext& thread_ctx); + + void BoxCountQuery(int64_t num_queries, int num_x_ranges_per_query, + const int64_t** x_begins, const int64_t** x_ends, + const int64_t* y_ends, int64_t* results, + ThreadContext& thread_context); + + // Internal state of a single nth element query preserved between visiting + // different levels of the merge tree. + struct NthQueryState { + // Position within a single node from a single level that encodes: + // - the node from which the search will continue, + // - the relative position of the output X within the sorted sequence of X + // of points associated with this node. + int64_t pos; + }; + + // Input and mutable state for a series of nth element queries + // + struct NthQueryRequest { + int64_t num_queries; + // Range predicates on Y. + // + // Since every query can use multiple ranges it is an array of arrays. + // + // Beginnings and ends of corresponding ranges are stored in separate arrays + // of arrays. + // + int num_y_ranges; + const int64_t** ybegins; + const int64_t** yends; + // State encodes a node (all states will point to nodes from the same level) + // and the N for the Nth element we are looking for. + // + // When the query starts it is set directly to N in the query (N part is the + // input and node part is zero). + // + // When the query finishes it is set to the query result - a value of X that + // is Nth in the given range of Y (node part is the result and N part is + // zero). + // + NthQueryState* states; + }; + + void NthQuery(const NthQueryRequest& queries, ThreadContext& thread_ctx); + + private: + // Return true if the given array of N elements contains a permutation of + // integers from [0, N - 1] range. + // + bool IsPermutation(int64_t length, const int64_t* values); + + // Find the beginning (index in the split bit vector) of the merge tree node + // for a given position within the range of bits for that node. + // + inline int64_t NodeBegin(int level, int64_t pos) const; + + // Find the end (index one after the last) of the merge tree node given a + // position within its range. + // + // All nodes of the level have (1 << level) elements except for the last that + // can be truncated. + // + inline int64_t NodeEnd(int level, int64_t pos) const; + + // Use split bit vector and bit vector navigator to map beginning of a + // range of Y from a parent node to both child nodes. + // + // If the child range is empty return kEmptyRange for it. + // + inline void CascadeBegin(int from_level, int64_t begin, int64_t* lbegin, + int64_t* rbegin) const; + + // Same as CascadeBegin but for the end (one after the last element) of the + // range. + // + // The difference is that end offset within the node can have values in + // [1; S] range, where S is the size of the node, while the beginning offset + // is in [0; S - 1]. + // + inline void CascadeEnd(int from_level, int64_t end, int64_t* lend, int64_t* rend) const; + + // Fractional cascading for a single element of a parent node. + // + inline int64_t CascadePos(int from_level, int64_t pos) const; + + enum class NodeSubsetType { EMPTY, PARTIAL, FULL }; + + // Check whether the intersection with respect to X axis of the range + // represented by the node and a given range is: a) empty, b) full node, c) + // partial node. + // + inline NodeSubsetType NodeIntersect(int level, int64_t pos, int64_t begin, int64_t end); + + // Split a subset of elements from the source level. + // + // When MULTIPLE_SOURCE_NODES == false, + // then the subset must be contained in a single source node (it can also + // represent the entire source node). + // + template + void SplitSubsetImp(const BitWeaverNavigator& split_bits, int source_level, + const T* source_level_vector, T* target_level_vector, + int64_t read_begin, int64_t read_end, int64_t write_begin_bit0, + int64_t write_begin_bit1, ThreadContext& thread_ctx); + + // Split a subset of elements from the source level. + // + template + void SplitSubset(int source_level, const T* source_level_vector, T* target_level_vector, + int64_t read_begin, int64_t read_end, ThreadContext& thread_ctx); + + void SetMorselLoglen(int morsel_loglen); + + // Load up to 64 bits from interleaved bit vector starting at an arbitrary bit + // index. + // + inline uint64_t GetWordUnaligned(const BitWeaverNavigator& source, int64_t bit_index, + int num_bits = 64); + + // Set a subsequence of bits within a single word inside an interleaved bit + // vector. + // + inline void UpdateWord(BitWeaverNavigator& target, int64_t bit_index, int num_bits, + uint64_t bits); + + // Copy bits while reading and writing aligned 64-bit words only. + // + // Input and output bit vectors may be logical bit vectors inside a + // collection of interleaved bit vectors of the same length (accessed + // using BitWeaverNavigator). + // + void BitMemcpy(const BitWeaverNavigator& source, BitWeaverNavigator& target, + int64_t source_begin, int64_t source_end, int64_t target_begin); + + void GetChildrenBoundaries(const BitWeaverNavigator& split_bits, + int64_t num_source_nodes, int64_t* source_node_begins, + int64_t* target_node_begins); + + void BuildUpperSliceMorsel(int level_begin, int64_t* permutation_of_X, + int64_t* temp_permutation_of_X, int64_t morsel_index, + ThreadContext& thread_ctx); + + void CombineUpperSlicesMorsel(int level_begin, int64_t output_morsel, + int64_t* input_permutation_of_X, + int64_t* output_permutation_of_X, + ThreadContext& thread_ctx); + + void BuildLower(int level_begin, int64_t morsel_index, int64_t* begin_permutation_of_X, + int64_t* temp_permutation_of_X, ThreadContext& thread_ctx); + + bool NOutOfBounds(const NthQueryRequest& queries, int64_t query_index); + + void DebugPrintToFile(const char* filename) const; + + static constexpr int kBitMatrixBandSize = 4; + static constexpr int kMinMorselLoglen = BitVectorWithCounts::kLogBitsPerBlock; + + int morsel_loglen_; + int64_t length_; + + BitMatrixWithCounts bit_matrix_; + BitMatrixWithCounts bit_matrix_upper_slices_; + + // Temp buffer used while building the tree for double buffering of the + // permutation of X (buffer for upper level is used to generate buffer for + // lower level, then we traverse down and swap the buffers). + // The other buffer is provided by the caller of the build method. + // + std::vector temp_permutation_of_X_; +}; + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/exec/window_functions/window_frame.h b/cpp/src/arrow/compute/exec/window_functions/window_frame.h new file mode 100644 index 00000000000..6752c005787 --- /dev/null +++ b/cpp/src/arrow/compute/exec/window_functions/window_frame.h @@ -0,0 +1,110 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include "arrow/compute/exec/util.h" + +namespace arrow { +namespace compute { + +// A collection of window frames for a sequence of rows in the window frame sort +// order. +// +struct WindowFrames { + // Every frame is associated with a single row. + // + // This is the index of the first row (in the window frame sort order) for the + // first frame. + // + int64_t first_row_index; + + // Number of frames in this collection + // + int64_t num_frames; + + // Maximum number of ranges that make up each single frame. + // + static constexpr int kMaxRangesInFrame = 3; + + // Number of ranges that make up each single frame. + // Every frame will have exactly that many ranges, but any number of these + // ranges can be empty. + // + int num_ranges_per_frame; + + // Range can be empty, in that case begin == end. Otherwise begin < end. + // + // Ranges in a single frame must be disjoint but beginning of next range can + // be equal to the end of the previous one. + // + // Beginning of each next range must be greater or equal to the end of the + // previous range. + // + const int64_t* begins[kMaxRangesInFrame]; + const int64_t* ends[kMaxRangesInFrame]; + + // Check if a collection of frames represents sliding frames, + // that is for every boundary (left and right) of every range, the values + // across all frames are non-decreasing. + // + bool IsSliding() const { + for (int64_t i = 1; i < num_frames; ++i) { + if (!(begins[i] >= begins[i - 1] && ends[i] >= ends[i - 1])) { + return false; + } + } + return true; + } + + // Check if a collection of frames represent cumulative frames, + // that is for every range, two adjacent frames either share the same + // beginning with end of the later one being no lesser than the end of the + // previous one, or the later one begins at or after the end of the previous + // one. + // + bool IsCummulative() const { + for (int64_t i = 1; i < num_frames; ++i) { + if (!((begins[i] >= ends[i - 1] || begins[i] == begins[i - 1]) && + (ends[i] >= ends[i - 1]))) { + return false; + } + } + return true; + } + + // Check if the row for which the frame is defined is included in any of the + // ranges defining that frame. + // + bool IsRowInsideItsFrame(int64_t frame_index) const { + bool is_inside = false; + int64_t row_index = first_row_index + frame_index; + for (int64_t range_index = 0; range_index < num_ranges_per_frame; ++range_index) { + int64_t range_begin = begins[range_index][frame_index]; + int64_t range_end = ends[range_index][frame_index]; + is_inside = is_inside || (row_index >= range_begin && row_index < range_end); + } + return is_inside; + } +}; + +enum class WindowFrameSequenceType { CUMMULATIVE, SLIDING, GENERIC }; + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/exec/window_functions/window_rank.cc b/cpp/src/arrow/compute/exec/window_functions/window_rank.cc new file mode 100644 index 00000000000..a8601f5644f --- /dev/null +++ b/cpp/src/arrow/compute/exec/window_functions/window_rank.cc @@ -0,0 +1,456 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/compute/exec/window_functions/window_rank.h" + +namespace arrow { +namespace compute { + +void WindowRank_Global::Eval(RankType rank_type, const BitVectorNavigator& tie_begins, + int64_t batch_begin, int64_t batch_end, int64_t* results) { + int64_t num_rows = tie_begins.bit_count(); + + if (rank_type == RankType::ROW_NUMBER) { + std::iota(results, results + batch_end - batch_begin, 1LL + batch_begin); + return; + } + + if (rank_type == RankType::DENSE_RANK) { + for (int64_t i = batch_begin; i < batch_end; ++i) { + results[i - batch_begin] = tie_begins.RankNext(i); + } + return; + } + + if (rank_type == RankType::RANK_TIES_LOW) { + int64_t rank = tie_begins.Select(tie_begins.RankNext(batch_begin) - 1); + ARROW_DCHECK( + tie_begins.RankNext(rank) == tie_begins.RankNext(batch_begin) && + (rank == 0 || tie_begins.RankNext(rank - 1) < tie_begins.RankNext(rank))); + rank += 1; + for (int64_t i = batch_begin; i < batch_end; ++i) { + rank = (tie_begins.GetBit(i) != 0) ? i + 1 : rank; + results[i - batch_begin] = rank; + } + return; + } + + if (rank_type == RankType::RANK_TIES_HIGH) { + int64_t rank_max = tie_begins.pop_count(); + int64_t rank_last = tie_begins.RankNext(batch_end - 1); + int64_t rank = (rank_last == rank_max) ? num_rows : tie_begins.Select(rank_last); + for (int64_t i = batch_end - 1; i >= batch_begin; --i) { + results[i - batch_begin] = rank; + rank = (tie_begins.GetBit(i) != 0) ? i : rank; + } + return; + } +} + +void WindowRank_Framed1D::Eval(RankType rank_type, const BitVectorNavigator& tie_begins, + const WindowFrames& frames, int64_t* results) { + if (rank_type == RankType::RANK_TIES_LOW) { + // We will compute global rank into the same array as the one provided for + // the output (to avoid allocating another array). + // + // When computing rank for a given row we will only read the result + // computed for that row (no access to other rows) and update the same + // result array entry. + // + int64_t* global_ranks = results; + WindowRank_Global::Eval(RankType::RANK_TIES_LOW, tie_begins, frames.first_row_index, + frames.first_row_index + frames.num_frames, global_ranks); + + // The rank is 1 + the number of rows with key strictly lower than the + // current row's key. + // + for (int64_t frame_index = 0; frame_index < frames.num_frames; ++frame_index) { + // If the frame does not contain current row it is still logically + // considered as included in the frame (e.g. empty frame will yield rank + // 1 since the set we look at consists of a single row, the current + // row). + // + int64_t rank = 1; + for (int range_index = 0; range_index < frames.num_ranges_per_frame; + ++range_index) { + int64_t global_rank = global_ranks[frame_index]; + int64_t range_begin = frames.begins[range_index][frame_index]; + int64_t range_end = frames.ends[range_index][frame_index]; + + // The formula below takes care of the cases: + // a) current row outside of the range to the left, + // b) current row in the range and ties with the first row in the + // range, + // c) current row in the range and no tie with the first row in + // the range, + // d) current row outside of the range to the right and + // ties with the last row in the range. + // e) current row outside of the range to the right and does no tie + // with the last row in the range. + // f) empty frame range, + // + rank += std::max(static_cast(0), + std::min(global_rank, range_end + 1) - range_begin - 1); + } + + results[frame_index] = rank; + } + } + + if (rank_type == RankType::RANK_TIES_HIGH) { + // To compute TIES_HIGH variant, we can reverse boundaries, + // global ranks by substracting their values from num_rows + // and num_rows + 1 respectively, and we will get the same problem as + // TIES_LOW, the result of which we can convert back using the same + // method but this time using number of rows inside the frame instead of + // global number of rows. + // + // That is how the formula used below was derived. + // + // Note that the number of rows considered to be in the frame depends + // whether the current row is inside or outside of the ranges defining its + // frame, because in the second case we need to add 1 to the total size of + // ranges. + // + int64_t* global_ranks = results; + WindowRank_Global::Eval(RankType::RANK_TIES_HIGH, tie_begins, frames.first_row_index, + frames.first_row_index + frames.num_frames, global_ranks); + + for (int64_t frame_index = 0; frame_index < frames.num_frames; ++frame_index) { + int64_t rank = 0; + + for (int range_index = 0; range_index < frames.num_ranges_per_frame; + ++range_index) { + int64_t global_rank = global_ranks[frame_index]; + int64_t range_begin = frames.begins[range_index][frame_index]; + int64_t range_end = frames.ends[range_index][frame_index]; + + rank += std::min(range_end, std::max(global_rank, range_begin)) - range_begin; + } + + rank += frames.IsRowInsideItsFrame(frame_index) ? 0 : 1; + + results[frame_index] = rank; + } + } + + if (rank_type == RankType::ROW_NUMBER) { + // Count rows inside the frame coming before the current row and add 1. + // + for (int64_t frame_index = 0; frame_index < frames.num_frames; ++frame_index) { + int64_t row_index = frames.first_row_index + frame_index; + int64_t rank = 1; + for (int range_index = 0; range_index < frames.num_ranges_per_frame; + ++range_index) { + int64_t range_begin = frames.begins[range_index][frame_index]; + int64_t range_end = frames.ends[range_index][frame_index]; + + rank += std::max(static_cast(0), + std::min(row_index, range_end) - range_begin); + } + + results[frame_index] = rank; + } + } + + if (rank_type == RankType::DENSE_RANK) { + for (int64_t frame_index = 0; frame_index < frames.num_frames; ++frame_index) { + int64_t row_index = frames.first_row_index + frame_index; + int64_t rank = 1; + + // gdr = global dense rank + // + // Note that computing global dense rank corresponds to calling + // tie_begin.RankNext(). + // + int64_t highest_gdr_seen = 0; + int64_t gdr = tie_begins.RankNext(row_index); + + for (int range_index = 0; range_index < frames.num_ranges_per_frame; + ++range_index) { + int64_t range_begin = frames.begins[range_index][frame_index]; + int64_t range_end = frames.ends[range_index][frame_index]; + + if (row_index < range_begin || range_end == range_begin) { + // Empty frame and frame starting after the current row - nothing to + // do. + // + } else { + // Count how many NEW peer groups before the current row's peer + // group are introduced by each range. + // + // Take into account when the last row of the previous range is in + // the same peer group as the first row of the next range. + // + int64_t gdr_first = tie_begins.RankNext(range_begin); + int64_t gdr_last = tie_begins.RankNext(range_end - 1); + int64_t new_peer_groups = std::max( + static_cast(0), std::min(gdr_last, gdr - 1) - + std::max(highest_gdr_seen + 1, gdr_first) + 1); + rank += new_peer_groups; + highest_gdr_seen = gdr_last; + } + } + + results[frame_index] = rank; + } + } +} + +Status WindowRank_Framed2D::Eval(RankType rank_type, + const BitVectorNavigator& rank_key_tie_begins, + const int64_t* order_by_rank_key, + const WindowFrames& frames, int64_t* results, + ThreadContext& thread_context) { + int64_t num_rows = rank_key_tie_begins.bit_count(); + + if (rank_type == RankType::DENSE_RANK) { + if (frames.IsSliding()) { + return DenseRankWithSplayTree(); + } else { + return DenseRankWithRangeTree(); + } + } + + ARROW_DCHECK(rank_type == RankType::ROW_NUMBER || + rank_type == RankType::RANK_TIES_LOW || + rank_type == RankType::RANK_TIES_HIGH); + + ParallelForStream exec_plan; + + // Build merge tree + // + MergeTree merge_tree; + std::vector order_by_rank_key_copy(num_rows); + memcpy(order_by_rank_key_copy.data(), order_by_rank_key, + num_rows * sizeof(order_by_rank_key[0])); + RETURN_NOT_OK(merge_tree.Build(num_rows, + /*level_begin=*/bit_util::Log2(num_rows), + order_by_rank_key_copy.data(), exec_plan)); + RETURN_NOT_OK(exec_plan.RunOnSingleThread(thread_context)); + + // For each row compute the number of rows with the lower rank (lower or + // equal in case of ties high). + // + // This will be used as an upper bound on rank attribute when querying + // merge tree. + // + std::vector y_ends; + std::swap(order_by_rank_key_copy, y_ends); + auto temp_vector_stack = thread_context.temp_vector_stack; + { + TEMP_VECTOR(int64_t, global_ranks); + BEGIN_MINI_BATCH_FOR(batch_begin, batch_length, num_rows) + WindowRank_Global::Eval(rank_type, rank_key_tie_begins, batch_begin, + batch_begin + batch_length, global_ranks); + if (rank_type == RankType::RANK_TIES_LOW || rank_type == RankType::ROW_NUMBER) { + for (int64_t i = 0; i < batch_length; ++i) { + --global_ranks[i]; + } + } + for (int64_t i = 0; i < batch_length; ++i) { + int64_t row_index = order_by_rank_key[batch_begin + i]; + y_ends[row_index] = global_ranks[i]; + } + END_MINI_BATCH_FOR + } + + BEGIN_MINI_BATCH_FOR(batch_begin, batch_length, frames.num_frames) + + // Execute box count queries one batch of frames at a time. + // + const int64_t* x_begins_batch[WindowFrames::kMaxRangesInFrame]; + const int64_t* x_ends_batch[WindowFrames::kMaxRangesInFrame]; + for (int64_t range_index = 0; range_index < frames.num_ranges_per_frame; + ++range_index) { + x_begins_batch[range_index] = frames.begins[range_index] + batch_begin; + x_ends_batch[range_index] = frames.ends[range_index] + batch_begin; + } + const int64_t* y_ends_batch = y_ends.data() + frames.first_row_index + batch_begin; + int64_t* results_batch = results + batch_begin; + merge_tree.BoxCountQuery(batch_length, frames.num_ranges_per_frame, x_begins_batch, + x_ends_batch, y_ends_batch, results_batch, thread_context); + + if (rank_type == RankType::RANK_TIES_LOW || rank_type == RankType::ROW_NUMBER) { + // For TIES_LOW and ROW_NUMBER we need to add 1 to the output of box count + // query to get the rank. + // + for (int64_t i = 0; i < batch_length; ++i) { + ++results_batch[i]; + } + } else { + // For TIES_HIGH we need to add 1 to the output only + // when the current row is outside of all the ranges defining its frame. + // + for (int64_t i = 0; i < batch_length; ++i) { + results_batch[i] += frames.IsRowInsideItsFrame(batch_begin + i) ? 0 : 1; + } + } + + END_MINI_BATCH_FOR + + return Status::OK(); +} + +void WindowRank_Global_Ref::Eval(RankType rank_type, const BitVectorNavigator& tie_begins, + int64_t* results) { + int64_t num_rows = tie_begins.bit_count(); + const uint8_t* bit_vector = tie_begins.GetBytes(); + + std::vector peer_group_offsets; + for (int64_t i = 0; i < num_rows; ++i) { + if (bit_util::GetBit(bit_vector, i)) { + peer_group_offsets.push_back(i); + } + } + int64_t num_peer_groups = static_cast(peer_group_offsets.size()); + peer_group_offsets.push_back(num_rows); + + for (int64_t peer_group = 0; peer_group < num_peer_groups; ++peer_group) { + int64_t peer_group_begin = peer_group_offsets[peer_group]; + int64_t peer_group_end = peer_group_offsets[peer_group + 1]; + for (int64_t i = peer_group_begin; i < peer_group_end; ++i) { + int64_t row_index = i; + int64_t rank; + switch (rank_type) { + case RankType::ROW_NUMBER: + rank = row_index + 1; + break; + case RankType::RANK_TIES_LOW: + rank = peer_group_begin + 1; + break; + case RankType::RANK_TIES_HIGH: + rank = peer_group_end; + break; + case RankType::DENSE_RANK: + rank = peer_group + 1; + break; + } + results[row_index] = rank; + } + } +} + +void WindowRank_Framed_Ref::Eval(RankType rank_type, + const BitVectorNavigator& rank_key_tie_begins, + const int64_t* order_by_rank_key, + const WindowFrames& frames, int64_t* results) { + int64_t num_rows = rank_key_tie_begins.bit_count(); + + std::vector global_ranks_order_by_rank_key(num_rows); + WindowRank_Global_Ref::Eval(rank_type, rank_key_tie_begins, + global_ranks_order_by_rank_key.data()); + + std::vector global_ranks(num_rows); + if (!order_by_rank_key) { + for (int64_t i = 0; i < num_rows; ++i) { + global_ranks[i] = global_ranks_order_by_rank_key[i]; + } + } else { + for (int64_t i = 0; i < num_rows; ++i) { + global_ranks[order_by_rank_key[i]] = global_ranks_order_by_rank_key[i]; + } + } + + for (int64_t frame_index = 0; frame_index < frames.num_frames; ++frame_index) { + int64_t current_row_index = frames.first_row_index + frame_index; + + // Compute list of global ranks for all rows within the frame. + // + // Make sure to include the current row in the frame, even if it lies + // outside of the ranges defining its. + // + std::vector global_ranks_within_frame; + bool current_row_included = false; + for (int64_t range_index = 0; range_index < frames.num_ranges_per_frame; + ++range_index) { + int64_t begin = frames.begins[range_index][frame_index]; + int64_t end = frames.ends[range_index][frame_index]; + if (!current_row_included && current_row_index < begin) { + global_ranks_within_frame.push_back(global_ranks[current_row_index]); + current_row_included = true; + } + for (int64_t row_index = begin; row_index < end; ++row_index) { + if (row_index == current_row_index) { + current_row_included = true; + } + global_ranks_within_frame.push_back(global_ranks[row_index]); + } + } + if (!current_row_included) { + global_ranks_within_frame.push_back(global_ranks[current_row_index]); + current_row_included = true; + } + + int64_t rank = 0; + for (int64_t frame_row_index = 0; + frame_row_index < static_cast(global_ranks_within_frame.size()); + ++frame_row_index) { + switch (rank_type) { + case RankType::ROW_NUMBER: + // Count the number of rows in the frame with lower global rank. + // + if (global_ranks_within_frame[frame_row_index] < + global_ranks[current_row_index]) { + ++rank; + } + break; + case RankType::RANK_TIES_LOW: + // Count the number of rows in the frame with lower global rank. + // + if (global_ranks_within_frame[frame_row_index] < + global_ranks[current_row_index]) { + ++rank; + } + break; + case RankType::RANK_TIES_HIGH: + // Count the number of rows in the frame with lower or equal global + // rank. + // + if (global_ranks_within_frame[frame_row_index] <= + global_ranks[current_row_index]) { + ++rank; + } + break; + case RankType::DENSE_RANK: + // Count the number of rows in the frame with lower global rank that + // have global rank different than the previous row. + // + bool global_rank_changed = + (frame_row_index == 0) || (global_ranks_within_frame[frame_row_index] != + global_ranks_within_frame[frame_row_index - 1]); + if (global_ranks_within_frame[frame_row_index] < + global_ranks[current_row_index] && + global_rank_changed) { + ++rank; + } + break; + } + } + // For all rank types except for RANK_TIES_HIGH increment obtained rank + // value by 1. + // + if (rank_type != RankType::RANK_TIES_HIGH) { + ++rank; + } + + results[frame_index] = rank; + } +} + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/exec/window_functions/window_rank.h b/cpp/src/arrow/compute/exec/window_functions/window_rank.h new file mode 100644 index 00000000000..a640dc3f516 --- /dev/null +++ b/cpp/src/arrow/compute/exec/window_functions/window_rank.h @@ -0,0 +1,94 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include // for std::iota +#include "arrow/compute/exec/util.h" +#include "arrow/compute/exec/window_functions/bit_vector_navigator.h" +#include "arrow/compute/exec/window_functions/merge_tree.h" +#include "arrow/compute/exec/window_functions/window_frame.h" + +// TODO: Add support for CUME_DIST and NTILE +// TODO: Add support for rank with limit +// TODO: Add support for rank with global filter + +namespace arrow { +namespace compute { + +enum class RankType : int { + ROW_NUMBER = 0, + RANK_TIES_LOW = 1, + RANK_TIES_HIGH = 2, + DENSE_RANK = 3 +}; + +class ARROW_EXPORT WindowRank_Global { + public: + static void Eval(RankType rank_type, const BitVectorNavigator& tie_begins, + int64_t batch_begin, int64_t batch_end, int64_t* results); +}; + +class ARROW_EXPORT WindowRank_Framed1D { + public: + static void Eval(RankType rank_type, const BitVectorNavigator& tie_begins, + const WindowFrames& frames, int64_t* results); +}; + +class ARROW_EXPORT WindowRank_Framed2D { + public: + static Status Eval(RankType rank_type, const BitVectorNavigator& rank_key_tie_begins, + const int64_t* order_by_rank_key, const WindowFrames& frames, + int64_t* results, ThreadContext& thread_context); + + private: + static Status DenseRankWithRangeTree() { + // TODO: Implement + ARROW_DCHECK(false); + return Status::OK(); + } + static Status DenseRankWithSplayTree() { + // TODO: Implement + ARROW_DCHECK(false); + return Status::OK(); + } +}; + +// Reference implementations used for testing. +// +// May also be useful for understanding the expected behaviour of the actual +// implementations, which trade simplicity for efficiency. +// +class ARROW_EXPORT WindowRank_Global_Ref { + public: + static void Eval(RankType rank_type, const BitVectorNavigator& tie_begins, + int64_t* results); +}; + +class ARROW_EXPORT WindowRank_Framed_Ref { + public: + // For 1D variant use null pointer for the permutation of rows ordered by + // ranking key. That will assume that the permutation is an identity mapping. + // + static void Eval(RankType rank_type, const BitVectorNavigator& rank_key_tie_begins, + const int64_t* order_by_rank_key, const WindowFrames& frames, + int64_t* results); +}; + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/exec/window_functions/window_test.cc b/cpp/src/arrow/compute/exec/window_functions/window_test.cc new file mode 100644 index 00000000000..4398ca03a91 --- /dev/null +++ b/cpp/src/arrow/compute/exec/window_functions/window_test.cc @@ -0,0 +1,453 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include + +#include +#include "arrow/compute/exec/test_util.h" +#include "arrow/compute/exec/util.h" +#include "arrow/compute/exec/window_functions/window_rank.h" + +namespace arrow { +namespace compute { + +class WindowFramesRandom { + public: + static void Generate(Random64Bit& rand, WindowFrameSequenceType frame_sequence_type, + int64_t num_rows, int num_ranges_per_frame, + std::vector>* range_boundaries); + + static void GenerateSliding(Random64Bit& rand, int64_t num_rows, + int num_ranges_per_frame, + std::vector>* range_boundaries, + int64_t suggested_frame_span, int64_t suggested_gap_length); + + static void GenerateCummulative(Random64Bit& rand, int64_t num_rows, + int num_ranges_per_frame, + std::vector>* range_boundaries, + int num_restarts); + + static void GenerateGeneric(Random64Bit& rand, int64_t num_rows, + int num_ranges_per_frame, + std::vector>* range_boundaries, + int64_t max_frame_span, int64_t max_gap_length); + + private: + static void CutHoles(Random64Bit& rand, int64_t frame_span, int64_t num_holes, + int64_t sum_hole_size, std::vector& result_boundaries); +}; + +void WindowFramesRandom::Generate(Random64Bit& rand, + WindowFrameSequenceType frame_sequence_type, + int64_t num_rows, int num_ranges_per_frame, + std::vector>* range_boundaries) { + switch (frame_sequence_type) { + case WindowFrameSequenceType::CUMMULATIVE: + GenerateCummulative(rand, num_rows, num_ranges_per_frame, range_boundaries, + /*num_restarts=*/rand.from_range(0, 2)); + break; + case WindowFrameSequenceType::SLIDING: { + int64_t suggested_frame_span = + rand.from_range(static_cast(0), num_rows / 4); + int64_t suggested_gap_length = + rand.from_range(static_cast(0), suggested_frame_span / 2); + GenerateSliding(rand, num_rows, num_ranges_per_frame, range_boundaries, + suggested_frame_span, suggested_gap_length); + } break; + case WindowFrameSequenceType::GENERIC: { + int64_t max_frame_span = rand.from_range(static_cast(0), num_rows / 4); + int64_t max_gap_length = + rand.from_range(static_cast(0), max_frame_span / 2); + GenerateGeneric(rand, num_rows, num_ranges_per_frame, range_boundaries, + max_frame_span, max_gap_length); + } break; + } +} + +void WindowFramesRandom::GenerateSliding( + Random64Bit& rand, int64_t num_rows, int num_ranges_per_frame, + std::vector>* range_boundaries, int64_t suggested_frame_span, + int64_t suggested_gap_length) { + if (num_rows == 0) { + return; + } + + // Generate a sorted list of points that will serve as frame boundaries (for + // all ranges in all frames). + // + std::vector boundaries(num_rows + suggested_frame_span); + for (size_t i = 0; i < boundaries.size(); ++i) { + boundaries[i] = rand.from_range(static_cast(0), num_rows); + } + std::sort(boundaries.begin(), boundaries.end()); + + // Generate desired first frame (relative positions and sizes of ranges in + // it). + // + // This will serve as a basis for distances between range boundary points. + // + std::vector desired_boundaries; + CutHoles(rand, suggested_frame_span, num_ranges_per_frame - 1, suggested_gap_length, + desired_boundaries); + + // Assign boundary points from the sorted random vector at predetermined + // distances from each other to consecutive frames. + // + range_boundaries->resize(num_ranges_per_frame * 2); + for (size_t i = 0; i < range_boundaries->size(); ++i) { + (*range_boundaries)[i].clear(); + } + for (int64_t i = 0; i < num_rows; ++i) { + for (int boundary_index = 0; boundary_index < 2 * num_ranges_per_frame; + ++boundary_index) { + (*range_boundaries)[boundary_index].push_back( + boundaries[i + desired_boundaries[boundary_index]]); + } + } +} + +void WindowFramesRandom::GenerateCummulative( + Random64Bit& rand, int64_t num_rows, int num_ranges_per_frame, + std::vector>* range_boundaries, int num_restarts) { + int num_boundaries_per_frame = 2 * num_ranges_per_frame; + range_boundaries->resize(num_boundaries_per_frame); + for (int64_t i = 0; i < num_boundaries_per_frame; ++i) { + (*range_boundaries)[i].clear(); + } + + // Divide rows into sections, each dedicated to a different range. + // + std::vector sections; + sections.push_back(0); + sections.push_back(num_rows); + for (int i = 0; i < num_ranges_per_frame - 1; ++i) { + sections.push_back(rand.from_range(static_cast(0), num_rows)); + } + std::sort(sections.begin(), sections.end()); + + // Process each section (range) separately. + // + for (int range_index = 0; range_index < num_ranges_per_frame; ++range_index) { + std::vector boundaries(num_rows + num_restarts + 1); + for (int64_t i = 0; i < num_rows + num_restarts + 1; ++i) { + boundaries[i] = rand.from_range(sections[range_index], sections[range_index + 1]); + } + std::sort(boundaries.begin(), boundaries.end()); + + // Mark restart points in the boundaries vector. + // + std::vector boundary_is_restart_point(boundaries.size()); + for (int64_t i = 0; i < num_rows + num_restarts + 1; ++i) { + boundary_is_restart_point[i] = false; + } + boundary_is_restart_point[0] = true; + for (int i = 0; i < num_restarts; ++i) { + for (;;) { + int64_t pos = + rand.from_range(static_cast(0), num_rows + num_restarts - 1); + if (!boundary_is_restart_point[pos]) { + boundary_is_restart_point[pos] = true; + break; + } + } + } + + // Output results for next range. + // + int64_t current_begin = 0; + for (int64_t i = 0; i < num_rows + num_restarts + 1; ++i) { + if (boundary_is_restart_point[i]) { + current_begin = boundaries[i]; + } else { + (*range_boundaries)[2 * range_index + 0].push_back(current_begin); + (*range_boundaries)[2 * range_index + 1].push_back(boundaries[i]); + } + } + } +} + +void WindowFramesRandom::GenerateGeneric( + Random64Bit& rand, int64_t num_rows, int num_ranges_per_frame, + std::vector>* range_boundaries, int64_t max_frame_span, + int64_t max_gap_length) { + int num_boundaries_per_frame = 2 * num_ranges_per_frame; + range_boundaries->resize(num_boundaries_per_frame); + for (int64_t i = 0; i < num_boundaries_per_frame; ++i) { + (*range_boundaries)[i].clear(); + } + + for (int64_t row_index = 0; row_index < num_rows; ++row_index) { + int64_t frame_span = + rand.from_range(static_cast(0), std::min(num_rows, max_frame_span)); + int64_t gap_length = + rand.from_range(static_cast(0), std::min(frame_span, max_gap_length)); + int64_t frame_pos = rand.from_range(static_cast(0), num_rows - frame_span); + std::vector frame_boundaries; + CutHoles(rand, frame_span, num_ranges_per_frame - 1, gap_length, frame_boundaries); + for (size_t i = 0; i < frame_boundaries.size(); ++i) { + (*range_boundaries)[i].push_back(frame_boundaries[i] + frame_pos); + } + } +} + +void WindowFramesRandom::CutHoles(Random64Bit& rand, int64_t frame_span, + int64_t num_holes, int64_t sum_hole_size, + std::vector& result_boundaries) { + // Randomly pick size of each hole so that the sum is equal to the requested + // total. + // + ARROW_DCHECK(sum_hole_size <= frame_span); + std::vector cummulative_hole_sizes(num_holes + 1); + cummulative_hole_sizes[0] = 0; + for (int64_t i = 1; i < num_holes; ++i) { + cummulative_hole_sizes[i] = rand.from_range(static_cast(0), sum_hole_size); + } + cummulative_hole_sizes[num_holes] = sum_hole_size; + std::sort(cummulative_hole_sizes.begin(), cummulative_hole_sizes.end()); + + // Randomly pick starting position for each hole. + // + std::vector hole_pos(num_holes); + for (int64_t i = 0; i < num_holes; ++i) { + hole_pos[i] = rand.from_range(static_cast(0), frame_span - sum_hole_size); + } + std::sort(hole_pos.begin(), hole_pos.end()); + for (int64_t i = 0; i < num_holes; ++i) { + hole_pos[i] += cummulative_hole_sizes[i]; + } + + // Output result. + // + int64_t num_boundaries = (num_holes + 1) * 2; + result_boundaries.resize(num_boundaries); + result_boundaries[0] = 0; + result_boundaries[num_boundaries - 1] = frame_span; + for (int64_t i = 0; i < num_holes; ++i) { + result_boundaries[1 + 2 * i] = hole_pos[i]; + result_boundaries[2 + 2 * i] = + hole_pos[i] + cummulative_hole_sizes[i + 1] - cummulative_hole_sizes[i]; + } +} + +void TestWindowRankVariant(RankType rank_type, bool use_frames, bool use_2D) { + // TODO: Framed dense rank is not implemented yet: + // + ARROW_DCHECK(!(rank_type == RankType::DENSE_RANK && use_2D)); + + Random64Bit rand(/*seed=*/0); + + // Preparing thread execution context + // + MemoryPool* pool = default_memory_pool(); + util::TempVectorStack temp_vector_stack; + Status status = temp_vector_stack.Init(pool, 128 * util::MiniBatch::kMiniBatchLength); + ARROW_DCHECK(status.ok()); + ThreadContext thread_context; + thread_context.thread_index = 0; + thread_context.temp_vector_stack = &temp_vector_stack; + thread_context.hardware_flags = 0LL; + + // There will be: 24 small tests, 12 medium tests and 3 large tests. + // + constexpr int num_tests = 24 + 12 + 3; + + // When debugging a failed test case, setting this value allows to skip + // execution of the first couple of test cases to go directly into the + // interesting one, while at the same time making sure that the generated + // random numbers are not affected. + // + const int num_tests_to_skip = 2; + + for (int test = 0; test < num_tests; ++test) { + // Generate random values. + // + // There will be: 24 small tests, 12 medium tests and 3 large tests. + // + int64_t max_rows = (test < 24) ? 100 : (test < 36) ? 256 : 2500; + int64_t num_rows = rand.from_range(static_cast(1), max_rows); + std::vector vals(num_rows); + int64_t max_val = num_rows; + int tie_probability = rand.from_range(0, 256); + for (int64_t i = 0; i < num_rows; ++i) { + bool tie = rand.from_range(0, 255) < tie_probability; + if (tie && i > 0) { + vals[i] = vals[rand.from_range(static_cast(0), i - 1)]; + } else { + vals[i] = rand.from_range(static_cast(0), max_val); + } + } + + // Generate random frames + // + int num_ranges_per_frame = rand.from_range(1, 3); + std::vector> range_boundaries; + int frame_sequence_type_index = rand.from_range(0, 2); + WindowFrameSequenceType frame_sequence_type = + (frame_sequence_type_index == 0) ? WindowFrameSequenceType::GENERIC + : (frame_sequence_type_index == 1) ? WindowFrameSequenceType::SLIDING + : WindowFrameSequenceType::CUMMULATIVE; + WindowFramesRandom::Generate(rand, frame_sequence_type, num_rows, + num_ranges_per_frame, &range_boundaries); + WindowFrames frames; + frames.first_row_index = 0; + frames.num_frames = num_rows; + frames.num_ranges_per_frame = num_ranges_per_frame; + for (int range_index = 0; range_index < num_ranges_per_frame; ++range_index) { + frames.begins[range_index] = range_boundaries[2 * range_index + 0].data(); + frames.ends[range_index] = range_boundaries[2 * range_index + 1].data(); + } + + // Random number generator is not used after this point in the test case, + // so we can skip the rest of the test case if we try to fast forward to a + // specific one. + // + if (test < num_tests_to_skip) { + continue; + } + + // Sort values and output permutation and bit vector of ties + // + BitVectorWithCounts tie_begins; + tie_begins.Resize(num_rows); + std::vector permutation(num_rows); + std::vector vals_sorted(num_rows); + { + std::vector> val_row_pairs(num_rows); + for (int64_t i = 0; i < num_rows; ++i) { + val_row_pairs[i] = std::make_pair(vals[i], i); + } + std::sort(val_row_pairs.begin(), val_row_pairs.end()); + for (int64_t i = 0; i < num_rows; ++i) { + permutation[i] = val_row_pairs[i].second; + vals_sorted[i] = val_row_pairs[i].first; + } + tie_begins.GetNavigator().MarkTieBegins(num_rows, vals_sorted.data()); + } + + ARROW_SCOPED_TRACE( + "num_rows = ", static_cast(num_rows), + "num_ranges_per_frame = ", num_ranges_per_frame, "window_frame_type = ", + use_frames + ? (frame_sequence_type == WindowFrameSequenceType::CUMMULATIVE ? "CUMMULATIVE" + : frame_sequence_type == WindowFrameSequenceType::SLIDING ? "SLIDING" + : "GENERIC") + : "NONE", + "rank_type = ", + rank_type == RankType::ROW_NUMBER ? "ROW_NUMBER" + : rank_type == RankType::RANK_TIES_LOW ? "RANK_TIES_LOW" + : rank_type == RankType::RANK_TIES_HIGH ? "RANK_TIES_HIGH" + : "DENSE_RANK", + "use_2D = ", use_2D); + + // At index 0 - reference results. + // At index 1 - actual results from implementation we wish to verify. + // + std::vector output[2]; + output[0].resize(num_rows); + output[1].resize(num_rows); + + // Execute reference implementation. + // + if (!use_frames) { + WindowRank_Global_Ref::Eval(rank_type, tie_begins.GetNavigator(), output[0].data()); + } else if (!use_2D) { + WindowRank_Framed_Ref::Eval(rank_type, tie_begins.GetNavigator(), nullptr, frames, + output[0].data()); + } else { + WindowRank_Framed_Ref::Eval(rank_type, tie_begins.GetNavigator(), + permutation.data(), frames, output[0].data()); + } + + // Execute actual implementation. + // + if (!use_frames) { + WindowRank_Global::Eval(rank_type, tie_begins.GetNavigator(), 0, num_rows, + output[1].data()); + } else if (!use_2D) { + WindowRank_Framed1D::Eval(rank_type, tie_begins.GetNavigator(), frames, + output[1].data()); + } else { + ASSERT_OK(WindowRank_Framed2D::Eval(rank_type, tie_begins.GetNavigator(), + permutation.data(), frames, output[1].data(), + thread_context)); + } + + bool ok = true; + for (int64_t i = 0; i < num_rows; ++i) { + if (output[0][i] != output[1][i]) { + ARROW_DCHECK(false); + ok = false; + } + } + ASSERT_TRUE(ok); + } +} + +TEST(WindowFunctions, Rank) { + // These flags are useful during debugging, to quickly restrict the set of + // executed tests to just the failing one. + // + bool use_filter_framed = false; + bool use_filter_rank_type = false; + bool use_filter_2D = false; + + bool filter_framed_value = true; + RankType filter_rank_type_value = RankType::RANK_TIES_HIGH; + bool filter_2D_value = true; + + // Global rank + // + for (auto rank_type : {RankType::ROW_NUMBER, RankType::RANK_TIES_LOW, + RankType::RANK_TIES_HIGH, RankType::DENSE_RANK}) { + if (use_filter_2D && filter_2D_value) { + continue; + } + if (use_filter_framed && filter_framed_value) { + continue; + } + if (use_filter_rank_type && filter_rank_type_value != rank_type) { + continue; + } + TestWindowRankVariant(rank_type, + /*use_frames=*/false, + /*ignored*/ false); + } + + // Framed rank + // + for (auto use_2D : {false, true}) { + for (auto rank_type : {RankType::ROW_NUMBER, RankType::RANK_TIES_LOW, + RankType::RANK_TIES_HIGH, RankType::DENSE_RANK}) { + if (use_filter_framed && !filter_framed_value) { + continue; + } + if (use_filter_rank_type && filter_rank_type_value != rank_type) { + continue; + } + if (use_filter_2D && filter_2D_value != use_2D) { + continue; + } + if (rank_type == RankType::DENSE_RANK && use_2D) { + continue; + } + TestWindowRankVariant(rank_type, /*use_frames=*/true, use_2D); + } + } +} + +} // namespace compute +} // namespace arrow