Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 41 additions & 7 deletions cpp/src/arrow/compute/exec/expression.cc
Original file line number Diff line number Diff line change
Expand Up @@ -993,6 +993,21 @@ Result<std::shared_ptr<Buffer>> Serialize(const Expression& expr) {
return std::to_string(ret);
}

Status VisitFieldRef(const FieldRef& ref) {
if (ref.nested_refs()) {
metadata_->Append("nested_field_ref", std::to_string(ref.nested_refs()->size()));
for (const auto& child : *ref.nested_refs()) {
RETURN_NOT_OK(VisitFieldRef(child));
}
return Status::OK();
}
if (!ref.name()) {
return Status::NotImplemented("Serialization of non-name field_refs");
}
metadata_->Append("field_ref", *ref.name());
return Status::OK();
}

Status Visit(const Expression& expr) {
if (auto lit = expr.literal()) {
if (!lit->is_scalar()) {
Expand All @@ -1004,11 +1019,7 @@ Result<std::shared_ptr<Buffer>> Serialize(const Expression& expr) {
}

if (auto ref = expr.field_ref()) {
if (!ref->name()) {
return Status::NotImplemented("Serialization of non-name field_refs");
}
metadata_->Append("field_ref", *ref->name());
return Status::OK();
return VisitFieldRef(*ref);
}

auto call = CallNotNull(expr);
Expand Down Expand Up @@ -1067,10 +1078,13 @@ Result<Expression> Deserialize(std::shared_ptr<Buffer> buffer) {

const KeyValueMetadata& metadata() { return *batch_.schema()->metadata(); }

bool ParseInteger(const std::string& s, int32_t* value) {
return ::arrow::internal::ParseValue<Int32Type>(s.data(), s.length(), value);
}

Result<std::shared_ptr<Scalar>> GetScalar(const std::string& i) {
int32_t column_index;
if (!::arrow::internal::ParseValue<Int32Type>(i.data(), i.length(),
&column_index)) {
if (!ParseInteger(i, &column_index)) {
return Status::Invalid("Couldn't parse column_index");
}
if (column_index >= batch_.num_columns()) {
Expand All @@ -1093,6 +1107,26 @@ Result<Expression> Deserialize(std::shared_ptr<Buffer> buffer) {
return literal(std::move(scalar));
}

if (key == "nested_field_ref") {
int32_t size;
if (!ParseInteger(value, &size)) {
return Status::Invalid("Couldn't parse nested field ref length");
}
if (size <= 0) {
return Status::Invalid("nested field ref length must be > 0");
}
std::vector<FieldRef> nested;
nested.reserve(size);
while (size-- > 0) {
ARROW_ASSIGN_OR_RAISE(auto ref, GetOne());
if (!ref.field_ref()) {
return Status::Invalid("invalid nested field ref");
}
nested.push_back(*ref.field_ref());
}
return field_ref(FieldRef(std::move(nested)));
}

if (key == "field_ref") {
return field_ref(value);
}
Expand Down
6 changes: 6 additions & 0 deletions cpp/src/arrow/compute/exec/expression_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1376,12 +1376,18 @@ TEST(Expression, SerializationRoundTrips) {

ExpectRoundTrips(field_ref("field"));

ExpectRoundTrips(field_ref(FieldRef("foo", "bar", "baz")));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have to spell out field_ref(FieldRef("foo", "bar")) explicitly or is it possible to simply write field_ref("foo", "bar")?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will try to simplify.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently it doesn't work, FieldRef("foo", "bar") needs to be spelled out.

Try 1: `field_ref("foo", "bar")`
```
/Users/alenkafrim/repos/arrow/cpp/src/arrow/compute/exec/expression_test.cc:1379:20: error: no matching function for call to 'field_ref'
  ExpectRoundTrips(field_ref("foo", "bar", "baz"));
                   ^~~~~~~~~
/Users/alenkafrim/repos/arrow/cpp/src/arrow/compute/exec/expression.h:152:12: note: candidate function not viable: requires single argument 'ref', but 3 arguments were provided
Expression field_ref(FieldRef ref);
```
Try 2: `field_ref(("foo", "bar"))`
```
/Users/alenkafrim/repos/arrow/cpp/src/arrow/compute/exec/expression_test.cc:1379:31: warning: expression result unused [-Wunused-value]
  ExpectRoundTrips(field_ref(("foo", "bar", "baz")));
```


ExpectRoundTrips(greater(field_ref("a"), literal(0.25)));

ExpectRoundTrips(
or_({equal(field_ref("a"), literal(1)), not_equal(field_ref("b"), literal("hello")),
equal(field_ref("b"), literal("foo bar"))}));

ExpectRoundTrips(or_({equal(field_ref(FieldRef("a", "b")), literal(1)),
not_equal(field_ref("b"), literal("hello")),
equal(field_ref(FieldRef("c", "d")), literal("foo bar"))}));

ExpectRoundTrips(not_(field_ref("alpha")));

ExpectRoundTrips(call("is_in", {literal(1)},
Expand Down
5 changes: 5 additions & 0 deletions cpp/src/arrow/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -1615,6 +1615,11 @@ class ARROW_EXPORT FieldRef {
/// Equivalent to a single index string of indices.
FieldRef(int index) : impl_(FieldPath({index})) {} // NOLINT runtime/explicit

/// Construct a nested FieldRef.
FieldRef(std::vector<FieldRef> refs) { // NOLINT runtime/explicit
Flatten(std::move(refs));
}

/// Convenience constructor for nested FieldRefs: each argument will be used to
/// construct a FieldRef
template <typename A0, typename A1, typename... A>
Expand Down
22 changes: 19 additions & 3 deletions python/pyarrow/_compute.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -2219,10 +2219,26 @@ cdef class Expression(_Weakrefable):

@staticmethod
def _field(name_or_idx not None):
if isinstance(name_or_idx, str):
return Expression.wrap(CMakeFieldExpression(tobytes(name_or_idx)))
else:
cdef:
CFieldRef c_field

if isinstance(name_or_idx, int):
return Expression.wrap(CMakeFieldExpressionByIndex(name_or_idx))
else:
c_field = CFieldRef(<c_string> tobytes(name_or_idx))
return Expression.wrap(CMakeFieldExpression(c_field))

@staticmethod
def _nested_field(tuple names not None):
cdef:
vector[CFieldRef] nested

if len(names) == 0:
raise ValueError("nested field reference should be non-empty")
nested.reserve(len(names))
for name in names:
nested.push_back(CFieldRef(<c_string> tobytes(name)))
return Expression.wrap(CMakeFieldExpression(CFieldRef(move(nested))))

@staticmethod
def _scalar(value):
Expand Down
38 changes: 34 additions & 4 deletions python/pyarrow/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,22 +591,52 @@ def bottom_k_unstable(values, k, sort_keys=None, *, memory_pool=None):
return call_function("select_k_unstable", [values], options, memory_pool)


def field(name_or_index):
def field(*name_or_index):
"""Reference a column of the dataset.

Stores only the field's name. Type and other information is known only when
the expression is bound to a dataset having an explicit scheme.

Nested references are allowed by passing multiple names or a tuple of
names. For example ``('foo', 'bar')`` references the field named "bar"
inside the field named "foo".

Parameters
----------
name_or_index : string or int
The name or index of the field the expression references to.
*name_or_index : string, multiple strings, tuple or int
The name or index of the (possibly nested) field the expression
references to.

Returns
-------
field_expr : Expression

Examples
--------
>>> import pyarrow.compute as pc
>>> pc.field("a")
<pyarrow.compute.Expression a>
>>> pc.field(1)
<pyarrow.compute.Expression FieldPath(1)>
>>> pc.field(("a", "b"))
<pyarrow.compute.Expression FieldRef.Nested(FieldRef.Name(a) ...
>>> pc.field("a", "b")
<pyarrow.compute.Expression FieldRef.Nested(FieldRef.Name(a) ...
"""
return Expression._field(name_or_index)
n = len(name_or_index)
if n == 1:
if isinstance(name_or_index[0], (str, int)):
return Expression._field(name_or_index[0])
elif isinstance(name_or_index[0], tuple):
return Expression._nested_field(name_or_index[0])
else:
raise TypeError(
"field reference should be str, multiple str, tuple or "
f"integer, got {type(name_or_index[0])}"
)
# In case of multiple strings not supplied in a tuple
else:
return Expression._nested_field(name_or_index)


def scalar(value):
Expand Down
3 changes: 2 additions & 1 deletion python/pyarrow/includes/libarrow.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,7 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil:
CFieldRef()
CFieldRef(c_string name)
CFieldRef(int index)
CFieldRef(vector[CFieldRef])
const c_string* name() const

cdef cppclass CFieldRefHash" arrow::FieldRef::Hash":
Expand Down Expand Up @@ -2411,7 +2412,7 @@ cdef extern from "arrow/compute/exec/expression.h" \
"arrow::compute::literal"(shared_ptr[CScalar] value)

cdef CExpression CMakeFieldExpression \
"arrow::compute::field_ref"(c_string name)
"arrow::compute::field_ref"(CFieldRef)

cdef CExpression CMakeFieldExpressionByIndex \
"arrow::compute::field_ref"(int idx)
Expand Down
8 changes: 7 additions & 1 deletion python/pyarrow/tests/test_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -2652,7 +2652,9 @@ def test_expression_serialization():
d.is_valid(), a.cast(pa.int32(), safe=False),
a.cast(pa.int32(), safe=False), a.isin([1, 2, 3]),
pc.field('i64') > 5, pc.field('i64') == 5,
pc.field('i64') == 7, pc.field('i64').is_null()]
pc.field('i64') == 7, pc.field('i64').is_null(),
pc.field(('foo', 'bar')) == 'value',
pc.field('foo', 'bar') == 'value']
for expr in all_exprs:
assert isinstance(expr, pc.Expression)
restored = pickle.loads(pickle.dumps(expr))
Expand All @@ -2666,13 +2668,17 @@ def test_expression_construction():
false = pc.scalar(False)
string = pc.scalar("string")
field = pc.field("field")
nested_field = pc.field(("nested", "field"))
nested_field2 = pc.field("nested", "field")

zero | one == string
~true == false
for typ in ("bool", pa.bool_()):
field.cast(typ) == true

field.isin([1, 2])
nested_field.isin(["foo", "bar"])
nested_field2.isin(["foo", "bar"])

with pytest.raises(TypeError):
field.isin(1)
Expand Down
59 changes: 47 additions & 12 deletions python/pyarrow/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,13 +102,15 @@ def mockfs():
list(range(5)),
list(map(float, range(5))),
list(map(str, range(5))),
[i] * 5
[i] * 5,
[{'a': j % 3, 'b': str(j % 3)} for j in range(5)],
]
schema = pa.schema([
pa.field('i64', pa.int64()),
pa.field('f64', pa.float64()),
pa.field('str', pa.string()),
pa.field('const', pa.int64()),
('i64', pa.int64()),
('f64', pa.float64()),
('str', pa.string()),
('const', pa.int64()),
('struct', pa.struct({'a': pa.int64(), 'b': pa.string()})),
])
batch = pa.record_batch(data, schema=schema)
table = pa.Table.from_batches([batch])
Expand Down Expand Up @@ -383,14 +385,41 @@ def test_dataset(dataset, dataset_reader):
assert len(table) == 10

condition = ds.field('i64') == 1
result = dataset.to_table(use_threads=True, filter=condition).to_pydict()
result = dataset.to_table(use_threads=True, filter=condition)
# Don't rely on the scanning order
result = result.sort_by('group').to_pydict()

# don't rely on the scanning order
assert result['i64'] == [1, 1]
assert result['f64'] == [1., 1.]
assert sorted(result['group']) == [1, 2]
assert sorted(result['key']) == ['xxx', 'yyy']

# Filtering on a nested field ref
condition = ds.field(('struct', 'b')) == '1'
result = dataset.to_table(use_threads=True, filter=condition)
result = result.sort_by('group').to_pydict()

assert result['i64'] == [1, 4, 1, 4]
assert result['f64'] == [1.0, 4.0, 1.0, 4.0]
assert result['group'] == [1, 1, 2, 2]
assert result['key'] == ['xxx', 'xxx', 'yyy', 'yyy']

# Projecting on a nested field ref expression
projection = {
'i64': ds.field('i64'),
'f64': ds.field('f64'),
'new': ds.field(('struct', 'b')) == '1',
}
result = dataset.to_table(use_threads=True, columns=projection)
result = result.sort_by('i64').to_pydict()

assert list(result) == ['i64', 'f64', 'new']
assert result['i64'] == [0, 0, 1, 1, 2, 2, 3, 3, 4, 4]
assert result['f64'] == [0.0, 0.0, 1.0, 1.0,
2.0, 2.0, 3.0, 3.0, 4.0, 4.0]
assert result['new'] == [False, False, True, True, False, False,
False, False, True, True]


@pytest.mark.parquet
def test_scanner(dataset, dataset_reader):
Expand Down Expand Up @@ -808,6 +837,8 @@ def test_filesystem_factory(mockfs, paths_or_selector, pre_buffer):
pa.field('f64', pa.float64()),
pa.field('str', pa.dictionary(pa.int32(), pa.string())),
pa.field('const', pa.int64()),
pa.field('struct', pa.struct({'a': pa.int64(),
'b': pa.string()})),
pa.field('group', pa.int32()),
pa.field('key', pa.string()),
]), check_metadata=False)
Expand All @@ -827,25 +858,28 @@ def test_filesystem_factory(mockfs, paths_or_selector, pre_buffer):
pa.array([0, 1, 2, 3, 4], type=pa.int32()),
pa.array("0 1 2 3 4".split(), type=pa.string())
)
expected_struct = pa.array([{'a': i % 3, 'b': str(i % 3)}
for i in range(5)])
iterator = scanner.scan_batches()
for (batch, fragment), group, key in zip(iterator, [1, 2], ['xxx', 'yyy']):
expected_group = pa.array([group] * 5, type=pa.int32())
expected_key = pa.array([key] * 5, type=pa.string())
expected_const = pa.array([group - 1] * 5, type=pa.int64())
# Can't compare or really introspect expressions from Python
assert fragment.partition_expression is not None
assert batch.num_columns == 6
assert batch.num_columns == 7
assert batch[0].equals(expected_i64)
assert batch[1].equals(expected_f64)
assert batch[2].equals(expected_str)
assert batch[3].equals(expected_const)
assert batch[4].equals(expected_group)
assert batch[5].equals(expected_key)
assert batch[4].equals(expected_struct)
assert batch[5].equals(expected_group)
assert batch[6].equals(expected_key)

table = dataset.to_table()
assert isinstance(table, pa.Table)
assert len(table) == 10
assert table.num_columns == 6
assert table.num_columns == 7


@pytest.mark.parquet
Expand Down Expand Up @@ -1480,6 +1514,7 @@ def test_partitioning_factory(mockfs):
("f64", pa.float64()),
("str", pa.string()),
("const", pa.int64()),
("struct", pa.struct({'a': pa.int64(), 'b': pa.string()})),
("group", pa.int32()),
("key", pa.string()),
])
Expand Down Expand Up @@ -2047,7 +2082,7 @@ def test_construct_from_mixed_child_datasets(mockfs):

table = dataset.to_table()
assert len(table) == 20
assert table.num_columns == 4
assert table.num_columns == 5

assert len(dataset.children) == 2
for child in dataset.children:
Expand Down