From 9dd2317f8691e1ab719cbf1e7d1feba71510e3b5 Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Thu, 9 Dec 2021 16:28:02 +0100 Subject: [PATCH 1/4] ARROW-12060: [Python] Enable calling compute functions on Expressions --- python/pyarrow/_compute.pyx | 16 ++++++++++++++++ python/pyarrow/compute.py | 22 ++++++++++++++++++++++ 2 files changed, 38 insertions(+) diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index 28c0c87543a..6892ab94398 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -2019,6 +2019,22 @@ cdef class Expression(_Weakrefable): return ( expr) return ( Expression._scalar(expr)) + @staticmethod + def _call_function(str function_name, arguments, options=None): + cdef: + vector[CExpression] c_arguments + shared_ptr[CFunctionOptions] c_options=( + nullptr) + + for argument in arguments: + c_arguments.push_back(( argument).expr) + + # if options is not None: + # c_options = make_shared[CFunctionOptions](options.get_options()) + + return Expression.wrap(CMakeCallExpression( + tobytes(function_name), move(c_arguments), c_options)) + @staticmethod cdef Expression _call(str function_name, list arguments, shared_ptr[CFunctionOptions] options=( diff --git a/python/pyarrow/compute.py b/python/pyarrow/compute.py index fdd100b289d..09aeee44ed7 100644 --- a/python/pyarrow/compute.py +++ b/python/pyarrow/compute.py @@ -89,6 +89,20 @@ from pyarrow.vendored import docscrape +ds = None + + +def has_dataset_module(): + global ds + + if ds is None: + try: + import pyarrow.dataset as ds + except ImportError: + ds = False + return ds + + def _get_arg_names(func): return func._doc.arg_names @@ -227,6 +241,9 @@ def wrapper(*args, memory_pool=None): f"{func_name} takes {arity} positional argument(s), " f"but {len(args)} were given" ) + if has_dataset_module(): + if isinstance(args[0], ds.Expression): + return ds.Expression._call_function(func_name, args) return func.call(args, None, memory_pool) else: def wrapper(*args, memory_pool=None, options=None, **kwargs): @@ -240,6 +257,11 @@ def wrapper(*args, memory_pool=None, options=None, **kwargs): args = args[:arity] else: option_args = () + if has_dataset_module(): + if isinstance(args[0], ds.Expression): + return ds.Expression._call_function( + func_name, args, options + ) options = _handle_options(func_name, options_class, options, option_args, kwargs) return func.call(args, options, memory_pool) From 3b4b0fcec2769385dcb11221dd64de729481e3e4 Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Mon, 17 Jan 2022 16:40:00 +0100 Subject: [PATCH 2/4] clean-up + add tests --- python/pyarrow/_compute.pyx | 5 ++++- python/pyarrow/compute.py | 28 ++++++---------------------- python/pyarrow/tests/test_compute.py | 20 ++++++++++++++++++++ python/pyarrow/tests/test_dataset.py | 23 +++++++++++++++++++++++ 4 files changed, 53 insertions(+), 23 deletions(-) diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index 6892ab94398..294cb5a190d 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -2020,13 +2020,16 @@ cdef class Expression(_Weakrefable): return ( Expression._scalar(expr)) @staticmethod - def _call_function(str function_name, arguments, options=None): + def _call_function(str function_name, arguments, + FunctionOptions options=None): cdef: vector[CExpression] c_arguments shared_ptr[CFunctionOptions] c_options=( nullptr) for argument in arguments: + if not isinstance(argument, Expression): + raise TypeError("only other expressions allowed as arguments") c_arguments.push_back(( argument).expr) # if options is not None: diff --git a/python/pyarrow/compute.py b/python/pyarrow/compute.py index 09aeee44ed7..3698bacfb0c 100644 --- a/python/pyarrow/compute.py +++ b/python/pyarrow/compute.py @@ -89,20 +89,6 @@ from pyarrow.vendored import docscrape -ds = None - - -def has_dataset_module(): - global ds - - if ds is None: - try: - import pyarrow.dataset as ds - except ImportError: - ds = False - return ds - - def _get_arg_names(func): return func._doc.arg_names @@ -241,9 +227,8 @@ def wrapper(*args, memory_pool=None): f"{func_name} takes {arity} positional argument(s), " f"but {len(args)} were given" ) - if has_dataset_module(): - if isinstance(args[0], ds.Expression): - return ds.Expression._call_function(func_name, args) + if isinstance(args[0], Expression): + return Expression._call_function(func_name, args) return func.call(args, None, memory_pool) else: def wrapper(*args, memory_pool=None, options=None, **kwargs): @@ -257,13 +242,12 @@ def wrapper(*args, memory_pool=None, options=None, **kwargs): args = args[:arity] else: option_args = () - if has_dataset_module(): - if isinstance(args[0], ds.Expression): - return ds.Expression._call_function( - func_name, args, options - ) options = _handle_options(func_name, options_class, options, option_args, kwargs) + if isinstance(args[0], Expression): + return Expression._call_function( + func_name, args, options + ) return func.call(args, options, memory_pool) return wrapper diff --git a/python/pyarrow/tests/test_compute.py b/python/pyarrow/tests/test_compute.py index 517f2ff6ad4..02b9bd29e81 100644 --- a/python/pyarrow/tests/test_compute.py +++ b/python/pyarrow/tests/test_compute.py @@ -2628,3 +2628,23 @@ def test_expression_boolean_operators(): with pytest.raises(ValueError, match="cannot be evaluated to python True"): not true + + +def test_expression_call_function(): + field = pc.field("field") + + # no options + assert str(pc.hour(field)) == "hour(field)" + + # default options + assert str(pc.round(field)) == "round(field)" + # TODO + # specified options + # pc.round(field, ndigits=1) + + # mixed types are not (yet) allowed + with pytest.raises(TypeError): + pc.add(field, 1) + + with pytest.raises(TypeError): + pc.add(1, field) diff --git a/python/pyarrow/tests/test_dataset.py b/python/pyarrow/tests/test_dataset.py index f2dccc478d0..7663c08fcf0 100644 --- a/python/pyarrow/tests/test_dataset.py +++ b/python/pyarrow/tests/test_dataset.py @@ -18,6 +18,7 @@ import contextlib import os import posixpath +import datetime import pathlib import pickle import textwrap @@ -29,6 +30,7 @@ import pytest import pyarrow as pa +import pyarrow.compute as pc import pyarrow.csv import pyarrow.feather import pyarrow.fs as fs @@ -2512,6 +2514,27 @@ def test_filter_equal_null(tempdir, dataset_reader): assert table.num_rows == 0 +def test_filter_compute_expression(tempdir, dataset_reader): + table = pa.table({ + "A": ["a", "b", None, "a", "c"], + "B": [datetime.datetime(2022, 1, 1, i) for i in range(5)], + "C": [datetime.datetime(2022, 1, i) for i in range(1, 6)], + }) + _, path = _create_single_file(tempdir, table) + dataset = ds.dataset(str(path)) + + # TODO needs options + # filter_ = pc.is_in(ds.field('A'), ["a", "b"]) + # assert dataset_reader.to_table(dataset, filter=filter_).num_rows == 3 + + filter_ = pc.hour(ds.field('B')) >= 3 + assert dataset_reader.to_table(dataset, filter=filter_).num_rows == 2 + + days = pc.days_between(ds.field('B'), ds.field("C")) + result = dataset_reader.to_table(dataset, columns={"days": days}) + assert result["days"].to_pylist() == [0, 1, 2, 3, 4] + + def test_dataset_union(multisourcefs): child = ds.FileSystemDatasetFactory( multisourcefs, fs.FileSelector('/plain'), From b28f6dfab8f939b43e8c1c75abbd975cf99e789d Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Mon, 17 Jan 2022 19:57:34 +0100 Subject: [PATCH 3/4] fix options - FunctionOptions.wrapped from unique_ptr -> shared_ptr --- python/pyarrow/_compute.pxd | 10 ++-- python/pyarrow/_compute.pyx | 68 ++++++++++------------------ python/pyarrow/compute.py | 10 ++-- python/pyarrow/tests/test_compute.py | 4 +- python/pyarrow/tests/test_dataset.py | 5 +- 5 files changed, 36 insertions(+), 61 deletions(-) diff --git a/python/pyarrow/_compute.pxd b/python/pyarrow/_compute.pxd index d18c08e6050..c10867a75b7 100644 --- a/python/pyarrow/_compute.pxd +++ b/python/pyarrow/_compute.pxd @@ -24,10 +24,12 @@ from pyarrow.includes.libarrow cimport * cdef class FunctionOptions(_Weakrefable): cdef: - unique_ptr[CFunctionOptions] wrapped + shared_ptr[CFunctionOptions] wrapped cdef const CFunctionOptions* get_options(self) except NULL - cdef void init(self, unique_ptr[CFunctionOptions] options) + cdef void init(self, const shared_ptr[CFunctionOptions]& sp) + + cdef inline shared_ptr[CFunctionOptions] unwrap(self) cdef CExpression _bind(Expression filter, Schema schema) except * @@ -47,7 +49,3 @@ cdef class Expression(_Weakrefable): @staticmethod cdef Expression _expr_or_scalar(object expr) - - @staticmethod - cdef Expression _call(str function_name, list arguments, - shared_ptr[CFunctionOptions] options=*) diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index 294cb5a190d..cd8cafd879a 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -537,8 +537,11 @@ cdef class FunctionOptions(_Weakrefable): cdef const CFunctionOptions* get_options(self) except NULL: return self.wrapped.get() - cdef void init(self, unique_ptr[CFunctionOptions] options): - self.wrapped = move(options) + cdef void init(self, const shared_ptr[CFunctionOptions]& sp): + self.wrapped = sp + + cdef inline shared_ptr[CFunctionOptions] unwrap(self): + return self.wrapped def serialize(self): cdef: @@ -560,15 +563,15 @@ cdef class FunctionOptions(_Weakrefable): shared_ptr[CBuffer] c_buf = pyarrow_unwrap_buffer(buf) CResult[unique_ptr[CFunctionOptions]] maybe_options = \ DeserializeFunctionOptions(deref(c_buf)) - unique_ptr[CFunctionOptions] c_options - c_options = move(GetResultValue(move(maybe_options))) + shared_ptr[CFunctionOptions] c_options + c_options = to_shared(GetResultValue(move(maybe_options))) type_name = frombytes(c_options.get().options_type().type_name()) module = globals() if type_name not in module: raise ValueError(f'Cannot deserialize "{type_name}"') klass = module[type_name] options = klass.__new__(klass) - ( options).init(move(c_options)) + ( options).init(c_options) return options def __repr__(self): @@ -597,15 +600,17 @@ def _raise_invalid_function_option(value, description, *, cdef class _CastOptions(FunctionOptions): cdef CCastOptions* options - cdef void init(self, unique_ptr[CFunctionOptions] options): - FunctionOptions.init(self, move(options)) + cdef void init(self, const shared_ptr[CFunctionOptions]& sp): + FunctionOptions.init(self, sp) self.options = self.wrapped.get() def _set_options(self, DataType target_type, allow_int_overflow, allow_time_truncate, allow_time_overflow, allow_decimal_truncate, allow_float_truncate, allow_invalid_utf8): - self.init(unique_ptr[CFunctionOptions](new CCastOptions())) + cdef: + shared_ptr[CCastOptions] wrapped = make_shared[CCastOptions]() + self.init( wrapped) self._set_type(target_type) if allow_int_overflow is not None: self.allow_int_overflow = allow_int_overflow @@ -626,11 +631,11 @@ cdef class _CastOptions(FunctionOptions): ( ensure_type(target_type)).sp_type def _set_safe(self): - self.init(unique_ptr[CFunctionOptions]( + self.init(shared_ptr[CFunctionOptions]( new CCastOptions(CCastOptions.Safe()))) def _set_unsafe(self): - self.init(unique_ptr[CFunctionOptions]( + self.init(shared_ptr[CFunctionOptions]( new CCastOptions(CCastOptions.Unsafe()))) def is_safe(self): @@ -2020,8 +2025,7 @@ cdef class Expression(_Weakrefable): return ( Expression._scalar(expr)) @staticmethod - def _call_function(str function_name, arguments, - FunctionOptions options=None): + def _call(str function_name, list arguments, FunctionOptions options=None): cdef: vector[CExpression] c_arguments shared_ptr[CFunctionOptions] c_options=( @@ -2032,25 +2036,12 @@ cdef class Expression(_Weakrefable): raise TypeError("only other expressions allowed as arguments") c_arguments.push_back(( argument).expr) - # if options is not None: - # c_options = make_shared[CFunctionOptions](options.get_options()) + if options is not None: + c_options = options.unwrap() return Expression.wrap(CMakeCallExpression( tobytes(function_name), move(c_arguments), c_options)) - @staticmethod - cdef Expression _call(str function_name, list arguments, - shared_ptr[CFunctionOptions] options=( - nullptr)): - cdef: - vector[CExpression] c_arguments - - for argument in arguments: - c_arguments.push_back(( argument).expr) - - return Expression.wrap(CMakeCallExpression(tobytes(function_name), - move(c_arguments), options)) - def __richcmp__(self, other, int op): other = Expression._expr_or_scalar(other) return Expression._call({ @@ -2102,32 +2093,21 @@ cdef class Expression(_Weakrefable): def is_null(self, bint nan_is_null=False): """Checks whether the expression is null""" - cdef: - shared_ptr[CFunctionOptions] c_options - - c_options.reset(new CNullOptions(nan_is_null)) - return Expression._call("is_null", [self], c_options) + options = NullOptions(nan_is_null=nan_is_null) + return Expression._call("is_null", [self], options) def cast(self, type, bint safe=True): """Explicitly change the expression's data type""" - cdef shared_ptr[CCastOptions] c_options - c_options.reset(new CCastOptions(safe)) - c_options.get().to_type = pyarrow_unwrap_data_type(ensure_type(type)) - return Expression._call("cast", [self], - c_options) + options = CastOptions.safe(ensure_type(type)) + return Expression._call("cast", [self], options) def isin(self, values): """Checks whether the expression is contained in values""" - cdef: - shared_ptr[CFunctionOptions] c_options - CDatum c_values - if not isinstance(values, Array): values = lib.array(values) - c_values = CDatum(pyarrow_unwrap_array(values)) - c_options.reset(new CSetLookupOptions(c_values, True)) - return Expression._call("is_in", [self], c_options) + options = SetLookupOptions(values) + return Expression._call("is_in", [self], options) @staticmethod def _field(str name not None): diff --git a/python/pyarrow/compute.py b/python/pyarrow/compute.py index 3698bacfb0c..11be3b6ffba 100644 --- a/python/pyarrow/compute.py +++ b/python/pyarrow/compute.py @@ -227,8 +227,8 @@ def wrapper(*args, memory_pool=None): f"{func_name} takes {arity} positional argument(s), " f"but {len(args)} were given" ) - if isinstance(args[0], Expression): - return Expression._call_function(func_name, args) + if args and isinstance(args[0], Expression): + return Expression._call(func_name, list(args)) return func.call(args, None, memory_pool) else: def wrapper(*args, memory_pool=None, options=None, **kwargs): @@ -244,10 +244,8 @@ def wrapper(*args, memory_pool=None, options=None, **kwargs): option_args = () options = _handle_options(func_name, options_class, options, option_args, kwargs) - if isinstance(args[0], Expression): - return Expression._call_function( - func_name, args, options - ) + if args and isinstance(args[0], Expression): + return Expression._call(func_name, list(args), options) return func.call(args, options, memory_pool) return wrapper diff --git a/python/pyarrow/tests/test_compute.py b/python/pyarrow/tests/test_compute.py index 02b9bd29e81..e97b2f53eb6 100644 --- a/python/pyarrow/tests/test_compute.py +++ b/python/pyarrow/tests/test_compute.py @@ -2638,9 +2638,9 @@ def test_expression_call_function(): # default options assert str(pc.round(field)) == "round(field)" - # TODO # specified options - # pc.round(field, ndigits=1) + assert str(pc.round(field, ndigits=1)) == \ + "round(field, {ndigits=1, round_mode=HALF_TO_EVEN})" # mixed types are not (yet) allowed with pytest.raises(TypeError): diff --git a/python/pyarrow/tests/test_dataset.py b/python/pyarrow/tests/test_dataset.py index 7663c08fcf0..bddf8a57d37 100644 --- a/python/pyarrow/tests/test_dataset.py +++ b/python/pyarrow/tests/test_dataset.py @@ -2523,9 +2523,8 @@ def test_filter_compute_expression(tempdir, dataset_reader): _, path = _create_single_file(tempdir, table) dataset = ds.dataset(str(path)) - # TODO needs options - # filter_ = pc.is_in(ds.field('A'), ["a", "b"]) - # assert dataset_reader.to_table(dataset, filter=filter_).num_rows == 3 + filter_ = pc.is_in(ds.field('A'), pa.array(["a", "b"])) + assert dataset_reader.to_table(dataset, filter=filter_).num_rows == 3 filter_ = pc.hour(ds.field('B')) >= 3 assert dataset_reader.to_table(dataset, filter=filter_).num_rows == 2 From b1a4d36ea3c4740c7bc9ec9a264613b8613884aa Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Tue, 18 Jan 2022 17:33:58 +0100 Subject: [PATCH 4/4] simplify c_options initialization --- python/pyarrow/_compute.pyx | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index cd8cafd879a..1ec96d33a02 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -2028,8 +2028,7 @@ cdef class Expression(_Weakrefable): def _call(str function_name, list arguments, FunctionOptions options=None): cdef: vector[CExpression] c_arguments - shared_ptr[CFunctionOptions] c_options=( - nullptr) + shared_ptr[CFunctionOptions] c_options for argument in arguments: if not isinstance(argument, Expression):