Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions cpp/src/arrow/python/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,66 @@ class ARROW_PYTHON_EXPORT OwnedRefNoGIL : public OwnedRef {
}
};

template <typename Fn>
struct BoundFunction;

template <typename... Args>
struct BoundFunction<void(PyObject*, Args...)> {
// We bind `cdef void fn(object, ...)` to get a `Status(...)`
// where the Status contains any Python error raised by `fn`
using Unbound = void(PyObject*, Args...);
using Bound = Status(Args...);

BoundFunction(Unbound* unbound, PyObject* bound_arg)
: bound_arg_(bound_arg), unbound_(unbound) {}

Status Invoke(Args... args) const {
PyAcquireGIL lock;
unbound_(bound_arg_.obj(), std::forward<Args>(args)...);
RETURN_IF_PYERROR();
return Status::OK();
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is meant to be called from arbitrary C++ code, it should probably take the GIL before anything else. Perhaps reuse SafeCallIntoPython?

Copy link
Member Author

@bkietz bkietz Jun 29, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SafeCallIntoPython assumes that we're returning a Result<T> or Status, which unbound_ doesn't (it returns T or void).

Additionally, gcc4.8 cannot pass a parameter pack into a lambda

I'll use PyAcquireGIL


Unbound* unbound_;
OwnedRefNoGIL bound_arg_;
};

template <typename Return, typename... Args>
struct BoundFunction<Return(PyObject*, Args...)> {
// We bind `cdef Return fn(object, ...)` to get a `Result<Return>(...)`
// where the Result contains any Python error raised by `fn` or the
// return value from `fn`.
using Unbound = Return(PyObject*, Args...);
using Bound = Result<Return>(Args...);

BoundFunction(Unbound* unbound, PyObject* bound_arg)
: bound_arg_(bound_arg), unbound_(unbound) {}

Result<Return> Invoke(Args... args) const {
PyAcquireGIL lock;
Return ret = unbound_(bound_arg_.obj(), std::forward<Args>(args)...);
RETURN_IF_PYERROR();
return ret;
}

Unbound* unbound_;
OwnedRefNoGIL bound_arg_;
};

template <typename OutFn, typename Return, typename... Args>
std::function<OutFn> BindFunction(Return (*unbound)(PyObject*, Args...),
PyObject* bound_arg) {
using Fn = BoundFunction<Return(PyObject*, Args...)>;

static_assert(std::is_same<typename Fn::Bound, OutFn>::value,
"requested bound function of unsupported type");

Py_XINCREF(bound_arg);
auto bound_fn = std::make_shared<Fn>(unbound, bound_arg);
return
[bound_fn](Args... args) { return bound_fn->Invoke(std::forward<Args>(args)...); };
}

// A temporary conversion of a Python object to a bytes area.
struct PyBytesView {
const char* bytes;
Expand Down
1 change: 0 additions & 1 deletion dev/archery/archery/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,6 @@ def _apply_options(cmd, options):
help="Specify Arrow source directory")
# toolchain
@cpp_toolchain_options
@java_toolchain_options
@click.option("--build-type", default=None, type=build_type,
help="CMake's CMAKE_BUILD_TYPE")
@click.option("--warn-level", default="production", type=warn_level_type,
Expand Down
1 change: 1 addition & 0 deletions python/pyarrow/includes/common.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ cdef extern from "arrow/result.h" namespace "arrow" nogil:

cdef extern from "arrow/python/common.h" namespace "arrow::py" nogil:
T GetResultValue[T](CResult[T]) except *
cdef function[F] BindFunction[F](void* unbound, object bound, ...)


cdef inline object PyObject_to_object(PyObject* o):
Expand Down
68 changes: 68 additions & 0 deletions python/pyarrow/tests/bound_function_visit_strings.pyx
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# distutils: language=c++
# cython: language_level = 3

import pyarrow as pa
from pyarrow.lib cimport *
from pyarrow.lib import frombytes, tobytes

# basic test to roundtrip through a BoundFunction

ctypedef CStatus visit_string_cb(const c_string&)

cdef extern from * namespace "arrow::py" nogil:
"""
#include <functional>
#include <string>
#include <vector>

#include "arrow/status.h"

namespace arrow {
namespace py {

Status VisitStrings(const std::vector<std::string>& strs,
std::function<Status(const std::string&)> cb) {
for (const std::string& str : strs) {
RETURN_NOT_OK(cb(str));
}
return Status::OK();
}

} // namespace py
} // namespace arrow
"""
cdef CStatus CVisitStrings" arrow::py::VisitStrings"(
vector[c_string], function[visit_string_cb])


cdef void _visit_strings_impl(py_cb, const c_string& s) except *:
py_cb(frombytes(s))


def _visit_strings(strings, cb):
cdef:
function[visit_string_cb] c_cb
vector[c_string] c_strings

c_cb = BindFunction[visit_string_cb](&_visit_strings_impl, cb)
for s in strings:
c_strings.push_back(tobytes(s))

check_status(CVisitStrings(c_strings, c_cb))
49 changes: 43 additions & 6 deletions python/pyarrow/tests/test_cython.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@


here = os.path.dirname(os.path.abspath(__file__))
test_ld_path = os.environ.get('PYARROW_TEST_LD_PATH', '')
if os.name == 'posix':
compiler_opts = ['-std=c++11']
else:
compiler_opts = []


setup_template = """if 1:
Expand Down Expand Up @@ -82,18 +87,12 @@ def test_cython_api(tmpdir):
# Fail early if cython is not found
import cython # noqa

test_ld_path = os.environ.get('PYARROW_TEST_LD_PATH', '')

with tmpdir.as_cwd():
# Set up temporary workspace
pyx_file = 'pyarrow_cython_example.pyx'
shutil.copyfile(os.path.join(here, pyx_file),
os.path.join(str(tmpdir), pyx_file))
# Create setup.py file
if os.name == 'posix':
compiler_opts = ['-std=c++11']
else:
compiler_opts = []
setup_code = setup_template.format(pyx_file=pyx_file,
compiler_opts=compiler_opts,
test_ld_path=test_ld_path)
Expand Down Expand Up @@ -141,3 +140,41 @@ def test_cython_api(tmpdir):
subprocess.check_call([sys.executable, '-c', code],
stdout=subprocess.PIPE,
env=subprocess_env)


@pytest.mark.cython
def test_visit_strings(tmpdir):
with tmpdir.as_cwd():
# Set up temporary workspace
pyx_file = 'bound_function_visit_strings.pyx'
shutil.copyfile(os.path.join(here, pyx_file),
os.path.join(str(tmpdir), pyx_file))
# Create setup.py file
setup_code = setup_template.format(pyx_file=pyx_file,
compiler_opts=compiler_opts,
test_ld_path=test_ld_path)
with open('setup.py', 'w') as f:
f.write(setup_code)

subprocess_env = test_util.get_modified_env_with_pythonpath()

# Compile extension module
subprocess.check_call([sys.executable, 'setup.py',
'build_ext', '--inplace'],
env=subprocess_env)

sys.path.insert(0, str(tmpdir))
mod = __import__('bound_function_visit_strings')

strings = ['a', 'b', 'c']
visited = []
mod._visit_strings(strings, visited.append)

assert visited == strings

with pytest.raises(ValueError, match="wtf"):
def raise_on_b(s):
if s == 'b':
raise ValueError('wtf')

mod._visit_strings(strings, raise_on_b)