diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index c75c5bf189b..b41a4bb105a 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -2591,8 +2591,8 @@ def _get_scalar_udf_context(memory_pool, batch_length): return context -def register_scalar_function(func, function_name, function_doc, in_types, - out_type): +def register_scalar_function(func, function_name, function_doc, in_arg_types, + in_arg_names, out_types): """ Register a user-defined scalar function. @@ -2624,15 +2624,17 @@ def register_scalar_function(func, function_name, function_doc, in_types, function_doc : dict A dictionary object with keys "summary" (str), and "description" (str). - in_types : Dict[str, DataType] - A dictionary mapping function argument names to - their respective DataType. - The argument names will be used to generate - documentation for the function. The number of - arguments specified here determines the function - arity. - out_type : DataType - Output type of the function. + in_arg_types : List[List[DataType]] + A list of list of DataTypes which includes input types for + each kernel. The number of arguments specified here + determines the function arity. + in_arg_names: List[str] + A list of str which contains the names of the arguments used to + generate the function documentation. + out_types : List[DataType] + A list of output types of the function. + Corresponding to the input types, the output type of + the function can be varied. Examples -------- @@ -2647,10 +2649,11 @@ def register_scalar_function(func, function_name, function_doc, in_types, ... return pc.add(array, 1, memory_pool=ctx.memory_pool) >>> >>> func_name = "py_add_func" - >>> in_types = {"array": pa.int64()} - >>> out_type = pa.int64() + >>> in_types = [[pa.int64()]] + >>> in_names = ["array"] + >>> out_type = [pa.int64()] >>> pc.register_scalar_function(add_constant, func_name, func_doc, - ... in_types, out_type) + ... in_types, in_names, out_type) >>> >>> func = pc.get_function(func_name) >>> func.name @@ -2666,9 +2669,10 @@ def register_scalar_function(func, function_name, function_doc, in_types, c_string c_func_name CArity c_arity CFunctionDoc c_func_doc + vector[vector[shared_ptr[CDataType]]] vec_c_in_types vector[shared_ptr[CDataType]] c_in_types PyObject* c_function - shared_ptr[CDataType] c_out_type + vector[shared_ptr[CDataType]] c_out_types CScalarUdfOptions c_options if callable(func): @@ -2680,15 +2684,30 @@ def register_scalar_function(func, function_name, function_doc, in_types, func_spec = inspect.getfullargspec(func) num_args = -1 - if isinstance(in_types, dict): - for in_type in in_types.values(): - c_in_types.push_back( - pyarrow_unwrap_data_type(ensure_type(in_type))) - function_doc["arg_names"] = in_types.keys() - num_args = len(in_types) - else: + if not isinstance(in_arg_types, list): + raise TypeError( + "in_arg_types must be a list of list of DataType") + if not isinstance(in_arg_names, list): raise TypeError( - "in_types must be a dictionary of DataType") + "in_arg_names must be a list of str") + if not isinstance(out_types, list): + raise TypeError("out_types must be a list of DataType") + + function_doc["arg_names"] = in_arg_names + num_args = len(in_arg_names) + + for in_types in in_arg_types: + if isinstance(in_types, list): + if len(in_arg_names) != len(in_types): + raise ValueError( + "in_arg_names and input types per kernel must contain same number of elements") + for in_type in in_types: + c_in_types.push_back( + pyarrow_unwrap_data_type(ensure_type(in_type))) + else: + raise TypeError( + "Elements in in_arg_types must be a list of DataType") + vec_c_in_types.push_back(move(c_in_types)) c_arity = CArity( num_args, func_spec.varargs) @@ -2702,14 +2721,14 @@ def register_scalar_function(func, function_name, function_doc, in_types, raise ValueError("Function doc must contain arg_names") c_func_doc = _make_function_doc(function_doc) - - c_out_type = pyarrow_unwrap_data_type(ensure_type(out_type)) + for out_type in out_types: + c_out_types.push_back(pyarrow_unwrap_data_type(ensure_type(out_type))) c_options.func_name = c_func_name c_options.arity = c_arity c_options.func_doc = c_func_doc - c_options.input_types = c_in_types - c_options.output_type = c_out_type + c_options.input_arg_types = vec_c_in_types + c_options.output_types = c_out_types check_status(RegisterScalarFunction(c_function, &_scalar_udf_callback, c_options)) diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index df6a883afe9..326bef9063d 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -2814,8 +2814,8 @@ cdef extern from "arrow/python/udf.h" namespace "arrow::py": c_string func_name CArity arity CFunctionDoc func_doc - vector[shared_ptr[CDataType]] input_types - shared_ptr[CDataType] output_type + vector[vector[shared_ptr[CDataType]]] input_arg_types + vector[shared_ptr[CDataType]] output_types CStatus RegisterScalarFunction(PyObject* function, function[CallbackUdf] wrapper, const CScalarUdfOptions& options) diff --git a/python/pyarrow/src/arrow/python/udf.cc b/python/pyarrow/src/arrow/python/udf.cc index 81bf47c0ade..9d819ad2308 100644 --- a/python/pyarrow/src/arrow/python/udf.cc +++ b/python/pyarrow/src/arrow/python/udf.cc @@ -99,22 +99,33 @@ Status RegisterScalarFunction(PyObject* user_function, ScalarUdfWrapperCallback auto scalar_func = std::make_shared( options.func_name, options.arity, options.func_doc); Py_INCREF(user_function); - std::vector input_types; - for (const auto& in_dtype : options.input_types) { - input_types.emplace_back(in_dtype); + + const size_t num_kernels = options.input_arg_types.size(); + // number of input_type variations and output_types must be + // equal in size + if(num_kernels != options.output_types.size()) { + return Status::Invalid("input_arg_types and output_types should be equal in size"); + } + // adding kernels + for(size_t idx=0; idx < num_kernels; idx++) { + const auto& opt_input_types = options.input_arg_types[idx]; + std::vector input_types; + for (const auto& in_dtype : opt_input_types) { + input_types.emplace_back(in_dtype); + } + const auto& opts_out_type = options.output_types[idx]; + compute::OutputType output_type(opts_out_type); + auto udf_data = std::make_shared( + wrapper, std::make_shared(user_function), opts_out_type); + compute::ScalarKernel kernel( + compute::KernelSignature::Make(std::move(input_types), std::move(output_type), + options.arity.is_varargs), PythonUdfExec); + kernel.data = std::move(udf_data); + + kernel.mem_allocation = compute::MemAllocation::NO_PREALLOCATE; + kernel.null_handling = compute::NullHandling::COMPUTED_NO_PREALLOCATE; + RETURN_NOT_OK(scalar_func->AddKernel(std::move(kernel))); } - compute::OutputType output_type(options.output_type); - auto udf_data = std::make_shared( - wrapper, std::make_shared(user_function), options.output_type); - compute::ScalarKernel kernel( - compute::KernelSignature::Make(std::move(input_types), std::move(output_type), - options.arity.is_varargs), - PythonUdfExec); - kernel.data = std::move(udf_data); - - kernel.mem_allocation = compute::MemAllocation::NO_PREALLOCATE; - kernel.null_handling = compute::NullHandling::COMPUTED_NO_PREALLOCATE; - RETURN_NOT_OK(scalar_func->AddKernel(std::move(kernel))); auto registry = compute::GetFunctionRegistry(); RETURN_NOT_OK(registry->AddFunction(std::move(scalar_func))); return Status::OK(); diff --git a/python/pyarrow/src/arrow/python/udf.h b/python/pyarrow/src/arrow/python/udf.h index 9a3666459fd..2952d890a5d 100644 --- a/python/pyarrow/src/arrow/python/udf.h +++ b/python/pyarrow/src/arrow/python/udf.h @@ -37,8 +37,8 @@ struct ARROW_PYTHON_EXPORT ScalarUdfOptions { std::string func_name; compute::Arity arity; compute::FunctionDoc func_doc; - std::vector> input_types; - std::shared_ptr output_type; + std::vector>> input_arg_types; + std::vector> output_types; }; /// \brief A context passed as the first argument of scalar UDF functions. diff --git a/python/pyarrow/tests/test_udf.py b/python/pyarrow/tests/test_udf.py index e711619582d..96b6b751d9e 100644 --- a/python/pyarrow/tests/test_udf.py +++ b/python/pyarrow/tests/test_udf.py @@ -54,8 +54,9 @@ def unary_function(ctx, x): pc.register_scalar_function(unary_function, func_name, unary_doc, - {"array": pa.int64()}, - pa.int64()) + [[pa.int64()]], + ["array"], + [pa.int64()]) return unary_function, func_name @@ -73,10 +74,13 @@ def binary_function(ctx, m, x): pc.register_scalar_function(binary_function, func_name, binary_doc, - {"m": pa.int64(), - "x": pa.int64(), - }, - pa.int64()) + [ + [pa.int64(), + pa.int64(), + ] + ], + ["m", "x"], + [pa.int64()]) return binary_function, func_name @@ -96,12 +100,15 @@ def ternary_function(ctx, m, x, c): pc.register_scalar_function(ternary_function, func_name, ternary_doc, - { - "array1": pa.int64(), - "array2": pa.int64(), - "array3": pa.int64(), - }, - pa.int64()) + [ + [ + pa.int64(), + pa.int64(), + pa.int64(), + ] + ], + ["array1", "array2", "array3"], + [pa.int64()]) return ternary_function, func_name @@ -123,11 +130,14 @@ def varargs_function(ctx, first, *values): pc.register_scalar_function(varargs_function, func_name, varargs_doc, - { - "array1": pa.int64(), - "array2": pa.int64(), - }, - pa.int64()) + [ + [ + pa.int64(), + pa.int64(), + ] + ], + ["array1", "array2"], + [pa.int64()]) return varargs_function, func_name @@ -148,8 +158,9 @@ def nullary_func(context): pc.register_scalar_function(nullary_func, func_name, func_doc, - {}, - pa.int64()) + [[]], + [], + [pa.int64()]) return nullary_func, func_name @@ -164,14 +175,15 @@ def wrong_output_type(ctx): return 42 func_name = "test_wrong_output_type" - in_types = {} - out_type = pa.int64() + in_types = [[]] + in_names = [] + out_types = [pa.int64()] doc = { "summary": "return wrong output type", "description": "" } pc.register_scalar_function(wrong_output_type, func_name, doc, - in_types, out_type) + in_types, in_names, out_types) return wrong_output_type, func_name @@ -184,15 +196,16 @@ def wrong_output_datatype_func_fixture(): def wrong_output_datatype(ctx, array): return pc.call_function("add", [array, 1]) func_name = "test_wrong_output_datatype" - in_types = {"array": pa.int64()} + in_types = [[pa.int64()]] + in_names = ["array"] # The actual output DataType will be int64. - out_type = pa.int16() + out_types = [pa.int16()] doc = { "summary": "return wrong output datatype", "description": "" } pc.register_scalar_function(wrong_output_datatype, func_name, doc, - in_types, out_type) + in_types, in_names, out_types) return wrong_output_datatype, func_name @@ -206,14 +219,15 @@ def wrong_signature(): return pa.scalar(1, type=pa.int64()) func_name = "test_wrong_signature" - in_types = {} - out_type = pa.int64() + in_types = [[]] + in_names = [] + out_types = [pa.int64()] doc = { "summary": "UDF with wrong signature", "description": "" } pc.register_scalar_function(wrong_signature, func_name, doc, - in_types, out_type) + in_types, in_names, out_types) return wrong_signature, func_name @@ -230,7 +244,7 @@ def raising_func(ctx): "description": "" } pc.register_scalar_function(raising_func, func_name, doc, - {}, pa.int64()) + [[]], [], [pa.int64()]) return raising_func, func_name @@ -311,48 +325,68 @@ def test_registration_errors(): "summary": "test udf input", "description": "parameters are validated" } - in_types = {"scalar": pa.int64()} - out_type = pa.int64() + in_types = [[]] + in_names = [] + out_types = [pa.int64()] def test_reg_function(context): return pa.array([10]) with pytest.raises(TypeError): pc.register_scalar_function(test_reg_function, - None, doc, in_types, - out_type) + None, doc, in_types, in_names, + out_types) # validate function with pytest.raises(TypeError, match="func must be a callable"): pc.register_scalar_function(None, "test_none_function", doc, in_types, - out_type) + in_names, out_types) # validate output type - expected_expr = "DataType expected, got " + expected_expr = "out_types must be a list of DataType" with pytest.raises(TypeError, match=expected_expr): pc.register_scalar_function(test_reg_function, "test_output_function", doc, in_types, - None) + in_names, None) # validate input type - expected_expr = "in_types must be a dictionary of DataType" + expected_expr = "in_arg_types must be a list of list of DataType" with pytest.raises(TypeError, match=expected_expr): pc.register_scalar_function(test_reg_function, "test_input_function", doc, None, - out_type) + in_names, out_types) # register an already registered function # first registration pc.register_scalar_function(test_reg_function, - "test_reg_function", doc, {}, - out_type) + "test_reg_function", doc, [[]], [], + out_types) # second registration expected_expr = "Already have a function registered with name:" \ + " test_reg_function" with pytest.raises(KeyError, match=expected_expr): pc.register_scalar_function(test_reg_function, - "test_reg_function", doc, {}, - out_type) + "test_reg_function", doc, [[]], [], + out_types) + + # mismatching input_types and input_names + + doc = { + "summary": "test udf input types and names", + "description": "parameters are validated" + } + in_types = [[pa.int64()]] + in_names = ["a1", "a2"] + out_types = [pa.int64()] + + def test_inputs_function(context, a1): + return pc.add(a1, a1) + expected_expr = "in_arg_names and input types per kernel " \ + + "must contain same number of elements" + with pytest.raises(ValueError, match=expected_expr): + pc.register_scalar_function(test_inputs_function, + "test_inputs_function", doc, + in_types, in_names, out_types) def test_varargs_function_validation(varargs_func_fixture): @@ -366,8 +400,9 @@ def test_varargs_function_validation(varargs_func_fixture): def test_function_doc_validation(): # validate arity - in_types = {"scalar": pa.int64()} - out_type = pa.int64() + in_types = [[pa.int64()]] + in_names = ["scalar"] + out_type = [pa.int64()] # doc with no summary func_doc = { @@ -380,7 +415,7 @@ def add_const(ctx, scalar): with pytest.raises(ValueError, match="Function doc must contain a summary"): pc.register_scalar_function(add_const, "test_no_summary", - func_doc, in_types, + func_doc, in_types, in_names, out_type) # doc with no decription @@ -391,12 +426,12 @@ def add_const(ctx, scalar): with pytest.raises(ValueError, match="Function doc must contain a description"): pc.register_scalar_function(add_const, "test_no_desc", - func_doc, in_types, + func_doc, in_types, in_names, out_type) def test_nullary_function(nullary_func_fixture): - # XXX the Python compute layer API doesn't let us override batch_length, + # the Python compute layer API doesn't let us override batch_length, # so only test with the default value of 1. check_scalar_function(nullary_func_fixture, [], run_in_dataset=False, batch_length=1) @@ -435,8 +470,9 @@ def identity(ctx, val): return val func_name = "test_wrong_datatype_declaration" - in_types = {"array": pa.int64()} - out_type = {} + in_types = [[pa.int64()]] + in_names = ["array"] + out_type = [{}] doc = { "summary": "test output value", "description": "test output" @@ -444,7 +480,7 @@ def identity(ctx, val): with pytest.raises(TypeError, match="DataType expected, got "): pc.register_scalar_function(identity, func_name, - doc, in_types, out_type) + doc, in_types, in_names, out_type) def test_wrong_input_type_declaration(): @@ -452,8 +488,9 @@ def identity(ctx, val): return val func_name = "test_wrong_input_type_declaration" - in_types = {"array": None} - out_type = pa.int64() + in_types = [[None]] + in_names = ["array"] + out_types = [pa.int64()] doc = { "summary": "test invalid input type", "description": "invalid input function" @@ -461,7 +498,7 @@ def identity(ctx, val): with pytest.raises(TypeError, match="DataType expected, got "): pc.register_scalar_function(identity, func_name, doc, - in_types, out_type) + in_types, in_names, out_types) def test_udf_context(unary_func_fixture): @@ -504,3 +541,70 @@ def test_input_lifetime(unary_func_fixture): # Calling a UDF should not have kept `v` alive longer than required v = None assert proxy_pool.bytes_allocated() == 0 + + +def test_multi_kernel_registration(): + def unary_function(ctx, x): + return pc.cast(pc.call_function("multiply", [x, 2], + memory_pool=ctx.memory_pool), x.type) + func_name = "y=x*2" + unary_doc = {"summary": "multiply by two function", + "description": "test multiply function"} + input_types = [ + [pa.int8()], + [pa.int16()], + [pa.int32()], + [pa.int64()], + [pa.float32()], + [pa.float64()] + ] + + input_names = ["array"] + + output_types = [ + pa.int8(), + pa.int16(), + pa.int32(), + pa.int64(), + pa.float32(), + pa.float64() + ] + pc.register_scalar_function(unary_function, + func_name, + unary_doc, + input_types, + input_names, + output_types) + + for out_type in output_types: + assert pc.call_function(func_name, + [pa.array([10, 20], out_type)]) \ + == pa.array([20, 40], out_type) + + +def test_invalid_multi_kernel_registration(): + def unary_function(ctx, x): + return x + func_name = "y=x" + unary_doc = {"summary": "pass value function", + "description": "test function"} + input_types = [ + [pa.int8()], + [pa.int16()] + ] + + input_names = ["array"] + + output_types = [ + pa.int8(), + pa.int16(), + pa.int64() + ] + error_msg = "input_arg_types and output_types should be equal in size" + with pytest.raises(pa.lib.ArrowInvalid, match=error_msg): + pc.register_scalar_function(unary_function, + func_name, + unary_doc, + input_types, + input_names, + output_types)