Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 16 additions & 8 deletions cpp/src/arrow/compute/kernels/scalar_set_lookup.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,23 @@ namespace {

template <typename Type>
struct SetLookupState : public KernelState {
explicit SetLookupState(MemoryPool* pool) : lookup_table(pool, 0) {}
explicit SetLookupState(MemoryPool* pool) : memory_pool(pool) {}

Status Init(const SetLookupOptions& options) {
if (options.value_set.is_array()) {
const ArrayData& value_set = *options.value_set.array();
memo_index_to_value_index.reserve(value_set.length);
lookup_table =
MemoTable(memory_pool,
::arrow::internal::HashTable<char>::kLoadFactor * value_set.length);
RETURN_NOT_OK(AddArrayValueSet(options, *options.value_set.array()));
} else if (options.value_set.kind() == Datum::CHUNKED_ARRAY) {
const ChunkedArray& value_set = *options.value_set.chunked_array();
memo_index_to_value_index.reserve(value_set.length());
lookup_table =
MemoTable(memory_pool,
::arrow::internal::HashTable<char>::kLoadFactor * value_set.length());

int64_t offset = 0;
for (const std::shared_ptr<Array>& chunk : value_set.chunks()) {
RETURN_NOT_OK(AddArrayValueSet(options, *chunk->data(), offset));
Expand All @@ -54,8 +61,8 @@ struct SetLookupState : public KernelState {
} else {
return Status::Invalid("value_set should be an array or chunked array");
}
if (!options.skip_nulls && lookup_table.GetNull() >= 0) {
null_index = memo_index_to_value_index[lookup_table.GetNull()];
if (!options.skip_nulls && lookup_table->GetNull() >= 0) {
null_index = memo_index_to_value_index[lookup_table->GetNull()];
}
return Status::OK();
}
Expand All @@ -75,7 +82,7 @@ struct SetLookupState : public KernelState {
DCHECK_EQ(memo_index, memo_size);
memo_index_to_value_index.push_back(index);
};
RETURN_NOT_OK(lookup_table.GetOrInsert(
RETURN_NOT_OK(lookup_table->GetOrInsert(
v, std::move(on_found), std::move(on_not_found), &unused_memo_index));
++index;
return Status::OK();
Expand All @@ -89,7 +96,7 @@ struct SetLookupState : public KernelState {
DCHECK_EQ(memo_index, memo_size);
memo_index_to_value_index.push_back(index);
};
lookup_table.GetOrInsertNull(std::move(on_found), std::move(on_not_found));
lookup_table->GetOrInsertNull(std::move(on_found), std::move(on_not_found));
++index;
return Status::OK();
};
Expand All @@ -98,7 +105,8 @@ struct SetLookupState : public KernelState {
}

using MemoTable = typename HashTraits<Type>::MemoTableType;
MemoTable lookup_table;
std::optional<MemoTable> lookup_table; // use optional for delayed initialization
MemoryPool* memory_pool;
// When there are duplicates in value_set, the MemoTable indices must
// be mapped back to indices in the value_set.
std::vector<int32_t> memo_index_to_value_index;
Expand Down Expand Up @@ -264,7 +272,7 @@ struct IndexInVisitor {
VisitArraySpanInline<Type>(
data,
[&](T v) {
int32_t index = state.lookup_table.Get(v);
int32_t index = state.lookup_table->Get(v);
if (index != -1) {
bitmap_writer.Set();

Expand Down Expand Up @@ -358,7 +366,7 @@ struct IsInVisitor {
VisitArraySpanInline<Type>(
this->data,
[&](T v) {
if (state.lookup_table.Get(v) != -1) {
if (state.lookup_table->Get(v) != -1) {
writer.Set();
} else {
writer.Clear();
Expand Down
45 changes: 35 additions & 10 deletions cpp/src/arrow/compute/kernels/scalar_set_lookup_benchmark.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,14 @@ static void SetLookupBenchmarkString(benchmark::State& state,
}
state.SetItemsProcessed(state.iterations() * array_length);
state.SetBytesProcessed(state.iterations() * values->data()->buffers[2]->size());
state.counters["value_set_length"] = static_cast<double>(value_set_length);
}

template <typename Type>
static void SetLookupBenchmarkNumeric(benchmark::State& state,
const std::string& func_name,
const int64_t value_set_length) {
const int64_t array_length = 1 << 18;
const int64_t value_set_length,
const int64_t array_length) {
const int64_t value_min = 0;
const int64_t value_max = std::numeric_limits<typename Type::c_type>::max();
const double null_probability = 0.1 / value_set_length;
Expand All @@ -72,6 +73,7 @@ static void SetLookupBenchmarkNumeric(benchmark::State& state,
}
state.SetItemsProcessed(state.iterations() * array_length);
state.SetBytesProcessed(state.iterations() * values->data()->buffers[1]->size());
state.counters["value_set_length"] = static_cast<double>(value_set_length);
}

static void IndexInStringSmallSet(benchmark::State& state) {
Expand All @@ -90,36 +92,57 @@ static void IsInStringLargeSet(benchmark::State& state) {
SetLookupBenchmarkString(state, "is_in_meta_binary", 1 << 10);
}

static constexpr int64_t kArrayLengthWithSmallSet = 1 << 18;
static constexpr int64_t kArrayLengthWithLargeSet = 1000;

static void IndexInInt8SmallSet(benchmark::State& state) {
SetLookupBenchmarkNumeric<Int8Type>(state, "index_in_meta_binary", state.range(0));
SetLookupBenchmarkNumeric<Int8Type>(state, "index_in_meta_binary", state.range(0),
kArrayLengthWithSmallSet);
}

static void IndexInInt16SmallSet(benchmark::State& state) {
SetLookupBenchmarkNumeric<Int16Type>(state, "index_in_meta_binary", state.range(0));
SetLookupBenchmarkNumeric<Int16Type>(state, "index_in_meta_binary", state.range(0),
kArrayLengthWithSmallSet);
}

static void IndexInInt32SmallSet(benchmark::State& state) {
SetLookupBenchmarkNumeric<Int32Type>(state, "index_in_meta_binary", state.range(0));
SetLookupBenchmarkNumeric<Int32Type>(state, "index_in_meta_binary", state.range(0),
kArrayLengthWithSmallSet);
}

static void IndexInInt64SmallSet(benchmark::State& state) {
SetLookupBenchmarkNumeric<Int64Type>(state, "index_in_meta_binary", state.range(0));
SetLookupBenchmarkNumeric<Int64Type>(state, "index_in_meta_binary", state.range(0),
kArrayLengthWithSmallSet);
}

static void IndexInInt32LargeSet(benchmark::State& state) {
SetLookupBenchmarkNumeric<Int32Type>(state, "index_in_meta_binary", state.range(0),
kArrayLengthWithLargeSet);
}

static void IsInInt8SmallSet(benchmark::State& state) {
SetLookupBenchmarkNumeric<Int8Type>(state, "is_in_meta_binary", state.range(0));
SetLookupBenchmarkNumeric<Int8Type>(state, "is_in_meta_binary", state.range(0),
kArrayLengthWithSmallSet);
}

static void IsInInt16SmallSet(benchmark::State& state) {
SetLookupBenchmarkNumeric<Int16Type>(state, "is_in_meta_binary", state.range(0));
SetLookupBenchmarkNumeric<Int16Type>(state, "is_in_meta_binary", state.range(0),
kArrayLengthWithSmallSet);
}

static void IsInInt32SmallSet(benchmark::State& state) {
SetLookupBenchmarkNumeric<Int32Type>(state, "is_in_meta_binary", state.range(0));
SetLookupBenchmarkNumeric<Int32Type>(state, "is_in_meta_binary", state.range(0),
kArrayLengthWithSmallSet);
}

static void IsInInt64SmallSet(benchmark::State& state) {
SetLookupBenchmarkNumeric<Int64Type>(state, "is_in_meta_binary", state.range(0));
SetLookupBenchmarkNumeric<Int64Type>(state, "is_in_meta_binary", state.range(0),
kArrayLengthWithSmallSet);
}

static void IsInInt32LargeSet(benchmark::State& state) {
SetLookupBenchmarkNumeric<Int32Type>(state, "is_in_meta_binary", state.range(0),
kArrayLengthWithLargeSet);
}

BENCHMARK(IndexInStringSmallSet)->RangeMultiplier(4)->Range(2, 64);
Expand All @@ -134,10 +157,12 @@ BENCHMARK(IndexInInt8SmallSet)->RangeMultiplier(4)->Range(2, 8);
BENCHMARK(IndexInInt16SmallSet)->RangeMultiplier(4)->Range(2, 64);
BENCHMARK(IndexInInt32SmallSet)->RangeMultiplier(4)->Range(2, 64);
BENCHMARK(IndexInInt64SmallSet)->RangeMultiplier(4)->Range(2, 64);
BENCHMARK(IndexInInt32LargeSet)->RangeMultiplier(100)->Range(1000, 1000000);
BENCHMARK(IsInInt8SmallSet)->RangeMultiplier(4)->Range(2, 8);
BENCHMARK(IsInInt16SmallSet)->RangeMultiplier(4)->Range(2, 64);
BENCHMARK(IsInInt32SmallSet)->RangeMultiplier(4)->Range(2, 64);
BENCHMARK(IsInInt64SmallSet)->RangeMultiplier(4)->Range(2, 64);
BENCHMARK(IsInInt32LargeSet)->RangeMultiplier(100)->Range(1000, 1000000);

} // namespace compute
} // namespace arrow