From 58a648b4e500750ffa1fe97abd672c3d4280c954 Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Sun, 9 Jul 2023 08:10:41 -0400 Subject: [PATCH] GH-36482: [C++][CI] Fix sporadic test failures in AsofJoinBasicTest --- cpp/src/arrow/acero/asof_join_node.cc | 64 ++++++++++++++++++--------- 1 file changed, 43 insertions(+), 21 deletions(-) diff --git a/cpp/src/arrow/acero/asof_join_node.cc b/cpp/src/arrow/acero/asof_join_node.cc index 98e5918ebbf..ed41e15b8a4 100644 --- a/cpp/src/arrow/acero/asof_join_node.cc +++ b/cpp/src/arrow/acero/asof_join_node.cc @@ -456,12 +456,18 @@ struct MemoStore { } }; +struct NumberedRecordBatch { + int index; + std::shared_ptr batch; +}; + // a specialized higher-performance variation of Hashing64 logic from hash_join_node // the code here avoids recreating objects that are independent of each batch processed class KeyHasher { friend class AsofJoinNode; static constexpr int kMiniBatchLength = arrow::util::MiniBatch::kMiniBatchLength; + static constexpr int kNoCachingKey = -1; public: // the key hasher is not thread-safe and is only used in sequential batch processing @@ -470,7 +476,7 @@ class KeyHasher { : index_(index), indices_(indices), metadata_(indices.size()), - batch_(NULLPTR), + batch_index_(kNoCachingKey), hashes_(), ctx_(), column_arrays_(), @@ -490,16 +496,19 @@ class KeyHasher { 4 * kMiniBatchLength * sizeof(uint32_t)); } - // invalidate cached hashes for batch - required when it changes - // only this method can be called concurrently with HashesFor - void Invalidate() { batch_ = NULLPTR; } + // get the batch index - the current caching key - or a negative if the cache is invalid + int GetBatchIndex() { return batch_index_; } + + // invalidate the cache + void Invalidate() { batch_index_ = kNoCachingKey; } // compute and cache a hash for each row of the given batch - const std::vector& HashesFor(const RecordBatch* batch) { - if (batch_ == batch) { + const std::vector& HashesFor(const NumberedRecordBatch& numbered_batch) { + if (batch_index_ == numbered_batch.index) { return hashes_; // cache hit - return cached hashes } - Invalidate(); + batch_index_ = kNoCachingKey; + const auto& batch = numbered_batch.batch; size_t batch_length = batch->num_rows(); hashes_.resize(batch_length); for (int64_t i = 0; i < static_cast(batch_length); i += kMiniBatchLength) { @@ -515,7 +524,7 @@ class KeyHasher { } DEBUG_SYNC(node_, "key hasher ", index_, " got hashes ", compute::internal::GenericToString(hashes_), DEBUG_MANIP(std::endl)); - batch_ = batch; // associate cache with current batch + batch_index_ = numbered_batch.index; // associate cache with current batch index return hashes_; } @@ -524,7 +533,7 @@ class KeyHasher { size_t index_; std::vector indices_; std::vector metadata_; - const RecordBatch* batch_; + int batch_index_; std::vector hashes_; LightContext ctx_; std::vector column_arrays_; @@ -722,9 +731,14 @@ class InputState { int total_batches() const { return total_batches_; } + // Gets latest numbered batch (precondition: must not be empty) + const NumberedRecordBatch& GetLatestNumberedBatch() const { + return queue_.UnsyncFront(); + } + // Gets latest batch (precondition: must not be empty) const std::shared_ptr& GetLatestBatch() const { - return queue_.UnsyncFront(); + return GetLatestNumberedBatch().batch; } #define LATEST_VAL_CASE(id, val) \ @@ -735,20 +749,20 @@ class InputState { } inline ByType GetLatestKey() const { - return GetKey(GetLatestBatch().get(), latest_ref_row_); + return GetKey(GetLatestNumberedBatch(), latest_ref_row_); } - inline ByType GetKey(const RecordBatch* batch, row_index_t row) const { + inline ByType GetKey(const NumberedRecordBatch& numbered_batch, row_index_t row) const { if (must_hash_) { // Query the key hasher. This may hit cache, which must be valid for the batch. // Therefore, the key hasher is invalidated when a new batch is pushed - see // `InputState::Push`. - return key_hasher_->HashesFor(batch)[row]; + return key_hasher_->HashesFor(numbered_batch)[row]; } if (key_col_index_.size() == 0) { return 0; } - auto data = batch->column_data(key_col_index_[0]); + auto data = numbered_batch.batch->column_data(key_col_index_[0]); switch (key_type_id_[0]) { LATEST_VAL_CASE(INT8, key_value) LATEST_VAL_CASE(INT16, key_value) @@ -812,15 +826,15 @@ class InputState { } latest_time_ = next_time; // If we have an active batch - if (++latest_ref_row_ >= (row_index_t)queue_.UnsyncFront()->num_rows()) { + if (++latest_ref_row_ >= (row_index_t)queue_.UnsyncFront().batch->num_rows()) { // hit the end of the batch, need to get the next batch if possible. ++batches_processed_; latest_ref_row_ = 0; have_active_batch &= !queue_.TryPop(); if (have_active_batch) { - DCHECK_GT(queue_.UnsyncFront()->num_rows(), 0); // empty batches disallowed - key_hasher_->Invalidate(); // batch changed - invalidate key hasher's cache - memo_.UpdateTime(GetTime(queue_.UnsyncFront().get(), 0)); // time changed + DCHECK_GT(queue_.UnsyncFront().batch->num_rows(), + 0); // empty batches disallowed + memo_.UpdateTime(GetTime(queue_.UnsyncFront().batch.get(), 0)); // time changed } } } @@ -845,6 +859,10 @@ class InputState { bool advanced, updated = false; OnType latest_time; do { + // must check for cache invalidation before invoking GetLatestKey + if (key_hasher_->GetBatchIndex() != queue_.UnsyncFront().index) { + key_hasher_->Invalidate(); + } latest_time = GetLatestTime(); // if Advance() returns true, then the latest_ts must also be valid // Keep advancing right table until we hit the latest row that has @@ -879,11 +897,12 @@ class InputState { void Rehash() { DEBUG_SYNC(node_, "rehashing for input ", index_, ":", DEBUG_MANIP(std::endl)); + int batch_index = key_hasher_->GetBatchIndex(); MemoStore new_memo(DEBUG_ADD(memo_.no_future_, node_, index_)); new_memo.current_time_ = (OnType)memo_.current_time_; for (auto e = memo_.entries_.begin(); e != memo_.entries_.end(); ++e) { auto& entry = e->second; - auto new_key = GetKey(entry.batch.get(), entry.row); + auto new_key = GetKey({batch_index, entry.batch}, entry.row); DEBUG_SYNC(node_, " ", e->first, " to ", new_key, DEBUG_MANIP(std::endl)); new_memo.entries_[new_key].swap(entry); auto fe = memo_.future_entries_.find(e->first); @@ -897,10 +916,11 @@ class InputState { Status Push(const std::shared_ptr& rb) { if (rb->num_rows() > 0) { - queue_.Push(rb); // only after above updates - push batch for processing + queue_.Push({batch_index_, std::move(rb)}); } else { ++batches_processed_; // don't enqueue empty batches, just record as processed } + ++batch_index_; return Status::OK(); } @@ -931,7 +951,9 @@ class InputState { private: // Pending record batches. The latest is the front. Batches cannot be empty. - BackpressureConcurrentQueue> queue_; + BackpressureConcurrentQueue queue_; + // Number to assign to next batch + int batch_index_ = 0; // Schema associated with the input std::shared_ptr schema_; // Total number of batches (only int because InputFinished uses int)