diff --git a/r/NAMESPACE b/r/NAMESPACE index 8b08b940b36..86e65087cba 100644 --- a/r/NAMESPACE +++ b/r/NAMESPACE @@ -405,6 +405,7 @@ importFrom(purrr,map_dbl) importFrom(purrr,map_dfr) importFrom(purrr,map_int) importFrom(purrr,map_lgl) +importFrom(purrr,reduce) importFrom(rlang,"%||%") importFrom(rlang,":=") importFrom(rlang,.data) @@ -426,6 +427,7 @@ importFrom(rlang,env_bind) importFrom(rlang,eval_tidy) importFrom(rlang,exec) importFrom(rlang,expr) +importFrom(rlang,expr_text) importFrom(rlang,f_env) importFrom(rlang,f_rhs) importFrom(rlang,is_bare_character) @@ -443,6 +445,7 @@ importFrom(rlang,list2) importFrom(rlang,new_data_mask) importFrom(rlang,new_environment) importFrom(rlang,new_quosure) +importFrom(rlang,new_quosures) importFrom(rlang,parse_expr) importFrom(rlang,quo) importFrom(rlang,quo_get_env) diff --git a/r/R/arrow-package.R b/r/R/arrow-package.R index 298aa94f52b..143f4c191bd 100644 --- a/r/R/arrow-package.R +++ b/r/R/arrow-package.R @@ -17,7 +17,8 @@ #' @importFrom stats quantile median na.omit na.exclude na.pass na.fail #' @importFrom R6 R6Class -#' @importFrom purrr as_mapper map map2 map_chr map2_chr map_dbl map_dfr map_int map_lgl keep imap imap_chr flatten +#' @importFrom purrr as_mapper map map2 map_chr map2_chr map_dbl map_dfr map_int map_lgl keep imap imap_chr +#' @importFrom purrr flatten reduce #' @importFrom assertthat assert_that is.string #' @importFrom rlang list2 %||% is_false abort dots_n warn enquo quo_is_null enquos is_integerish quos quo #' @importFrom rlang eval_tidy new_data_mask syms env new_environment env_bind set_names exec @@ -25,6 +26,7 @@ #' @importFrom rlang expr caller_env is_character quo_name is_quosure enexpr enexprs as_quosure #' @importFrom rlang is_list call2 is_empty as_function as_label arg_match is_symbol is_call call_args #' @importFrom rlang quo_set_env quo_get_env is_formula quo_is_call f_rhs parse_expr f_env new_quosure +#' @importFrom rlang new_quosures expr_text #' @importFrom tidyselect vars_pull vars_rename vars_select eval_select #' @importFrom glue glue #' @useDynLib arrow, .registration = TRUE diff --git a/r/R/dplyr-across.R b/r/R/dplyr-across.R index 6550978d6f6..d23525ddfb5 100644 --- a/r/R/dplyr-across.R +++ b/r/R/dplyr-across.R @@ -23,7 +23,7 @@ expand_across <- function(.data, quos_in) { quo_expr <- quo_get_expr(quo_in[[1]]) quo_env <- quo_get_env(quo_in[[1]]) - if (is_call(quo_expr, "across")) { + if (is_call(quo_expr, c("across", "if_any", "if_all"))) { new_quos <- list() across_call <- match.call( @@ -58,9 +58,30 @@ expand_across <- function(.data, quos_in) { } else { quos_out <- append(quos_out, quo_in) } + + if (is_call(quo_expr, "if_any")) { + quos_out <- append(list(), purrr::reduce(quos_out, combine_if, op = "|", envir = quo_get_env(quos_out[[1]]))) + } + + if (is_call(quo_expr, "if_all")) { + quos_out <- append(list(), purrr::reduce(quos_out, combine_if, op = "&", envir = quo_get_env(quos_out[[1]]))) + } } - quos_out + new_quosures(quos_out) +} + +# takes multiple expressions and combines them with & or | +combine_if <- function(lhs, rhs, op, envir) { + expr_text <- paste( + expr_text(quo_get_expr(lhs)), + expr_text(quo_get_expr(rhs)), + sep = paste0(" ", op, " ") + ) + + expr <- parse_expr(expr_text) + + new_quosure(expr, envir) } # given a named list of functions and column names, create a list of new quosures @@ -157,7 +178,7 @@ across_glue_mask <- function(.col, .fn, .caller_env) { env(.caller_env, .col = .col, .fn = .fn, col = .col, fn = .fn) } -# Substitutes instances of `.` and `.x` with the variable in question +# Substitutes instances of "." and ".x" with `var` as_across_fn_call <- function(fn, var, quo_env) { if (is_formula(fn, lhs = FALSE)) { expr <- f_rhs(fn) diff --git a/r/R/dplyr-filter.R b/r/R/dplyr-filter.R index 7db68b43e93..1ef2b6d7e58 100644 --- a/r/R/dplyr-filter.R +++ b/r/R/dplyr-filter.R @@ -20,7 +20,7 @@ filter.arrow_dplyr_query <- function(.data, ..., .preserve = FALSE) { # TODO something with the .preserve argument - filts <- quos(...) + filts <- expand_across(.data, quos(...)) if (length(filts) == 0) { # Nothing to do return(.data) diff --git a/r/data-raw/docgen.R b/r/data-raw/docgen.R index 2e2581c5788..e2c7f94eafc 100644 --- a/r/data-raw/docgen.R +++ b/r/data-raw/docgen.R @@ -128,11 +128,15 @@ docs <- arrow:::.cache$docs # across() is handled by manipulating the quosures, not by nse_funcs docs[["dplyr::across"]] <- c( - # TODO(ARROW-17387): do filter - "not yet supported inside `filter()`;", # TODO(ARROW-17384): implement where - "and use of `where()` selection helper not yet supported" + "Use of `where()` selection helper not yet supported" ) + +# if_any() and if_all() are used instead of across() in filter() +# they are both handled by manipulating the quosures, not by nse_funcs +docs[["dplyr::if_any"]] <- character(0) +docs[["dplyr::if_all"]] <- character(0) + # desc() is a special helper handled inside of arrange() docs[["dplyr::desc"]] <- character(0) diff --git a/r/tests/testthat/helper-expectation.R b/r/tests/testthat/helper-expectation.R index 17e57dd4194..2bb0bf205e9 100644 --- a/r/tests/testthat/helper-expectation.R +++ b/r/tests/testthat/helper-expectation.R @@ -323,5 +323,5 @@ split_vector_as_list <- function(vec) { } expect_across_equal <- function(across_expr, expected, tbl) { - expect_identical(expand_across(as_adq(tbl), across_expr), as.list(expected)) + expect_identical(expand_across(as_adq(tbl), across_expr), new_quosures(expected)) } diff --git a/r/tests/testthat/test-dplyr-across.R b/r/tests/testthat/test-dplyr-across.R index d622351a28c..5ded2038c4f 100644 --- a/r/tests/testthat/test-dplyr-across.R +++ b/r/tests/testthat/test-dplyr-across.R @@ -278,3 +278,19 @@ test_that("ARROW-14071 - function(x)-style lambda functions are not supported", regexp = "Anonymous functions are not yet supported in Arrow" ) }) + +test_that("if_all() and if_any() are supported", { + + expect_across_equal( + quos(if_any(everything(), ~is.na(.x))), + quos(is.na(int) | is.na(dbl) | is.na(dbl2) | is.na(lgl) | is.na(false) | is.na(chr) | is.na(fct)), + example_data + ) + + expect_across_equal( + quos(if_all(everything(), ~is.na(.x))), + quos(is.na(int) & is.na(dbl) & is.na(dbl2) & is.na(lgl) & is.na(false) & is.na(chr) & is.na(fct)), + example_data + ) + +}) diff --git a/r/tests/testthat/test-dplyr-filter.R b/r/tests/testthat/test-dplyr-filter.R index f94450a0257..81a9ba3f6e5 100644 --- a/r/tests/testthat/test-dplyr-filter.R +++ b/r/tests/testthat/test-dplyr-filter.R @@ -417,3 +417,25 @@ test_that("filter() with namespaced functions", { tbl ) }) + +test_that("filter() with across()", { + + compare_dplyr_binding( + .input %>% + filter(if_any(ends_with("l"), ~ is.na(.))) %>% + collect(), + tbl + ) + + compare_dplyr_binding( + .input %>% + filter( + false == FALSE, + if_all(everything(), ~ !is.na(.)), + int > 2 + ) %>% + collect(), + tbl + ) + +})