Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions cpp/src/arrow/compute/exec/hash_join.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,12 +103,14 @@ class HashJoinBasicImpl : public HashJoinImpl {
filter_ = std::move(filter);
output_batch_callback_ = std::move(output_batch_callback);
finished_callback_ = std::move(finished_callback);
local_states_.resize(num_threads + 1); // +1 for calling thread + worker threads
// TODO(ARROW-15732)
// Each side of join might have an IO thread being called from.
local_states_.resize(GetCpuThreadPoolCapacity() + io::GetIOThreadPoolCapacity() + 1);
for (size_t i = 0; i < local_states_.size(); ++i) {
local_states_[i].is_initialized = false;
local_states_[i].is_has_match_initialized = false;
}
dict_probe_.Init(num_threads);
dict_probe_.Init(GetCpuThreadPoolCapacity() + io::GetIOThreadPoolCapacity() + 1);

has_hash_table_ = false;
num_batches_produced_.store(0);
Expand Down
3 changes: 2 additions & 1 deletion cpp/src/arrow/compute/exec/hash_join_dict.cc
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,7 @@ Status HashJoinDictBuildMulti::PostDecode(
}

void HashJoinDictProbeMulti::Init(size_t num_threads) {
local_states_.resize(num_threads + 1); // +1 for calling thread + worker threads
local_states_.resize(num_threads);
for (size_t i = 0; i < local_states_.size(); ++i) {
local_states_[i].is_initialized = false;
}
Expand All @@ -576,6 +576,7 @@ bool HashJoinDictProbeMulti::BatchRemapNeeded(
size_t thread_index, const SchemaProjectionMaps<HashJoinProjection>& proj_map_probe,
const SchemaProjectionMaps<HashJoinProjection>& proj_map_build, ExecContext* ctx) {
InitLocalStateIfNeeded(thread_index, proj_map_probe, proj_map_build, ctx);
DCHECK_LT(thread_index, local_states_.size());
return local_states_[thread_index].any_needs_remap;
}

Expand Down
2 changes: 1 addition & 1 deletion cpp/src/arrow/compute/exec/util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ size_t ThreadIndexer::operator()() {
}

size_t ThreadIndexer::Capacity() {
static size_t max_size = arrow::internal::ThreadPool::DefaultCapacity() + 1;
static size_t max_size = GetCpuThreadPoolCapacity() + io::GetIOThreadPoolCapacity() + 1;
return max_size;
}

Expand Down
28 changes: 17 additions & 11 deletions r/tests/testthat/test-dplyr-join.R
Original file line number Diff line number Diff line change
Expand Up @@ -310,27 +310,33 @@ test_that("summarize and join", {
expect_equal(expected_col_names, res_col_names)
})

test_that("arrow dplyr query can join with tibble", {
# ARROW-14908
test_that("arrow dplyr query can join two datasets", {
# ARROW-14908 and ARROW-15718
skip_if_not_available("dataset")

# By default, snappy encoding will be used, and
# Snappy has a UBSan issue: https://github.com/google/snappy/pull/148
skip_on_linux_devel()
skip_if_not_available("dataset")

dir_out <- tempdir()
write_dataset(iris, file.path(dir_out, "iris"))
species_codes <- data.frame(
Species = c("setosa", "versicolor", "virginica"),
code = c("SET", "VER", "VIR")
)

quakes %>%
select(stations, lat, long) %>%
group_by(stations) %>%
write_dataset(file.path(dir_out, "ds1"))

quakes %>%
select(stations, mag, depth) %>%
group_by(stations) %>%
write_dataset(file.path(dir_out, "ds2"))

withr::with_options(
list(arrow.use_threads = FALSE),
{
iris <- open_dataset(file.path(dir_out, "iris"))
res <- left_join(iris, species_codes) %>% collect() # We should not segfault here.
expect_equal(nrow(res), 150)
res <- open_dataset(file.path(dir_out, "ds1")) %>%
left_join(open_dataset(file.path(dir_out, "ds2")), by = "stations") %>%
collect() # We should not segfault here.
expect_equal(nrow(res), 21872)
}
)
})