From a8aafef53ab8609959ef76d75a96ca04035dc36c Mon Sep 17 00:00:00 2001 From: Ian Cook Date: Sat, 17 Jul 2021 01:54:37 -0400 Subject: [PATCH 1/8] Implement case_when() --- r/R/dplyr-functions.R | 46 ++++++++++++++++++++++++++++++++++++++++++- r/src/compute.cpp | 7 +++++++ 2 files changed, 52 insertions(+), 1 deletion(-) diff --git a/r/R/dplyr-functions.R b/r/R/dplyr-functions.R index d118eefaa85..42accdf6ddb 100644 --- a/r/R/dplyr-functions.R +++ b/r/R/dplyr-functions.R @@ -698,6 +698,50 @@ nse_funcs$if_else <- function(condition, true, false, missing = NULL){ # Although base R ifelse allows `yes` and `no` to be different classes # -nse_funcs$ifelse <- function(test, yes, no){ +nse_funcs$ifelse <- function(test, yes, no) { nse_funcs$if_else(condition = test, true = yes, false = no) } + +nse_funcs$case_when <- function(...) { + formulas <- list2(...) + n <- length(formulas) + if (n == 0) { + abort("No cases provided in case_when()") + } + query <- vector("list", n) + value <- vector("list", n) + mask <- caller_env() + for (i in seq_len(n)) { + f <- formulas[[i]] + if (!inherits(f, "formula")) { + abort("Each argument to case_when() must be a two-sided formula") + } + query[[i]] <- arrow_eval(f[[2]], mask) + value[[i]] <- arrow_eval(f[[3]], mask) + if (!nse_funcs$is.logical(query[[i]])) { + abort("Left side of each formula in case_when() must be a logical expression") + } + # TODO: remove these checks after the case_when kernel supports variable-width + # types (ARROW-13222) + has_bad_r_type <- inherits(value[[i]], c("character", "raw", "list", "factor")) + has_bad_arrow_type <- inherits(value[[i]], "Expression") && + value[[i]]$type_id() %in% Type[c( + "STRING", "BINARY", "LIST", "MAP", "STRUCT", "SPARSE_UNION", + "DENSE_UNION", "DICTIONARY", "EXTENSION", "FIXED_SIZE_LIST", + "LARGE_STRING", "LARGE_BINARY", "LARGE_LIST")] + if (has_bad_r_type || has_bad_arrow_type) { + arrow_not_supported("case_when() with variable-width data types") + } + } + build_expr( + "case_when", + args = c( + build_expr( + "make_struct", + args = query, + options = list(field_names = as.character(seq_along(query))) + ), + value + ) + ) +} diff --git a/r/src/compute.cpp b/r/src/compute.cpp index 2c5ee77c8d0..f524ad76dc4 100644 --- a/r/src/compute.cpp +++ b/r/src/compute.cpp @@ -241,6 +241,13 @@ std::shared_ptr make_compute_options( return out; } + if (func_name == "make_struct") { + using Options = arrow::compute::MakeStructOptions; + // TODO: accept `field_nullability` and `field_metadata` options + return std::make_shared( + cpp11::as_cpp>(options["field_names"])); + } + if (func_name == "match_substring" || func_name == "match_substring_regex" || func_name == "find_substring" || func_name == "find_substring_regex" || func_name == "match_like") { From e343d04675103d459961470d95beec9a4ea9f4d7 Mon Sep 17 00:00:00 2001 From: Ian Cook Date: Sat, 17 Jul 2021 13:07:32 -0400 Subject: [PATCH 2/8] Add tests --- r/tests/testthat/test-dplyr.R | 106 ++++++++++++++++++++++++++++++++++ 1 file changed, 106 insertions(+) diff --git a/r/tests/testthat/test-dplyr.R b/r/tests/testthat/test-dplyr.R index e99f743690b..be75da71407 100644 --- a/r/tests/testthat/test-dplyr.R +++ b/r/tests/testthat/test-dplyr.R @@ -1225,3 +1225,109 @@ test_that("if_else and ifelse", { tbl ) }) + +test_that("case_when()", { + expect_dplyr_equal( + input %>% + transmute(cw = case_when(lgl ~ dbl, !false ~ dbl + dbl2)) %>% + collect(), + tbl + ) + expect_dplyr_equal( + input %>% + mutate(cw = case_when(int > 5 ~ 1, TRUE ~ 0)) %>% + collect(), + tbl + ) + expect_dplyr_equal( + input %>% + transmute(cw = case_when(chr %in% letters[1:3] ~ 1L) + 41L) %>% + collect(), + tbl + ) + expect_dplyr_equal( + input %>% + filter(case_when( + dbl + int - 1.1 == dbl2 ~ TRUE, + NA ~ NA, + TRUE ~ FALSE + ) & !is.na(dbl2)) %>% + collect(), + tbl + ) + + # dplyr::case_when() errors if values on right side of formulas do not have + # exactly the same type, but the Arrow case_when kernel allows compatible types + expect_equal( + tbl %>% + mutate(i64 = as.integer64(1e10)) %>% + Table$create() %>% + transmute(cw = case_when( + is.na(fct) ~ int, + is.na(chr) ~ dbl, + TRUE ~ i64 + )) %>% + collect(), + tbl %>% + transmute( + cw = ifelse(is.na(fct), int, ifelse(is.na(chr), dbl, 1e10)) + ) + ) + + # expected errors (which are caught by abandon_ship() and changed to warnings) + # TODO: Find a way to test these directly without abandon_ship() interfering + expect_error( + # no cases + expect_warning( + tbl %>% + Table$create() %>% + transmute(cw = case_when()), + "case_when" + ) + ) + expect_error( + # argument not a formula + expect_warning( + tbl %>% + Table$create() %>% + transmute(cw = case_when(TRUE ~ FALSE, TRUE)), + "case_when" + ) + ) + expect_error( + # non-logical R scalar on left side of formula + expect_warning( + tbl %>% + Table$create() %>% + transmute(cw = case_when(0L ~ FALSE, TRUE ~ FALSE)), + "case_when" + ) + ) + expect_error( + # non-logical Arrow expression on left side of formula + expect_warning( + tbl %>% + Table$create() %>% + transmute(cw = case_when(int ~ FALSE)), + "case_when" + ) + ) + + skip("case_when does not yet support with variable-width types (ARROW-13222)") + expect_dplyr_equal( + input %>% + transmute( + cw = case_when(lgl ~ verses, !false ~ paste(chr, chr)) + ) %>% + collect(), + tbl + ) + expect_dplyr_equal( + input %>% + mutate( + cw = paste0(case_when(!(!(!(lgl))) ~ factor(chr), TRUE ~ fct), "!") + ) %>% + collect(), + tbl + ) +}) From f1ff874cdb69906c6e77c6619b2c2d734d6c3b85 Mon Sep 17 00:00:00 2001 From: Ian Cook Date: Sat, 17 Jul 2021 13:12:03 -0400 Subject: [PATCH 3/8] Add one more skipped test --- r/tests/testthat/test-dplyr.R | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/r/tests/testthat/test-dplyr.R b/r/tests/testthat/test-dplyr.R index be75da71407..acd86013cb1 100644 --- a/r/tests/testthat/test-dplyr.R +++ b/r/tests/testthat/test-dplyr.R @@ -1316,9 +1316,13 @@ test_that("case_when()", { skip("case_when does not yet support with variable-width types (ARROW-13222)") expect_dplyr_equal( input %>% - transmute( - cw = case_when(lgl ~ verses, !false ~ paste(chr, chr)) - ) %>% + transmute(cw = case_when(lgl ~ "abc")) %>% + collect(), + tbl + ) + expect_dplyr_equal( + input %>% + transmute(cw = case_when(lgl ~ verses, !false ~ paste(chr, chr))) %>% collect(), tbl ) From 8b98e908173302cec052c878c317cf51c4654722 Mon Sep 17 00:00:00 2001 From: Ian Cook Date: Sat, 17 Jul 2021 13:14:28 -0400 Subject: [PATCH 4/8] Add one more error condition test --- r/tests/testthat/test-dplyr.R | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/r/tests/testthat/test-dplyr.R b/r/tests/testthat/test-dplyr.R index acd86013cb1..468ad8593bd 100644 --- a/r/tests/testthat/test-dplyr.R +++ b/r/tests/testthat/test-dplyr.R @@ -1304,7 +1304,7 @@ test_that("case_when()", { ) ) expect_error( - # non-logical Arrow expression on left side of formula + # non-logical Arrow column reference on left side of formula expect_warning( tbl %>% Table$create() %>% @@ -1312,6 +1312,15 @@ test_that("case_when()", { "case_when" ) ) + expect_error( + # non-logical Arrow expression on left side of formula + expect_warning( + tbl %>% + Table$create() %>% + transmute(cw = case_when(dbl + 3.14159 ~ TRUE)), + "case_when" + ) + ) skip("case_when does not yet support with variable-width types (ARROW-13222)") expect_dplyr_equal( From 52662c75b329f4d70abb8f48ad2b075632757f4c Mon Sep 17 00:00:00 2001 From: Ian Cook Date: Sat, 17 Jul 2021 14:30:45 -0400 Subject: [PATCH 5/8] Remove unnecessary checks --- r/R/dplyr-functions.R | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/r/R/dplyr-functions.R b/r/R/dplyr-functions.R index 42accdf6ddb..d42992006a5 100644 --- a/r/R/dplyr-functions.R +++ b/r/R/dplyr-functions.R @@ -721,17 +721,6 @@ nse_funcs$case_when <- function(...) { if (!nse_funcs$is.logical(query[[i]])) { abort("Left side of each formula in case_when() must be a logical expression") } - # TODO: remove these checks after the case_when kernel supports variable-width - # types (ARROW-13222) - has_bad_r_type <- inherits(value[[i]], c("character", "raw", "list", "factor")) - has_bad_arrow_type <- inherits(value[[i]], "Expression") && - value[[i]]$type_id() %in% Type[c( - "STRING", "BINARY", "LIST", "MAP", "STRUCT", "SPARSE_UNION", - "DENSE_UNION", "DICTIONARY", "EXTENSION", "FIXED_SIZE_LIST", - "LARGE_STRING", "LARGE_BINARY", "LARGE_LIST")] - if (has_bad_r_type || has_bad_arrow_type) { - arrow_not_supported("case_when() with variable-width data types") - } } build_expr( "case_when", From 16699bf39b803afeaa0150eb263c6df40491d910 Mon Sep 17 00:00:00 2001 From: Ian Cook Date: Sat, 17 Jul 2021 15:19:51 -0400 Subject: [PATCH 6/8] Add NEWS bullet --- r/NEWS.md | 1 + 1 file changed, 1 insertion(+) diff --git a/r/NEWS.md b/r/NEWS.md index a1cd67a2ec3..b7369a4b338 100644 --- a/r/NEWS.md +++ b/r/NEWS.md @@ -26,6 +26,7 @@ * String operations: `strsplit()` and `str_split()`; `strptime()`; `paste()`, `paste0()`, and `str_c()`; `substr()` and `str_sub()`; `str_like()`; `str_pad()`; `stri_reverse()` * Date/time operations: `lubridate` methods such as `year()`, `month()`, `wday()`, and so on * Math: `log()`, trigonometry (`sin()`, `cos()`, et al.), `abs()`, `sign()`, `pmin()`/`pmax()` + * Conditional: `ifelse()` and `if_else()`; `case_when()` * `is.*` functions are supported and can be used inside `relocate()` * The print method for `arrow_dplyr_query` now includes the expression and the resulting type of columns derived by `mutate()`. From a1aefb4d4df46d2caebccae03602ca03eb0e2c63 Mon Sep 17 00:00:00 2001 From: Ian Cook Date: Sat, 17 Jul 2021 15:47:21 -0400 Subject: [PATCH 7/8] Note if_else and case_when limitations in NEWS --- r/NEWS.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/r/NEWS.md b/r/NEWS.md index b7369a4b338..9cd7542a012 100644 --- a/r/NEWS.md +++ b/r/NEWS.md @@ -26,7 +26,7 @@ * String operations: `strsplit()` and `str_split()`; `strptime()`; `paste()`, `paste0()`, and `str_c()`; `substr()` and `str_sub()`; `str_like()`; `str_pad()`; `stri_reverse()` * Date/time operations: `lubridate` methods such as `year()`, `month()`, `wday()`, and so on * Math: `log()`, trigonometry (`sin()`, `cos()`, et al.), `abs()`, `sign()`, `pmin()`/`pmax()` - * Conditional: `ifelse()` and `if_else()`; `case_when()` + * Conditional: `ifelse()` and `if_else()` (fixed-precision decimal numbers do not yet work and factors/dictionaries are converted to character strings); `case_when()` (currently works with numeric data types but not character strings, factors/dictionaries, or lists/structs) * `is.*` functions are supported and can be used inside `relocate()` * The print method for `arrow_dplyr_query` now includes the expression and the resulting type of columns derived by `mutate()`. From c7155a6b75f1d542497ccdbeb43eb68601e0f166 Mon Sep 17 00:00:00 2001 From: Neal Richardson Date: Sat, 17 Jul 2021 17:25:48 -0400 Subject: [PATCH 8/8] Add TODO jiras --- r/src/compute.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/r/src/compute.cpp b/r/src/compute.cpp index f524ad76dc4..30821137383 100644 --- a/r/src/compute.cpp +++ b/r/src/compute.cpp @@ -243,7 +243,7 @@ std::shared_ptr make_compute_options( if (func_name == "make_struct") { using Options = arrow::compute::MakeStructOptions; - // TODO: accept `field_nullability` and `field_metadata` options + // TODO (ARROW-13371): accept `field_nullability` and `field_metadata` options return std::make_shared( cpp11::as_cpp>(options["field_names"])); }