diff --git a/cpp/src/arrow/dataset/expression.cc b/cpp/src/arrow/dataset/expression.cc index 627477b3038..514c2c41ecc 100644 --- a/cpp/src/arrow/dataset/expression.cc +++ b/cpp/src/arrow/dataset/expression.cc @@ -1081,6 +1081,21 @@ Result> Serialize(const Expression& expr) { return std::to_string(ret); } + Status VisitFieldRef(const FieldRef& ref) { + if (ref.nested()) { + metadata_->Append("nested_field_ref", std::to_string(ref.nested()->size())); + for (const auto& child : *ref.nested()) { + RETURN_NOT_OK(VisitFieldRef(child)); + } + return Status::OK(); + } + if (!ref.name()) { + return Status::NotImplemented("Serialization of non-name field_refs"); + } + metadata_->Append("field_ref", *ref.name()); + return Status::OK(); + } + Status Visit(const Expression& expr) { if (auto lit = expr.literal()) { if (!lit->is_scalar()) { @@ -1092,11 +1107,7 @@ Result> Serialize(const Expression& expr) { } if (auto ref = expr.field_ref()) { - if (!ref->name()) { - return Status::NotImplemented("Serialization of non-name field_refs"); - } - metadata_->Append("field_ref", *ref->name()); - return Status::OK(); + return VisitFieldRef(*ref); } auto call = CallNotNull(expr); @@ -1154,9 +1165,13 @@ Result Deserialize(std::shared_ptr buffer) { const KeyValueMetadata& metadata() { return *batch_.schema()->metadata(); } + bool ParseInteger(const std::string& s, int32_t* value) { + return internal::ParseValue(s.data(), s.length(), value); + } + Result> GetScalar(const std::string& i) { int32_t column_index; - if (!internal::ParseValue(i.data(), i.length(), &column_index)) { + if (!ParseInteger(i, &column_index)) { return Status::Invalid("Couldn't parse column_index"); } if (column_index >= batch_.num_columns()) { @@ -1179,6 +1194,26 @@ Result Deserialize(std::shared_ptr buffer) { return literal(std::move(scalar)); } + if (key == "nested_field_ref") { + int32_t size; + if (!ParseInteger(value, &size)) { + return Status::Invalid("Couldn't parse nested field ref length"); + } + if (size <= 0) { + return Status::Invalid("nested field ref length must be > 0"); + } + std::vector nested; + nested.reserve(size); + while (size-- > 0) { + ARROW_ASSIGN_OR_RAISE(auto ref, GetOne()); + if (!ref.field_ref()) { + return Status::Invalid("invalid nested field ref"); + } + nested.push_back(*ref.field_ref()); + } + return field_ref(FieldRef(std::move(nested))); + } + if (key == "field_ref") { return field_ref(value); } diff --git a/cpp/src/arrow/dataset/expression_test.cc b/cpp/src/arrow/dataset/expression_test.cc index 2ab796b052f..d519b1b40dd 100644 --- a/cpp/src/arrow/dataset/expression_test.cc +++ b/cpp/src/arrow/dataset/expression_test.cc @@ -1180,12 +1180,18 @@ TEST(Expression, SerializationRoundTrips) { ExpectRoundTrips(field_ref("field")); + ExpectRoundTrips(field_ref(FieldRef("foo", "bar", "baz"))); + ExpectRoundTrips(greater(field_ref("a"), literal(0.25))); ExpectRoundTrips( or_({equal(field_ref("a"), literal(1)), not_equal(field_ref("b"), literal("hello")), equal(field_ref("b"), literal("foo bar"))})); + ExpectRoundTrips(or_({equal(field_ref(FieldRef("a", "b")), literal(1)), + not_equal(field_ref("b"), literal("hello")), + equal(field_ref(FieldRef("c", "d")), literal("foo bar"))})); + ExpectRoundTrips(not_(field_ref("alpha"))); ExpectRoundTrips(call("is_in", {literal(1)}, diff --git a/cpp/src/arrow/type.h b/cpp/src/arrow/type.h index 1d3d1e27f92..6e47e8e10ee 100644 --- a/cpp/src/arrow/type.h +++ b/cpp/src/arrow/type.h @@ -1504,6 +1504,11 @@ class ARROW_EXPORT FieldRef { /// Equivalent to a single index string of indices. FieldRef(int index) : impl_(FieldPath({index})) {} // NOLINT runtime/explicit + /// Construct a nested FieldRef. + FieldRef(std::vector refs) { // NOLINT runtime/explicit + Flatten(std::move(refs)); + } + /// Convenience constructor for nested FieldRefs: each argument will be used to /// construct a FieldRef template @@ -1560,6 +1565,11 @@ class ARROW_EXPORT FieldRef { const std::string* name() const { return IsName() ? &util::get(impl_) : NULLPTR; } + const std::vector* nested() const { + return util::holds_alternative>(impl_) + ? &util::get>(impl_) + : NULLPTR; + } /// \brief Retrieve FieldPath of every child field which matches this FieldRef. std::vector FindAll(const Schema& schema) const; diff --git a/python/pyarrow/_dataset.pyx b/python/pyarrow/_dataset.pyx index 668176a46ae..9d197df82db 100644 --- a/python/pyarrow/_dataset.pyx +++ b/python/pyarrow/_dataset.pyx @@ -233,7 +233,23 @@ cdef class Expression(_Weakrefable): @staticmethod def _field(str name not None): - return Expression.wrap(CMakeFieldExpression(tobytes(name))) + cdef: + CFieldRef c_field + + c_field = CFieldRef( tobytes(name)) + return Expression.wrap(CMakeFieldExpression(c_field)) + + @staticmethod + def _nested_field(tuple names not None): + cdef: + vector[CFieldRef] nested + + if len(names) == 0: + raise ValueError("nested field reference should be non-empty") + nested.reserve(len(names)) + for name in names: + nested.push_back(CFieldRef( tobytes(name))) + return Expression.wrap(CMakeFieldExpression(CFieldRef(move(nested)))) @staticmethod def _scalar(value): diff --git a/python/pyarrow/dataset.py b/python/pyarrow/dataset.py index 195d414b047..84451757969 100644 --- a/python/pyarrow/dataset.py +++ b/python/pyarrow/dataset.py @@ -61,16 +61,26 @@ def field(name): Stores only the field's name. Type and other information is known only when the expression is bound to a dataset having an explicit scheme. + Nested references are allowed by passing a tuple of names. + For example ``('foo', 'bar')`` references the field named "bar" inside + the field named "foo". + Parameters ---------- - name : string - The name of the field the expression references to. + name : string or tuple + The name of the (possibly nested) field the expression references to. Returns ------- field_expr : Expression """ - return Expression._field(name) + if isinstance(name, str): + return Expression._field(name) + elif isinstance(name, tuple): + return Expression._nested_field(name) + else: + raise TypeError( + f"field reference should be str or tuple, got {type(name)}") def scalar(value): diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index 61deb658b0c..f7b0d329510 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -418,6 +418,7 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil: CFieldRef() CFieldRef(c_string name) CFieldRef(int index) + CFieldRef(vector[CFieldRef]) const c_string* name() const cdef cppclass CFieldRefHash" arrow::FieldRef::Hash": diff --git a/python/pyarrow/includes/libarrow_dataset.pxd b/python/pyarrow/includes/libarrow_dataset.pxd index 309d3530eec..07325d2c438 100644 --- a/python/pyarrow/includes/libarrow_dataset.pxd +++ b/python/pyarrow/includes/libarrow_dataset.pxd @@ -43,7 +43,7 @@ cdef extern from "arrow/dataset/api.h" namespace "arrow::dataset" nogil: "arrow::dataset::literal"(shared_ptr[CScalar] value) cdef CExpression CMakeFieldExpression \ - "arrow::dataset::field_ref"(c_string name) + "arrow::dataset::field_ref"(CFieldRef) cdef CExpression CMakeCallExpression \ "arrow::dataset::call"(c_string function, diff --git a/python/pyarrow/tests/test_dataset.py b/python/pyarrow/tests/test_dataset.py index 8e6bf9c3217..c5fe1be8cfe 100644 --- a/python/pyarrow/tests/test_dataset.py +++ b/python/pyarrow/tests/test_dataset.py @@ -92,13 +92,15 @@ def mockfs(): list(range(5)), list(map(float, range(5))), list(map(str, range(5))), - [i] * 5 + [i] * 5, + [{'a': j % 3, 'b': str(j % 3)} for j in range(5)], ] schema = pa.schema([ - pa.field('i64', pa.int64()), - pa.field('f64', pa.float64()), - pa.field('str', pa.string()), - pa.field('const', pa.int64()), + ('i64', pa.int64()), + ('f64', pa.float64()), + ('str', pa.string()), + ('const', pa.int64()), + ('struct', pa.struct({'a': pa.int64(), 'b': pa.string()})), ]) batch = pa.record_batch(data, schema=schema) table = pa.Table.from_batches([batch]) @@ -328,6 +330,11 @@ def test_dataset(dataset): assert sorted(result['group']) == [1, 2] assert sorted(result['key']) == ['xxx', 'yyy'] + condition = ds.field(('struct', 'b')) == '1' + with pytest.raises(NotImplementedError, + match="Nested field references in scans"): + dataset.to_table(use_threads=True, filter=condition) + def test_scanner(dataset): scanner = ds.Scanner.from_dataset(dataset, @@ -425,7 +432,8 @@ def test_expression_serialization(): d.is_valid(), a.cast(pa.int32(), safe=False), a.cast(pa.int32(), safe=False), a.isin([1, 2, 3]), ds.field('i64') > 5, ds.field('i64') == 5, - ds.field('i64') == 7, ds.field('i64').is_null()] + ds.field('i64') == 7, ds.field('i64').is_null(), + ds.field(('foo', 'bar')) == 'value'] for expr in all_exprs: assert isinstance(expr, ds.Expression) restored = pickle.loads(pickle.dumps(expr)) @@ -439,6 +447,7 @@ def test_expression_construction(): false = ds.scalar(False) string = ds.scalar("string") field = ds.field("field") + nested_field = ds.field(("nested", "field")) zero | one == string ~true == false @@ -446,6 +455,7 @@ def test_expression_construction(): field.cast(typ) == true field.isin([1, 2]) + nested_field.isin(["foo", "bar"]) with pytest.raises(TypeError): field.isin(1) @@ -592,6 +602,8 @@ def test_filesystem_factory(mockfs, paths_or_selector, pre_buffer): pa.field('f64', pa.float64()), pa.field('str', pa.dictionary(pa.int32(), pa.string())), pa.field('const', pa.int64()), + pa.field('struct', pa.struct({'a': pa.int64(), + 'b': pa.string()})), pa.field('group', pa.int32()), pa.field('key', pa.string()), ]), check_metadata=False) @@ -612,23 +624,26 @@ def test_filesystem_factory(mockfs, paths_or_selector, pre_buffer): pa.array([0, 1, 2, 3, 4], type=pa.int32()), pa.array("0 1 2 3 4".split(), type=pa.string()) ) + expected_struct = pa.array([{'a': i % 3, 'b': str(i % 3)} + for i in range(5)]) for task, group, key in zip(scanner.scan(), [1, 2], ['xxx', 'yyy']): expected_group = pa.array([group] * 5, type=pa.int32()) expected_key = pa.array([key] * 5, type=pa.string()) expected_const = pa.array([group - 1] * 5, type=pa.int64()) for batch in task.execute(): - assert batch.num_columns == 6 + assert batch.num_columns == 7 assert batch[0].equals(expected_i64) assert batch[1].equals(expected_f64) assert batch[2].equals(expected_str) assert batch[3].equals(expected_const) - assert batch[4].equals(expected_group) - assert batch[5].equals(expected_key) + assert batch[4].equals(expected_struct) + assert batch[5].equals(expected_group) + assert batch[6].equals(expected_key) table = dataset.to_table() assert isinstance(table, pa.Table) assert len(table) == 10 - assert table.num_columns == 6 + assert table.num_columns == 7 def test_make_fragment(multisourcefs): @@ -1231,6 +1246,7 @@ def test_partitioning_factory(mockfs): ("f64", pa.float64()), ("str", pa.string()), ("const", pa.int64()), + ("struct", pa.struct({'a': pa.int64(), 'b': pa.string()})), ("group", pa.int32()), ("key", pa.string()), ]) @@ -1572,7 +1588,7 @@ def test_construct_from_mixed_child_datasets(mockfs): table = dataset.to_table() assert len(table) == 20 - assert table.num_columns == 4 + assert table.num_columns == 5 assert len(dataset.children) == 2 for child in dataset.children: