diff --git a/r/R/dplyr-eval.R b/r/R/dplyr-eval.R index a8fb7c43300..2d2fd972a48 100644 --- a/r/R/dplyr-eval.R +++ b/r/R/dplyr-eval.R @@ -22,30 +22,38 @@ arrow_eval <- function(expr, mask) { # This yields an Expression as long as the `exprs` are implemented in Arrow. # Otherwise, it returns a try-error - tryCatch(eval_tidy(expr, mask), error = function(e) { - # Look for the cases where bad input was given, i.e. this would fail - # in regular dplyr anyway, and let those raise those as errors; - # else, for things not supported in Arrow return a "try-error", - # which we'll handle differently - msg <- conditionMessage(e) - if (getOption("arrow.debug", FALSE)) print(msg) - patterns <- .cache$i18ized_error_pattern - if (is.null(patterns)) { - patterns <- i18ize_error_messages() - # Memoize it - .cache$i18ized_error_pattern <- patterns - } - if (grepl(patterns, msg)) { - stop(e) - } + tryCatch( + eval_tidy(expr, mask), + error = function(e) { + # Look for the cases where bad input was given, i.e. this would fail + # in regular dplyr anyway, and let those raise those as errors; + # else, for things not supported in Arrow return a "try-error", + # which we'll handle differently + msg <- conditionMessage(e) + if (getOption("arrow.debug", FALSE)) print(msg) + patterns <- .cache$i18ized_error_pattern + if (is.null(patterns)) { + patterns <- i18ize_error_messages() + # Memoize it + .cache$i18ized_error_pattern <- patterns + } + if (grepl(patterns, msg)) { + stop(e) + } - out <- structure(msg, class = "try-error", condition = e) - if (grepl("not supported.*Arrow", msg) || getOption("arrow.debug", FALSE)) { - # One of ours. Mark it so that consumers can handle it differently - class(out) <- c("arrow-try-error", class(out)) - } - invisible(out) - }) + out <- structure(msg, class = "try-error", condition = e) + if (grepl("not supported.*Arrow", msg) || getOption("arrow.debug", FALSE)) { + # One of ours. Mark it so that consumers can handle it differently + class(out) <- c("arrow-try-error", class(out)) + } + # if the expression text (i.e. the name of the function) is not in the + # names of the top environment (which is the function registry) + expr_text <- rlang::quo_get_expr(expr)[[1]] %>% rlang::expr_text() + if (!expr_text %in% names(mask$.top_env)) { + class(out) <- c("arrow-binding-error", class(out)) + } + invisible(out) + }) } handle_arrow_not_supported <- function(err, lab) { diff --git a/r/R/dplyr-funcs-datetime.R b/r/R/dplyr-funcs-datetime.R index 7d11cdc1134..4b3c529c07d 100644 --- a/r/R/dplyr-funcs-datetime.R +++ b/r/R/dplyr-funcs-datetime.R @@ -29,7 +29,7 @@ register_bindings_datetime <- function() { register_bindings_datetime_utility <- function() { register_binding("base::strptime", function(x, - format = "%Y-%m-%d %H:%M:%S", + format = "%Y-%m-%d %H:%M:%S", tz = "", unit = "ms") { # Arrow uses unit for time parsing, strptime() does not. diff --git a/r/R/dplyr-funcs.R b/r/R/dplyr-funcs.R index 7c4ed99e2ed..38c1dc82b8c 100644 --- a/r/R/dplyr-funcs.R +++ b/r/R/dplyr-funcs.R @@ -118,6 +118,7 @@ create_binding_cache <- function() { register_bindings_math() register_bindings_string() register_bindings_type() + register_bindings_utils() # We only create the cache for nse_funcs and not agg_funcs .cache$functions <- c(as.list(nse_funcs), arrow_funcs) diff --git a/r/R/dplyr-mutate.R b/r/R/dplyr-mutate.R index 653c1e6f25a..b256e7dad23 100644 --- a/r/R/dplyr-mutate.R +++ b/r/R/dplyr-mutate.R @@ -54,12 +54,27 @@ mutate.arrow_dplyr_query <- function(.data, # (which overwrites the previous name) new_var <- names(exprs)[i] results[[new_var]] <- arrow_eval(exprs[[i]], mask) - if (inherits(results[[new_var]], "try-error")) { + if (inherits(results[[new_var]], "arrow-binding-error")) { + + expr <- rlang::quo_get_expr(exprs[[i]]) + new_expr <- grep( + paste0("::", rlang::expr_text(expr[[1]])), + names(mask$.top_env), + value = TRUE + ) %>% + rlang::parse_expr() + + exprs[[i]][[2]][[1]] <- new_expr + + results[[new_var]] <- arrow_eval(exprs[[i]], mask) + + } else if (inherits(results[[new_var]], "try-error")) { msg <- handle_arrow_not_supported( results[[new_var]], format_expr(exprs[[i]]) ) return(abandon_ship(call, .data, msg)) + } else if (!inherits(results[[new_var]], "Expression") && !is.null(results[[new_var]])) { # We need some wrapping to handle literal values diff --git a/r/tests/testthat/test-dplyr-funcs-datetime.R b/r/tests/testthat/test-dplyr-funcs-datetime.R index f0543736404..64a08ba5ec3 100644 --- a/r/tests/testthat/test-dplyr-funcs-datetime.R +++ b/r/tests/testthat/test-dplyr-funcs-datetime.R @@ -90,7 +90,7 @@ test_that("strptime", { t_string %>% arrow_table() %>% mutate( - x = strptime(x, format = "%Y-%m-%d %H:%M:%S", tz = "Pacific/Marquesas") + x = base::strptime(x, format = "%Y-%m-%d %H:%M:%S", tz = "Pacific/Marquesas") ) %>% collect(), t_stamp_with_pm_tz