diff --git a/r/R/compute.R b/r/R/compute.R index a144e7d678a..1386728ac90 100644 --- a/r/R/compute.R +++ b/r/R/compute.R @@ -379,9 +379,17 @@ register_scalar_function <- function(name, fun, in_type, out_type, RegisterScalarUDF(name, scalar_function) # register with dplyr binding (enables its use in mutate(), filter(), etc.) + binding_fun <- function(...) build_expr(name, ...) + + # inject the value of `name` into the expression to avoid saving this + # execution environment in the binding, which eliminates a warning when the + # same binding is registered twice + body(binding_fun) <- expr_substitute(body(binding_fun), sym("name"), name) + environment(binding_fun) <- asNamespace("arrow") + register_binding( name, - function(...) build_expr(name, ...), + binding_fun, update_cache = TRUE ) diff --git a/r/R/dplyr-funcs.R b/r/R/dplyr-funcs.R index e5f76570616..ee64a09918d 100644 --- a/r/R/dplyr-funcs.R +++ b/r/R/dplyr-funcs.R @@ -75,7 +75,7 @@ register_binding <- function(fun_name, previous_fun <- registry[[unqualified_name]] # if the unqualified name exists in the registry, warn - if (!is.null(previous_fun)) { + if (!is.null(previous_fun) && !identical(fun, previous_fun)) { warn( paste0( "A \"", diff --git a/r/src/compute.cpp b/r/src/compute.cpp index 1ed949e7295..0bfc5172852 100644 --- a/r/src/compute.cpp +++ b/r/src/compute.cpp @@ -611,8 +611,8 @@ class RScalarUDFKernelState : public arrow::compute::KernelState { RScalarUDFKernelState(cpp11::sexp exec_func, cpp11::sexp resolver) : exec_func_(exec_func), resolver_(resolver) {} - cpp11::function exec_func_; - cpp11::function resolver_; + cpp11::sexp exec_func_; + cpp11::sexp resolver_; }; arrow::Result ResolveScalarUDFOutputType( @@ -630,7 +630,8 @@ arrow::Result ResolveScalarUDFOutputType( cpp11::to_r6(input_types[i].GetSharedPtr()); } - cpp11::sexp output_type_sexp = state->resolver_(input_types_sexp); + cpp11::sexp output_type_sexp = + cpp11::function(state->resolver_)(input_types_sexp); if (!Rf_inherits(output_type_sexp, "DataType")) { cpp11::stop( "Function specified as arrow_scalar_function() out_type argument must " @@ -674,7 +675,8 @@ arrow::Status CallRScalarUDF(arrow::compute::KernelContext* context, cpp11::writable::list udf_context = {batch_length_sexp, output_type_sexp}; udf_context.names() = {"batch_length", "output_type"}; - cpp11::sexp func_result_sexp = state->exec_func_(udf_context, args_sexp); + cpp11::sexp func_result_sexp = + cpp11::function(state->exec_func_)(udf_context, args_sexp); if (Rf_inherits(func_result_sexp, "Array")) { auto array = cpp11::as_cpp>(func_result_sexp); diff --git a/r/src/recordbatchreader.cpp b/r/src/recordbatchreader.cpp index d0c52acc416..8e9df121748 100644 --- a/r/src/recordbatchreader.cpp +++ b/r/src/recordbatchreader.cpp @@ -70,7 +70,7 @@ class RFunctionRecordBatchReader : public arrow::RecordBatchReader { arrow::Status ReadNext(std::shared_ptr* batch_out) { auto batch = SafeCallIntoR>([&]() { - cpp11::sexp result_sexp = fun_(); + cpp11::sexp result_sexp = cpp11::function(fun_)(); if (result_sexp == R_NilValue) { return std::shared_ptr(nullptr); } else if (!Rf_inherits(result_sexp, "RecordBatch")) { @@ -94,7 +94,7 @@ class RFunctionRecordBatchReader : public arrow::RecordBatchReader { } private: - cpp11::function fun_; + cpp11::sexp fun_; std::shared_ptr schema_; }; diff --git a/r/tests/testthat/test-dplyr-funcs.R b/r/tests/testthat/test-dplyr-funcs.R index 86f984dd32c..48b74c9af43 100644 --- a/r/tests/testthat/test-dplyr-funcs.R +++ b/r/tests/testthat/test-dplyr-funcs.R @@ -35,6 +35,9 @@ test_that("register_binding()/unregister_binding() works", { register_binding("some.pkg2::some_fun", fun2, fake_registry), "A \"some_fun\" binding already exists in the registry and will be overwritten." ) + + # No warning when an identical function is re-registered + expect_silent(register_binding("some.pkg2::some_fun", fun2, fake_registry)) }) test_that("register_binding_agg() works", {