diff --git a/cpp/src/arrow/python/common.h b/cpp/src/arrow/python/common.h index 8560fa2d6f4..24dcb130a26 100644 --- a/cpp/src/arrow/python/common.h +++ b/cpp/src/arrow/python/common.h @@ -185,6 +185,66 @@ class ARROW_PYTHON_EXPORT OwnedRefNoGIL : public OwnedRef { } }; +template +struct BoundFunction; + +template +struct BoundFunction { + // We bind `cdef void fn(object, ...)` to get a `Status(...)` + // where the Status contains any Python error raised by `fn` + using Unbound = void(PyObject*, Args...); + using Bound = Status(Args...); + + BoundFunction(Unbound* unbound, PyObject* bound_arg) + : bound_arg_(bound_arg), unbound_(unbound) {} + + Status Invoke(Args... args) const { + PyAcquireGIL lock; + unbound_(bound_arg_.obj(), std::forward(args)...); + RETURN_IF_PYERROR(); + return Status::OK(); + } + + Unbound* unbound_; + OwnedRefNoGIL bound_arg_; +}; + +template +struct BoundFunction { + // We bind `cdef Return fn(object, ...)` to get a `Result(...)` + // where the Result contains any Python error raised by `fn` or the + // return value from `fn`. + using Unbound = Return(PyObject*, Args...); + using Bound = Result(Args...); + + BoundFunction(Unbound* unbound, PyObject* bound_arg) + : bound_arg_(bound_arg), unbound_(unbound) {} + + Result Invoke(Args... args) const { + PyAcquireGIL lock; + Return ret = unbound_(bound_arg_.obj(), std::forward(args)...); + RETURN_IF_PYERROR(); + return ret; + } + + Unbound* unbound_; + OwnedRefNoGIL bound_arg_; +}; + +template +std::function BindFunction(Return (*unbound)(PyObject*, Args...), + PyObject* bound_arg) { + using Fn = BoundFunction; + + static_assert(std::is_same::value, + "requested bound function of unsupported type"); + + Py_XINCREF(bound_arg); + auto bound_fn = std::make_shared(unbound, bound_arg); + return + [bound_fn](Args... args) { return bound_fn->Invoke(std::forward(args)...); }; +} + // A temporary conversion of a Python object to a bytes area. struct PyBytesView { const char* bytes; diff --git a/dev/archery/archery/cli.py b/dev/archery/archery/cli.py index fefd7b02ed4..7fef9edb4b9 100644 --- a/dev/archery/archery/cli.py +++ b/dev/archery/archery/cli.py @@ -115,7 +115,6 @@ def _apply_options(cmd, options): help="Specify Arrow source directory") # toolchain @cpp_toolchain_options -@java_toolchain_options @click.option("--build-type", default=None, type=build_type, help="CMake's CMAKE_BUILD_TYPE") @click.option("--warn-level", default="production", type=warn_level_type, diff --git a/python/pyarrow/includes/common.pxd b/python/pyarrow/includes/common.pxd index 3f67a3256cc..902eaafbbbd 100644 --- a/python/pyarrow/includes/common.pxd +++ b/python/pyarrow/includes/common.pxd @@ -128,6 +128,7 @@ cdef extern from "arrow/result.h" namespace "arrow" nogil: cdef extern from "arrow/python/common.h" namespace "arrow::py" nogil: T GetResultValue[T](CResult[T]) except * + cdef function[F] BindFunction[F](void* unbound, object bound, ...) cdef inline object PyObject_to_object(PyObject* o): diff --git a/python/pyarrow/tests/bound_function_visit_strings.pyx b/python/pyarrow/tests/bound_function_visit_strings.pyx new file mode 100644 index 00000000000..90437be8cde --- /dev/null +++ b/python/pyarrow/tests/bound_function_visit_strings.pyx @@ -0,0 +1,68 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# distutils: language=c++ +# cython: language_level = 3 + +import pyarrow as pa +from pyarrow.lib cimport * +from pyarrow.lib import frombytes, tobytes + +# basic test to roundtrip through a BoundFunction + +ctypedef CStatus visit_string_cb(const c_string&) + +cdef extern from * namespace "arrow::py" nogil: + """ + #include + #include + #include + + #include "arrow/status.h" + + namespace arrow { + namespace py { + + Status VisitStrings(const std::vector& strs, + std::function cb) { + for (const std::string& str : strs) { + RETURN_NOT_OK(cb(str)); + } + return Status::OK(); + } + + } // namespace py + } // namespace arrow + """ + cdef CStatus CVisitStrings" arrow::py::VisitStrings"( + vector[c_string], function[visit_string_cb]) + + +cdef void _visit_strings_impl(py_cb, const c_string& s) except *: + py_cb(frombytes(s)) + + +def _visit_strings(strings, cb): + cdef: + function[visit_string_cb] c_cb + vector[c_string] c_strings + + c_cb = BindFunction[visit_string_cb](&_visit_strings_impl, cb) + for s in strings: + c_strings.push_back(tobytes(s)) + + check_status(CVisitStrings(c_strings, c_cb)) diff --git a/python/pyarrow/tests/test_cython.py b/python/pyarrow/tests/test_cython.py index b852981ba39..e202b417a18 100644 --- a/python/pyarrow/tests/test_cython.py +++ b/python/pyarrow/tests/test_cython.py @@ -27,6 +27,11 @@ here = os.path.dirname(os.path.abspath(__file__)) +test_ld_path = os.environ.get('PYARROW_TEST_LD_PATH', '') +if os.name == 'posix': + compiler_opts = ['-std=c++11'] +else: + compiler_opts = [] setup_template = """if 1: @@ -82,18 +87,12 @@ def test_cython_api(tmpdir): # Fail early if cython is not found import cython # noqa - test_ld_path = os.environ.get('PYARROW_TEST_LD_PATH', '') - with tmpdir.as_cwd(): # Set up temporary workspace pyx_file = 'pyarrow_cython_example.pyx' shutil.copyfile(os.path.join(here, pyx_file), os.path.join(str(tmpdir), pyx_file)) # Create setup.py file - if os.name == 'posix': - compiler_opts = ['-std=c++11'] - else: - compiler_opts = [] setup_code = setup_template.format(pyx_file=pyx_file, compiler_opts=compiler_opts, test_ld_path=test_ld_path) @@ -141,3 +140,41 @@ def test_cython_api(tmpdir): subprocess.check_call([sys.executable, '-c', code], stdout=subprocess.PIPE, env=subprocess_env) + + +@pytest.mark.cython +def test_visit_strings(tmpdir): + with tmpdir.as_cwd(): + # Set up temporary workspace + pyx_file = 'bound_function_visit_strings.pyx' + shutil.copyfile(os.path.join(here, pyx_file), + os.path.join(str(tmpdir), pyx_file)) + # Create setup.py file + setup_code = setup_template.format(pyx_file=pyx_file, + compiler_opts=compiler_opts, + test_ld_path=test_ld_path) + with open('setup.py', 'w') as f: + f.write(setup_code) + + subprocess_env = test_util.get_modified_env_with_pythonpath() + + # Compile extension module + subprocess.check_call([sys.executable, 'setup.py', + 'build_ext', '--inplace'], + env=subprocess_env) + + sys.path.insert(0, str(tmpdir)) + mod = __import__('bound_function_visit_strings') + + strings = ['a', 'b', 'c'] + visited = [] + mod._visit_strings(strings, visited.append) + + assert visited == strings + + with pytest.raises(ValueError, match="wtf"): + def raise_on_b(s): + if s == 'b': + raise ValueError('wtf') + + mod._visit_strings(strings, raise_on_b)