Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions python/pyarrow/_compute.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,12 @@ from pyarrow.includes.libarrow cimport *

cdef class FunctionOptions(_Weakrefable):
cdef:
unique_ptr[CFunctionOptions] wrapped
shared_ptr[CFunctionOptions] wrapped

cdef const CFunctionOptions* get_options(self) except NULL
cdef void init(self, unique_ptr[CFunctionOptions] options)
cdef void init(self, const shared_ptr[CFunctionOptions]& sp)

cdef inline shared_ptr[CFunctionOptions] unwrap(self)


cdef CExpression _bind(Expression filter, Schema schema) except *
Expand All @@ -47,7 +49,3 @@ cdef class Expression(_Weakrefable):

@staticmethod
cdef Expression _expr_or_scalar(object expr)

@staticmethod
cdef Expression _call(str function_name, list arguments,
shared_ptr[CFunctionOptions] options=*)
62 changes: 30 additions & 32 deletions python/pyarrow/_compute.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -537,8 +537,11 @@ cdef class FunctionOptions(_Weakrefable):
cdef const CFunctionOptions* get_options(self) except NULL:
return self.wrapped.get()

cdef void init(self, unique_ptr[CFunctionOptions] options):
self.wrapped = move(options)
cdef void init(self, const shared_ptr[CFunctionOptions]& sp):
self.wrapped = sp

cdef inline shared_ptr[CFunctionOptions] unwrap(self):
return self.wrapped

def serialize(self):
cdef:
Expand All @@ -560,15 +563,15 @@ cdef class FunctionOptions(_Weakrefable):
shared_ptr[CBuffer] c_buf = pyarrow_unwrap_buffer(buf)
CResult[unique_ptr[CFunctionOptions]] maybe_options = \
DeserializeFunctionOptions(deref(c_buf))
unique_ptr[CFunctionOptions] c_options
c_options = move(GetResultValue(move(maybe_options)))
shared_ptr[CFunctionOptions] c_options
c_options = to_shared(GetResultValue(move(maybe_options)))
type_name = frombytes(c_options.get().options_type().type_name())
module = globals()
if type_name not in module:
raise ValueError(f'Cannot deserialize "{type_name}"')
klass = module[type_name]
options = klass.__new__(klass)
(<FunctionOptions> options).init(move(c_options))
(<FunctionOptions> options).init(c_options)
return options

def __repr__(self):
Expand Down Expand Up @@ -597,15 +600,17 @@ def _raise_invalid_function_option(value, description, *,
cdef class _CastOptions(FunctionOptions):
cdef CCastOptions* options

cdef void init(self, unique_ptr[CFunctionOptions] options):
FunctionOptions.init(self, move(options))
cdef void init(self, const shared_ptr[CFunctionOptions]& sp):
FunctionOptions.init(self, sp)
self.options = <CCastOptions*> self.wrapped.get()

def _set_options(self, DataType target_type, allow_int_overflow,
allow_time_truncate, allow_time_overflow,
allow_decimal_truncate, allow_float_truncate,
allow_invalid_utf8):
self.init(unique_ptr[CFunctionOptions](new CCastOptions()))
cdef:
shared_ptr[CCastOptions] wrapped = make_shared[CCastOptions]()
self.init(<shared_ptr[CFunctionOptions]> wrapped)
self._set_type(target_type)
if allow_int_overflow is not None:
self.allow_int_overflow = allow_int_overflow
Expand All @@ -626,11 +631,11 @@ cdef class _CastOptions(FunctionOptions):
(<DataType> ensure_type(target_type)).sp_type

def _set_safe(self):
self.init(unique_ptr[CFunctionOptions](
self.init(shared_ptr[CFunctionOptions](
new CCastOptions(CCastOptions.Safe())))

def _set_unsafe(self):
self.init(unique_ptr[CFunctionOptions](
self.init(shared_ptr[CFunctionOptions](
new CCastOptions(CCastOptions.Unsafe())))

def is_safe(self):
Expand Down Expand Up @@ -2020,17 +2025,21 @@ cdef class Expression(_Weakrefable):
return (<Expression> Expression._scalar(expr))

@staticmethod
cdef Expression _call(str function_name, list arguments,
shared_ptr[CFunctionOptions] options=(
<shared_ptr[CFunctionOptions]> nullptr)):
def _call(str function_name, list arguments, FunctionOptions options=None):
cdef:
vector[CExpression] c_arguments
shared_ptr[CFunctionOptions] c_options

for argument in arguments:
if not isinstance(argument, Expression):
raise TypeError("only other expressions allowed as arguments")
c_arguments.push_back((<Expression> argument).expr)

return Expression.wrap(CMakeCallExpression(tobytes(function_name),
move(c_arguments), options))
if options is not None:
c_options = options.unwrap()

return Expression.wrap(CMakeCallExpression(
tobytes(function_name), move(c_arguments), c_options))

def __richcmp__(self, other, int op):
other = Expression._expr_or_scalar(other)
Expand Down Expand Up @@ -2083,32 +2092,21 @@ cdef class Expression(_Weakrefable):

def is_null(self, bint nan_is_null=False):
"""Checks whether the expression is null"""
cdef:
shared_ptr[CFunctionOptions] c_options

c_options.reset(new CNullOptions(nan_is_null))
return Expression._call("is_null", [self], c_options)
options = NullOptions(nan_is_null=nan_is_null)
return Expression._call("is_null", [self], options)

def cast(self, type, bint safe=True):
"""Explicitly change the expression's data type"""
cdef shared_ptr[CCastOptions] c_options
c_options.reset(new CCastOptions(safe))
c_options.get().to_type = pyarrow_unwrap_data_type(ensure_type(type))
return Expression._call("cast", [self],
<shared_ptr[CFunctionOptions]> c_options)
options = CastOptions.safe(ensure_type(type))
return Expression._call("cast", [self], options)

def isin(self, values):
"""Checks whether the expression is contained in values"""
cdef:
shared_ptr[CFunctionOptions] c_options
CDatum c_values

if not isinstance(values, Array):
values = lib.array(values)

c_values = CDatum(pyarrow_unwrap_array(values))
c_options.reset(new CSetLookupOptions(c_values, True))
return Expression._call("is_in", [self], c_options)
options = SetLookupOptions(values)
return Expression._call("is_in", [self], options)

@staticmethod
def _field(str name not None):
Expand Down
4 changes: 4 additions & 0 deletions python/pyarrow/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,8 @@ def wrapper(*args, memory_pool=None):
f"{func_name} takes {arity} positional argument(s), "
f"but {len(args)} were given"
)
if args and isinstance(args[0], Expression):
return Expression._call(func_name, list(args))
return func.call(args, None, memory_pool)
else:
def wrapper(*args, memory_pool=None, options=None, **kwargs):
Expand All @@ -242,6 +244,8 @@ def wrapper(*args, memory_pool=None, options=None, **kwargs):
option_args = ()
options = _handle_options(func_name, options_class, options,
option_args, kwargs)
if args and isinstance(args[0], Expression):
return Expression._call(func_name, list(args), options)
return func.call(args, options, memory_pool)
return wrapper

Expand Down
20 changes: 20 additions & 0 deletions python/pyarrow/tests/test_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -2628,3 +2628,23 @@ def test_expression_boolean_operators():

with pytest.raises(ValueError, match="cannot be evaluated to python True"):
not true


def test_expression_call_function():
field = pc.field("field")

# no options
assert str(pc.hour(field)) == "hour(field)"

# default options
assert str(pc.round(field)) == "round(field)"
# specified options
assert str(pc.round(field, ndigits=1)) == \
"round(field, {ndigits=1, round_mode=HALF_TO_EVEN})"

# mixed types are not (yet) allowed
with pytest.raises(TypeError):
pc.add(field, 1)

with pytest.raises(TypeError):
pc.add(1, field)
22 changes: 22 additions & 0 deletions python/pyarrow/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import contextlib
import os
import posixpath
import datetime
import pathlib
import pickle
import textwrap
Expand All @@ -29,6 +30,7 @@
import pytest

import pyarrow as pa
import pyarrow.compute as pc
import pyarrow.csv
import pyarrow.feather
import pyarrow.fs as fs
Expand Down Expand Up @@ -2512,6 +2514,26 @@ def test_filter_equal_null(tempdir, dataset_reader):
assert table.num_rows == 0


def test_filter_compute_expression(tempdir, dataset_reader):
table = pa.table({
"A": ["a", "b", None, "a", "c"],
"B": [datetime.datetime(2022, 1, 1, i) for i in range(5)],
"C": [datetime.datetime(2022, 1, i) for i in range(1, 6)],
})
_, path = _create_single_file(tempdir, table)
dataset = ds.dataset(str(path))

filter_ = pc.is_in(ds.field('A'), pa.array(["a", "b"]))
assert dataset_reader.to_table(dataset, filter=filter_).num_rows == 3

filter_ = pc.hour(ds.field('B')) >= 3
assert dataset_reader.to_table(dataset, filter=filter_).num_rows == 2

days = pc.days_between(ds.field('B'), ds.field("C"))
result = dataset_reader.to_table(dataset, columns={"days": days})
assert result["days"].to_pylist() == [0, 1, 2, 3, 4]


def test_dataset_union(multisourcefs):
child = ds.FileSystemDatasetFactory(
multisourcefs, fs.FileSelector('/plain'),
Expand Down