Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions cpp/src/arrow/acero/swiss_join.cc
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,7 @@ void RowArray::DebugPrintToFile(const char* filename, bool print_sorted) const {

Status RowArrayMerge::PrepareForMerge(RowArray* target,
const std::vector<RowArray*>& sources,
bool is_key_data,
std::vector<int64_t>* first_target_row_id,
MemoryPool* pool) {
ARROW_DCHECK(!sources.empty());
Expand Down Expand Up @@ -473,7 +474,7 @@ Status RowArrayMerge::PrepareForMerge(RowArray* target,
(*first_target_row_id)[sources.size()] = num_rows;
}

if (num_bytes > std::numeric_limits<uint32_t>::max()) {
if (is_key_data && num_bytes > std::numeric_limits<uint32_t>::max()) {
Copy link
Contributor

@zanmato1984 zanmato1984 Jul 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When letting non-key data that is greater than std::numeric_limits<uint32_t>::max() bypass this check, the num_bytes will underflow to a much smaller value in the static_cast<uint32_t> in #486. Then the target won't allocate enough space, resulting in segfault when copying data to target. The original check is necessary and there is unfortunately nothing to loosen.

return Status::Invalid(
"There are more than 2^32 bytes of key data. Acero cannot "
"process a join of this magnitude");
Expand Down Expand Up @@ -1330,7 +1331,8 @@ Status SwissTableForJoinBuild::PreparePrtnMerge() {
for (int i = 0; i < num_prtns_; ++i) {
partition_keys[i] = prtn_states_[i].keys.keys();
}
RETURN_NOT_OK(RowArrayMerge::PrepareForMerge(target_->map_.keys(), partition_keys,

RETURN_NOT_OK(RowArrayMerge::PrepareForMerge(target_->map_.keys(), partition_keys, true,
&partition_keys_first_row_id_, pool_));

// 2. SwissTable:
Expand All @@ -1353,7 +1355,7 @@ Status SwissTableForJoinBuild::PreparePrtnMerge() {
partition_payloads[i] = &prtn_states_[i].payloads;
}
RETURN_NOT_OK(RowArrayMerge::PrepareForMerge(&target_->payloads_, partition_payloads,
&partition_payloads_first_row_id_,
false, &partition_payloads_first_row_id_,
pool_));
}

Expand Down
1 change: 1 addition & 0 deletions cpp/src/arrow/acero/swiss_join_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ class RowArrayMerge {
// caller can pass in nullptr to indicate that it is not needed.
//
static Status PrepareForMerge(RowArray* target, const std::vector<RowArray*>& sources,
bool is_key_data,
std::vector<int64_t>* first_target_row_id,
MemoryPool* pool);

Expand Down