Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions cpp/src/arrow/array/array_nested.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#include "arrow/type_traits.h"
#include "arrow/util/atomic_shared_ptr.h"
#include "arrow/util/bit_util.h"
#include "arrow/util/bitmap_generate.h"
#include "arrow/util/bitmap_ops.h"
#include "arrow/util/checked_cast.h"
#include "arrow/util/logging.h"
Expand Down Expand Up @@ -650,6 +651,43 @@ SparseUnionArray::SparseUnionArray(std::shared_ptr<DataType> type, int64_t lengt
SetData(std::move(internal_data));
}

Result<std::shared_ptr<Array>> SparseUnionArray::GetFlattenedField(
int index, MemoryPool* pool) const {
if (index < 0 || index >= num_fields()) {
return Status::Invalid("Index out of range: ", index);
}
auto child_data = data_->child_data[index]->Copy();
// Adjust the result offset/length to be absolute.
if (data_->offset != 0 || data_->length != child_data->length) {
child_data = child_data->Slice(data_->offset, data_->length);
}
std::shared_ptr<Buffer> child_null_bitmap = child_data->buffers[0];
const int64_t child_offset = child_data->offset;

// Synthesize a null bitmap based on the union discriminant.
// Make sure the bitmap has extra bits corresponding to the child offset.
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Buffer> flattened_null_bitmap,
AllocateEmptyBitmap(child_data->length + child_offset, pool));
const int8_t type_code = union_type()->type_codes()[index];
const int8_t* type_codes = raw_type_codes();
int64_t offset = 0;
internal::GenerateBitsUnrolled(flattened_null_bitmap->mutable_data(), child_offset,
data_->length,
[&] { return type_codes[offset++] == type_code; });

// The validity of a flattened datum is the logical AND of the synthesized
// null bitmap buffer and the individual field element's validity.
if (child_null_bitmap) {
BitmapAnd(flattened_null_bitmap->data(), child_offset, child_null_bitmap->data(),
child_offset, child_data->length, child_offset,
flattened_null_bitmap->mutable_data());
}

child_data->buffers[0] = std::move(flattened_null_bitmap);
child_data->null_count = kUnknownNullCount;
return MakeArray(child_data);
}

DenseUnionArray::DenseUnionArray(const std::shared_ptr<ArrayData>& data) {
SetData(data);
}
Expand Down
8 changes: 8 additions & 0 deletions cpp/src/arrow/array/array_nested.h
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,14 @@ class ARROW_EXPORT SparseUnionArray : public UnionArray {
return internal::checked_cast<const SparseUnionType*>(union_type_);
}

/// \brief Get one of the child arrays, adjusting its null bitmap
/// where the union array type code does not match.
///
/// \param[in] index Which child array to get (i.e. the physical index, not the type
/// code) \param[in] pool The pool to allocate null bitmaps from, if necessary
Result<std::shared_ptr<Array>> GetFlattenedField(
int index, MemoryPool* pool = default_memory_pool()) const;

protected:
void SetData(std::shared_ptr<ArrayData> data);
};
Expand Down
65 changes: 65 additions & 0 deletions cpp/src/arrow/array/array_union_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
namespace arrow {

using internal::checked_cast;
using internal::checked_pointer_cast;

TEST(TestUnionArray, TestSliceEquals) {
std::shared_ptr<RecordBatch> batch;
Expand Down Expand Up @@ -68,6 +69,70 @@ TEST(TestUnionArray, TestSliceEquals) {
CheckUnion(batch->column(1));
}

TEST(TestSparseUnionArray, GetFlattenedField) {
auto ty = sparse_union({field("ints", int64()), field("strs", utf8())}, {2, 7});
auto ints = ArrayFromJSON(int64(), "[0, 1, 2, 3]");
auto strs = ArrayFromJSON(utf8(), R"(["a", null, "c", "d"])");
auto ids = ArrayFromJSON(int8(), "[2, 7, 2, 7]")->data()->buffers[1];
const int length = 4;

{
SparseUnionArray arr(ty, length, {ints, strs}, ids);
ASSERT_OK(arr.ValidateFull());

ASSERT_OK_AND_ASSIGN(auto flattened, arr.GetFlattenedField(0));
AssertArraysEqual(*ArrayFromJSON(int64(), "[0, null, 2, null]"), *flattened,
/*verbose=*/true);

ASSERT_OK_AND_ASSIGN(flattened, arr.GetFlattenedField(1));
AssertArraysEqual(*ArrayFromJSON(utf8(), R"([null, null, null, "d"])"), *flattened,
/*verbose=*/true);

const auto sliced = checked_pointer_cast<SparseUnionArray>(arr.Slice(1, 2));

ASSERT_OK_AND_ASSIGN(flattened, sliced->GetFlattenedField(0));
AssertArraysEqual(*ArrayFromJSON(int64(), "[null, 2]"), *flattened, /*verbose=*/true);

ASSERT_OK_AND_ASSIGN(flattened, sliced->GetFlattenedField(1));
AssertArraysEqual(*ArrayFromJSON(utf8(), R"([null, null])"), *flattened,
/*verbose=*/true);

ASSERT_RAISES(Invalid, arr.GetFlattenedField(-1));
ASSERT_RAISES(Invalid, arr.GetFlattenedField(2));
}
{
SparseUnionArray arr(ty, length - 2, {ints->Slice(1, 2), strs->Slice(1, 2)}, ids);
ASSERT_OK(arr.ValidateFull());

ASSERT_OK_AND_ASSIGN(auto flattened, arr.GetFlattenedField(0));
AssertArraysEqual(*ArrayFromJSON(int64(), "[1, null]"), *flattened, /*verbose=*/true);

ASSERT_OK_AND_ASSIGN(flattened, arr.GetFlattenedField(1));
AssertArraysEqual(*ArrayFromJSON(utf8(), R"([null, "c"])"), *flattened,
/*verbose=*/true);

const auto sliced = checked_pointer_cast<SparseUnionArray>(arr.Slice(1, 1));

ASSERT_OK_AND_ASSIGN(flattened, sliced->GetFlattenedField(0));
AssertArraysEqual(*ArrayFromJSON(int64(), "[null]"), *flattened, /*verbose=*/true);

ASSERT_OK_AND_ASSIGN(flattened, sliced->GetFlattenedField(1));
AssertArraysEqual(*ArrayFromJSON(utf8(), R"(["c"])"), *flattened, /*verbose=*/true);
}
{
SparseUnionArray arr(ty, /*length=*/0, {ints->Slice(length), strs->Slice(length)},
ids);
ASSERT_OK(arr.ValidateFull());

ASSERT_OK_AND_ASSIGN(auto flattened, arr.GetFlattenedField(0));
AssertArraysEqual(*ArrayFromJSON(int64(), "[]"), *flattened, /*verbose=*/true);

ASSERT_OK_AND_ASSIGN(flattened, arr.GetFlattenedField(1));
AssertArraysEqual(*ArrayFromJSON(utf8(), "[]"), *flattened,
/*verbose=*/true);
}
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also test with an empty union array?


TEST(TestSparseUnionArray, Validate) {
auto a = ArrayFromJSON(int32(), "[4, 5]");
auto type = sparse_union({field("a", int32())});
Expand Down
8 changes: 8 additions & 0 deletions cpp/src/arrow/compute/api_scalar.cc
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,8 @@ static auto kExtractRegexOptionsType = GetFunctionOptionsType<ExtractRegexOption
static auto kSetLookupOptionsType = GetFunctionOptionsType<SetLookupOptions>(
DataMember("value_set", &SetLookupOptions::value_set),
DataMember("skip_nulls", &SetLookupOptions::skip_nulls));
static auto kStructFieldOptionsType = GetFunctionOptionsType<StructFieldOptions>(
DataMember("indices", &StructFieldOptions::indices));
static auto kStrptimeOptionsType = GetFunctionOptionsType<StrptimeOptions>(
DataMember("format", &StrptimeOptions::format),
DataMember("unit", &StrptimeOptions::unit));
Expand Down Expand Up @@ -351,6 +353,11 @@ SetLookupOptions::SetLookupOptions(Datum value_set, bool skip_nulls)
SetLookupOptions::SetLookupOptions() : SetLookupOptions({}, false) {}
constexpr char SetLookupOptions::kTypeName[];

StructFieldOptions::StructFieldOptions(std::vector<int> indices)
: FunctionOptions(internal::kStructFieldOptionsType), indices(std::move(indices)) {}
StructFieldOptions::StructFieldOptions() : StructFieldOptions(std::vector<int>()) {}
constexpr char StructFieldOptions::kTypeName[];

StrptimeOptions::StrptimeOptions(std::string format, TimeUnit::type unit)
: FunctionOptions(internal::kStrptimeOptionsType),
format(std::move(format)),
Expand Down Expand Up @@ -444,6 +451,7 @@ void RegisterScalarOptions(FunctionRegistry* registry) {
DCHECK_OK(registry->AddFunctionOptionsType(kReplaceSubstringOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kExtractRegexOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kSetLookupOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kStructFieldOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kStrptimeOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kStrftimeOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kAssumeTimezoneOptionsType));
Expand Down
12 changes: 12 additions & 0 deletions cpp/src/arrow/compute/api_scalar.h
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,18 @@ class ARROW_EXPORT SetLookupOptions : public FunctionOptions {
bool skip_nulls;
};

/// Options for struct_field function
class ARROW_EXPORT StructFieldOptions : public FunctionOptions {
public:
explicit StructFieldOptions(std::vector<int> indices);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder whether this should also accept a FieldRef or field resolution should be left to the caller.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FieldRef is relative to a schema so we'd want/need a variant of this function that operates on a RecordBatch.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But a plain string name could be useful for a StructArray?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'd still need a type to resolve it to an index, right? Unless you mean storing std::vector<std::string> or std::string directly? (Which might be reasonable.)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, storing it as strings on the options, I meant. Because when actually executing the kernel, the struct array itself can perfectly resolve the name I think?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's fair. Want to file a followup? I think we can support having a FieldRef internally, basically. (Though the interpretation will be a little different - it'll be relative to an array, not a schema.)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that I think of it, it's probably better to resolve the field up front (using the schema) than pay the cost for every kernel invocation with the same schema.

StructFieldOptions();
constexpr static char const kTypeName[] = "StructFieldOptions";

/// The child indices to extract. For instance, to get the 2nd child
/// of the 1st child of a struct or union, this would be {0, 1}.
std::vector<int> indices;
};

class ARROW_EXPORT StrptimeOptions : public FunctionOptions {
public:
explicit StrptimeOptions(std::string format, TimeUnit::type unit);
Expand Down
28 changes: 5 additions & 23 deletions cpp/src/arrow/compute/exec/expression.cc
Original file line number Diff line number Diff line change
Expand Up @@ -510,29 +510,11 @@ Result<Datum> ExecuteScalarExpression(const Expression& expr, const ExecBatch& i
}

Datum field = input[param->indices[0]];
for (auto it = param->indices.begin() + 1; it != param->indices.end(); ++it) {
if (field.type()->id() != Type::STRUCT) {
return Status::Invalid("Nested field reference into a non-struct: ",
*field.type());
}
const int index = *it;
if (index < 0 || index >= field.type()->num_fields()) {
return Status::Invalid("Out of bounds field reference: ", index, " but type has ",
field.type()->num_fields(), " fields");
}
if (field.is_scalar()) {
const auto& struct_scalar = field.scalar_as<StructScalar>();
if (!struct_scalar.is_valid) {
return MakeNullScalar(param->descr.type);
}
field = struct_scalar.value[index];
} else if (field.is_array()) {
const auto& struct_array = field.array_as<StructArray>();
ARROW_ASSIGN_OR_RAISE(
field, struct_array->GetFlattenedField(index, exec_context->memory_pool()));
} else {
return Status::NotImplemented("Nested field reference into a ", field.ToString());
}
if (param->indices.size() > 1) {
std::vector<int> indices(param->indices.begin() + 1, param->indices.end());
compute::StructFieldOptions options(std::move(indices));
ARROW_ASSIGN_OR_RAISE(
field, compute::CallFunction("struct_field", {std::move(field)}, &options));
}
if (!field.type()->Equals(param->descr.type)) {
return Status::Invalid("Referenced field ", expr.ToString(), " was ",
Expand Down
Loading