Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions r/R/dplyr-funcs-conditional.R
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,15 @@ register_bindings_conditional <- function() {
out
})

register_binding("dplyr::case_when", function(...) {
register_binding("dplyr::case_when", function(..., .default = NULL, .ptype = NULL, .size = NULL) {
if (!is.null(.ptype)) {
arrow_not_supported("`case_when()` with `.ptype` specified")
}

if (!is.null(.size)) {
arrow_not_supported("`case_when()` with `.size` specified")
}

formulas <- list2(...)
n <- length(formulas)
if (n == 0) {
Expand All @@ -113,6 +121,14 @@ register_bindings_conditional <- function() {
abort(handle_arrow_not_supported(value[[i]], format_expr(f[[3]])))
}
}
if (!is.null(.default)) {
if (length(.default) != 1) {
abort(paste0("`.default` must have size 1, not size ", length(.default), "."))
}

query[n + 1] <- TRUE
value[n + 1] <- .default
}
Expression$create(
"case_when",
args = c(
Expand All @@ -124,5 +140,6 @@ register_bindings_conditional <- function() {
value
)
)
})
}, notes = "`.ptype` and `.size` arguments not supported"
)
}
26 changes: 13 additions & 13 deletions r/R/dplyr-funcs-doc.R
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@
#' Functions can be called either as `pkg::fun()` or just `fun()`, i.e. both
#' `str_sub()` and `stringr::str_sub()` work.
#'
#' In addition to these functions, you can call any of Arrow's 246 compute
#' In addition to these functions, you can call any of Arrow's 254 compute
#' functions directly. Arrow has many functions that don't map to an existing R
#' function. In other cases where there is an R function mapping, you can still
#' call the Arrow function directly if you don't want the adaptations that the R
Expand All @@ -99,30 +99,31 @@
#'
#' ## base
#'
#' * [`-`][-()]
#' * [`!`][!()]
#' * [`!=`][!=()]
#' * [`*`][*()]
#' * [`/`][/()]
#' * [`&`][&()]
#' * [`%/%`][%/%()]
#' * [`%%`][%%()]
#' * [`%/%`][%/%()]
#' * [`%in%`][%in%()]
#' * [`^`][^()]
#' * [`&`][&()]
#' * [`*`][*()]
#' * [`+`][+()]
#' * [`-`][-()]
#' * [`/`][/()]
#' * [`<`][<()]
#' * [`<=`][<=()]
#' * [`==`][==()]
#' * [`>`][>()]
#' * [`>=`][>=()]
#' * [`|`][|()]
#' * [`ISOdate()`][base::ISOdate()]
#' * [`ISOdatetime()`][base::ISOdatetime()]
#' * [`^`][^()]
#' * [`abs()`][base::abs()]
#' * [`acos()`][base::acos()]
#' * [`all()`][base::all()]
#' * [`any()`][base::any()]
#' * [`as.character()`][base::as.character()]
#' * [`as.Date()`][base::as.Date()]: Multiple `tryFormats` not supported in Arrow.
#' Consider using the lubridate specialised parsing functions `ymd()`, `ymd()`, etc.
#' * [`as.character()`][base::as.character()]
#' * [`as.difftime()`][base::as.difftime()]: only supports `units = "secs"` (the default)
#' * [`as.double()`][base::as.double()]
#' * [`as.integer()`][base::as.integer()]
Expand Down Expand Up @@ -153,8 +154,6 @@
#' * [`is.na()`][base::is.na()]
#' * [`is.nan()`][base::is.nan()]
#' * [`is.numeric()`][base::is.numeric()]
#' * [`ISOdate()`][base::ISOdate()]
#' * [`ISOdatetime()`][base::ISOdatetime()]
#' * [`log()`][base::log()]
#' * [`log10()`][base::log10()]
#' * [`log1p()`][base::log1p()]
Expand Down Expand Up @@ -186,6 +185,7 @@
#' * [`tolower()`][base::tolower()]
#' * [`toupper()`][base::toupper()]
#' * [`trunc()`][base::trunc()]
#' * [`|`][|()]
#'
#' ## bit64
#'
Expand All @@ -196,7 +196,7 @@
#'
#' * [`across()`][dplyr::across()]
#' * [`between()`][dplyr::between()]
#' * [`case_when()`][dplyr::case_when()]
#' * [`case_when()`][dplyr::case_when()]: `.ptype` and `.size` arguments not supported
#' * [`coalesce()`][dplyr::coalesce()]
#' * [`desc()`][dplyr::desc()]
#' * [`if_all()`][dplyr::if_all()]
Expand Down Expand Up @@ -242,8 +242,8 @@
#' * [`format_ISO8601()`][lubridate::format_ISO8601()]
#' * [`hour()`][lubridate::hour()]
#' * [`is.Date()`][lubridate::is.Date()]
#' * [`is.instant()`][lubridate::is.instant()]
#' * [`is.POSIXct()`][lubridate::is.POSIXct()]
#' * [`is.instant()`][lubridate::is.instant()]
#' * [`is.timepoint()`][lubridate::is.timepoint()]
#' * [`isoweek()`][lubridate::isoweek()]
#' * [`isoyear()`][lubridate::isoyear()]
Expand Down
26 changes: 13 additions & 13 deletions r/man/acero.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

31 changes: 31 additions & 0 deletions r/tests/testthat/test-dplyr-funcs-conditional.R
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,14 @@ test_that("case_when()", {
collect(),
tbl
)

compare_dplyr_binding(
.input %>%
mutate(cw = case_when(int > 5 ~ 1, .default = 0)) %>%
collect(),
tbl
)

compare_dplyr_binding(
.input %>%
transmute(cw = case_when(chr %in% letters[1:3] ~ 1L) + 41L) %>%
Expand Down Expand Up @@ -271,6 +279,29 @@ test_that("case_when()", {
)
)

expect_error(
expect_warning(
tbl %>%
arrow_table() %>%
mutate(cw = case_when(int > 5 ~ 1, .default = c(0, 1)))
),
"`.default` must have size"
)

expect_warning(
tbl %>%
arrow_table() %>%
mutate(cw = case_when(int > 5 ~ 1, .ptype = integer())),
"not supported in Arrow"
)

expect_warning(
tbl %>%
arrow_table() %>%
mutate(cw = case_when(int > 5 ~ 1, .size = 10)),
"not supported in Arrow"
)

compare_dplyr_binding(
.input %>%
transmute(cw = case_when(lgl ~ "abc")) %>%
Expand Down