diff --git a/r/R/dplyr-functions.R b/r/R/dplyr-functions.R index d2f7892aee8..808956efe15 100644 --- a/r/R/dplyr-functions.R +++ b/r/R/dplyr-functions.R @@ -694,6 +694,23 @@ nse_funcs$wday <- function(x, label = FALSE, abbr = TRUE, week_start = getOption } nse_funcs$log <- nse_funcs$logb <- function(x, base = exp(1)) { + # like other binary functions, either `x` or `base` can be Expression or double(1) + if (is.numeric(x) && length(x) == 1) { + x <- Expression$scalar(x) + } else if (!inherits(x, "Expression")) { + arrow_not_supported("x must be a column or a length-1 numeric; other values") + } + + # handle `base` differently because we use the simpler ln, log2, and log10 + # functions for specific scalar base values + if (inherits(base, "Expression")) { + return(Expression$create("logb_checked", x, base)) + } + + if (!is.numeric(base) || length(base) != 1) { + arrow_not_supported("base must be a column or a length-1 numeric; other values") + } + if (base == exp(1)) { return(Expression$create("ln_checked", x)) } @@ -705,8 +722,8 @@ nse_funcs$log <- nse_funcs$logb <- function(x, base = exp(1)) { if (base == 10) { return(Expression$create("log10_checked", x)) } - # ARROW-13345 - arrow_not_supported("`base` values other than exp(1), 2 and 10") + + Expression$create("logb_checked", x, Expression$scalar(base)) } nse_funcs$if_else <- function(condition, true, false, missing = NULL) { diff --git a/r/tests/testthat/test-dplyr.R b/r/tests/testthat/test-dplyr.R index d3a9994b5f1..ab7296b7818 100644 --- a/r/tests/testthat/test-dplyr.R +++ b/r/tests/testthat/test-dplyr.R @@ -1035,12 +1035,60 @@ test_that("log functions", { df ) + # test log(, base = (length == 1)) + expect_dplyr_equal( + input %>% + mutate(y = log(x, base = 5)) %>% + collect(), + df + ) + + # test log(, base = (length != 1)) expect_error( - nse_funcs$log(Expression$scalar(x), base = 5), - "`base` values other than exp(1), 2 and 10 not supported by Arrow", + nse_funcs$log(10, base = 5:6), + "base must be a column or a length-1 numeric; other values not supported by Arrow", fixed = TRUE ) + # test log(x = (length != 1)) + expect_error( + nse_funcs$log(10:11), + "x must be a column or a length-1 numeric; other values not supported by Arrow", + fixed = TRUE + ) + + # test log(, base = Expression) + expect_dplyr_equal( + input %>% + # test cases where base = 1 below + filter(x != 1) %>% + mutate( + y = log(x, base = x), + z = log(2, base = x) + ) %>% + collect(), + df + ) + + # log(1, base = 1) is NaN in both R and Arrow + # suppress the R warning because R warns but Arrow does not + suppressWarnings( + expect_dplyr_equal( + input %>% + mutate(y = log(x, base = y)) %>% + collect(), + tibble(x = 1, y = 1) + ) + ) + + # log(n != 1, base = 1) is Inf in R and Arrow + expect_dplyr_equal( + input %>% + mutate(y = log(x, base = y)) %>% + collect(), + tibble(x = 10, y = 1) + ) + expect_dplyr_equal( input %>% mutate(y = logb(x)) %>%