Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 41 additions & 6 deletions cpp/src/arrow/dataset/expression.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1081,6 +1081,21 @@ Result<std::shared_ptr<Buffer>> Serialize(const Expression& expr) {
return std::to_string(ret);
}

Status VisitFieldRef(const FieldRef& ref) {
if (ref.nested()) {
metadata_->Append("nested_field_ref", std::to_string(ref.nested()->size()));
for (const auto& child : *ref.nested()) {
RETURN_NOT_OK(VisitFieldRef(child));
}
return Status::OK();
}
if (!ref.name()) {
return Status::NotImplemented("Serialization of non-name field_refs");
}
metadata_->Append("field_ref", *ref.name());
return Status::OK();
}

Status Visit(const Expression& expr) {
if (auto lit = expr.literal()) {
if (!lit->is_scalar()) {
Expand All @@ -1092,11 +1107,7 @@ Result<std::shared_ptr<Buffer>> Serialize(const Expression& expr) {
}

if (auto ref = expr.field_ref()) {
if (!ref->name()) {
return Status::NotImplemented("Serialization of non-name field_refs");
}
metadata_->Append("field_ref", *ref->name());
return Status::OK();
return VisitFieldRef(*ref);
}

auto call = CallNotNull(expr);
Expand Down Expand Up @@ -1154,9 +1165,13 @@ Result<Expression> Deserialize(std::shared_ptr<Buffer> buffer) {

const KeyValueMetadata& metadata() { return *batch_.schema()->metadata(); }

bool ParseInteger(const std::string& s, int32_t* value) {
return internal::ParseValue<Int32Type>(s.data(), s.length(), value);
}

Result<std::shared_ptr<Scalar>> GetScalar(const std::string& i) {
int32_t column_index;
if (!internal::ParseValue<Int32Type>(i.data(), i.length(), &column_index)) {
if (!ParseInteger(i, &column_index)) {
return Status::Invalid("Couldn't parse column_index");
}
if (column_index >= batch_.num_columns()) {
Expand All @@ -1179,6 +1194,26 @@ Result<Expression> Deserialize(std::shared_ptr<Buffer> buffer) {
return literal(std::move(scalar));
}

if (key == "nested_field_ref") {
int32_t size;
if (!ParseInteger(value, &size)) {
return Status::Invalid("Couldn't parse nested field ref length");
}
if (size <= 0) {
return Status::Invalid("nested field ref length must be > 0");
}
std::vector<FieldRef> nested;
nested.reserve(size);
while (size-- > 0) {
ARROW_ASSIGN_OR_RAISE(auto ref, GetOne());
if (!ref.field_ref()) {
return Status::Invalid("invalid nested field ref");
}
nested.push_back(*ref.field_ref());
}
return field_ref(FieldRef(std::move(nested)));
}

if (key == "field_ref") {
return field_ref(value);
}
Expand Down
6 changes: 6 additions & 0 deletions cpp/src/arrow/dataset/expression_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1180,12 +1180,18 @@ TEST(Expression, SerializationRoundTrips) {

ExpectRoundTrips(field_ref("field"));

ExpectRoundTrips(field_ref(FieldRef("foo", "bar", "baz")));

ExpectRoundTrips(greater(field_ref("a"), literal(0.25)));

ExpectRoundTrips(
or_({equal(field_ref("a"), literal(1)), not_equal(field_ref("b"), literal("hello")),
equal(field_ref("b"), literal("foo bar"))}));

ExpectRoundTrips(or_({equal(field_ref(FieldRef("a", "b")), literal(1)),
not_equal(field_ref("b"), literal("hello")),
equal(field_ref(FieldRef("c", "d")), literal("foo bar"))}));

ExpectRoundTrips(not_(field_ref("alpha")));

ExpectRoundTrips(call("is_in", {literal(1)},
Expand Down
10 changes: 10 additions & 0 deletions cpp/src/arrow/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -1504,6 +1504,11 @@ class ARROW_EXPORT FieldRef {
/// Equivalent to a single index string of indices.
FieldRef(int index) : impl_(FieldPath({index})) {} // NOLINT runtime/explicit

/// Construct a nested FieldRef.
FieldRef(std::vector<FieldRef> refs) { // NOLINT runtime/explicit
Flatten(std::move(refs));
}

/// Convenience constructor for nested FieldRefs: each argument will be used to
/// construct a FieldRef
template <typename A0, typename A1, typename... A>
Expand Down Expand Up @@ -1560,6 +1565,11 @@ class ARROW_EXPORT FieldRef {
const std::string* name() const {
return IsName() ? &util::get<std::string>(impl_) : NULLPTR;
}
const std::vector<FieldRef>* nested() const {
return util::holds_alternative<std::vector<FieldRef>>(impl_)
? &util::get<std::vector<FieldRef>>(impl_)
: NULLPTR;
}

/// \brief Retrieve FieldPath of every child field which matches this FieldRef.
std::vector<FieldPath> FindAll(const Schema& schema) const;
Expand Down
18 changes: 17 additions & 1 deletion python/pyarrow/_dataset.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,23 @@ cdef class Expression(_Weakrefable):

@staticmethod
def _field(str name not None):
return Expression.wrap(CMakeFieldExpression(tobytes(name)))
cdef:
CFieldRef c_field

c_field = CFieldRef(<c_string> tobytes(name))
return Expression.wrap(CMakeFieldExpression(c_field))

@staticmethod
def _nested_field(tuple names not None):
cdef:
vector[CFieldRef] nested

if len(names) == 0:
raise ValueError("nested field reference should be non-empty")
nested.reserve(len(names))
for name in names:
nested.push_back(CFieldRef(<c_string> tobytes(name)))
return Expression.wrap(CMakeFieldExpression(CFieldRef(move(nested))))

@staticmethod
def _scalar(value):
Expand Down
16 changes: 13 additions & 3 deletions python/pyarrow/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,16 +61,26 @@ def field(name):
Stores only the field's name. Type and other information is known only when
the expression is bound to a dataset having an explicit scheme.

Nested references are allowed by passing a tuple of names.
For example ``('foo', 'bar')`` references the field named "bar" inside
the field named "foo".

Parameters
----------
name : string
The name of the field the expression references to.
name : string or tuple
The name of the (possibly nested) field the expression references to.

Returns
-------
field_expr : Expression
"""
return Expression._field(name)
if isinstance(name, str):
return Expression._field(name)
elif isinstance(name, tuple):
return Expression._nested_field(name)
else:
raise TypeError(
f"field reference should be str or tuple, got {type(name)}")


def scalar(value):
Expand Down
1 change: 1 addition & 0 deletions python/pyarrow/includes/libarrow.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,7 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil:
CFieldRef()
CFieldRef(c_string name)
CFieldRef(int index)
CFieldRef(vector[CFieldRef])
const c_string* name() const

cdef cppclass CFieldRefHash" arrow::FieldRef::Hash":
Expand Down
2 changes: 1 addition & 1 deletion python/pyarrow/includes/libarrow_dataset.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ cdef extern from "arrow/dataset/api.h" namespace "arrow::dataset" nogil:
"arrow::dataset::literal"(shared_ptr[CScalar] value)

cdef CExpression CMakeFieldExpression \
"arrow::dataset::field_ref"(c_string name)
"arrow::dataset::field_ref"(CFieldRef)

cdef CExpression CMakeCallExpression \
"arrow::dataset::call"(c_string function,
Expand Down
38 changes: 27 additions & 11 deletions python/pyarrow/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,15 @@ def mockfs():
list(range(5)),
list(map(float, range(5))),
list(map(str, range(5))),
[i] * 5
[i] * 5,
[{'a': j % 3, 'b': str(j % 3)} for j in range(5)],
]
schema = pa.schema([
pa.field('i64', pa.int64()),
pa.field('f64', pa.float64()),
pa.field('str', pa.string()),
pa.field('const', pa.int64()),
('i64', pa.int64()),
('f64', pa.float64()),
('str', pa.string()),
('const', pa.int64()),
('struct', pa.struct({'a': pa.int64(), 'b': pa.string()})),
])
batch = pa.record_batch(data, schema=schema)
table = pa.Table.from_batches([batch])
Expand Down Expand Up @@ -328,6 +330,11 @@ def test_dataset(dataset):
assert sorted(result['group']) == [1, 2]
assert sorted(result['key']) == ['xxx', 'yyy']

condition = ds.field(('struct', 'b')) == '1'
with pytest.raises(NotImplementedError,
match="Nested field references in scans"):
dataset.to_table(use_threads=True, filter=condition)


def test_scanner(dataset):
scanner = ds.Scanner.from_dataset(dataset,
Expand Down Expand Up @@ -425,7 +432,8 @@ def test_expression_serialization():
d.is_valid(), a.cast(pa.int32(), safe=False),
a.cast(pa.int32(), safe=False), a.isin([1, 2, 3]),
ds.field('i64') > 5, ds.field('i64') == 5,
ds.field('i64') == 7, ds.field('i64').is_null()]
ds.field('i64') == 7, ds.field('i64').is_null(),
ds.field(('foo', 'bar')) == 'value']
for expr in all_exprs:
assert isinstance(expr, ds.Expression)
restored = pickle.loads(pickle.dumps(expr))
Expand All @@ -439,13 +447,15 @@ def test_expression_construction():
false = ds.scalar(False)
string = ds.scalar("string")
field = ds.field("field")
nested_field = ds.field(("nested", "field"))

zero | one == string
~true == false
for typ in ("bool", pa.bool_()):
field.cast(typ) == true

field.isin([1, 2])
nested_field.isin(["foo", "bar"])

with pytest.raises(TypeError):
field.isin(1)
Expand Down Expand Up @@ -592,6 +602,8 @@ def test_filesystem_factory(mockfs, paths_or_selector, pre_buffer):
pa.field('f64', pa.float64()),
pa.field('str', pa.dictionary(pa.int32(), pa.string())),
pa.field('const', pa.int64()),
pa.field('struct', pa.struct({'a': pa.int64(),
'b': pa.string()})),
pa.field('group', pa.int32()),
pa.field('key', pa.string()),
]), check_metadata=False)
Expand All @@ -612,23 +624,26 @@ def test_filesystem_factory(mockfs, paths_or_selector, pre_buffer):
pa.array([0, 1, 2, 3, 4], type=pa.int32()),
pa.array("0 1 2 3 4".split(), type=pa.string())
)
expected_struct = pa.array([{'a': i % 3, 'b': str(i % 3)}
for i in range(5)])
for task, group, key in zip(scanner.scan(), [1, 2], ['xxx', 'yyy']):
expected_group = pa.array([group] * 5, type=pa.int32())
expected_key = pa.array([key] * 5, type=pa.string())
expected_const = pa.array([group - 1] * 5, type=pa.int64())
for batch in task.execute():
assert batch.num_columns == 6
assert batch.num_columns == 7
assert batch[0].equals(expected_i64)
assert batch[1].equals(expected_f64)
assert batch[2].equals(expected_str)
assert batch[3].equals(expected_const)
assert batch[4].equals(expected_group)
assert batch[5].equals(expected_key)
assert batch[4].equals(expected_struct)
assert batch[5].equals(expected_group)
assert batch[6].equals(expected_key)

table = dataset.to_table()
assert isinstance(table, pa.Table)
assert len(table) == 10
assert table.num_columns == 6
assert table.num_columns == 7


def test_make_fragment(multisourcefs):
Expand Down Expand Up @@ -1231,6 +1246,7 @@ def test_partitioning_factory(mockfs):
("f64", pa.float64()),
("str", pa.string()),
("const", pa.int64()),
("struct", pa.struct({'a': pa.int64(), 'b': pa.string()})),
("group", pa.int32()),
("key", pa.string()),
])
Expand Down Expand Up @@ -1572,7 +1588,7 @@ def test_construct_from_mixed_child_datasets(mockfs):

table = dataset.to_table()
assert len(table) == 20
assert table.num_columns == 4
assert table.num_columns == 5

assert len(dataset.children) == 2
for child in dataset.children:
Expand Down