diff --git a/cpp/src/arrow/array/array_nested.cc b/cpp/src/arrow/array/array_nested.cc index 22ad728a4ec..2b4006961c7 100644 --- a/cpp/src/arrow/array/array_nested.cc +++ b/cpp/src/arrow/array/array_nested.cc @@ -541,56 +541,63 @@ std::shared_ptr StructArray::GetFieldByName(const std::string& name) cons Result StructArray::Flatten(MemoryPool* pool) const { ArrayVector flattened; - flattened.reserve(data_->child_data.size()); + flattened.resize(data_->child_data.size()); std::shared_ptr null_bitmap = data_->buffers[0]; - for (const auto& child_data_ptr : data_->child_data) { - auto child_data = child_data_ptr->Copy(); + for (int i = 0; static_cast(i) < data_->child_data.size(); i++) { + ARROW_ASSIGN_OR_RAISE(flattened[i], GetFlattenedField(i, pool)); + } - std::shared_ptr flattened_null_bitmap; - int64_t flattened_null_count = kUnknownNullCount; + return flattened; +} - // Need to adjust for parent offset - if (data_->offset != 0 || data_->length != child_data->length) { - child_data = child_data->Slice(data_->offset, data_->length); - } - std::shared_ptr child_null_bitmap = child_data->buffers[0]; - const int64_t child_offset = child_data->offset; - - // The validity of a flattened datum is the logical AND of the struct - // element's validity and the individual field element's validity. - if (null_bitmap && child_null_bitmap) { - ARROW_ASSIGN_OR_RAISE( - flattened_null_bitmap, - BitmapAnd(pool, child_null_bitmap->data(), child_offset, null_bitmap_data_, - data_->offset, data_->length, child_offset)); - } else if (child_null_bitmap) { - flattened_null_bitmap = child_null_bitmap; - flattened_null_count = child_data->null_count; - } else if (null_bitmap) { - if (child_offset == data_->offset) { - flattened_null_bitmap = null_bitmap; - } else { - // If the child has an offset, need to synthesize a validity - // buffer with an offset too - ARROW_ASSIGN_OR_RAISE(flattened_null_bitmap, - AllocateEmptyBitmap(child_offset + data_->length, pool)); - CopyBitmap(null_bitmap_data_, data_->offset, data_->length, - flattened_null_bitmap->mutable_data(), child_offset); - } - flattened_null_count = data_->null_count; - } else { - flattened_null_count = 0; - } +Result> StructArray::GetFlattenedField(int index, + MemoryPool* pool) const { + std::shared_ptr null_bitmap = data_->buffers[0]; + + auto child_data = data_->child_data[index]->Copy(); - auto flattened_data = child_data->Copy(); - flattened_data->buffers[0] = flattened_null_bitmap; - flattened_data->null_count = flattened_null_count; + std::shared_ptr flattened_null_bitmap; + int64_t flattened_null_count = kUnknownNullCount; - flattened.push_back(MakeArray(flattened_data)); + // Need to adjust for parent offset + if (data_->offset != 0 || data_->length != child_data->length) { + child_data = child_data->Slice(data_->offset, data_->length); } + std::shared_ptr child_null_bitmap = child_data->buffers[0]; + const int64_t child_offset = child_data->offset; - return flattened; + // The validity of a flattened datum is the logical AND of the struct + // element's validity and the individual field element's validity. + if (null_bitmap && child_null_bitmap) { + ARROW_ASSIGN_OR_RAISE( + flattened_null_bitmap, + BitmapAnd(pool, child_null_bitmap->data(), child_offset, null_bitmap_data_, + data_->offset, data_->length, child_offset)); + } else if (child_null_bitmap) { + flattened_null_bitmap = child_null_bitmap; + flattened_null_count = child_data->null_count; + } else if (null_bitmap) { + if (child_offset == data_->offset) { + flattened_null_bitmap = null_bitmap; + } else { + // If the child has an offset, need to synthesize a validity + // buffer with an offset too + ARROW_ASSIGN_OR_RAISE(flattened_null_bitmap, + AllocateEmptyBitmap(child_offset + data_->length, pool)); + CopyBitmap(null_bitmap_data_, data_->offset, data_->length, + flattened_null_bitmap->mutable_data(), child_offset); + } + flattened_null_count = data_->null_count; + } else { + flattened_null_count = 0; + } + + auto flattened_data = child_data->Copy(); + flattened_data->buffers[0] = flattened_null_bitmap; + flattened_data->null_count = flattened_null_count; + + return MakeArray(flattened_data); } // ---------------------------------------------------------------------- diff --git a/cpp/src/arrow/array/array_nested.h b/cpp/src/arrow/array/array_nested.h index 97e470f550a..178a0589d5a 100644 --- a/cpp/src/arrow/array/array_nested.h +++ b/cpp/src/arrow/array/array_nested.h @@ -370,6 +370,14 @@ class ARROW_EXPORT StructArray : public Array { /// \param[in] pool The pool to allocate null bitmaps from, if necessary Result Flatten(MemoryPool* pool = default_memory_pool()) const; + /// \brief Get one of the child arrays, combining its null bitmap + /// with the parent struct array's bitmap. + /// + /// \param[in] index Which child array to get + /// \param[in] pool The pool to allocate null bitmaps from, if necessary + Result> GetFlattenedField( + int index, MemoryPool* pool = default_memory_pool()) const; + private: // For caching boxed child data // XXX This is not handled in a thread-safe manner. diff --git a/cpp/src/arrow/compute/exec/expression.cc b/cpp/src/arrow/compute/exec/expression.cc index 64e3305825d..03db24b5413 100644 --- a/cpp/src/arrow/compute/exec/expression.cc +++ b/cpp/src/arrow/compute/exec/expression.cc @@ -63,7 +63,7 @@ Expression::Expression(Parameter parameter) Expression literal(Datum lit) { return Expression(std::move(lit)); } Expression field_ref(FieldRef ref) { - return Expression(Expression::Parameter{std::move(ref), ValueDescr{}, -1}); + return Expression(Expression::Parameter{std::move(ref), ValueDescr{}, {-1}}); } Expression call(std::string function, std::vector arguments, @@ -394,14 +394,11 @@ Result BindImpl(Expression expr, const TypeOrSchema& in, if (expr.literal()) return expr; if (auto ref = expr.field_ref()) { - if (ref->IsNested()) { - return Status::NotImplemented("nested field references"); - } - ARROW_ASSIGN_OR_RAISE(auto path, ref->FindOne(in)); auto bound = *expr.parameter(); - bound.index = path[0]; + bound.indices.resize(path.indices().size()); + std::copy(path.indices().begin(), path.indices().end(), bound.indices.begin()); ARROW_ASSIGN_OR_RAISE(auto field, path.Get(in)); bound.descr.type = field->type(); bound.descr.shape = shape; @@ -512,7 +509,31 @@ Result ExecuteScalarExpression(const Expression& expr, const ExecBatch& i return MakeNullScalar(null()); } - const Datum& field = input[param->index]; + Datum field = input[param->indices[0]]; + for (auto it = param->indices.begin() + 1; it != param->indices.end(); ++it) { + if (field.type()->id() != Type::STRUCT) { + return Status::Invalid("Nested field reference into a non-struct: ", + *field.type()); + } + const int index = *it; + if (index < 0 || index >= field.type()->num_fields()) { + return Status::Invalid("Out of bounds field reference: ", index, " but type has ", + field.type()->num_fields(), " fields"); + } + if (field.is_scalar()) { + const auto& struct_scalar = field.scalar_as(); + if (!struct_scalar.is_valid) { + return MakeNullScalar(param->descr.type); + } + field = struct_scalar.value[index]; + } else if (field.is_array()) { + const auto& struct_array = field.array_as(); + ARROW_ASSIGN_OR_RAISE( + field, struct_array->GetFlattenedField(index, exec_context->memory_pool())); + } else { + return Status::NotImplemented("Nested field reference into a ", field.ToString()); + } + } if (!field.type()->Equals(param->descr.type)) { return Status::Invalid("Referenced field ", expr.ToString(), " was ", field.type()->ToString(), " but should have been ", diff --git a/cpp/src/arrow/compute/exec/expression.h b/cpp/src/arrow/compute/exec/expression.h index dac5728ab46..7c567cc8fc6 100644 --- a/cpp/src/arrow/compute/exec/expression.h +++ b/cpp/src/arrow/compute/exec/expression.h @@ -27,6 +27,7 @@ #include "arrow/compute/type_fwd.h" #include "arrow/datum.h" #include "arrow/type_fwd.h" +#include "arrow/util/small_vector.h" #include "arrow/util/variant.h" namespace arrow { @@ -112,7 +113,7 @@ class ARROW_EXPORT Expression { // post-bind properties ValueDescr descr; - int index; + internal::SmallVector indices; }; const Parameter* parameter() const; diff --git a/cpp/src/arrow/compute/exec/expression_benchmark.cc b/cpp/src/arrow/compute/exec/expression_benchmark.cc index 1899b7caab6..d1738c9c23c 100644 --- a/cpp/src/arrow/compute/exec/expression_benchmark.cc +++ b/cpp/src/arrow/compute/exec/expression_benchmark.cc @@ -19,6 +19,7 @@ #include "arrow/compute/cast.h" #include "arrow/compute/exec/expression.h" +#include "arrow/compute/exec/test_util.h" #include "arrow/dataset/partition.h" #include "arrow/testing/gtest_util.h" #include "arrow/type.h" @@ -29,6 +30,34 @@ namespace compute { std::shared_ptr ninety_nine_dict = DictionaryScalar::Make(MakeScalar(0), ArrayFromJSON(int64(), "[99]")); +static void BindAndEvaluate(benchmark::State& state, Expression expr) { + ExecContext ctx; + auto struct_type = struct_({ + field("int", int64()), + field("float", float64()), + }); + auto dataset_schema = schema({ + field("int_arr", int64()), + field("struct_arr", struct_type), + field("int_scalar", int64()), + field("struct_scalar", struct_type), + }); + ExecBatch input( + { + Datum(ArrayFromJSON(int64(), "[0, 2, 4, 8]")), + Datum(ArrayFromJSON(struct_type, + "[[0, 2.0], [4, 8.0], [16, 32.0], [64, 128.0]]")), + Datum(ScalarFromJSON(int64(), "16")), + Datum(ScalarFromJSON(struct_type, "[32, 64.0]")), + }, + /*length=*/4); + + for (auto _ : state) { + ASSIGN_OR_ABORT(auto bound, expr.Bind(*dataset_schema)); + ABORT_NOT_OK(ExecuteScalarExpression(bound, input, &ctx).status()); + } +} + // A benchmark of SimplifyWithGuarantee using expressions arising from partitioning. static void SimplifyFilterWithGuarantee(benchmark::State& state, Expression filter, Expression guarantee) { @@ -84,5 +113,12 @@ BENCHMARK_CAPTURE(SimplifyFilterWithGuarantee, BENCHMARK_CAPTURE(SimplifyFilterWithGuarantee, positive_filter_cast_guarantee_dictionary, filter_cast_positive, guarantee_dictionary); +BENCHMARK_CAPTURE(BindAndEvaluate, simple_array, field_ref("int_arr")); +BENCHMARK_CAPTURE(BindAndEvaluate, simple_scalar, field_ref("int_scalar")); +BENCHMARK_CAPTURE(BindAndEvaluate, nested_array, + field_ref(FieldRef("struct_arr", "float"))); +BENCHMARK_CAPTURE(BindAndEvaluate, nested_scalar, + field_ref(FieldRef("struct_scalar", "float"))); + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/exec/expression_test.cc b/cpp/src/arrow/compute/exec/expression_test.cc index 88b94e80434..94ca2074835 100644 --- a/cpp/src/arrow/compute/exec/expression_test.cc +++ b/cpp/src/arrow/compute/exec/expression_test.cc @@ -476,15 +476,16 @@ TEST(Expression, BindLiteral) { } void ExpectBindsTo(Expression expr, util::optional expected, - Expression* bound_out = nullptr) { + Expression* bound_out = nullptr, + const Schema& schema = *kBoringSchema) { if (!expected) { expected = expr; } - ASSERT_OK_AND_ASSIGN(auto bound, expr.Bind(*kBoringSchema)); + ASSERT_OK_AND_ASSIGN(auto bound, expr.Bind(schema)); EXPECT_TRUE(bound.IsBound()); - ASSERT_OK_AND_ASSIGN(expected, expected->Bind(*kBoringSchema)); + ASSERT_OK_AND_ASSIGN(expected, expected->Bind(schema)); EXPECT_EQ(bound, *expected) << " unbound: " << expr.ToString(); if (bound_out) { @@ -508,11 +509,24 @@ TEST(Expression, BindFieldRef) { // in the input schema ASSERT_RAISES(Invalid, field_ref("alpha").Bind(Schema( {field("alpha", int32()), field("alpha", float32())}))); +} + +TEST(Expression, BindNestedFieldRef) { + Expression expr; + auto schema = Schema({field("a", struct_({field("b", int32())}))}); + + ExpectBindsTo(field_ref(FieldRef("a", "b")), no_change, &expr, schema); + EXPECT_TRUE(expr.IsBound()); + EXPECT_EQ(expr.descr(), ValueDescr::Array(int32())); - // referencing nested fields is not supported - ASSERT_RAISES(NotImplemented, - field_ref(FieldRef("a", "b")) - .Bind(Schema({field("a", struct_({field("b", int32())}))}))); + ExpectBindsTo(field_ref(FieldRef(FieldPath({0, 0}))), no_change, &expr, schema); + EXPECT_TRUE(expr.IsBound()); + EXPECT_EQ(expr.descr(), ValueDescr::Array(int32())); + + ASSERT_RAISES(Invalid, field_ref(FieldPath({0, 1})).Bind(schema)); + ASSERT_RAISES(Invalid, field_ref(FieldRef("a", "b")) + .Bind(Schema({field("a", struct_({field("b", int32()), + field("b", int64())}))}))); } TEST(Expression, BindCall) { @@ -614,6 +628,45 @@ TEST(Expression, ExecuteFieldRef) { {"a": -1, "b": 4.0} ])"), ArrayFromJSON(float64(), R"([7.5, 2.125, 4.0])")); + + ExpectRefIs(FieldRef(FieldPath({0, 0})), + ArrayFromJSON(struct_({field("a", struct_({field("b", float64())}))}), R"([ + {"a": {"b": 6.125}}, + {"a": {"b": 0.0}}, + {"a": {"b": -1}} + ])"), + ArrayFromJSON(float64(), R"([6.125, 0.0, -1])")); + + ExpectRefIs(FieldRef("a", "b"), + ArrayFromJSON(struct_({field("a", struct_({field("b", float64())}))}), R"([ + {"a": {"b": 6.125}}, + {"a": {"b": 0.0}}, + {"a": {"b": -1}} + ])"), + ArrayFromJSON(float64(), R"([6.125, 0.0, -1])")); + + ExpectRefIs(FieldRef("a", "b"), + ArrayFromJSON(struct_({field("a", struct_({field("b", float64())}))}), R"([ + {"a": {"b": 6.125}}, + {"a": null}, + {"a": {"b": null}} + ])"), + ArrayFromJSON(float64(), R"([6.125, null, null])")); + + ExpectRefIs( + FieldRef("a", "b"), + ScalarFromJSON(struct_({field("a", struct_({field("b", float64())}))}), "[[64.0]]"), + ScalarFromJSON(float64(), "64.0")); + + ExpectRefIs( + FieldRef("a", "b"), + ScalarFromJSON(struct_({field("a", struct_({field("b", float64())}))}), "[[null]]"), + ScalarFromJSON(float64(), "null")); + + ExpectRefIs( + FieldRef("a", "b"), + ScalarFromJSON(struct_({field("a", struct_({field("b", float64())}))}), "[null]"), + ScalarFromJSON(float64(), "null")); } Result NaiveExecuteScalarExpression(const Expression& expr, const Datum& input) { @@ -697,6 +750,18 @@ TEST(Expression, ExecuteCall) { {"a": 0.0}, {"a": -1} ])")); + + ExpectExecute( + call("add", {field_ref(FieldRef("a", "a")), field_ref(FieldRef("a", "b"))}), + ArrayFromJSON(struct_({field("a", struct_({ + field("a", float64()), + field("b", float64()), + }))}), + R"([ + {"a": {"a": 6.125, "b": 3.375}}, + {"a": {"a": 0.0, "b": 1}}, + {"a": {"a": -1, "b": 4.75}} + ])")); } TEST(Expression, ExecuteDictionaryTransparent) {