Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 19 additions & 8 deletions cpp/src/arrow/compute/exec/hash_join.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ class HashJoinBasicImpl : public HashJoinImpl {
filter_ = std::move(filter);
output_batch_callback_ = std::move(output_batch_callback);
finished_callback_ = std::move(finished_callback);
local_states_.resize(num_threads + 1); // +1 for calling thread + worker threads
local_states_.resize(num_threads);
for (size_t i = 0; i < local_states_.size(); ++i) {
local_states_[i].is_initialized = false;
local_states_[i].is_has_match_initialized = false;
Expand Down Expand Up @@ -152,7 +152,7 @@ class HashJoinBasicImpl : public HashJoinImpl {

void InitLocalStateIfNeeded(size_t thread_index) {
DCHECK_LT(thread_index, local_states_.size());
ThreadLocalState& local_state = local_states_[thread_index];
ThreadLocalState& local_state = GetLocalState(thread_index);
if (!local_state.is_initialized) {
InitEncoder(0, HashJoinProjection::KEY, &local_state.exec_batch_keys);
bool has_payload =
Expand Down Expand Up @@ -429,7 +429,7 @@ class HashJoinBasicImpl : public HashJoinImpl {
has_right &&
(schema_mgr_->proj_maps[1].num_cols(HashJoinProjection::PAYLOAD) > 0);

ThreadLocalState& local_state = local_states_[thread_index];
ThreadLocalState& local_state = GetLocalState(thread_index);
InitLocalStateIfNeeded(thread_index);

ExecBatch left_key;
Expand Down Expand Up @@ -541,7 +541,7 @@ class HashJoinBasicImpl : public HashJoinImpl {
}

Status ProbeBatch(size_t thread_index, const ExecBatch& batch) {
ThreadLocalState& local_state = local_states_[thread_index];
ThreadLocalState& local_state = GetLocalState(thread_index);
InitLocalStateIfNeeded(thread_index);

local_state.exec_batch_keys.Clear();
Expand Down Expand Up @@ -748,7 +748,7 @@ class HashJoinBasicImpl : public HashJoinImpl {
static_cast<int32_t>(std::min(static_cast<int64_t>(hash_table_keys_.num_rows()),
hash_table_scan_unit_ * (task_id + 1)));

ThreadLocalState& local_state = local_states_[thread_index];
ThreadLocalState& local_state = GetLocalState(thread_index);
InitLocalStateIfNeeded(thread_index);

std::vector<int32_t>& id_left = local_state.no_match;
Expand Down Expand Up @@ -850,13 +850,13 @@ class HashJoinBasicImpl : public HashJoinImpl {
memset(has_match_.data(), 0, bit_util::BytesForBits(num_rows));

for (size_t tid = 0; tid < local_states_.size(); ++tid) {
if (!local_states_[tid].is_initialized) {
if (!GetLocalState(tid).is_initialized) {
continue;
}
if (!local_states_[tid].is_has_match_initialized) {
if (!GetLocalState(tid).is_has_match_initialized) {
continue;
}
arrow::internal::BitmapOr(has_match_.data(), 0, local_states_[tid].has_match.data(),
arrow::internal::BitmapOr(has_match_.data(), 0, GetLocalState(tid).has_match.data(),
0, num_rows, 0, has_match_.data());
}
}
Expand Down Expand Up @@ -896,6 +896,17 @@ class HashJoinBasicImpl : public HashJoinImpl {
std::vector<uint8_t> has_match;
};
std::vector<ThreadLocalState> local_states_;
ThreadLocalState& GetLocalState(size_t thread_index) {
if (ARROW_PREDICT_FALSE(thread_index >= local_states_.size())) {
size_t old_size = local_states_.size();
local_states_.resize(thread_index + 1);
for (size_t i = old_size; i < local_states_.size(); ++i) {
local_states_[i].is_initialized = false;
local_states_[i].is_has_match_initialized = false;
}
}
return local_states_[thread_index];
}
Comment on lines +899 to +909
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So here's my suggestion for an alternative approach: Instead of trusting that the exact correct number of threads has been created (since that seems hard), gracefully resize the local state vectors as needed. I think there's a few more places I'd need to add this logic (we might even need this in the indexer; see my earlier comment about the occasional failure).

What do you think @westonpace?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this thread-unsafe? You're resizing a vector while it could be accessed by other threads concurrently?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, that needs to be fixed.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm curios, given that during the Init the vector is set to the size equal to the number of threads, when does it happen that GetLocalState is invoked with a thread index outside of the already allocated ones? I thought we were using threadpools and thus the amount of threads was stable. Are we recycling them or something like that?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The thread pools don't include the main thread, for instance. Also it's sized to the CPU thread pool, but something might 'leak' from the IO thread pool.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The IO threads and CPU threads are separate pools, and if we don't pass an executor the source node runs the downstream nodes on the IO thread:

outputs_[0]->InputReceived(this, std::move(batch));

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And by the way, the user is allowed to change thread pool capacity at runtime, so static sizing will never be correct even in the simple case of a single thread pool.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The sizing is recomputed for each new exec plan so the failure would only occur on plans that were running when the thread pool was resized. I am planning on taking a look at a better implementation for use_threads=FALSE on Friday using a serial executor which will ensure that an exec plan always has an executor and exec plan steps are always run on a thread belonging to that executor. This will solve all but the issue Antoine mentioned.

That being said, I think your solution is reasonable. I'll have to ping @bkietz and @michalursa as they were the original proponents of the statically sized thread states. I don't know if that was based on speculation, existing literature, or actual benchmark measurements however.


// Shared runtime state
//
Expand Down
18 changes: 15 additions & 3 deletions cpp/src/arrow/compute/exec/hash_join_dict.cc
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,7 @@ Status HashJoinDictBuildMulti::PostDecode(
}

void HashJoinDictProbeMulti::Init(size_t num_threads) {
local_states_.resize(num_threads + 1); // +1 for calling thread + worker threads
local_states_.resize(num_threads);
for (size_t i = 0; i < local_states_.size(); ++i) {
local_states_[i].is_initialized = false;
}
Expand All @@ -576,13 +576,13 @@ bool HashJoinDictProbeMulti::BatchRemapNeeded(
size_t thread_index, const SchemaProjectionMaps<HashJoinProjection>& proj_map_probe,
const SchemaProjectionMaps<HashJoinProjection>& proj_map_build, ExecContext* ctx) {
InitLocalStateIfNeeded(thread_index, proj_map_probe, proj_map_build, ctx);
return local_states_[thread_index].any_needs_remap;
return GetLocalState(thread_index).any_needs_remap;
}

void HashJoinDictProbeMulti::InitLocalStateIfNeeded(
size_t thread_index, const SchemaProjectionMaps<HashJoinProjection>& proj_map_probe,
const SchemaProjectionMaps<HashJoinProjection>& proj_map_build, ExecContext* ctx) {
ThreadLocalState& local_state = local_states_[thread_index];
ThreadLocalState& local_state = GetLocalState(thread_index);

// Check if we need to remap any of the input keys because of dictionary encoding
// on either side of the join
Expand Down Expand Up @@ -661,5 +661,17 @@ Status HashJoinDictProbeMulti::EncodeBatch(
return Status::OK();
}

HashJoinDictProbeMulti::ThreadLocalState& HashJoinDictProbeMulti::GetLocalState(
size_t thread_index) {
if (ARROW_PREDICT_FALSE(thread_index >= local_states_.size())) {
size_t old_size = local_states_.size();
local_states_.resize(thread_index + 1);
for (size_t i = old_size; i < local_states_.size(); ++i) {
local_states_[i].is_initialized = false;
}
}
return local_states_[thread_index];
}

} // namespace compute
} // namespace arrow
1 change: 1 addition & 0 deletions cpp/src/arrow/compute/exec/hash_join_dict.h
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,7 @@ class HashJoinDictProbeMulti {
RowEncoder post_remap_encoder;
};
std::vector<ThreadLocalState> local_states_;
ThreadLocalState& GetLocalState(size_t thread_index);
};

} // namespace compute
Expand Down
26 changes: 26 additions & 0 deletions r/tests/testthat/test-dplyr-join.R
Original file line number Diff line number Diff line change
Expand Up @@ -269,3 +269,29 @@ test_that("arrow dplyr query can join with tibble", {
}
)
})


test_that("arrow dplyr query can join on partition column", {
# ARROW-14908
dir_out <- tempdir()

quakes %>%
select(stations, lat, long) %>%
group_by(stations) %>%
write_dataset(file.path(dir_out, "ds1"))

quakes %>%
select(stations, mag, depth) %>%
group_by(stations) %>%
write_dataset(file.path(dir_out, "ds2"))

withr::with_options(
list(arrow.use_threads = FALSE),
{
res <- open_dataset(file.path(dir_out, "ds1")) %>%
left_join(open_dataset(file.path(dir_out, "ds2")), by = "stations") %>%
collect() # We should not segfault here.
expect_equal(nrow(res), 21872)
}
)
})