diff --git a/python/pyarrow/_compute.pxd b/python/pyarrow/_compute.pxd index d18c08e6050..c10867a75b7 100644 --- a/python/pyarrow/_compute.pxd +++ b/python/pyarrow/_compute.pxd @@ -24,10 +24,12 @@ from pyarrow.includes.libarrow cimport * cdef class FunctionOptions(_Weakrefable): cdef: - unique_ptr[CFunctionOptions] wrapped + shared_ptr[CFunctionOptions] wrapped cdef const CFunctionOptions* get_options(self) except NULL - cdef void init(self, unique_ptr[CFunctionOptions] options) + cdef void init(self, const shared_ptr[CFunctionOptions]& sp) + + cdef inline shared_ptr[CFunctionOptions] unwrap(self) cdef CExpression _bind(Expression filter, Schema schema) except * @@ -47,7 +49,3 @@ cdef class Expression(_Weakrefable): @staticmethod cdef Expression _expr_or_scalar(object expr) - - @staticmethod - cdef Expression _call(str function_name, list arguments, - shared_ptr[CFunctionOptions] options=*) diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index 28c0c87543a..1ec96d33a02 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -537,8 +537,11 @@ cdef class FunctionOptions(_Weakrefable): cdef const CFunctionOptions* get_options(self) except NULL: return self.wrapped.get() - cdef void init(self, unique_ptr[CFunctionOptions] options): - self.wrapped = move(options) + cdef void init(self, const shared_ptr[CFunctionOptions]& sp): + self.wrapped = sp + + cdef inline shared_ptr[CFunctionOptions] unwrap(self): + return self.wrapped def serialize(self): cdef: @@ -560,15 +563,15 @@ cdef class FunctionOptions(_Weakrefable): shared_ptr[CBuffer] c_buf = pyarrow_unwrap_buffer(buf) CResult[unique_ptr[CFunctionOptions]] maybe_options = \ DeserializeFunctionOptions(deref(c_buf)) - unique_ptr[CFunctionOptions] c_options - c_options = move(GetResultValue(move(maybe_options))) + shared_ptr[CFunctionOptions] c_options + c_options = to_shared(GetResultValue(move(maybe_options))) type_name = frombytes(c_options.get().options_type().type_name()) module = globals() if type_name not in module: raise ValueError(f'Cannot deserialize "{type_name}"') klass = module[type_name] options = klass.__new__(klass) - ( options).init(move(c_options)) + ( options).init(c_options) return options def __repr__(self): @@ -597,15 +600,17 @@ def _raise_invalid_function_option(value, description, *, cdef class _CastOptions(FunctionOptions): cdef CCastOptions* options - cdef void init(self, unique_ptr[CFunctionOptions] options): - FunctionOptions.init(self, move(options)) + cdef void init(self, const shared_ptr[CFunctionOptions]& sp): + FunctionOptions.init(self, sp) self.options = self.wrapped.get() def _set_options(self, DataType target_type, allow_int_overflow, allow_time_truncate, allow_time_overflow, allow_decimal_truncate, allow_float_truncate, allow_invalid_utf8): - self.init(unique_ptr[CFunctionOptions](new CCastOptions())) + cdef: + shared_ptr[CCastOptions] wrapped = make_shared[CCastOptions]() + self.init( wrapped) self._set_type(target_type) if allow_int_overflow is not None: self.allow_int_overflow = allow_int_overflow @@ -626,11 +631,11 @@ cdef class _CastOptions(FunctionOptions): ( ensure_type(target_type)).sp_type def _set_safe(self): - self.init(unique_ptr[CFunctionOptions]( + self.init(shared_ptr[CFunctionOptions]( new CCastOptions(CCastOptions.Safe()))) def _set_unsafe(self): - self.init(unique_ptr[CFunctionOptions]( + self.init(shared_ptr[CFunctionOptions]( new CCastOptions(CCastOptions.Unsafe()))) def is_safe(self): @@ -2020,17 +2025,21 @@ cdef class Expression(_Weakrefable): return ( Expression._scalar(expr)) @staticmethod - cdef Expression _call(str function_name, list arguments, - shared_ptr[CFunctionOptions] options=( - nullptr)): + def _call(str function_name, list arguments, FunctionOptions options=None): cdef: vector[CExpression] c_arguments + shared_ptr[CFunctionOptions] c_options for argument in arguments: + if not isinstance(argument, Expression): + raise TypeError("only other expressions allowed as arguments") c_arguments.push_back(( argument).expr) - return Expression.wrap(CMakeCallExpression(tobytes(function_name), - move(c_arguments), options)) + if options is not None: + c_options = options.unwrap() + + return Expression.wrap(CMakeCallExpression( + tobytes(function_name), move(c_arguments), c_options)) def __richcmp__(self, other, int op): other = Expression._expr_or_scalar(other) @@ -2083,32 +2092,21 @@ cdef class Expression(_Weakrefable): def is_null(self, bint nan_is_null=False): """Checks whether the expression is null""" - cdef: - shared_ptr[CFunctionOptions] c_options - - c_options.reset(new CNullOptions(nan_is_null)) - return Expression._call("is_null", [self], c_options) + options = NullOptions(nan_is_null=nan_is_null) + return Expression._call("is_null", [self], options) def cast(self, type, bint safe=True): """Explicitly change the expression's data type""" - cdef shared_ptr[CCastOptions] c_options - c_options.reset(new CCastOptions(safe)) - c_options.get().to_type = pyarrow_unwrap_data_type(ensure_type(type)) - return Expression._call("cast", [self], - c_options) + options = CastOptions.safe(ensure_type(type)) + return Expression._call("cast", [self], options) def isin(self, values): """Checks whether the expression is contained in values""" - cdef: - shared_ptr[CFunctionOptions] c_options - CDatum c_values - if not isinstance(values, Array): values = lib.array(values) - c_values = CDatum(pyarrow_unwrap_array(values)) - c_options.reset(new CSetLookupOptions(c_values, True)) - return Expression._call("is_in", [self], c_options) + options = SetLookupOptions(values) + return Expression._call("is_in", [self], options) @staticmethod def _field(str name not None): diff --git a/python/pyarrow/compute.py b/python/pyarrow/compute.py index fdd100b289d..11be3b6ffba 100644 --- a/python/pyarrow/compute.py +++ b/python/pyarrow/compute.py @@ -227,6 +227,8 @@ def wrapper(*args, memory_pool=None): f"{func_name} takes {arity} positional argument(s), " f"but {len(args)} were given" ) + if args and isinstance(args[0], Expression): + return Expression._call(func_name, list(args)) return func.call(args, None, memory_pool) else: def wrapper(*args, memory_pool=None, options=None, **kwargs): @@ -242,6 +244,8 @@ def wrapper(*args, memory_pool=None, options=None, **kwargs): option_args = () options = _handle_options(func_name, options_class, options, option_args, kwargs) + if args and isinstance(args[0], Expression): + return Expression._call(func_name, list(args), options) return func.call(args, options, memory_pool) return wrapper diff --git a/python/pyarrow/tests/test_compute.py b/python/pyarrow/tests/test_compute.py index 517f2ff6ad4..e97b2f53eb6 100644 --- a/python/pyarrow/tests/test_compute.py +++ b/python/pyarrow/tests/test_compute.py @@ -2628,3 +2628,23 @@ def test_expression_boolean_operators(): with pytest.raises(ValueError, match="cannot be evaluated to python True"): not true + + +def test_expression_call_function(): + field = pc.field("field") + + # no options + assert str(pc.hour(field)) == "hour(field)" + + # default options + assert str(pc.round(field)) == "round(field)" + # specified options + assert str(pc.round(field, ndigits=1)) == \ + "round(field, {ndigits=1, round_mode=HALF_TO_EVEN})" + + # mixed types are not (yet) allowed + with pytest.raises(TypeError): + pc.add(field, 1) + + with pytest.raises(TypeError): + pc.add(1, field) diff --git a/python/pyarrow/tests/test_dataset.py b/python/pyarrow/tests/test_dataset.py index f2dccc478d0..bddf8a57d37 100644 --- a/python/pyarrow/tests/test_dataset.py +++ b/python/pyarrow/tests/test_dataset.py @@ -18,6 +18,7 @@ import contextlib import os import posixpath +import datetime import pathlib import pickle import textwrap @@ -29,6 +30,7 @@ import pytest import pyarrow as pa +import pyarrow.compute as pc import pyarrow.csv import pyarrow.feather import pyarrow.fs as fs @@ -2512,6 +2514,26 @@ def test_filter_equal_null(tempdir, dataset_reader): assert table.num_rows == 0 +def test_filter_compute_expression(tempdir, dataset_reader): + table = pa.table({ + "A": ["a", "b", None, "a", "c"], + "B": [datetime.datetime(2022, 1, 1, i) for i in range(5)], + "C": [datetime.datetime(2022, 1, i) for i in range(1, 6)], + }) + _, path = _create_single_file(tempdir, table) + dataset = ds.dataset(str(path)) + + filter_ = pc.is_in(ds.field('A'), pa.array(["a", "b"])) + assert dataset_reader.to_table(dataset, filter=filter_).num_rows == 3 + + filter_ = pc.hour(ds.field('B')) >= 3 + assert dataset_reader.to_table(dataset, filter=filter_).num_rows == 2 + + days = pc.days_between(ds.field('B'), ds.field("C")) + result = dataset_reader.to_table(dataset, columns={"days": days}) + assert result["days"].to_pylist() == [0, 1, 2, 3, 4] + + def test_dataset_union(multisourcefs): child = ds.FileSystemDatasetFactory( multisourcefs, fs.FileSelector('/plain'),