From 94eff67c913a515e9056d580f82b7cff288cbf36 Mon Sep 17 00:00:00 2001 From: Vibhatha Lakmal Abeykoon Date: Wed, 5 Oct 2022 11:11:28 +0530 Subject: [PATCH 01/13] feat(multi-kernel): added multi kernel registration and updated test cases --- python/pyarrow/_compute.pyx | 53 +++++--- python/pyarrow/includes/libarrow.pxd | 4 +- python/pyarrow/src/arrow/python/udf.cc | 42 ++++--- python/pyarrow/src/arrow/python/udf.h | 4 +- python/pyarrow/tests/test_udf.py | 160 ++++++++++++++++++------- 5 files changed, 184 insertions(+), 79 deletions(-) diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index c75c5bf189b..d5249d1c2e1 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -2591,8 +2591,8 @@ def _get_scalar_udf_context(memory_pool, batch_length): return context -def register_scalar_function(func, function_name, function_doc, in_types, - out_type): +def register_scalar_function(func, function_name, function_doc, in_arg_types, + out_types): """ Register a user-defined scalar function. @@ -2624,15 +2624,15 @@ def register_scalar_function(func, function_name, function_doc, in_types, function_doc : dict A dictionary object with keys "summary" (str), and "description" (str). - in_types : Dict[str, DataType] - A dictionary mapping function argument names to + in_arg_types : List[Dict[str, DataType]] + A list of dictionary mapping function argument names to their respective DataType. The argument names will be used to generate documentation for the function. The number of arguments specified here determines the function arity. - out_type : DataType - Output type of the function. + out_types : List[DataType] + Output types of the function. Examples -------- @@ -2666,9 +2666,10 @@ def register_scalar_function(func, function_name, function_doc, in_types, c_string c_func_name CArity c_arity CFunctionDoc c_func_doc + vector[vector[shared_ptr[CDataType]]] vec_c_in_types vector[shared_ptr[CDataType]] c_in_types PyObject* c_function - shared_ptr[CDataType] c_out_type + vector[shared_ptr[CDataType]] c_out_types CScalarUdfOptions c_options if callable(func): @@ -2680,15 +2681,29 @@ def register_scalar_function(func, function_name, function_doc, in_types, func_spec = inspect.getfullargspec(func) num_args = -1 - if isinstance(in_types, dict): - for in_type in in_types.values(): - c_in_types.push_back( - pyarrow_unwrap_data_type(ensure_type(in_type))) - function_doc["arg_names"] = in_types.keys() - num_args = len(in_types) + if not isinstance(in_arg_types, list): + raise TypeError( + "in_arg_types must be a list of dictionaries of DataTypes") + if not isinstance(out_types, list): + raise TypeError("out_types must be a list of DataTypes") + # each input_type dict in input_types list must + # have same arg_names + if isinstance(in_arg_types[0], dict): + function_doc["arg_names"] = in_arg_types[0].keys() + num_args = len(in_arg_types[0]) else: raise TypeError( - "in_types must be a dictionary of DataType") + "Elements in in_arg_types must be a dictionary of DataTypes") + + for in_types in in_arg_types: + if isinstance(in_types, dict): + for in_type in in_types.values(): + c_in_types.push_back( + pyarrow_unwrap_data_type(ensure_type(in_type))) + else: + raise TypeError( + "in_types must be a dictionary of DataType") + vec_c_in_types.push_back(move(c_in_types)) c_arity = CArity( num_args, func_spec.varargs) @@ -2702,14 +2717,18 @@ def register_scalar_function(func, function_name, function_doc, in_types, raise ValueError("Function doc must contain arg_names") c_func_doc = _make_function_doc(function_doc) + for out_type in out_types: + c_out_types.push_back(pyarrow_unwrap_data_type(ensure_type(out_type))) - c_out_type = pyarrow_unwrap_data_type(ensure_type(out_type)) + print("out types : ", len(out_types)) + print("in types : ", vec_c_in_types.size()) + print("in types : ", vec_c_in_types.at(0).size()) c_options.func_name = c_func_name c_options.arity = c_arity c_options.func_doc = c_func_doc - c_options.input_types = c_in_types - c_options.output_type = c_out_type + c_options.input_arg_types = vec_c_in_types + c_options.output_types = c_out_types check_status(RegisterScalarFunction(c_function, &_scalar_udf_callback, c_options)) diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index df6a883afe9..326bef9063d 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -2814,8 +2814,8 @@ cdef extern from "arrow/python/udf.h" namespace "arrow::py": c_string func_name CArity arity CFunctionDoc func_doc - vector[shared_ptr[CDataType]] input_types - shared_ptr[CDataType] output_type + vector[vector[shared_ptr[CDataType]]] input_arg_types + vector[shared_ptr[CDataType]] output_types CStatus RegisterScalarFunction(PyObject* function, function[CallbackUdf] wrapper, const CScalarUdfOptions& options) diff --git a/python/pyarrow/src/arrow/python/udf.cc b/python/pyarrow/src/arrow/python/udf.cc index 81bf47c0ade..a024d8d9e25 100644 --- a/python/pyarrow/src/arrow/python/udf.cc +++ b/python/pyarrow/src/arrow/python/udf.cc @@ -18,6 +18,7 @@ #include "arrow/python/udf.h" #include "arrow/compute/function.h" #include "arrow/python/common.h" +#include "arrow/util/logging.h" namespace arrow { @@ -99,22 +100,33 @@ Status RegisterScalarFunction(PyObject* user_function, ScalarUdfWrapperCallback auto scalar_func = std::make_shared( options.func_name, options.arity, options.func_doc); Py_INCREF(user_function); - std::vector input_types; - for (const auto& in_dtype : options.input_types) { - input_types.emplace_back(in_dtype); + + const size_t num_kernels = options.input_arg_types.size(); + // number of input_type variations and output_types must be + // equal in size + if(num_kernels != options.output_types.size()) { + return Status::Invalid("input_arg_types and output_types should be equal in size"); + } + // adding kernels + for(size_t idx=0 ; idx < num_kernels; idx++) { + const auto& opt_input_types = options.input_arg_types[idx]; + std::vector input_types; + for (const auto& in_dtype : opt_input_types) { + input_types.emplace_back(in_dtype); + } + const auto opts_out_type = options.output_types[idx]; + compute::OutputType output_type(opts_out_type); + auto udf_data = std::make_shared( + wrapper, std::make_shared(user_function), opts_out_type); + compute::ScalarKernel kernel( + compute::KernelSignature::Make(std::move(input_types), std::move(output_type), + options.arity.is_varargs), PythonUdfExec); + kernel.data = std::move(udf_data); + + kernel.mem_allocation = compute::MemAllocation::NO_PREALLOCATE; + kernel.null_handling = compute::NullHandling::COMPUTED_NO_PREALLOCATE; + RETURN_NOT_OK(scalar_func->AddKernel(std::move(kernel))); } - compute::OutputType output_type(options.output_type); - auto udf_data = std::make_shared( - wrapper, std::make_shared(user_function), options.output_type); - compute::ScalarKernel kernel( - compute::KernelSignature::Make(std::move(input_types), std::move(output_type), - options.arity.is_varargs), - PythonUdfExec); - kernel.data = std::move(udf_data); - - kernel.mem_allocation = compute::MemAllocation::NO_PREALLOCATE; - kernel.null_handling = compute::NullHandling::COMPUTED_NO_PREALLOCATE; - RETURN_NOT_OK(scalar_func->AddKernel(std::move(kernel))); auto registry = compute::GetFunctionRegistry(); RETURN_NOT_OK(registry->AddFunction(std::move(scalar_func))); return Status::OK(); diff --git a/python/pyarrow/src/arrow/python/udf.h b/python/pyarrow/src/arrow/python/udf.h index 9a3666459fd..2952d890a5d 100644 --- a/python/pyarrow/src/arrow/python/udf.h +++ b/python/pyarrow/src/arrow/python/udf.h @@ -37,8 +37,8 @@ struct ARROW_PYTHON_EXPORT ScalarUdfOptions { std::string func_name; compute::Arity arity; compute::FunctionDoc func_doc; - std::vector> input_types; - std::shared_ptr output_type; + std::vector>> input_arg_types; + std::vector> output_types; }; /// \brief A context passed as the first argument of scalar UDF functions. diff --git a/python/pyarrow/tests/test_udf.py b/python/pyarrow/tests/test_udf.py index e711619582d..9b3f47eb735 100644 --- a/python/pyarrow/tests/test_udf.py +++ b/python/pyarrow/tests/test_udf.py @@ -54,8 +54,8 @@ def unary_function(ctx, x): pc.register_scalar_function(unary_function, func_name, unary_doc, - {"array": pa.int64()}, - pa.int64()) + [{"array": pa.int64()}], + [pa.int64()]) return unary_function, func_name @@ -73,10 +73,12 @@ def binary_function(ctx, m, x): pc.register_scalar_function(binary_function, func_name, binary_doc, - {"m": pa.int64(), - "x": pa.int64(), - }, - pa.int64()) + [ + {"m": pa.int64(), + "x": pa.int64(), + } + ], + [pa.int64()]) return binary_function, func_name @@ -96,12 +98,12 @@ def ternary_function(ctx, m, x, c): pc.register_scalar_function(ternary_function, func_name, ternary_doc, - { + [{ "array1": pa.int64(), "array2": pa.int64(), "array3": pa.int64(), - }, - pa.int64()) + }], + [pa.int64()]) return ternary_function, func_name @@ -123,11 +125,11 @@ def varargs_function(ctx, first, *values): pc.register_scalar_function(varargs_function, func_name, varargs_doc, - { + [{ "array1": pa.int64(), "array2": pa.int64(), - }, - pa.int64()) + }], + [pa.int64()]) return varargs_function, func_name @@ -148,8 +150,8 @@ def nullary_func(context): pc.register_scalar_function(nullary_func, func_name, func_doc, - {}, - pa.int64()) + [{}], + [pa.int64()]) return nullary_func, func_name @@ -164,14 +166,14 @@ def wrong_output_type(ctx): return 42 func_name = "test_wrong_output_type" - in_types = {} - out_type = pa.int64() + in_types = [{}] + out_types = [pa.int64()] doc = { "summary": "return wrong output type", "description": "" } pc.register_scalar_function(wrong_output_type, func_name, doc, - in_types, out_type) + in_types, out_types) return wrong_output_type, func_name @@ -184,15 +186,15 @@ def wrong_output_datatype_func_fixture(): def wrong_output_datatype(ctx, array): return pc.call_function("add", [array, 1]) func_name = "test_wrong_output_datatype" - in_types = {"array": pa.int64()} + in_types = [{"array": pa.int64()}] # The actual output DataType will be int64. - out_type = pa.int16() + out_types = [pa.int16()] doc = { "summary": "return wrong output datatype", "description": "" } pc.register_scalar_function(wrong_output_datatype, func_name, doc, - in_types, out_type) + in_types, out_types) return wrong_output_datatype, func_name @@ -206,14 +208,14 @@ def wrong_signature(): return pa.scalar(1, type=pa.int64()) func_name = "test_wrong_signature" - in_types = {} - out_type = pa.int64() + in_types = [{}] + out_types = [pa.int64()] doc = { "summary": "UDF with wrong signature", "description": "" } pc.register_scalar_function(wrong_signature, func_name, doc, - in_types, out_type) + in_types, out_types) return wrong_signature, func_name @@ -230,7 +232,7 @@ def raising_func(ctx): "description": "" } pc.register_scalar_function(raising_func, func_name, doc, - {}, pa.int64()) + [{}], [pa.int64()]) return raising_func, func_name @@ -311,8 +313,8 @@ def test_registration_errors(): "summary": "test udf input", "description": "parameters are validated" } - in_types = {"scalar": pa.int64()} - out_type = pa.int64() + in_types = [{"scalar": pa.int64()}] + out_types = [pa.int64()] def test_reg_function(context): return pa.array([10]) @@ -320,39 +322,39 @@ def test_reg_function(context): with pytest.raises(TypeError): pc.register_scalar_function(test_reg_function, None, doc, in_types, - out_type) + out_types) # validate function with pytest.raises(TypeError, match="func must be a callable"): pc.register_scalar_function(None, "test_none_function", doc, in_types, - out_type) + out_types) # validate output type - expected_expr = "DataType expected, got " + expected_expr = "out_types must be a list of DataTypes" with pytest.raises(TypeError, match=expected_expr): pc.register_scalar_function(test_reg_function, "test_output_function", doc, in_types, None) # validate input type - expected_expr = "in_types must be a dictionary of DataType" + expected_expr = "in_arg_types must be a list of dictionaries of DataTypes" with pytest.raises(TypeError, match=expected_expr): pc.register_scalar_function(test_reg_function, "test_input_function", doc, None, - out_type) + out_types) # register an already registered function # first registration pc.register_scalar_function(test_reg_function, - "test_reg_function", doc, {}, - out_type) + "test_reg_function", doc, [{}], + out_types) # second registration expected_expr = "Already have a function registered with name:" \ + " test_reg_function" with pytest.raises(KeyError, match=expected_expr): pc.register_scalar_function(test_reg_function, - "test_reg_function", doc, {}, - out_type) + "test_reg_function", doc, [{}], + out_types) def test_varargs_function_validation(varargs_func_fixture): @@ -366,8 +368,8 @@ def test_varargs_function_validation(varargs_func_fixture): def test_function_doc_validation(): # validate arity - in_types = {"scalar": pa.int64()} - out_type = pa.int64() + in_types = [{"scalar": pa.int64()}] + out_type = [pa.int64()] # doc with no summary func_doc = { @@ -396,7 +398,7 @@ def add_const(ctx, scalar): def test_nullary_function(nullary_func_fixture): - # XXX the Python compute layer API doesn't let us override batch_length, + # the Python compute layer API doesn't let us override batch_length, # so only test with the default value of 1. check_scalar_function(nullary_func_fixture, [], run_in_dataset=False, batch_length=1) @@ -435,8 +437,8 @@ def identity(ctx, val): return val func_name = "test_wrong_datatype_declaration" - in_types = {"array": pa.int64()} - out_type = {} + in_types = [{"array": pa.int64()}] + out_type = [{}] doc = { "summary": "test output value", "description": "test output" @@ -452,8 +454,8 @@ def identity(ctx, val): return val func_name = "test_wrong_input_type_declaration" - in_types = {"array": None} - out_type = pa.int64() + in_types = [{"array": None}] + out_types = [pa.int64()] doc = { "summary": "test invalid input type", "description": "invalid input function" @@ -461,7 +463,7 @@ def identity(ctx, val): with pytest.raises(TypeError, match="DataType expected, got "): pc.register_scalar_function(identity, func_name, doc, - in_types, out_type) + in_types, out_types) def test_udf_context(unary_func_fixture): @@ -504,3 +506,75 @@ def test_input_lifetime(unary_func_fixture): # Calling a UDF should not have kept `v` alive longer than required v = None assert proxy_pool.bytes_allocated() == 0 + + +def test_multi_kernel_registration(): + """ + Register a unary scalar function. + """ + def unary_function(ctx, x): + return pc.cast(pc.call_function("multiply", [x, 2], + memory_pool=ctx.memory_pool), x.type) + func_name = "y=x*1" + unary_doc = {"summary": "add function", + "description": "test add function"} + input_types = [ + {"array": pa.int8()}, + {"array": pa.int16()}, + {"array": pa.int32()}, + {"array": pa.int64()}, + {"array": pa.float32()}, + {"array": pa.float64()} + ] + + output_types = [ + pa.int8(), + pa.int16(), + pa.int32(), + pa.int64(), + pa.float32(), + pa.float64() + ] + pc.register_scalar_function(unary_function, + func_name, + unary_doc, + input_types, + output_types) + + for out_type in output_types: + assert pc.call_function(func_name, + [pa.array([10, 20], out_type)]) \ + == pa.array([20, 40], out_type) + + +def test_invalid_multi_kernel_registration(): + """ + Register a unary scalar function. + """ + def unary_function(ctx, x): + return pc.cast(pc.call_function("multiply", [x, 2], + memory_pool=ctx.memory_pool), x.type) + func_name = "y=x*1" + unary_doc = {"summary": "add function", + "description": "test add function"} + input_types = [ + {"array": pa.int8()}, + {"array": pa.int16()}, + {"array": pa.float32()}, + {"array": pa.float64()} + ] + + output_types = [ + pa.int8(), + pa.int16(), + pa.int64(), + pa.float32(), + pa.float64() + ] + error_msg = "input_arg_types and output_types should be equal in size" + with pytest.raises(pa.lib.ArrowInvalid, match=error_msg): + pc.register_scalar_function(unary_function, + func_name, + unary_doc, + input_types, + output_types) From bf126b3c8c50d5b7b13d1b20649720cea5835137 Mon Sep 17 00:00:00 2001 From: Vibhatha Lakmal Abeykoon Date: Wed, 5 Oct 2022 11:13:33 +0530 Subject: [PATCH 02/13] fix(docs): updated args --- python/pyarrow/_compute.pyx | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index d5249d1c2e1..0372a745f6d 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -2647,8 +2647,8 @@ def register_scalar_function(func, function_name, function_doc, in_arg_types, ... return pc.add(array, 1, memory_pool=ctx.memory_pool) >>> >>> func_name = "py_add_func" - >>> in_types = {"array": pa.int64()} - >>> out_type = pa.int64() + >>> in_types = [{"array": pa.int64()}] + >>> out_type = [pa.int64()] >>> pc.register_scalar_function(add_constant, func_name, func_doc, ... in_types, out_type) >>> From 65f5772d9e9840089cb53b83d79aed8d6a50818d Mon Sep 17 00:00:00 2001 From: Vibhatha Lakmal Abeykoon Date: Wed, 5 Oct 2022 11:29:14 +0530 Subject: [PATCH 03/13] fix(docs): fixed error messages --- python/pyarrow/_compute.pyx | 10 +++------- python/pyarrow/tests/test_udf.py | 31 ++++++++++--------------------- 2 files changed, 13 insertions(+), 28 deletions(-) diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index 0372a745f6d..5ced4fee8b1 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -2683,7 +2683,7 @@ def register_scalar_function(func, function_name, function_doc, in_arg_types, num_args = -1 if not isinstance(in_arg_types, list): raise TypeError( - "in_arg_types must be a list of dictionaries of DataTypes") + "in_arg_types must be a list of dictionaries of DataType") if not isinstance(out_types, list): raise TypeError("out_types must be a list of DataTypes") # each input_type dict in input_types list must @@ -2693,7 +2693,7 @@ def register_scalar_function(func, function_name, function_doc, in_arg_types, num_args = len(in_arg_types[0]) else: raise TypeError( - "Elements in in_arg_types must be a dictionary of DataTypes") + "Elements in in_arg_types must be a dictionary of DataType") for in_types in in_arg_types: if isinstance(in_types, dict): @@ -2702,7 +2702,7 @@ def register_scalar_function(func, function_name, function_doc, in_arg_types, pyarrow_unwrap_data_type(ensure_type(in_type))) else: raise TypeError( - "in_types must be a dictionary of DataType") + "Elements in in_arg_types must be a dictionary of DataType") vec_c_in_types.push_back(move(c_in_types)) c_arity = CArity( num_args, func_spec.varargs) @@ -2720,10 +2720,6 @@ def register_scalar_function(func, function_name, function_doc, in_arg_types, for out_type in out_types: c_out_types.push_back(pyarrow_unwrap_data_type(ensure_type(out_type))) - print("out types : ", len(out_types)) - print("in types : ", vec_c_in_types.size()) - print("in types : ", vec_c_in_types.at(0).size()) - c_options.func_name = c_func_name c_options.arity = c_arity c_options.func_doc = c_func_doc diff --git a/python/pyarrow/tests/test_udf.py b/python/pyarrow/tests/test_udf.py index 9b3f47eb735..4040d08934c 100644 --- a/python/pyarrow/tests/test_udf.py +++ b/python/pyarrow/tests/test_udf.py @@ -337,7 +337,7 @@ def test_reg_function(context): None) # validate input type - expected_expr = "in_arg_types must be a list of dictionaries of DataTypes" + expected_expr = "in_arg_types must be a list of dictionaries of DataType" with pytest.raises(TypeError, match=expected_expr): pc.register_scalar_function(test_reg_function, "test_input_function", doc, None, @@ -509,15 +509,12 @@ def test_input_lifetime(unary_func_fixture): def test_multi_kernel_registration(): - """ - Register a unary scalar function. - """ def unary_function(ctx, x): return pc.cast(pc.call_function("multiply", [x, 2], memory_pool=ctx.memory_pool), x.type) - func_name = "y=x*1" - unary_doc = {"summary": "add function", - "description": "test add function"} + func_name = "y=x*2" + unary_doc = {"summary": "multiply by two function", + "description": "test multiply function"} input_types = [ {"array": pa.int8()}, {"array": pa.int16()}, @@ -548,28 +545,20 @@ def unary_function(ctx, x): def test_invalid_multi_kernel_registration(): - """ - Register a unary scalar function. - """ def unary_function(ctx, x): - return pc.cast(pc.call_function("multiply", [x, 2], - memory_pool=ctx.memory_pool), x.type) - func_name = "y=x*1" - unary_doc = {"summary": "add function", - "description": "test add function"} + return x + func_name = "y=x" + unary_doc = {"summary": "pass value function", + "description": "test function"} input_types = [ {"array": pa.int8()}, - {"array": pa.int16()}, - {"array": pa.float32()}, - {"array": pa.float64()} + {"array": pa.int16()} ] output_types = [ pa.int8(), pa.int16(), - pa.int64(), - pa.float32(), - pa.float64() + pa.int64() ] error_msg = "input_arg_types and output_types should be equal in size" with pytest.raises(pa.lib.ArrowInvalid, match=error_msg): From caabab4911f1cdc1e27f99f1d7196a857e0fce5c Mon Sep 17 00:00:00 2001 From: Vibhatha Lakmal Abeykoon Date: Wed, 5 Oct 2022 11:29:55 +0530 Subject: [PATCH 04/13] fix(minor): naming exception --- python/pyarrow/_compute.pyx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index 5ced4fee8b1..5d7ccbc6873 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -2685,7 +2685,7 @@ def register_scalar_function(func, function_name, function_doc, in_arg_types, raise TypeError( "in_arg_types must be a list of dictionaries of DataType") if not isinstance(out_types, list): - raise TypeError("out_types must be a list of DataTypes") + raise TypeError("out_types must be a list of DataType") # each input_type dict in input_types list must # have same arg_names if isinstance(in_arg_types[0], dict): From 52b70953d3680149443cc155770cae7db94fac6e Mon Sep 17 00:00:00 2001 From: Vibhatha Lakmal Abeykoon Date: Wed, 5 Oct 2022 11:32:39 +0530 Subject: [PATCH 05/13] fix(docs): update docstring --- python/pyarrow/_compute.pyx | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index 5d7ccbc6873..45769c0fa88 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -2632,7 +2632,9 @@ def register_scalar_function(func, function_name, function_doc, in_arg_types, arguments specified here determines the function arity. out_types : List[DataType] - Output types of the function. + A list of output types of the function. + Corresponding to the input types, the output type of + the function can be varied. Examples -------- From 9a5d52153995160a5dede295b088e7dea893c9d5 Mon Sep 17 00:00:00 2001 From: Vibhatha Lakmal Abeykoon Date: Wed, 5 Oct 2022 11:35:27 +0530 Subject: [PATCH 06/13] fix(typo): fixed minor typos --- python/pyarrow/tests/test_udf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyarrow/tests/test_udf.py b/python/pyarrow/tests/test_udf.py index 4040d08934c..6bb023eed7b 100644 --- a/python/pyarrow/tests/test_udf.py +++ b/python/pyarrow/tests/test_udf.py @@ -330,7 +330,7 @@ def test_reg_function(context): out_types) # validate output type - expected_expr = "out_types must be a list of DataTypes" + expected_expr = "out_types must be a list of DataType" with pytest.raises(TypeError, match=expected_expr): pc.register_scalar_function(test_reg_function, "test_output_function", doc, in_types, From 43ec148a407760106007be23fcfe39a225bb7d70 Mon Sep 17 00:00:00 2001 From: Vibhatha Lakmal Abeykoon Date: Tue, 11 Oct 2022 11:29:29 +0530 Subject: [PATCH 07/13] fix(rebase) --- python/pyarrow/src/arrow/python/udf.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/python/pyarrow/src/arrow/python/udf.cc b/python/pyarrow/src/arrow/python/udf.cc index a024d8d9e25..5ac20a9ac11 100644 --- a/python/pyarrow/src/arrow/python/udf.cc +++ b/python/pyarrow/src/arrow/python/udf.cc @@ -18,7 +18,6 @@ #include "arrow/python/udf.h" #include "arrow/compute/function.h" #include "arrow/python/common.h" -#include "arrow/util/logging.h" namespace arrow { From 90792bf90df5cc1dc4271ae449b98cc56879c3ae Mon Sep 17 00:00:00 2001 From: vibhatha Date: Tue, 3 Jan 2023 23:00:46 +0530 Subject: [PATCH 08/13] fix(reviews): addressing reviews for input_type passing --- python/pyarrow/_compute.pyx | 46 ++++++----- python/pyarrow/src/arrow/python/udf.cc | 2 +- python/pyarrow/tests/test_udf.py | 108 +++++++++++++++---------- 3 files changed, 90 insertions(+), 66 deletions(-) diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index 45769c0fa88..b41a4bb105a 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -2592,7 +2592,7 @@ def _get_scalar_udf_context(memory_pool, batch_length): def register_scalar_function(func, function_name, function_doc, in_arg_types, - out_types): + in_arg_names, out_types): """ Register a user-defined scalar function. @@ -2624,13 +2624,13 @@ def register_scalar_function(func, function_name, function_doc, in_arg_types, function_doc : dict A dictionary object with keys "summary" (str), and "description" (str). - in_arg_types : List[Dict[str, DataType]] - A list of dictionary mapping function argument names to - their respective DataType. - The argument names will be used to generate - documentation for the function. The number of - arguments specified here determines the function - arity. + in_arg_types : List[List[DataType]] + A list of list of DataTypes which includes input types for + each kernel. The number of arguments specified here + determines the function arity. + in_arg_names: List[str] + A list of str which contains the names of the arguments used to + generate the function documentation. out_types : List[DataType] A list of output types of the function. Corresponding to the input types, the output type of @@ -2649,10 +2649,11 @@ def register_scalar_function(func, function_name, function_doc, in_arg_types, ... return pc.add(array, 1, memory_pool=ctx.memory_pool) >>> >>> func_name = "py_add_func" - >>> in_types = [{"array": pa.int64()}] + >>> in_types = [[pa.int64()]] + >>> in_names = ["array"] >>> out_type = [pa.int64()] >>> pc.register_scalar_function(add_constant, func_name, func_doc, - ... in_types, out_type) + ... in_types, in_names, out_type) >>> >>> func = pc.get_function(func_name) >>> func.name @@ -2685,26 +2686,27 @@ def register_scalar_function(func, function_name, function_doc, in_arg_types, num_args = -1 if not isinstance(in_arg_types, list): raise TypeError( - "in_arg_types must be a list of dictionaries of DataType") + "in_arg_types must be a list of list of DataType") + if not isinstance(in_arg_names, list): + raise TypeError( + "in_arg_names must be a list of str") if not isinstance(out_types, list): raise TypeError("out_types must be a list of DataType") - # each input_type dict in input_types list must - # have same arg_names - if isinstance(in_arg_types[0], dict): - function_doc["arg_names"] = in_arg_types[0].keys() - num_args = len(in_arg_types[0]) - else: - raise TypeError( - "Elements in in_arg_types must be a dictionary of DataType") + + function_doc["arg_names"] = in_arg_names + num_args = len(in_arg_names) for in_types in in_arg_types: - if isinstance(in_types, dict): - for in_type in in_types.values(): + if isinstance(in_types, list): + if len(in_arg_names) != len(in_types): + raise ValueError( + "in_arg_names and input types per kernel must contain same number of elements") + for in_type in in_types: c_in_types.push_back( pyarrow_unwrap_data_type(ensure_type(in_type))) else: raise TypeError( - "Elements in in_arg_types must be a dictionary of DataType") + "Elements in in_arg_types must be a list of DataType") vec_c_in_types.push_back(move(c_in_types)) c_arity = CArity( num_args, func_spec.varargs) diff --git a/python/pyarrow/src/arrow/python/udf.cc b/python/pyarrow/src/arrow/python/udf.cc index 5ac20a9ac11..c105107826a 100644 --- a/python/pyarrow/src/arrow/python/udf.cc +++ b/python/pyarrow/src/arrow/python/udf.cc @@ -113,7 +113,7 @@ Status RegisterScalarFunction(PyObject* user_function, ScalarUdfWrapperCallback for (const auto& in_dtype : opt_input_types) { input_types.emplace_back(in_dtype); } - const auto opts_out_type = options.output_types[idx]; + const auto& opts_out_type = options.output_types[idx]; compute::OutputType output_type(opts_out_type); auto udf_data = std::make_shared( wrapper, std::make_shared(user_function), opts_out_type); diff --git a/python/pyarrow/tests/test_udf.py b/python/pyarrow/tests/test_udf.py index 6bb023eed7b..b898b7ab5f0 100644 --- a/python/pyarrow/tests/test_udf.py +++ b/python/pyarrow/tests/test_udf.py @@ -54,7 +54,8 @@ def unary_function(ctx, x): pc.register_scalar_function(unary_function, func_name, unary_doc, - [{"array": pa.int64()}], + [[pa.int64()]], + ["array"], [pa.int64()]) return unary_function, func_name @@ -74,10 +75,11 @@ def binary_function(ctx, m, x): func_name, binary_doc, [ - {"m": pa.int64(), - "x": pa.int64(), - } + [pa.int64(), + pa.int64(), + ] ], + ["m", "x"], [pa.int64()]) return binary_function, func_name @@ -98,11 +100,14 @@ def ternary_function(ctx, m, x, c): pc.register_scalar_function(ternary_function, func_name, ternary_doc, - [{ - "array1": pa.int64(), - "array2": pa.int64(), - "array3": pa.int64(), - }], + [ + [ + pa.int64(), + pa.int64(), + pa.int64(), + ] + ], + ["array1", "array2", "array3"], [pa.int64()]) return ternary_function, func_name @@ -125,10 +130,13 @@ def varargs_function(ctx, first, *values): pc.register_scalar_function(varargs_function, func_name, varargs_doc, - [{ - "array1": pa.int64(), - "array2": pa.int64(), - }], + [ + [ + pa.int64(), + pa.int64(), + ] + ], + ["array1", "array2"], [pa.int64()]) return varargs_function, func_name @@ -150,7 +158,8 @@ def nullary_func(context): pc.register_scalar_function(nullary_func, func_name, func_doc, - [{}], + [[]], + [], [pa.int64()]) return nullary_func, func_name @@ -166,14 +175,15 @@ def wrong_output_type(ctx): return 42 func_name = "test_wrong_output_type" - in_types = [{}] + in_types = [[]] + in_names = [] out_types = [pa.int64()] doc = { "summary": "return wrong output type", "description": "" } pc.register_scalar_function(wrong_output_type, func_name, doc, - in_types, out_types) + in_types, in_names, out_types) return wrong_output_type, func_name @@ -186,7 +196,8 @@ def wrong_output_datatype_func_fixture(): def wrong_output_datatype(ctx, array): return pc.call_function("add", [array, 1]) func_name = "test_wrong_output_datatype" - in_types = [{"array": pa.int64()}] + in_types = [[pa.int64()]] + in_names = ["array"] # The actual output DataType will be int64. out_types = [pa.int16()] doc = { @@ -194,7 +205,7 @@ def wrong_output_datatype(ctx, array): "description": "" } pc.register_scalar_function(wrong_output_datatype, func_name, doc, - in_types, out_types) + in_types, in_names, out_types) return wrong_output_datatype, func_name @@ -208,14 +219,15 @@ def wrong_signature(): return pa.scalar(1, type=pa.int64()) func_name = "test_wrong_signature" - in_types = [{}] + in_types = [[]] + in_names = [] out_types = [pa.int64()] doc = { "summary": "UDF with wrong signature", "description": "" } pc.register_scalar_function(wrong_signature, func_name, doc, - in_types, out_types) + in_types, in_names, out_types) return wrong_signature, func_name @@ -232,7 +244,7 @@ def raising_func(ctx): "description": "" } pc.register_scalar_function(raising_func, func_name, doc, - [{}], [pa.int64()]) + [[]], [], [pa.int64()]) return raising_func, func_name @@ -313,7 +325,8 @@ def test_registration_errors(): "summary": "test udf input", "description": "parameters are validated" } - in_types = [{"scalar": pa.int64()}] + in_types = [[pa.int64()]] + in_names = ["scalar"] out_types = [pa.int64()] def test_reg_function(context): @@ -327,33 +340,33 @@ def test_reg_function(context): # validate function with pytest.raises(TypeError, match="func must be a callable"): pc.register_scalar_function(None, "test_none_function", doc, in_types, - out_types) + in_names, out_types) # validate output type expected_expr = "out_types must be a list of DataType" with pytest.raises(TypeError, match=expected_expr): pc.register_scalar_function(test_reg_function, "test_output_function", doc, in_types, - None) + in_names, None) # validate input type - expected_expr = "in_arg_types must be a list of dictionaries of DataType" + expected_expr = "in_arg_types must be a list of DataType" with pytest.raises(TypeError, match=expected_expr): pc.register_scalar_function(test_reg_function, "test_input_function", doc, None, - out_types) + in_names, out_types) # register an already registered function # first registration pc.register_scalar_function(test_reg_function, - "test_reg_function", doc, [{}], + "test_reg_function", doc, [[]], [], out_types) # second registration expected_expr = "Already have a function registered with name:" \ + " test_reg_function" with pytest.raises(KeyError, match=expected_expr): pc.register_scalar_function(test_reg_function, - "test_reg_function", doc, [{}], + "test_reg_function", doc, [[]], [], out_types) @@ -368,7 +381,8 @@ def test_varargs_function_validation(varargs_func_fixture): def test_function_doc_validation(): # validate arity - in_types = [{"scalar": pa.int64()}] + in_types = [[pa.int64()]] + in_names = ["scalar"] out_type = [pa.int64()] # doc with no summary @@ -382,7 +396,7 @@ def add_const(ctx, scalar): with pytest.raises(ValueError, match="Function doc must contain a summary"): pc.register_scalar_function(add_const, "test_no_summary", - func_doc, in_types, + func_doc, in_types, in_names, out_type) # doc with no decription @@ -393,7 +407,7 @@ def add_const(ctx, scalar): with pytest.raises(ValueError, match="Function doc must contain a description"): pc.register_scalar_function(add_const, "test_no_desc", - func_doc, in_types, + func_doc, in_types, in_names, out_type) @@ -437,7 +451,8 @@ def identity(ctx, val): return val func_name = "test_wrong_datatype_declaration" - in_types = [{"array": pa.int64()}] + in_types = [[pa.int64()]] + in_names = ["array"] out_type = [{}] doc = { "summary": "test output value", @@ -446,7 +461,7 @@ def identity(ctx, val): with pytest.raises(TypeError, match="DataType expected, got "): pc.register_scalar_function(identity, func_name, - doc, in_types, out_type) + doc, in_types, in_names, out_type) def test_wrong_input_type_declaration(): @@ -454,7 +469,8 @@ def identity(ctx, val): return val func_name = "test_wrong_input_type_declaration" - in_types = [{"array": None}] + in_types = [[None]] + in_names = ["array"] out_types = [pa.int64()] doc = { "summary": "test invalid input type", @@ -463,7 +479,7 @@ def identity(ctx, val): with pytest.raises(TypeError, match="DataType expected, got "): pc.register_scalar_function(identity, func_name, doc, - in_types, out_types) + in_types, in_names, out_types) def test_udf_context(unary_func_fixture): @@ -516,14 +532,16 @@ def unary_function(ctx, x): unary_doc = {"summary": "multiply by two function", "description": "test multiply function"} input_types = [ - {"array": pa.int8()}, - {"array": pa.int16()}, - {"array": pa.int32()}, - {"array": pa.int64()}, - {"array": pa.float32()}, - {"array": pa.float64()} + [pa.int8()], + [pa.int16()], + [pa.int32()], + [pa.int64()], + [pa.float32()], + [pa.float64()] ] + input_names = ["array"] + output_types = [ pa.int8(), pa.int16(), @@ -536,6 +554,7 @@ def unary_function(ctx, x): func_name, unary_doc, input_types, + input_names, output_types) for out_type in output_types: @@ -551,10 +570,12 @@ def unary_function(ctx, x): unary_doc = {"summary": "pass value function", "description": "test function"} input_types = [ - {"array": pa.int8()}, - {"array": pa.int16()} + [pa.int8()], + [pa.int16()] ] + input_names = ["array"] + output_types = [ pa.int8(), pa.int16(), @@ -566,4 +587,5 @@ def unary_function(ctx, x): func_name, unary_doc, input_types, + input_names, output_types) From 4ea2ed2ae3f6af93c4dbe68b89acc03f72ac7e59 Mon Sep 17 00:00:00 2001 From: vibhatha Date: Tue, 3 Jan 2023 23:11:51 +0530 Subject: [PATCH 09/13] feat(test): adding test case for mismatch input types and args --- python/pyarrow/tests/test_udf.py | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/python/pyarrow/tests/test_udf.py b/python/pyarrow/tests/test_udf.py index b898b7ab5f0..c94cdf7f389 100644 --- a/python/pyarrow/tests/test_udf.py +++ b/python/pyarrow/tests/test_udf.py @@ -325,8 +325,8 @@ def test_registration_errors(): "summary": "test udf input", "description": "parameters are validated" } - in_types = [[pa.int64()]] - in_names = ["scalar"] + in_types = [] + in_names = [] out_types = [pa.int64()] def test_reg_function(context): @@ -334,7 +334,7 @@ def test_reg_function(context): with pytest.raises(TypeError): pc.register_scalar_function(test_reg_function, - None, doc, in_types, + None, doc, in_types, in_names, out_types) # validate function @@ -369,6 +369,25 @@ def test_reg_function(context): "test_reg_function", doc, [[]], [], out_types) + # mismatching input_types and input_names + + doc = { + "summary": "test udf input types and names", + "description": "parameters are validated" + } + in_types = [[pa.int64()]] + in_names = ["a1", "a2"] + out_types = [pa.int64()] + + def test_inputs_function(context, a1): + return pc.add(a1, a1) + expected_expr = "in_arg_names and input types per kernel must contain same " \ + + "number of elements" + with pytest.raises(ValueError, match=expected_expr): + pc.register_scalar_function(test_inputs_function, + "test_inputs_function", doc, in_types, in_names, + out_types) + def test_varargs_function_validation(varargs_func_fixture): _, func_name = varargs_func_fixture From 73c3d06efa58a6cb6af046a5b1a55fb2f074b4df Mon Sep 17 00:00:00 2001 From: vibhatha Date: Tue, 3 Jan 2023 23:30:25 +0530 Subject: [PATCH 10/13] fix(format) --- python/pyarrow/tests/test_udf.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/pyarrow/tests/test_udf.py b/python/pyarrow/tests/test_udf.py index c94cdf7f389..c1f26170aff 100644 --- a/python/pyarrow/tests/test_udf.py +++ b/python/pyarrow/tests/test_udf.py @@ -381,12 +381,12 @@ def test_reg_function(context): def test_inputs_function(context, a1): return pc.add(a1, a1) - expected_expr = "in_arg_names and input types per kernel must contain same " \ - + "number of elements" + expected_expr = "in_arg_names and input types per kernel " \ + + "must contain same number of elements" with pytest.raises(ValueError, match=expected_expr): pc.register_scalar_function(test_inputs_function, - "test_inputs_function", doc, in_types, in_names, - out_types) + "test_inputs_function", doc, + in_types, in_names, out_types) def test_varargs_function_validation(varargs_func_fixture): From e6bd3bcfced504845a56196a1471bc1397e1bc41 Mon Sep 17 00:00:00 2001 From: vibhatha Date: Tue, 3 Jan 2023 23:32:02 +0530 Subject: [PATCH 11/13] fix(format): v2 --- python/pyarrow/tests/test_udf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyarrow/tests/test_udf.py b/python/pyarrow/tests/test_udf.py index c1f26170aff..608aa411d4f 100644 --- a/python/pyarrow/tests/test_udf.py +++ b/python/pyarrow/tests/test_udf.py @@ -325,7 +325,7 @@ def test_registration_errors(): "summary": "test udf input", "description": "parameters are validated" } - in_types = [] + in_types = [[]] in_names = [] out_types = [pa.int64()] From 37df214547e7ba3bc2c8e6aa1c9184e9a2588608 Mon Sep 17 00:00:00 2001 From: vibhatha Date: Wed, 4 Jan 2023 06:55:40 +0530 Subject: [PATCH 12/13] fix(test): error message updated --- python/pyarrow/tests/test_udf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyarrow/tests/test_udf.py b/python/pyarrow/tests/test_udf.py index 608aa411d4f..96b6b751d9e 100644 --- a/python/pyarrow/tests/test_udf.py +++ b/python/pyarrow/tests/test_udf.py @@ -350,7 +350,7 @@ def test_reg_function(context): in_names, None) # validate input type - expected_expr = "in_arg_types must be a list of DataType" + expected_expr = "in_arg_types must be a list of list of DataType" with pytest.raises(TypeError, match=expected_expr): pc.register_scalar_function(test_reg_function, "test_input_function", doc, None, From 96842d832dde72ab2e0e028d2c89efde808f8b5c Mon Sep 17 00:00:00 2001 From: vibhatha Date: Wed, 4 Jan 2023 06:57:44 +0530 Subject: [PATCH 13/13] fix(format) --- python/pyarrow/src/arrow/python/udf.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyarrow/src/arrow/python/udf.cc b/python/pyarrow/src/arrow/python/udf.cc index c105107826a..9d819ad2308 100644 --- a/python/pyarrow/src/arrow/python/udf.cc +++ b/python/pyarrow/src/arrow/python/udf.cc @@ -107,7 +107,7 @@ Status RegisterScalarFunction(PyObject* user_function, ScalarUdfWrapperCallback return Status::Invalid("input_arg_types and output_types should be equal in size"); } // adding kernels - for(size_t idx=0 ; idx < num_kernels; idx++) { + for(size_t idx=0; idx < num_kernels; idx++) { const auto& opt_input_types = options.input_arg_types[idx]; std::vector input_types; for (const auto& in_dtype : opt_input_types) {