From ad77e21fff9004dc3667ed17625e2e53e98cc34f Mon Sep 17 00:00:00 2001 From: Will Jones Date: Tue, 15 Feb 2022 12:51:13 -0800 Subject: [PATCH 1/4] Add failing test --- r/tests/testthat/test-dplyr-join.R | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/r/tests/testthat/test-dplyr-join.R b/r/tests/testthat/test-dplyr-join.R index 319aa8cf98d..2f6d08e8499 100644 --- a/r/tests/testthat/test-dplyr-join.R +++ b/r/tests/testthat/test-dplyr-join.R @@ -269,3 +269,29 @@ test_that("arrow dplyr query can join with tibble", { } ) }) + + +test_that("arrow dplyr query can join on partition column", { + # ARROW-14908 + dir_out <- tempdir() + + quakes %>% + select(stations, lat, long) %>% + group_by(stations) %>% + write_dataset(file.path(dir_out, "ds1")) + + quakes |> + select(stations, mag, depth) %>% + group_by(stations) %>% + write_dataset(file.path(dir_out, "ds2")) + + withr::with_options( + list(arrow.use_threads = FALSE), + { + open_dataset(file.path(dir_out, "ds1")) |> + left_join(open_dataset(file.path(dir_out, "ds2")), by = "stations") |> + collect() # We should not segfault here. + expect_equal(nrow(res), 150) + } + ) +}) From 3a99d613577b8728cb6b2e9abe15521375e0478c Mon Sep 17 00:00:00 2001 From: Will Jones Date: Tue, 15 Feb 2022 13:42:35 -0800 Subject: [PATCH 2/4] Fix test --- r/tests/testthat/test-dplyr-join.R | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/r/tests/testthat/test-dplyr-join.R b/r/tests/testthat/test-dplyr-join.R index 2f6d08e8499..28e3a66d6a3 100644 --- a/r/tests/testthat/test-dplyr-join.R +++ b/r/tests/testthat/test-dplyr-join.R @@ -288,10 +288,10 @@ test_that("arrow dplyr query can join on partition column", { withr::with_options( list(arrow.use_threads = FALSE), { - open_dataset(file.path(dir_out, "ds1")) |> - left_join(open_dataset(file.path(dir_out, "ds2")), by = "stations") |> + res <- open_dataset(file.path(dir_out, "ds1")) %>% + left_join(open_dataset(file.path(dir_out, "ds2")), by = "stations") %>% collect() # We should not segfault here. - expect_equal(nrow(res), 150) + expect_equal(nrow(res), 21872) } ) }) From 904682a652c099061758d588b344204f85b28f0e Mon Sep 17 00:00:00 2001 From: Will Jones Date: Tue, 15 Feb 2022 13:42:52 -0800 Subject: [PATCH 3/4] Make hash join resize local_state_ as needed --- cpp/src/arrow/compute/exec/hash_join.cc | 27 ++++++++++++++------ cpp/src/arrow/compute/exec/hash_join_dict.cc | 18 ++++++++++--- cpp/src/arrow/compute/exec/hash_join_dict.h | 1 + 3 files changed, 35 insertions(+), 11 deletions(-) diff --git a/cpp/src/arrow/compute/exec/hash_join.cc b/cpp/src/arrow/compute/exec/hash_join.cc index 5a9afaa5bdf..2e3c2da9a92 100644 --- a/cpp/src/arrow/compute/exec/hash_join.cc +++ b/cpp/src/arrow/compute/exec/hash_join.cc @@ -103,7 +103,7 @@ class HashJoinBasicImpl : public HashJoinImpl { filter_ = std::move(filter); output_batch_callback_ = std::move(output_batch_callback); finished_callback_ = std::move(finished_callback); - local_states_.resize(num_threads + 1); // +1 for calling thread + worker threads + local_states_.resize(num_threads); for (size_t i = 0; i < local_states_.size(); ++i) { local_states_[i].is_initialized = false; local_states_[i].is_has_match_initialized = false; @@ -152,7 +152,7 @@ class HashJoinBasicImpl : public HashJoinImpl { void InitLocalStateIfNeeded(size_t thread_index) { DCHECK_LT(thread_index, local_states_.size()); - ThreadLocalState& local_state = local_states_[thread_index]; + ThreadLocalState& local_state = GetLocalState(thread_index); if (!local_state.is_initialized) { InitEncoder(0, HashJoinProjection::KEY, &local_state.exec_batch_keys); bool has_payload = @@ -429,7 +429,7 @@ class HashJoinBasicImpl : public HashJoinImpl { has_right && (schema_mgr_->proj_maps[1].num_cols(HashJoinProjection::PAYLOAD) > 0); - ThreadLocalState& local_state = local_states_[thread_index]; + ThreadLocalState& local_state = GetLocalState(thread_index); InitLocalStateIfNeeded(thread_index); ExecBatch left_key; @@ -541,7 +541,7 @@ class HashJoinBasicImpl : public HashJoinImpl { } Status ProbeBatch(size_t thread_index, const ExecBatch& batch) { - ThreadLocalState& local_state = local_states_[thread_index]; + ThreadLocalState& local_state = GetLocalState(thread_index); InitLocalStateIfNeeded(thread_index); local_state.exec_batch_keys.Clear(); @@ -748,7 +748,7 @@ class HashJoinBasicImpl : public HashJoinImpl { static_cast(std::min(static_cast(hash_table_keys_.num_rows()), hash_table_scan_unit_ * (task_id + 1))); - ThreadLocalState& local_state = local_states_[thread_index]; + ThreadLocalState& local_state = GetLocalState(thread_index); InitLocalStateIfNeeded(thread_index); std::vector& id_left = local_state.no_match; @@ -850,13 +850,13 @@ class HashJoinBasicImpl : public HashJoinImpl { memset(has_match_.data(), 0, bit_util::BytesForBits(num_rows)); for (size_t tid = 0; tid < local_states_.size(); ++tid) { - if (!local_states_[tid].is_initialized) { + if (!GetLocalState(tid).is_initialized) { continue; } - if (!local_states_[tid].is_has_match_initialized) { + if (!GetLocalState(tid).is_has_match_initialized) { continue; } - arrow::internal::BitmapOr(has_match_.data(), 0, local_states_[tid].has_match.data(), + arrow::internal::BitmapOr(has_match_.data(), 0, GetLocalState(tid).has_match.data(), 0, num_rows, 0, has_match_.data()); } } @@ -896,6 +896,17 @@ class HashJoinBasicImpl : public HashJoinImpl { std::vector has_match; }; std::vector local_states_; + ThreadLocalState& GetLocalState(size_t thread_index) { + if (ARROW_PREDICT_FALSE(thread_index >= local_states_.size())) { + size_t old_size = local_states_.size(); + local_states_.resize(thread_index + 1); + for (size_t i = old_size; i < local_states_.size(); ++i) { + local_states_[i].is_initialized = false; + local_states_[i].is_has_match_initialized = false; + } + } + return local_states_[thread_index]; + } // Shared runtime state // diff --git a/cpp/src/arrow/compute/exec/hash_join_dict.cc b/cpp/src/arrow/compute/exec/hash_join_dict.cc index ac1fbbaa3df..8139b963ae3 100644 --- a/cpp/src/arrow/compute/exec/hash_join_dict.cc +++ b/cpp/src/arrow/compute/exec/hash_join_dict.cc @@ -566,7 +566,7 @@ Status HashJoinDictBuildMulti::PostDecode( } void HashJoinDictProbeMulti::Init(size_t num_threads) { - local_states_.resize(num_threads + 1); // +1 for calling thread + worker threads + local_states_.resize(num_threads); for (size_t i = 0; i < local_states_.size(); ++i) { local_states_[i].is_initialized = false; } @@ -576,13 +576,13 @@ bool HashJoinDictProbeMulti::BatchRemapNeeded( size_t thread_index, const SchemaProjectionMaps& proj_map_probe, const SchemaProjectionMaps& proj_map_build, ExecContext* ctx) { InitLocalStateIfNeeded(thread_index, proj_map_probe, proj_map_build, ctx); - return local_states_[thread_index].any_needs_remap; + return GetLocalState(thread_index).any_needs_remap; } void HashJoinDictProbeMulti::InitLocalStateIfNeeded( size_t thread_index, const SchemaProjectionMaps& proj_map_probe, const SchemaProjectionMaps& proj_map_build, ExecContext* ctx) { - ThreadLocalState& local_state = local_states_[thread_index]; + ThreadLocalState& local_state = GetLocalState(thread_index); // Check if we need to remap any of the input keys because of dictionary encoding // on either side of the join @@ -661,5 +661,17 @@ Status HashJoinDictProbeMulti::EncodeBatch( return Status::OK(); } +HashJoinDictProbeMulti::ThreadLocalState& HashJoinDictProbeMulti::GetLocalState( + size_t thread_index) { + if (ARROW_PREDICT_FALSE(thread_index >= local_states_.size())) { + size_t old_size = local_states_.size(); + local_states_.resize(thread_index + 1); + for (size_t i = old_size; i < local_states_.size(); ++i) { + local_states_[i].is_initialized = false; + } + } + return local_states_[thread_index]; +} + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/exec/hash_join_dict.h b/cpp/src/arrow/compute/exec/hash_join_dict.h index 26605cc449a..e2aed375226 100644 --- a/cpp/src/arrow/compute/exec/hash_join_dict.h +++ b/cpp/src/arrow/compute/exec/hash_join_dict.h @@ -309,6 +309,7 @@ class HashJoinDictProbeMulti { RowEncoder post_remap_encoder; }; std::vector local_states_; + ThreadLocalState& GetLocalState(size_t thread_index); }; } // namespace compute From 1cea000ef79d262cbfe0a6fb608064dcb26c74a0 Mon Sep 17 00:00:00 2001 From: Will Jones Date: Tue, 15 Feb 2022 14:06:01 -0800 Subject: [PATCH 4/4] Fix last pipe --- r/tests/testthat/test-dplyr-join.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/r/tests/testthat/test-dplyr-join.R b/r/tests/testthat/test-dplyr-join.R index 28e3a66d6a3..975d22f7bd9 100644 --- a/r/tests/testthat/test-dplyr-join.R +++ b/r/tests/testthat/test-dplyr-join.R @@ -280,7 +280,7 @@ test_that("arrow dplyr query can join on partition column", { group_by(stations) %>% write_dataset(file.path(dir_out, "ds1")) - quakes |> + quakes %>% select(stations, mag, depth) %>% group_by(stations) %>% write_dataset(file.path(dir_out, "ds2"))