diff --git a/cpp/src/arrow/compute/kernels/scalar_set_lookup.cc b/cpp/src/arrow/compute/kernels/scalar_set_lookup.cc index c3d2bc5417a..0fbc5b62fe1 100644 --- a/cpp/src/arrow/compute/kernels/scalar_set_lookup.cc +++ b/cpp/src/arrow/compute/kernels/scalar_set_lookup.cc @@ -36,16 +36,23 @@ namespace { template struct SetLookupState : public KernelState { - explicit SetLookupState(MemoryPool* pool) : lookup_table(pool, 0) {} + explicit SetLookupState(MemoryPool* pool) : memory_pool(pool) {} Status Init(const SetLookupOptions& options) { if (options.value_set.is_array()) { const ArrayData& value_set = *options.value_set.array(); memo_index_to_value_index.reserve(value_set.length); + lookup_table = + MemoTable(memory_pool, + ::arrow::internal::HashTable::kLoadFactor * value_set.length); RETURN_NOT_OK(AddArrayValueSet(options, *options.value_set.array())); } else if (options.value_set.kind() == Datum::CHUNKED_ARRAY) { const ChunkedArray& value_set = *options.value_set.chunked_array(); memo_index_to_value_index.reserve(value_set.length()); + lookup_table = + MemoTable(memory_pool, + ::arrow::internal::HashTable::kLoadFactor * value_set.length()); + int64_t offset = 0; for (const std::shared_ptr& chunk : value_set.chunks()) { RETURN_NOT_OK(AddArrayValueSet(options, *chunk->data(), offset)); @@ -54,8 +61,8 @@ struct SetLookupState : public KernelState { } else { return Status::Invalid("value_set should be an array or chunked array"); } - if (!options.skip_nulls && lookup_table.GetNull() >= 0) { - null_index = memo_index_to_value_index[lookup_table.GetNull()]; + if (!options.skip_nulls && lookup_table->GetNull() >= 0) { + null_index = memo_index_to_value_index[lookup_table->GetNull()]; } return Status::OK(); } @@ -75,7 +82,7 @@ struct SetLookupState : public KernelState { DCHECK_EQ(memo_index, memo_size); memo_index_to_value_index.push_back(index); }; - RETURN_NOT_OK(lookup_table.GetOrInsert( + RETURN_NOT_OK(lookup_table->GetOrInsert( v, std::move(on_found), std::move(on_not_found), &unused_memo_index)); ++index; return Status::OK(); @@ -89,7 +96,7 @@ struct SetLookupState : public KernelState { DCHECK_EQ(memo_index, memo_size); memo_index_to_value_index.push_back(index); }; - lookup_table.GetOrInsertNull(std::move(on_found), std::move(on_not_found)); + lookup_table->GetOrInsertNull(std::move(on_found), std::move(on_not_found)); ++index; return Status::OK(); }; @@ -98,7 +105,8 @@ struct SetLookupState : public KernelState { } using MemoTable = typename HashTraits::MemoTableType; - MemoTable lookup_table; + std::optional lookup_table; // use optional for delayed initialization + MemoryPool* memory_pool; // When there are duplicates in value_set, the MemoTable indices must // be mapped back to indices in the value_set. std::vector memo_index_to_value_index; @@ -264,7 +272,7 @@ struct IndexInVisitor { VisitArraySpanInline( data, [&](T v) { - int32_t index = state.lookup_table.Get(v); + int32_t index = state.lookup_table->Get(v); if (index != -1) { bitmap_writer.Set(); @@ -358,7 +366,7 @@ struct IsInVisitor { VisitArraySpanInline( this->data, [&](T v) { - if (state.lookup_table.Get(v) != -1) { + if (state.lookup_table->Get(v) != -1) { writer.Set(); } else { writer.Clear(); diff --git a/cpp/src/arrow/compute/kernels/scalar_set_lookup_benchmark.cc b/cpp/src/arrow/compute/kernels/scalar_set_lookup_benchmark.cc index c49dd740848..9158c518b41 100644 --- a/cpp/src/arrow/compute/kernels/scalar_set_lookup_benchmark.cc +++ b/cpp/src/arrow/compute/kernels/scalar_set_lookup_benchmark.cc @@ -51,13 +51,14 @@ static void SetLookupBenchmarkString(benchmark::State& state, } state.SetItemsProcessed(state.iterations() * array_length); state.SetBytesProcessed(state.iterations() * values->data()->buffers[2]->size()); + state.counters["value_set_length"] = static_cast(value_set_length); } template static void SetLookupBenchmarkNumeric(benchmark::State& state, const std::string& func_name, - const int64_t value_set_length) { - const int64_t array_length = 1 << 18; + const int64_t value_set_length, + const int64_t array_length) { const int64_t value_min = 0; const int64_t value_max = std::numeric_limits::max(); const double null_probability = 0.1 / value_set_length; @@ -72,6 +73,7 @@ static void SetLookupBenchmarkNumeric(benchmark::State& state, } state.SetItemsProcessed(state.iterations() * array_length); state.SetBytesProcessed(state.iterations() * values->data()->buffers[1]->size()); + state.counters["value_set_length"] = static_cast(value_set_length); } static void IndexInStringSmallSet(benchmark::State& state) { @@ -90,36 +92,57 @@ static void IsInStringLargeSet(benchmark::State& state) { SetLookupBenchmarkString(state, "is_in_meta_binary", 1 << 10); } +static constexpr int64_t kArrayLengthWithSmallSet = 1 << 18; +static constexpr int64_t kArrayLengthWithLargeSet = 1000; + static void IndexInInt8SmallSet(benchmark::State& state) { - SetLookupBenchmarkNumeric(state, "index_in_meta_binary", state.range(0)); + SetLookupBenchmarkNumeric(state, "index_in_meta_binary", state.range(0), + kArrayLengthWithSmallSet); } static void IndexInInt16SmallSet(benchmark::State& state) { - SetLookupBenchmarkNumeric(state, "index_in_meta_binary", state.range(0)); + SetLookupBenchmarkNumeric(state, "index_in_meta_binary", state.range(0), + kArrayLengthWithSmallSet); } static void IndexInInt32SmallSet(benchmark::State& state) { - SetLookupBenchmarkNumeric(state, "index_in_meta_binary", state.range(0)); + SetLookupBenchmarkNumeric(state, "index_in_meta_binary", state.range(0), + kArrayLengthWithSmallSet); } static void IndexInInt64SmallSet(benchmark::State& state) { - SetLookupBenchmarkNumeric(state, "index_in_meta_binary", state.range(0)); + SetLookupBenchmarkNumeric(state, "index_in_meta_binary", state.range(0), + kArrayLengthWithSmallSet); +} + +static void IndexInInt32LargeSet(benchmark::State& state) { + SetLookupBenchmarkNumeric(state, "index_in_meta_binary", state.range(0), + kArrayLengthWithLargeSet); } static void IsInInt8SmallSet(benchmark::State& state) { - SetLookupBenchmarkNumeric(state, "is_in_meta_binary", state.range(0)); + SetLookupBenchmarkNumeric(state, "is_in_meta_binary", state.range(0), + kArrayLengthWithSmallSet); } static void IsInInt16SmallSet(benchmark::State& state) { - SetLookupBenchmarkNumeric(state, "is_in_meta_binary", state.range(0)); + SetLookupBenchmarkNumeric(state, "is_in_meta_binary", state.range(0), + kArrayLengthWithSmallSet); } static void IsInInt32SmallSet(benchmark::State& state) { - SetLookupBenchmarkNumeric(state, "is_in_meta_binary", state.range(0)); + SetLookupBenchmarkNumeric(state, "is_in_meta_binary", state.range(0), + kArrayLengthWithSmallSet); } static void IsInInt64SmallSet(benchmark::State& state) { - SetLookupBenchmarkNumeric(state, "is_in_meta_binary", state.range(0)); + SetLookupBenchmarkNumeric(state, "is_in_meta_binary", state.range(0), + kArrayLengthWithSmallSet); +} + +static void IsInInt32LargeSet(benchmark::State& state) { + SetLookupBenchmarkNumeric(state, "is_in_meta_binary", state.range(0), + kArrayLengthWithLargeSet); } BENCHMARK(IndexInStringSmallSet)->RangeMultiplier(4)->Range(2, 64); @@ -134,10 +157,12 @@ BENCHMARK(IndexInInt8SmallSet)->RangeMultiplier(4)->Range(2, 8); BENCHMARK(IndexInInt16SmallSet)->RangeMultiplier(4)->Range(2, 64); BENCHMARK(IndexInInt32SmallSet)->RangeMultiplier(4)->Range(2, 64); BENCHMARK(IndexInInt64SmallSet)->RangeMultiplier(4)->Range(2, 64); +BENCHMARK(IndexInInt32LargeSet)->RangeMultiplier(100)->Range(1000, 1000000); BENCHMARK(IsInInt8SmallSet)->RangeMultiplier(4)->Range(2, 8); BENCHMARK(IsInInt16SmallSet)->RangeMultiplier(4)->Range(2, 64); BENCHMARK(IsInInt32SmallSet)->RangeMultiplier(4)->Range(2, 64); BENCHMARK(IsInInt64SmallSet)->RangeMultiplier(4)->Range(2, 64); +BENCHMARK(IsInInt32LargeSet)->RangeMultiplier(100)->Range(1000, 1000000); } // namespace compute } // namespace arrow