-
Notifications
You must be signed in to change notification settings - Fork 4k
ARROW-10663: [C++] Fix is_in and index_in behaviour #9164
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -36,53 +36,59 @@ namespace { | |
|
|
||
| template <typename Type> | ||
| struct SetLookupState : public KernelState { | ||
| explicit SetLookupState(MemoryPool* pool) | ||
| : lookup_table(pool, 0), lookup_null_count(0) {} | ||
| explicit SetLookupState(const SetLookupOptions& options, MemoryPool* pool) | ||
| : options(options), lookup_table(pool, 0) {} | ||
|
|
||
| Status Init(const SetLookupOptions& options) { | ||
| Status Init() { | ||
| if (options.value_set.kind() == Datum::ARRAY) { | ||
| RETURN_NOT_OK(AddArrayValueSet(*options.value_set.array())); | ||
| } else if (options.value_set.kind() == Datum::CHUNKED_ARRAY) { | ||
| const ChunkedArray& value_set = *options.value_set.chunked_array(); | ||
| for (const std::shared_ptr<Array>& chunk : value_set.chunks()) { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We have a lot of code like this, maybe later we should add if (!options.values_set.is_arraylike()) {
return Status::Invalid("value_set should be an array or chunked array");
}
for (const std::shared_ptr<ArrayData>& chunk : options.value_set.chunks()) {
RETURN_NOT_OK(AddArrayValueSet(*chunk->data()));
}
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we're not too bothered by the cost of a temporary vector then it may be nice indeed. |
||
| RETURN_NOT_OK(AddArrayValueSet(*chunk->data())); | ||
| } | ||
| } else { | ||
| return Status::Invalid("value_set should be an array or chunked array"); | ||
| } | ||
| if (lookup_table.size() != options.value_set.length()) { | ||
| return Status::NotImplemented("duplicate values in value_set"); | ||
| } | ||
| value_set_has_null = (lookup_table.GetNull() >= 0); | ||
| return Status::OK(); | ||
| } | ||
|
|
||
| Status AddArrayValueSet(const ArrayData& data) { | ||
| using T = typename GetViewType<Type>::T; | ||
| auto visit_valid = [&](T v) { | ||
| int32_t unused_memo_index; | ||
| return lookup_table.GetOrInsert(v, &unused_memo_index); | ||
| }; | ||
| auto visit_null = [&]() { | ||
| if (!options.skip_nulls) { | ||
| lookup_table.GetOrInsertNull(); | ||
| } | ||
| lookup_table.GetOrInsertNull(); | ||
| return Status::OK(); | ||
| }; | ||
| if (options.value_set.kind() == Datum::ARRAY) { | ||
| const std::shared_ptr<ArrayData>& value_set = options.value_set.array(); | ||
| this->lookup_null_count += value_set->GetNullCount(); | ||
| return VisitArrayDataInline<Type>(*value_set, std::move(visit_valid), | ||
| std::move(visit_null)); | ||
| } else { | ||
| const ChunkedArray& value_set = *options.value_set.chunked_array(); | ||
| for (const std::shared_ptr<Array>& chunk : value_set.chunks()) { | ||
| this->lookup_null_count += chunk->null_count(); | ||
| RETURN_NOT_OK(VisitArrayDataInline<Type>(*chunk->data(), std::move(visit_valid), | ||
| std::move(visit_null))); | ||
| } | ||
| return Status::OK(); | ||
| } | ||
|
|
||
| return VisitArrayDataInline<Type>(data, visit_valid, visit_null); | ||
| } | ||
|
|
||
| using MemoTable = typename HashTraits<Type>::MemoTableType; | ||
| const SetLookupOptions& options; | ||
| MemoTable lookup_table; | ||
| int64_t lookup_null_count; | ||
| int64_t null_index = -1; | ||
| bool value_set_has_null; | ||
| }; | ||
|
|
||
| template <> | ||
| struct SetLookupState<NullType> : public KernelState { | ||
| explicit SetLookupState(MemoryPool*) {} | ||
| explicit SetLookupState(const SetLookupOptions& options, MemoryPool*) | ||
| : options(options) {} | ||
|
|
||
| Status Init(const SetLookupOptions& options) { | ||
| this->lookup_null_count = options.value_set.null_count(); | ||
| Status Init() { | ||
| this->value_set_has_null = (options.value_set.length() > 0); | ||
| return Status::OK(); | ||
| } | ||
|
|
||
| int64_t lookup_null_count; | ||
| const SetLookupOptions& options; | ||
| bool value_set_has_null; | ||
| }; | ||
|
|
||
| // TODO: Put this concept somewhere reusable | ||
|
|
@@ -125,8 +131,8 @@ struct InitStateVisitor { | |
| "Attempted to call a set lookup function without SetLookupOptions"); | ||
| } | ||
| using StateType = SetLookupState<Type>; | ||
| result.reset(new StateType(ctx->exec_context()->memory_pool())); | ||
| return static_cast<StateType*>(result.get())->Init(*options); | ||
| result.reset(new StateType(*options, ctx->exec_context()->memory_pool())); | ||
| return static_cast<StateType*>(result.get())->Init(); | ||
| } | ||
|
|
||
| Status Visit(const DataType&) { return Init<NullType>(); } | ||
|
|
@@ -174,16 +180,18 @@ struct IndexInVisitor { | |
| IndexInVisitor(KernelContext* ctx, const ArrayData& data, Datum* out) | ||
| : ctx(ctx), data(data), out(out), builder(ctx->exec_context()->memory_pool()) {} | ||
|
|
||
| Status Visit(const DataType&) { | ||
| Status Visit(const DataType& type) { | ||
| DCHECK_EQ(type.id(), Type::NA); | ||
| const auto& state = checked_cast<const SetLookupState<NullType>&>(*ctx->state()); | ||
| if (data.length != 0) { | ||
| if (state.lookup_null_count == 0) { | ||
| RETURN_NOT_OK(this->builder.AppendNulls(data.length)); | ||
| } else { | ||
| // skip_nulls is honored for consistency with other types | ||
| if (state.value_set_has_null && !state.options.skip_nulls) { | ||
| RETURN_NOT_OK(this->builder.Reserve(data.length)); | ||
| for (int64_t i = 0; i < data.length; ++i) { | ||
| this->builder.UnsafeAppend(0); | ||
| } | ||
| } else { | ||
| RETURN_NOT_OK(this->builder.AppendNulls(data.length)); | ||
| } | ||
| } | ||
| return Status::OK(); | ||
|
|
@@ -195,7 +203,7 @@ struct IndexInVisitor { | |
|
|
||
| const auto& state = checked_cast<const SetLookupState<Type>&>(*ctx->state()); | ||
|
|
||
| int32_t null_index = state.lookup_table.GetNull(); | ||
| int32_t null_index = state.options.skip_nulls ? -1 : state.lookup_table.GetNull(); | ||
| RETURN_NOT_OK(this->builder.Reserve(data.length)); | ||
| VisitArrayDataInline<Type>( | ||
| data, | ||
|
|
@@ -261,7 +269,7 @@ void ExecIndexIn(KernelContext* ctx, const ExecBatch& batch, Datum* out) { | |
|
|
||
| // ---------------------------------------------------------------------- | ||
|
|
||
| // IsIn writes the results into a preallocated binary data bitmap | ||
| // IsIn writes the results into a preallocated boolean data bitmap | ||
| struct IsInVisitor { | ||
| KernelContext* ctx; | ||
| const ArrayData& data; | ||
|
|
@@ -270,12 +278,12 @@ struct IsInVisitor { | |
| IsInVisitor(KernelContext* ctx, const ArrayData& data, Datum* out) | ||
| : ctx(ctx), data(data), out(out) {} | ||
|
|
||
| Status Visit(const DataType&) { | ||
| Status Visit(const DataType& type) { | ||
| DCHECK_EQ(type.id(), Type::NA); | ||
| const auto& state = checked_cast<const SetLookupState<NullType>&>(*ctx->state()); | ||
| ArrayData* output = out->mutable_array(); | ||
| if (state.lookup_null_count > 0) { | ||
| BitUtil::SetBitsTo(output->buffers[0]->mutable_data(), output->offset, | ||
| output->length, true); | ||
| // skip_nulls is honored for consistency with other types | ||
| if (state.value_set_has_null && !state.options.skip_nulls) { | ||
| BitUtil::SetBitsTo(output->buffers[1]->mutable_data(), output->offset, | ||
| output->length, true); | ||
| } else { | ||
|
|
@@ -291,13 +299,6 @@ struct IsInVisitor { | |
| const auto& state = checked_cast<const SetLookupState<Type>&>(*ctx->state()); | ||
| ArrayData* output = out->mutable_array(); | ||
|
|
||
| if (this->data.GetNullCount() > 0 && state.lookup_null_count > 0) { | ||
| // If there were nulls in the value set, set the whole validity bitmap to | ||
| // true | ||
| output->null_count = 0; | ||
| BitUtil::SetBitsTo(output->buffers[0]->mutable_data(), output->offset, | ||
| output->length, true); | ||
| } | ||
| FirstTimeBitmapWriter writer(output->buffers[1]->mutable_data(), output->offset, | ||
| output->length); | ||
| VisitArrayDataInline<Type>( | ||
|
|
@@ -311,7 +312,11 @@ struct IsInVisitor { | |
| writer.Next(); | ||
| }, | ||
| [&]() { | ||
| writer.Set(); | ||
| if (!state.options.skip_nulls && state.lookup_table.GetNull() != -1) { | ||
| writer.Set(); | ||
| } else { | ||
| writer.Clear(); | ||
| } | ||
| writer.Next(); | ||
| }); | ||
| writer.Finish(); | ||
|
|
@@ -412,33 +417,37 @@ class IndexInMetaBinary : public MetaFunction { | |
| const FunctionDoc is_in_doc{ | ||
| "Find each element in a set of values", | ||
| ("For each element in `values`, return true if it is found in a given\n" | ||
| "set of values. The set of values to look for must be given in\n" | ||
| "SetLookupOptions."), | ||
| "set of values, false otherwise.\n" | ||
| "The set of values to look for must be given in SetLookupOptions.\n" | ||
| "By default, nulls are matched against the value set, this can be\n" | ||
| "changed in SetLookupOptions."), | ||
| {"values"}, | ||
| "SetLookupOptions"}; | ||
|
|
||
| const FunctionDoc index_in_doc{ | ||
| "Return index of each element in a set of values", | ||
| ("For each element in `values`, return its index in a given set of\n" | ||
| "values, or null if it is not found there.\n" | ||
| "The set of values to look for must be given in SetLookupOptions."), | ||
| "The set of values to look for must be given in SetLookupOptions.\n" | ||
| "By default, nulls are matched against the value set, this can be\n" | ||
| "changed in SetLookupOptions."), | ||
| {"values"}, | ||
| "SetLookupOptions"}; | ||
|
|
||
| } // namespace | ||
|
|
||
| void RegisterScalarSetLookup(FunctionRegistry* registry) { | ||
| // IsIn always writes into preallocated memory | ||
| // IsIn writes its boolean output into preallocated memory | ||
| { | ||
| ScalarKernel isin_base; | ||
| isin_base.init = InitSetLookup; | ||
| isin_base.exec = TrivialScalarUnaryAsArraysExec(ExecIsIn); | ||
| isin_base.null_handling = NullHandling::OUTPUT_NOT_NULL; | ||
| auto is_in = std::make_shared<ScalarFunction>("is_in", Arity::Unary(), &is_in_doc); | ||
|
|
||
| AddBasicSetLookupKernels(isin_base, /*output_type=*/boolean(), is_in.get()); | ||
|
|
||
| isin_base.signature = KernelSignature::Make({null()}, boolean()); | ||
| isin_base.null_handling = NullHandling::COMPUTED_PREALLOCATE; | ||
| DCHECK_OK(is_in->AddKernel(isin_base)); | ||
| DCHECK_OK(registry->AddFunction(is_in)); | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could mention here that the default is
skip_nulls=False(or in the docstring of SetLookupOptions)?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, could maybe add a similar sentence "Behaviour of nulls is governed by SetLookupOptions::skip_nulls" to the
is_in_doc?