Skip to content
Closed
73 changes: 46 additions & 27 deletions python/pyarrow/_compute.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -2591,8 +2591,8 @@ def _get_scalar_udf_context(memory_pool, batch_length):
return context


def register_scalar_function(func, function_name, function_doc, in_types,
out_type):
def register_scalar_function(func, function_name, function_doc, in_arg_types,
in_arg_names, out_types):
"""
Register a user-defined scalar function.

Expand Down Expand Up @@ -2624,15 +2624,17 @@ def register_scalar_function(func, function_name, function_doc, in_types,
function_doc : dict
A dictionary object with keys "summary" (str),
and "description" (str).
in_types : Dict[str, DataType]
A dictionary mapping function argument names to
their respective DataType.
The argument names will be used to generate
documentation for the function. The number of
arguments specified here determines the function
arity.
out_type : DataType
Output type of the function.
in_arg_types : List[List[DataType]]
A list of list of DataTypes which includes input types for
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this okay? list of list of usage...

each kernel. The number of arguments specified here
determines the function arity.
in_arg_names: List[str]
A list of str which contains the names of the arguments used to
generate the function documentation.
out_types : List[DataType]
A list of output types of the function.
Corresponding to the input types, the output type of
the function can be varied.

Examples
--------
Expand All @@ -2647,10 +2649,11 @@ def register_scalar_function(func, function_name, function_doc, in_types,
... return pc.add(array, 1, memory_pool=ctx.memory_pool)
>>>
>>> func_name = "py_add_func"
>>> in_types = {"array": pa.int64()}
>>> out_type = pa.int64()
>>> in_types = [[pa.int64()]]
>>> in_names = ["array"]
>>> out_type = [pa.int64()]
>>> pc.register_scalar_function(add_constant, func_name, func_doc,
... in_types, out_type)
... in_types, in_names, out_type)
>>>
>>> func = pc.get_function(func_name)
>>> func.name
Expand All @@ -2666,9 +2669,10 @@ def register_scalar_function(func, function_name, function_doc, in_types,
c_string c_func_name
CArity c_arity
CFunctionDoc c_func_doc
vector[vector[shared_ptr[CDataType]]] vec_c_in_types
vector[shared_ptr[CDataType]] c_in_types
PyObject* c_function
shared_ptr[CDataType] c_out_type
vector[shared_ptr[CDataType]] c_out_types
CScalarUdfOptions c_options

if callable(func):
Expand All @@ -2680,15 +2684,30 @@ def register_scalar_function(func, function_name, function_doc, in_types,

func_spec = inspect.getfullargspec(func)
num_args = -1
if isinstance(in_types, dict):
for in_type in in_types.values():
c_in_types.push_back(
pyarrow_unwrap_data_type(ensure_type(in_type)))
function_doc["arg_names"] = in_types.keys()
num_args = len(in_types)
else:
if not isinstance(in_arg_types, list):
raise TypeError(
"in_arg_types must be a list of list of DataType")
if not isinstance(in_arg_names, list):
raise TypeError(
"in_types must be a dictionary of DataType")
"in_arg_names must be a list of str")
if not isinstance(out_types, list):
raise TypeError("out_types must be a list of DataType")

function_doc["arg_names"] = in_arg_names
num_args = len(in_arg_names)

for in_types in in_arg_types:
if isinstance(in_types, list):
if len(in_arg_names) != len(in_types):
raise ValueError(
"in_arg_names and input types per kernel must contain same number of elements")
for in_type in in_types:
c_in_types.push_back(
pyarrow_unwrap_data_type(ensure_type(in_type)))
else:
raise TypeError(
"Elements in in_arg_types must be a list of DataType")
vec_c_in_types.push_back(move(c_in_types))

c_arity = CArity(<int> num_args, func_spec.varargs)

Expand All @@ -2702,14 +2721,14 @@ def register_scalar_function(func, function_name, function_doc, in_types,
raise ValueError("Function doc must contain arg_names")

c_func_doc = _make_function_doc(function_doc)

c_out_type = pyarrow_unwrap_data_type(ensure_type(out_type))
for out_type in out_types:
c_out_types.push_back(pyarrow_unwrap_data_type(ensure_type(out_type)))

c_options.func_name = c_func_name
c_options.arity = c_arity
c_options.func_doc = c_func_doc
c_options.input_types = c_in_types
c_options.output_type = c_out_type
c_options.input_arg_types = vec_c_in_types
c_options.output_types = c_out_types

check_status(RegisterScalarFunction(c_function,
<function[CallbackUdf]> &_scalar_udf_callback, c_options))
4 changes: 2 additions & 2 deletions python/pyarrow/includes/libarrow.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -2814,8 +2814,8 @@ cdef extern from "arrow/python/udf.h" namespace "arrow::py":
c_string func_name
CArity arity
CFunctionDoc func_doc
vector[shared_ptr[CDataType]] input_types
shared_ptr[CDataType] output_type
vector[vector[shared_ptr[CDataType]]] input_arg_types
vector[shared_ptr[CDataType]] output_types

CStatus RegisterScalarFunction(PyObject* function,
function[CallbackUdf] wrapper, const CScalarUdfOptions& options)
41 changes: 26 additions & 15 deletions python/pyarrow/src/arrow/python/udf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,22 +99,33 @@ Status RegisterScalarFunction(PyObject* user_function, ScalarUdfWrapperCallback
auto scalar_func = std::make_shared<compute::ScalarFunction>(
options.func_name, options.arity, options.func_doc);
Py_INCREF(user_function);
std::vector<compute::InputType> input_types;
for (const auto& in_dtype : options.input_types) {
input_types.emplace_back(in_dtype);

const size_t num_kernels = options.input_arg_types.size();
// number of input_type variations and output_types must be
// equal in size
if(num_kernels != options.output_types.size()) {
return Status::Invalid("input_arg_types and output_types should be equal in size");
}
// adding kernels
for(size_t idx=0; idx < num_kernels; idx++) {
const auto& opt_input_types = options.input_arg_types[idx];
std::vector<compute::InputType> input_types;
for (const auto& in_dtype : opt_input_types) {
input_types.emplace_back(in_dtype);
}
const auto& opts_out_type = options.output_types[idx];
compute::OutputType output_type(opts_out_type);
auto udf_data = std::make_shared<PythonUdf>(
wrapper, std::make_shared<OwnedRefNoGIL>(user_function), opts_out_type);
compute::ScalarKernel kernel(
compute::KernelSignature::Make(std::move(input_types), std::move(output_type),
options.arity.is_varargs), PythonUdfExec);
kernel.data = std::move(udf_data);

kernel.mem_allocation = compute::MemAllocation::NO_PREALLOCATE;
kernel.null_handling = compute::NullHandling::COMPUTED_NO_PREALLOCATE;
RETURN_NOT_OK(scalar_func->AddKernel(std::move(kernel)));
}
compute::OutputType output_type(options.output_type);
auto udf_data = std::make_shared<PythonUdf>(
wrapper, std::make_shared<OwnedRefNoGIL>(user_function), options.output_type);
compute::ScalarKernel kernel(
compute::KernelSignature::Make(std::move(input_types), std::move(output_type),
options.arity.is_varargs),
PythonUdfExec);
kernel.data = std::move(udf_data);

kernel.mem_allocation = compute::MemAllocation::NO_PREALLOCATE;
kernel.null_handling = compute::NullHandling::COMPUTED_NO_PREALLOCATE;
RETURN_NOT_OK(scalar_func->AddKernel(std::move(kernel)));
auto registry = compute::GetFunctionRegistry();
RETURN_NOT_OK(registry->AddFunction(std::move(scalar_func)));
return Status::OK();
Expand Down
4 changes: 2 additions & 2 deletions python/pyarrow/src/arrow/python/udf.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ struct ARROW_PYTHON_EXPORT ScalarUdfOptions {
std::string func_name;
compute::Arity arity;
compute::FunctionDoc func_doc;
std::vector<std::shared_ptr<DataType>> input_types;
std::shared_ptr<DataType> output_type;
std::vector<std::vector<std::shared_ptr<DataType>>> input_arg_types;
std::vector<std::shared_ptr<DataType>> output_types;
};

/// \brief A context passed as the first argument of scalar UDF functions.
Expand Down
Loading