diff --git a/cpp/src/arrow/compute/exec/expression.cc b/cpp/src/arrow/compute/exec/expression.cc index 4249179e1bf..1ef5c6e7b98 100644 --- a/cpp/src/arrow/compute/exec/expression.cc +++ b/cpp/src/arrow/compute/exec/expression.cc @@ -993,6 +993,21 @@ Result> Serialize(const Expression& expr) { return std::to_string(ret); } + Status VisitFieldRef(const FieldRef& ref) { + if (ref.nested_refs()) { + metadata_->Append("nested_field_ref", std::to_string(ref.nested_refs()->size())); + for (const auto& child : *ref.nested_refs()) { + RETURN_NOT_OK(VisitFieldRef(child)); + } + return Status::OK(); + } + if (!ref.name()) { + return Status::NotImplemented("Serialization of non-name field_refs"); + } + metadata_->Append("field_ref", *ref.name()); + return Status::OK(); + } + Status Visit(const Expression& expr) { if (auto lit = expr.literal()) { if (!lit->is_scalar()) { @@ -1004,11 +1019,7 @@ Result> Serialize(const Expression& expr) { } if (auto ref = expr.field_ref()) { - if (!ref->name()) { - return Status::NotImplemented("Serialization of non-name field_refs"); - } - metadata_->Append("field_ref", *ref->name()); - return Status::OK(); + return VisitFieldRef(*ref); } auto call = CallNotNull(expr); @@ -1067,10 +1078,13 @@ Result Deserialize(std::shared_ptr buffer) { const KeyValueMetadata& metadata() { return *batch_.schema()->metadata(); } + bool ParseInteger(const std::string& s, int32_t* value) { + return ::arrow::internal::ParseValue(s.data(), s.length(), value); + } + Result> GetScalar(const std::string& i) { int32_t column_index; - if (!::arrow::internal::ParseValue(i.data(), i.length(), - &column_index)) { + if (!ParseInteger(i, &column_index)) { return Status::Invalid("Couldn't parse column_index"); } if (column_index >= batch_.num_columns()) { @@ -1093,6 +1107,26 @@ Result Deserialize(std::shared_ptr buffer) { return literal(std::move(scalar)); } + if (key == "nested_field_ref") { + int32_t size; + if (!ParseInteger(value, &size)) { + return Status::Invalid("Couldn't parse nested field ref length"); + } + if (size <= 0) { + return Status::Invalid("nested field ref length must be > 0"); + } + std::vector nested; + nested.reserve(size); + while (size-- > 0) { + ARROW_ASSIGN_OR_RAISE(auto ref, GetOne()); + if (!ref.field_ref()) { + return Status::Invalid("invalid nested field ref"); + } + nested.push_back(*ref.field_ref()); + } + return field_ref(FieldRef(std::move(nested))); + } + if (key == "field_ref") { return field_ref(value); } diff --git a/cpp/src/arrow/compute/exec/expression_test.cc b/cpp/src/arrow/compute/exec/expression_test.cc index 30ddef69010..f916bc2a1cf 100644 --- a/cpp/src/arrow/compute/exec/expression_test.cc +++ b/cpp/src/arrow/compute/exec/expression_test.cc @@ -1376,12 +1376,18 @@ TEST(Expression, SerializationRoundTrips) { ExpectRoundTrips(field_ref("field")); + ExpectRoundTrips(field_ref(FieldRef("foo", "bar", "baz"))); + ExpectRoundTrips(greater(field_ref("a"), literal(0.25))); ExpectRoundTrips( or_({equal(field_ref("a"), literal(1)), not_equal(field_ref("b"), literal("hello")), equal(field_ref("b"), literal("foo bar"))})); + ExpectRoundTrips(or_({equal(field_ref(FieldRef("a", "b")), literal(1)), + not_equal(field_ref("b"), literal("hello")), + equal(field_ref(FieldRef("c", "d")), literal("foo bar"))})); + ExpectRoundTrips(not_(field_ref("alpha"))); ExpectRoundTrips(call("is_in", {literal(1)}, diff --git a/cpp/src/arrow/type.h b/cpp/src/arrow/type.h index 83dc6fa5695..440b95ce59c 100644 --- a/cpp/src/arrow/type.h +++ b/cpp/src/arrow/type.h @@ -1615,6 +1615,11 @@ class ARROW_EXPORT FieldRef { /// Equivalent to a single index string of indices. FieldRef(int index) : impl_(FieldPath({index})) {} // NOLINT runtime/explicit + /// Construct a nested FieldRef. + FieldRef(std::vector refs) { // NOLINT runtime/explicit + Flatten(std::move(refs)); + } + /// Convenience constructor for nested FieldRefs: each argument will be used to /// construct a FieldRef template diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index 2f18ab99866..73c188edbbf 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -2219,10 +2219,26 @@ cdef class Expression(_Weakrefable): @staticmethod def _field(name_or_idx not None): - if isinstance(name_or_idx, str): - return Expression.wrap(CMakeFieldExpression(tobytes(name_or_idx))) - else: + cdef: + CFieldRef c_field + + if isinstance(name_or_idx, int): return Expression.wrap(CMakeFieldExpressionByIndex(name_or_idx)) + else: + c_field = CFieldRef( tobytes(name_or_idx)) + return Expression.wrap(CMakeFieldExpression(c_field)) + + @staticmethod + def _nested_field(tuple names not None): + cdef: + vector[CFieldRef] nested + + if len(names) == 0: + raise ValueError("nested field reference should be non-empty") + nested.reserve(len(names)) + for name in names: + nested.push_back(CFieldRef( tobytes(name))) + return Expression.wrap(CMakeFieldExpression(CFieldRef(move(nested)))) @staticmethod def _scalar(value): diff --git a/python/pyarrow/compute.py b/python/pyarrow/compute.py index 6cd65123e88..40751eab26a 100644 --- a/python/pyarrow/compute.py +++ b/python/pyarrow/compute.py @@ -591,22 +591,52 @@ def bottom_k_unstable(values, k, sort_keys=None, *, memory_pool=None): return call_function("select_k_unstable", [values], options, memory_pool) -def field(name_or_index): +def field(*name_or_index): """Reference a column of the dataset. Stores only the field's name. Type and other information is known only when the expression is bound to a dataset having an explicit scheme. + Nested references are allowed by passing multiple names or a tuple of + names. For example ``('foo', 'bar')`` references the field named "bar" + inside the field named "foo". + Parameters ---------- - name_or_index : string or int - The name or index of the field the expression references to. + *name_or_index : string, multiple strings, tuple or int + The name or index of the (possibly nested) field the expression + references to. Returns ------- field_expr : Expression + + Examples + -------- + >>> import pyarrow.compute as pc + >>> pc.field("a") + + >>> pc.field(1) + + >>> pc.field(("a", "b")) + >> pc.field("a", "b") + 5, pc.field('i64') == 5, - pc.field('i64') == 7, pc.field('i64').is_null()] + pc.field('i64') == 7, pc.field('i64').is_null(), + pc.field(('foo', 'bar')) == 'value', + pc.field('foo', 'bar') == 'value'] for expr in all_exprs: assert isinstance(expr, pc.Expression) restored = pickle.loads(pickle.dumps(expr)) @@ -2666,6 +2668,8 @@ def test_expression_construction(): false = pc.scalar(False) string = pc.scalar("string") field = pc.field("field") + nested_field = pc.field(("nested", "field")) + nested_field2 = pc.field("nested", "field") zero | one == string ~true == false @@ -2673,6 +2677,8 @@ def test_expression_construction(): field.cast(typ) == true field.isin([1, 2]) + nested_field.isin(["foo", "bar"]) + nested_field2.isin(["foo", "bar"]) with pytest.raises(TypeError): field.isin(1) diff --git a/python/pyarrow/tests/test_dataset.py b/python/pyarrow/tests/test_dataset.py index b4564abef5e..388a6f867e4 100644 --- a/python/pyarrow/tests/test_dataset.py +++ b/python/pyarrow/tests/test_dataset.py @@ -102,13 +102,15 @@ def mockfs(): list(range(5)), list(map(float, range(5))), list(map(str, range(5))), - [i] * 5 + [i] * 5, + [{'a': j % 3, 'b': str(j % 3)} for j in range(5)], ] schema = pa.schema([ - pa.field('i64', pa.int64()), - pa.field('f64', pa.float64()), - pa.field('str', pa.string()), - pa.field('const', pa.int64()), + ('i64', pa.int64()), + ('f64', pa.float64()), + ('str', pa.string()), + ('const', pa.int64()), + ('struct', pa.struct({'a': pa.int64(), 'b': pa.string()})), ]) batch = pa.record_batch(data, schema=schema) table = pa.Table.from_batches([batch]) @@ -383,14 +385,41 @@ def test_dataset(dataset, dataset_reader): assert len(table) == 10 condition = ds.field('i64') == 1 - result = dataset.to_table(use_threads=True, filter=condition).to_pydict() + result = dataset.to_table(use_threads=True, filter=condition) + # Don't rely on the scanning order + result = result.sort_by('group').to_pydict() - # don't rely on the scanning order assert result['i64'] == [1, 1] assert result['f64'] == [1., 1.] assert sorted(result['group']) == [1, 2] assert sorted(result['key']) == ['xxx', 'yyy'] + # Filtering on a nested field ref + condition = ds.field(('struct', 'b')) == '1' + result = dataset.to_table(use_threads=True, filter=condition) + result = result.sort_by('group').to_pydict() + + assert result['i64'] == [1, 4, 1, 4] + assert result['f64'] == [1.0, 4.0, 1.0, 4.0] + assert result['group'] == [1, 1, 2, 2] + assert result['key'] == ['xxx', 'xxx', 'yyy', 'yyy'] + + # Projecting on a nested field ref expression + projection = { + 'i64': ds.field('i64'), + 'f64': ds.field('f64'), + 'new': ds.field(('struct', 'b')) == '1', + } + result = dataset.to_table(use_threads=True, columns=projection) + result = result.sort_by('i64').to_pydict() + + assert list(result) == ['i64', 'f64', 'new'] + assert result['i64'] == [0, 0, 1, 1, 2, 2, 3, 3, 4, 4] + assert result['f64'] == [0.0, 0.0, 1.0, 1.0, + 2.0, 2.0, 3.0, 3.0, 4.0, 4.0] + assert result['new'] == [False, False, True, True, False, False, + False, False, True, True] + @pytest.mark.parquet def test_scanner(dataset, dataset_reader): @@ -808,6 +837,8 @@ def test_filesystem_factory(mockfs, paths_or_selector, pre_buffer): pa.field('f64', pa.float64()), pa.field('str', pa.dictionary(pa.int32(), pa.string())), pa.field('const', pa.int64()), + pa.field('struct', pa.struct({'a': pa.int64(), + 'b': pa.string()})), pa.field('group', pa.int32()), pa.field('key', pa.string()), ]), check_metadata=False) @@ -827,6 +858,8 @@ def test_filesystem_factory(mockfs, paths_or_selector, pre_buffer): pa.array([0, 1, 2, 3, 4], type=pa.int32()), pa.array("0 1 2 3 4".split(), type=pa.string()) ) + expected_struct = pa.array([{'a': i % 3, 'b': str(i % 3)} + for i in range(5)]) iterator = scanner.scan_batches() for (batch, fragment), group, key in zip(iterator, [1, 2], ['xxx', 'yyy']): expected_group = pa.array([group] * 5, type=pa.int32()) @@ -834,18 +867,19 @@ def test_filesystem_factory(mockfs, paths_or_selector, pre_buffer): expected_const = pa.array([group - 1] * 5, type=pa.int64()) # Can't compare or really introspect expressions from Python assert fragment.partition_expression is not None - assert batch.num_columns == 6 + assert batch.num_columns == 7 assert batch[0].equals(expected_i64) assert batch[1].equals(expected_f64) assert batch[2].equals(expected_str) assert batch[3].equals(expected_const) - assert batch[4].equals(expected_group) - assert batch[5].equals(expected_key) + assert batch[4].equals(expected_struct) + assert batch[5].equals(expected_group) + assert batch[6].equals(expected_key) table = dataset.to_table() assert isinstance(table, pa.Table) assert len(table) == 10 - assert table.num_columns == 6 + assert table.num_columns == 7 @pytest.mark.parquet @@ -1480,6 +1514,7 @@ def test_partitioning_factory(mockfs): ("f64", pa.float64()), ("str", pa.string()), ("const", pa.int64()), + ("struct", pa.struct({'a': pa.int64(), 'b': pa.string()})), ("group", pa.int32()), ("key", pa.string()), ]) @@ -2047,7 +2082,7 @@ def test_construct_from_mixed_child_datasets(mockfs): table = dataset.to_table() assert len(table) == 20 - assert table.num_columns == 4 + assert table.num_columns == 5 assert len(dataset.children) == 2 for child in dataset.children: