Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions r/R/dplyr-functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -694,6 +694,23 @@ nse_funcs$wday <- function(x, label = FALSE, abbr = TRUE, week_start = getOption
}

nse_funcs$log <- nse_funcs$logb <- function(x, base = exp(1)) {
# like other binary functions, either `x` or `base` can be Expression or double(1)
if (is.numeric(x) && length(x) == 1) {
x <- Expression$scalar(x)
} else if (!inherits(x, "Expression")) {
arrow_not_supported("x must be a column or a length-1 numeric; other values")
}

# handle `base` differently because we use the simpler ln, log2, and log10
# functions for specific scalar base values
if (inherits(base, "Expression")) {
return(Expression$create("logb_checked", x, base))
}

if (!is.numeric(base) || length(base) != 1) {
arrow_not_supported("base must be a column or a length-1 numeric; other values")
}

if (base == exp(1)) {
return(Expression$create("ln_checked", x))
}
Expand All @@ -705,8 +722,8 @@ nse_funcs$log <- nse_funcs$logb <- function(x, base = exp(1)) {
if (base == 10) {
return(Expression$create("log10_checked", x))
}
# ARROW-13345
arrow_not_supported("`base` values other than exp(1), 2 and 10")

Expression$create("logb_checked", x, Expression$scalar(base))
}

nse_funcs$if_else <- function(condition, true, false, missing = NULL) {
Expand Down
52 changes: 50 additions & 2 deletions r/tests/testthat/test-dplyr.R
Original file line number Diff line number Diff line change
Expand Up @@ -1035,12 +1035,60 @@ test_that("log functions", {
df
)

# test log(, base = (length == 1))
expect_dplyr_equal(
input %>%
mutate(y = log(x, base = 5)) %>%
collect(),
df
)

# test log(, base = (length != 1))
expect_error(
nse_funcs$log(Expression$scalar(x), base = 5),
"`base` values other than exp(1), 2 and 10 not supported by Arrow",
nse_funcs$log(10, base = 5:6),
"base must be a column or a length-1 numeric; other values not supported by Arrow",
fixed = TRUE
)

# test log(x = (length != 1))
expect_error(
nse_funcs$log(10:11),
"x must be a column or a length-1 numeric; other values not supported by Arrow",
fixed = TRUE
)

# test log(, base = Expression)
expect_dplyr_equal(
input %>%
# test cases where base = 1 below
filter(x != 1) %>%
mutate(
y = log(x, base = x),
z = log(2, base = x)
) %>%
collect(),
df
)

# log(1, base = 1) is NaN in both R and Arrow
# suppress the R warning because R warns but Arrow does not
suppressWarnings(
expect_dplyr_equal(
input %>%
mutate(y = log(x, base = y)) %>%
collect(),
tibble(x = 1, y = 1)
)
)

# log(n != 1, base = 1) is Inf in R and Arrow
expect_dplyr_equal(
input %>%
mutate(y = log(x, base = y)) %>%
collect(),
tibble(x = 10, y = 1)
)

expect_dplyr_equal(
input %>%
mutate(y = logb(x)) %>%
Expand Down