Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 43 additions & 21 deletions cpp/src/arrow/acero/asof_join_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -456,12 +456,18 @@ struct MemoStore {
}
};

struct NumberedRecordBatch {
int index;
std::shared_ptr<RecordBatch> batch;
};

// a specialized higher-performance variation of Hashing64 logic from hash_join_node
// the code here avoids recreating objects that are independent of each batch processed
class KeyHasher {
friend class AsofJoinNode;

static constexpr int kMiniBatchLength = arrow::util::MiniBatch::kMiniBatchLength;
static constexpr int kNoCachingKey = -1;

public:
// the key hasher is not thread-safe and is only used in sequential batch processing
Expand All @@ -470,7 +476,7 @@ class KeyHasher {
: index_(index),
indices_(indices),
metadata_(indices.size()),
batch_(NULLPTR),
batch_index_(kNoCachingKey),
hashes_(),
ctx_(),
column_arrays_(),
Expand All @@ -490,16 +496,19 @@ class KeyHasher {
4 * kMiniBatchLength * sizeof(uint32_t));
}

// invalidate cached hashes for batch - required when it changes
// only this method can be called concurrently with HashesFor
void Invalidate() { batch_ = NULLPTR; }
// get the batch index - the current caching key - or a negative if the cache is invalid
int GetBatchIndex() { return batch_index_; }

// invalidate the cache
void Invalidate() { batch_index_ = kNoCachingKey; }

// compute and cache a hash for each row of the given batch
const std::vector<HashType>& HashesFor(const RecordBatch* batch) {
if (batch_ == batch) {
const std::vector<HashType>& HashesFor(const NumberedRecordBatch& numbered_batch) {
if (batch_index_ == numbered_batch.index) {
return hashes_; // cache hit - return cached hashes
}
Invalidate();
batch_index_ = kNoCachingKey;
const auto& batch = numbered_batch.batch;
size_t batch_length = batch->num_rows();
hashes_.resize(batch_length);
for (int64_t i = 0; i < static_cast<int64_t>(batch_length); i += kMiniBatchLength) {
Expand All @@ -515,7 +524,7 @@ class KeyHasher {
}
DEBUG_SYNC(node_, "key hasher ", index_, " got hashes ",
compute::internal::GenericToString(hashes_), DEBUG_MANIP(std::endl));
batch_ = batch; // associate cache with current batch
batch_index_ = numbered_batch.index; // associate cache with current batch index
return hashes_;
}

Expand All @@ -524,7 +533,7 @@ class KeyHasher {
size_t index_;
std::vector<col_index_t> indices_;
std::vector<KeyColumnMetadata> metadata_;
const RecordBatch* batch_;
int batch_index_;
std::vector<HashType> hashes_;
LightContext ctx_;
std::vector<KeyColumnArray> column_arrays_;
Expand Down Expand Up @@ -722,9 +731,14 @@ class InputState {

int total_batches() const { return total_batches_; }

// Gets latest numbered batch (precondition: must not be empty)
const NumberedRecordBatch& GetLatestNumberedBatch() const {
return queue_.UnsyncFront();
}

// Gets latest batch (precondition: must not be empty)
const std::shared_ptr<arrow::RecordBatch>& GetLatestBatch() const {
return queue_.UnsyncFront();
return GetLatestNumberedBatch().batch;
}

#define LATEST_VAL_CASE(id, val) \
Expand All @@ -735,20 +749,20 @@ class InputState {
}

inline ByType GetLatestKey() const {
return GetKey(GetLatestBatch().get(), latest_ref_row_);
return GetKey(GetLatestNumberedBatch(), latest_ref_row_);
}

inline ByType GetKey(const RecordBatch* batch, row_index_t row) const {
inline ByType GetKey(const NumberedRecordBatch& numbered_batch, row_index_t row) const {
if (must_hash_) {
// Query the key hasher. This may hit cache, which must be valid for the batch.
// Therefore, the key hasher is invalidated when a new batch is pushed - see
// `InputState::Push`.
return key_hasher_->HashesFor(batch)[row];
return key_hasher_->HashesFor(numbered_batch)[row];
}
if (key_col_index_.size() == 0) {
return 0;
}
auto data = batch->column_data(key_col_index_[0]);
auto data = numbered_batch.batch->column_data(key_col_index_[0]);
switch (key_type_id_[0]) {
LATEST_VAL_CASE(INT8, key_value)
LATEST_VAL_CASE(INT16, key_value)
Expand Down Expand Up @@ -812,15 +826,15 @@ class InputState {
}
latest_time_ = next_time;
// If we have an active batch
if (++latest_ref_row_ >= (row_index_t)queue_.UnsyncFront()->num_rows()) {
if (++latest_ref_row_ >= (row_index_t)queue_.UnsyncFront().batch->num_rows()) {
// hit the end of the batch, need to get the next batch if possible.
++batches_processed_;
latest_ref_row_ = 0;
have_active_batch &= !queue_.TryPop();
if (have_active_batch) {
DCHECK_GT(queue_.UnsyncFront()->num_rows(), 0); // empty batches disallowed
key_hasher_->Invalidate(); // batch changed - invalidate key hasher's cache
memo_.UpdateTime(GetTime(queue_.UnsyncFront().get(), 0)); // time changed
DCHECK_GT(queue_.UnsyncFront().batch->num_rows(),
0); // empty batches disallowed
memo_.UpdateTime(GetTime(queue_.UnsyncFront().batch.get(), 0)); // time changed
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure that queue_.UnsyncFront() here is the same as just above?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, because only the processing thread, which runs this code, is popping from the (front of the) queue.

}
}
}
Expand All @@ -845,6 +859,10 @@ class InputState {
bool advanced, updated = false;
OnType latest_time;
do {
// must check for cache invalidation before invoking GetLatestKey
if (key_hasher_->GetBatchIndex() != queue_.UnsyncFront().index) {
key_hasher_->Invalidate();
}
latest_time = GetLatestTime();
// if Advance() returns true, then the latest_ts must also be valid
// Keep advancing right table until we hit the latest row that has
Expand Down Expand Up @@ -879,11 +897,12 @@ class InputState {

void Rehash() {
DEBUG_SYNC(node_, "rehashing for input ", index_, ":", DEBUG_MANIP(std::endl));
int batch_index = key_hasher_->GetBatchIndex();
MemoStore new_memo(DEBUG_ADD(memo_.no_future_, node_, index_));
new_memo.current_time_ = (OnType)memo_.current_time_;
for (auto e = memo_.entries_.begin(); e != memo_.entries_.end(); ++e) {
auto& entry = e->second;
auto new_key = GetKey(entry.batch.get(), entry.row);
auto new_key = GetKey({batch_index, entry.batch}, entry.row);
DEBUG_SYNC(node_, " ", e->first, " to ", new_key, DEBUG_MANIP(std::endl));
new_memo.entries_[new_key].swap(entry);
auto fe = memo_.future_entries_.find(e->first);
Expand All @@ -897,10 +916,11 @@ class InputState {

Status Push(const std::shared_ptr<arrow::RecordBatch>& rb) {
if (rb->num_rows() > 0) {
queue_.Push(rb); // only after above updates - push batch for processing
queue_.Push({batch_index_, std::move(rb)});
} else {
++batches_processed_; // don't enqueue empty batches, just record as processed
}
++batch_index_;
return Status::OK();
}

Expand Down Expand Up @@ -931,7 +951,9 @@ class InputState {

private:
// Pending record batches. The latest is the front. Batches cannot be empty.
BackpressureConcurrentQueue<std::shared_ptr<RecordBatch>> queue_;
BackpressureConcurrentQueue<NumberedRecordBatch> queue_;
// Number to assign to next batch
int batch_index_ = 0;
// Schema associated with the input
std::shared_ptr<Schema> schema_;
// Total number of batches (only int because InputFinished uses int)
Expand Down