diff --git a/r/R/dplyr-funcs-conditional.R b/r/R/dplyr-funcs-conditional.R index 37411ed2616..cd0245eeee1 100644 --- a/r/R/dplyr-funcs-conditional.R +++ b/r/R/dplyr-funcs-conditional.R @@ -90,7 +90,15 @@ register_bindings_conditional <- function() { out }) - register_binding("dplyr::case_when", function(...) { + register_binding("dplyr::case_when", function(..., .default = NULL, .ptype = NULL, .size = NULL) { + if (!is.null(.ptype)) { + arrow_not_supported("`case_when()` with `.ptype` specified") + } + + if (!is.null(.size)) { + arrow_not_supported("`case_when()` with `.size` specified") + } + formulas <- list2(...) n <- length(formulas) if (n == 0) { @@ -113,6 +121,14 @@ register_bindings_conditional <- function() { abort(handle_arrow_not_supported(value[[i]], format_expr(f[[3]]))) } } + if (!is.null(.default)) { + if (length(.default) != 1) { + abort(paste0("`.default` must have size 1, not size ", length(.default), ".")) + } + + query[n + 1] <- TRUE + value[n + 1] <- .default + } Expression$create( "case_when", args = c( @@ -124,5 +140,6 @@ register_bindings_conditional <- function() { value ) ) - }) + }, notes = "`.ptype` and `.size` arguments not supported" + ) } diff --git a/r/R/dplyr-funcs-doc.R b/r/R/dplyr-funcs-doc.R index b619cfe509b..a62c4f8335a 100644 --- a/r/R/dplyr-funcs-doc.R +++ b/r/R/dplyr-funcs-doc.R @@ -83,7 +83,7 @@ #' Functions can be called either as `pkg::fun()` or just `fun()`, i.e. both #' `str_sub()` and `stringr::str_sub()` work. #' -#' In addition to these functions, you can call any of Arrow's 246 compute +#' In addition to these functions, you can call any of Arrow's 254 compute #' functions directly. Arrow has many functions that don't map to an existing R #' function. In other cases where there is an R function mapping, you can still #' call the Arrow function directly if you don't want the adaptations that the R @@ -99,30 +99,31 @@ #' #' ## base #' -#' * [`-`][-()] #' * [`!`][!()] #' * [`!=`][!=()] -#' * [`*`][*()] -#' * [`/`][/()] -#' * [`&`][&()] -#' * [`%/%`][%/%()] #' * [`%%`][%%()] +#' * [`%/%`][%/%()] #' * [`%in%`][%in%()] -#' * [`^`][^()] +#' * [`&`][&()] +#' * [`*`][*()] #' * [`+`][+()] +#' * [`-`][-()] +#' * [`/`][/()] #' * [`<`][<()] #' * [`<=`][<=()] #' * [`==`][==()] #' * [`>`][>()] #' * [`>=`][>=()] -#' * [`|`][|()] +#' * [`ISOdate()`][base::ISOdate()] +#' * [`ISOdatetime()`][base::ISOdatetime()] +#' * [`^`][^()] #' * [`abs()`][base::abs()] #' * [`acos()`][base::acos()] #' * [`all()`][base::all()] #' * [`any()`][base::any()] -#' * [`as.character()`][base::as.character()] #' * [`as.Date()`][base::as.Date()]: Multiple `tryFormats` not supported in Arrow. #' Consider using the lubridate specialised parsing functions `ymd()`, `ymd()`, etc. +#' * [`as.character()`][base::as.character()] #' * [`as.difftime()`][base::as.difftime()]: only supports `units = "secs"` (the default) #' * [`as.double()`][base::as.double()] #' * [`as.integer()`][base::as.integer()] @@ -153,8 +154,6 @@ #' * [`is.na()`][base::is.na()] #' * [`is.nan()`][base::is.nan()] #' * [`is.numeric()`][base::is.numeric()] -#' * [`ISOdate()`][base::ISOdate()] -#' * [`ISOdatetime()`][base::ISOdatetime()] #' * [`log()`][base::log()] #' * [`log10()`][base::log10()] #' * [`log1p()`][base::log1p()] @@ -186,6 +185,7 @@ #' * [`tolower()`][base::tolower()] #' * [`toupper()`][base::toupper()] #' * [`trunc()`][base::trunc()] +#' * [`|`][|()] #' #' ## bit64 #' @@ -196,7 +196,7 @@ #' #' * [`across()`][dplyr::across()] #' * [`between()`][dplyr::between()] -#' * [`case_when()`][dplyr::case_when()] +#' * [`case_when()`][dplyr::case_when()]: `.ptype` and `.size` arguments not supported #' * [`coalesce()`][dplyr::coalesce()] #' * [`desc()`][dplyr::desc()] #' * [`if_all()`][dplyr::if_all()] @@ -242,8 +242,8 @@ #' * [`format_ISO8601()`][lubridate::format_ISO8601()] #' * [`hour()`][lubridate::hour()] #' * [`is.Date()`][lubridate::is.Date()] -#' * [`is.instant()`][lubridate::is.instant()] #' * [`is.POSIXct()`][lubridate::is.POSIXct()] +#' * [`is.instant()`][lubridate::is.instant()] #' * [`is.timepoint()`][lubridate::is.timepoint()] #' * [`isoweek()`][lubridate::isoweek()] #' * [`isoyear()`][lubridate::isoyear()] diff --git a/r/man/acero.Rd b/r/man/acero.Rd index 6d4476c44c2..d41029c70b7 100644 --- a/r/man/acero.Rd +++ b/r/man/acero.Rd @@ -68,7 +68,7 @@ can assume that the function works in Acero just as it does in R. Functions can be called either as \code{pkg::fun()} or just \code{fun()}, i.e. both \code{str_sub()} and \code{stringr::str_sub()} work. -In addition to these functions, you can call any of Arrow's 246 compute +In addition to these functions, you can call any of Arrow's 254 compute functions directly. Arrow has many functions that don't map to an existing R function. In other cases where there is an R function mapping, you can still call the Arrow function directly if you don't want the adaptations that the R @@ -85,30 +85,31 @@ as \code{arrow_ascii_is_decimal}. \subsection{base}{ \itemize{ -\item \code{\link[=-]{-}} \item \code{\link[=!]{!}} \item \code{\link[=!=]{!=}} -\item \code{\link[=*]{*}} -\item \code{\link[=/]{/}} -\item \code{\link[=&]{&}} -\item \code{\link[=\%/\%]{\%/\%}} \item \code{\link[=\%\%]{\%\%}} +\item \code{\link[=\%/\%]{\%/\%}} \item \code{\link[=\%in\%]{\%in\%}} -\item \code{\link[=^]{^}} +\item \code{\link[=&]{&}} +\item \code{\link[=*]{*}} \item \code{\link[=+]{+}} +\item \code{\link[=-]{-}} +\item \code{\link[=/]{/}} \item \code{\link[=<]{<}} \item \code{\link[=<=]{<=}} \item \code{\link[===]{==}} \item \code{\link[=>]{>}} \item \code{\link[=>=]{>=}} -\item \code{\link[=|]{|}} +\item \code{\link[base:ISOdatetime]{ISOdate()}} +\item \code{\link[base:ISOdatetime]{ISOdatetime()}} +\item \code{\link[=^]{^}} \item \code{\link[base:MathFun]{abs()}} \item \code{\link[base:Trig]{acos()}} \item \code{\link[base:all]{all()}} \item \code{\link[base:any]{any()}} -\item \code{\link[base:character]{as.character()}} \item \code{\link[base:as.Date]{as.Date()}}: Multiple \code{tryFormats} not supported in Arrow. Consider using the lubridate specialised parsing functions \code{ymd()}, \code{ymd()}, etc. +\item \code{\link[base:character]{as.character()}} \item \code{\link[base:difftime]{as.difftime()}}: only supports \code{units = "secs"} (the default) \item \code{\link[base:double]{as.double()}} \item \code{\link[base:integer]{as.integer()}} @@ -139,8 +140,6 @@ Consider using the lubridate specialised parsing functions \code{ymd()}, \code{y \item \code{\link[base:NA]{is.na()}} \item \code{\link[base:is.finite]{is.nan()}} \item \code{\link[base:numeric]{is.numeric()}} -\item \code{\link[base:ISOdatetime]{ISOdate()}} -\item \code{\link[base:ISOdatetime]{ISOdatetime()}} \item \code{\link[base:Log]{log()}} \item \code{\link[base:Log]{log10()}} \item \code{\link[base:Log]{log1p()}} @@ -172,6 +171,7 @@ Valid values are "s", "ms" (default), "us", "ns". \item \code{\link[base:chartr]{tolower()}} \item \code{\link[base:chartr]{toupper()}} \item \code{\link[base:Round]{trunc()}} +\item \code{\link[=|]{|}} } } @@ -186,7 +186,7 @@ Valid values are "s", "ms" (default), "us", "ns". \itemize{ \item \code{\link[dplyr:across]{across()}} \item \code{\link[dplyr:between]{between()}} -\item \code{\link[dplyr:case_when]{case_when()}} +\item \code{\link[dplyr:case_when]{case_when()}}: \code{.ptype} and \code{.size} arguments not supported \item \code{\link[dplyr:coalesce]{coalesce()}} \item \code{\link[dplyr:desc]{desc()}} \item \code{\link[dplyr:across]{if_all()}} @@ -234,8 +234,8 @@ Valid values are "s", "ms" (default), "us", "ns". \item \code{\link[lubridate:format_ISO8601]{format_ISO8601()}} \item \code{\link[lubridate:hour]{hour()}} \item \code{\link[lubridate:date_utils]{is.Date()}} -\item \code{\link[lubridate:is.instant]{is.instant()}} \item \code{\link[lubridate:posix_utils]{is.POSIXct()}} +\item \code{\link[lubridate:is.instant]{is.instant()}} \item \code{\link[lubridate:is.instant]{is.timepoint()}} \item \code{\link[lubridate:week]{isoweek()}} \item \code{\link[lubridate:year]{isoyear()}} diff --git a/r/tests/testthat/test-dplyr-funcs-conditional.R b/r/tests/testthat/test-dplyr-funcs-conditional.R index b3d86da8b41..e60712e9e61 100644 --- a/r/tests/testthat/test-dplyr-funcs-conditional.R +++ b/r/tests/testthat/test-dplyr-funcs-conditional.R @@ -176,6 +176,14 @@ test_that("case_when()", { collect(), tbl ) + + compare_dplyr_binding( + .input %>% + mutate(cw = case_when(int > 5 ~ 1, .default = 0)) %>% + collect(), + tbl + ) + compare_dplyr_binding( .input %>% transmute(cw = case_when(chr %in% letters[1:3] ~ 1L) + 41L) %>% @@ -271,6 +279,29 @@ test_that("case_when()", { ) ) + expect_error( + expect_warning( + tbl %>% + arrow_table() %>% + mutate(cw = case_when(int > 5 ~ 1, .default = c(0, 1))) + ), + "`.default` must have size" + ) + + expect_warning( + tbl %>% + arrow_table() %>% + mutate(cw = case_when(int > 5 ~ 1, .ptype = integer())), + "not supported in Arrow" + ) + + expect_warning( + tbl %>% + arrow_table() %>% + mutate(cw = case_when(int > 5 ~ 1, .size = 10)), + "not supported in Arrow" + ) + compare_dplyr_binding( .input %>% transmute(cw = case_when(lgl ~ "abc")) %>%