From b88252203c6bc901a00c5119bb5f6ca6249d74f4 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Fri, 10 Sep 2021 10:59:43 -0300 Subject: [PATCH 1/7] implement log(x, base = b) -> logb mapping --- r/R/dplyr-functions.R | 4 ++-- r/tests/testthat/test-dplyr.R | 9 +++++---- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/r/R/dplyr-functions.R b/r/R/dplyr-functions.R index d2f7892aee8..60d6062f6e7 100644 --- a/r/R/dplyr-functions.R +++ b/r/R/dplyr-functions.R @@ -705,8 +705,8 @@ nse_funcs$log <- nse_funcs$logb <- function(x, base = exp(1)) { if (base == 10) { return(Expression$create("log10_checked", x)) } - # ARROW-13345 - arrow_not_supported("`base` values other than exp(1), 2 and 10") + + Expression$create("logb_checked", x, Expression$scalar(base)) } nse_funcs$if_else <- function(condition, true, false, missing = NULL) { diff --git a/r/tests/testthat/test-dplyr.R b/r/tests/testthat/test-dplyr.R index d3a9994b5f1..3c5fb9aea11 100644 --- a/r/tests/testthat/test-dplyr.R +++ b/r/tests/testthat/test-dplyr.R @@ -1035,10 +1035,11 @@ test_that("log functions", { df ) - expect_error( - nse_funcs$log(Expression$scalar(x), base = 5), - "`base` values other than exp(1), 2 and 10 not supported by Arrow", - fixed = TRUE + expect_dplyr_equal( + input %>% + mutate(y = log(x, base = 5)) %>% + collect(), + df ) expect_dplyr_equal( From a85a62f7ff74df014493c22e493835c792433dac Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Fri, 10 Sep 2021 14:39:35 -0300 Subject: [PATCH 2/7] support two-column input for log(, base = ) --- r/R/dplyr-functions.R | 4 ++++ r/tests/testthat/test-dplyr.R | 9 +++++++++ 2 files changed, 13 insertions(+) diff --git a/r/R/dplyr-functions.R b/r/R/dplyr-functions.R index 60d6062f6e7..6d4637bc5e7 100644 --- a/r/R/dplyr-functions.R +++ b/r/R/dplyr-functions.R @@ -694,6 +694,10 @@ nse_funcs$wday <- function(x, label = FALSE, abbr = TRUE, week_start = getOption } nse_funcs$log <- nse_funcs$logb <- function(x, base = exp(1)) { + if (inherits(base, "Expression")) { + return(Expression$create("logb_checked", x, base)) + } + if (base == exp(1)) { return(Expression$create("ln_checked", x)) } diff --git a/r/tests/testthat/test-dplyr.R b/r/tests/testthat/test-dplyr.R index 3c5fb9aea11..85e47c63ddf 100644 --- a/r/tests/testthat/test-dplyr.R +++ b/r/tests/testthat/test-dplyr.R @@ -1042,6 +1042,15 @@ test_that("log functions", { df ) + expect_dplyr_equal( + input %>% + # suppress 'NaNs produced' warning on the first row of df + # that evaluates to NaN (R raises warning but Arrow does not) + suppressWarnings(mutate(., y = log(x, base = x))) %>% + collect(), + df + ) + expect_dplyr_equal( input %>% mutate(y = logb(x)) %>% From 812f941a7720d3e73236d032e8cbf54db171353e Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Fri, 10 Sep 2021 16:56:33 -0300 Subject: [PATCH 3/7] more explicit testing of base = 1 --- r/tests/testthat/test-dplyr.R | 27 ++++++++++++++++++++++++--- 1 file changed, 24 insertions(+), 3 deletions(-) diff --git a/r/tests/testthat/test-dplyr.R b/r/tests/testthat/test-dplyr.R index 85e47c63ddf..9ebe27e570d 100644 --- a/r/tests/testthat/test-dplyr.R +++ b/r/tests/testthat/test-dplyr.R @@ -1035,6 +1035,7 @@ test_that("log functions", { df ) + # test log(, base = (length-1)) expect_dplyr_equal( input %>% mutate(y = log(x, base = 5)) %>% @@ -1042,15 +1043,35 @@ test_that("log functions", { df ) + # test log(, base = Expression) expect_dplyr_equal( input %>% - # suppress 'NaNs produced' warning on the first row of df - # that evaluates to NaN (R raises warning but Arrow does not) - suppressWarnings(mutate(., y = log(x, base = x))) %>% + # test cases where base = 1 below + filter(x != 1) %>% + mutate(y = log(x, base = x)) %>% collect(), df ) + # log(1, base = 1) is NaN in both R and Arrow + # suppress the R warning because R warns but Arrow does not + suppressWarnings( + expect_dplyr_equal( + input %>% + mutate(y = log(x, base = y)) %>% + collect(), + tibble(x = 1, y = 1) + ) + ) + + # log(n != 1, base = 1) is Inf in R and Arrow + expect_dplyr_equal( + input %>% + mutate(y = log(x, base = y)) %>% + collect(), + tibble(x = 10, y = 1) + ) + expect_dplyr_equal( input %>% mutate(y = logb(x)) %>% From 6175c7319ea8a541fb4c5c38b053ad2a6eaffb71 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Fri, 10 Sep 2021 21:34:35 -0300 Subject: [PATCH 4/7] error for length(base) != 1 --- r/R/dplyr-functions.R | 4 ++++ r/tests/testthat/test-dplyr.R | 9 ++++++++- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/r/R/dplyr-functions.R b/r/R/dplyr-functions.R index 6d4637bc5e7..8bf39da4130 100644 --- a/r/R/dplyr-functions.R +++ b/r/R/dplyr-functions.R @@ -698,6 +698,10 @@ nse_funcs$log <- nse_funcs$logb <- function(x, base = exp(1)) { return(Expression$create("logb_checked", x, base)) } + if (!is.numeric(base) || length(base) != 1) { + arrow_not_supported("base with length != 1") + } + if (base == exp(1)) { return(Expression$create("ln_checked", x)) } diff --git a/r/tests/testthat/test-dplyr.R b/r/tests/testthat/test-dplyr.R index 9ebe27e570d..7f6a7c4a4b2 100644 --- a/r/tests/testthat/test-dplyr.R +++ b/r/tests/testthat/test-dplyr.R @@ -1035,7 +1035,7 @@ test_that("log functions", { df ) - # test log(, base = (length-1)) + # test log(, base = (length == 1)) expect_dplyr_equal( input %>% mutate(y = log(x, base = 5)) %>% @@ -1043,6 +1043,13 @@ test_that("log functions", { df ) + # test log(, base = (length != 1)) + expect_error( + nse_funcs$log(Expression$scalar(10), base = 5:6), + "base with length != 1 not supported by Arrow", + fixed = TRUE + ) + # test log(, base = Expression) expect_dplyr_equal( input %>% From b43897b93cd27af7caaf4f6054bd287e18d68c14 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Mon, 13 Sep 2021 21:52:02 -0300 Subject: [PATCH 5/7] Update r/R/dplyr-functions.R Co-authored-by: Neal Richardson --- r/R/dplyr-functions.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/r/R/dplyr-functions.R b/r/R/dplyr-functions.R index 8bf39da4130..3f796a02090 100644 --- a/r/R/dplyr-functions.R +++ b/r/R/dplyr-functions.R @@ -699,7 +699,7 @@ nse_funcs$log <- nse_funcs$logb <- function(x, base = exp(1)) { } if (!is.numeric(base) || length(base) != 1) { - arrow_not_supported("base with length != 1") + arrow_not_supported("base must be either a column in the data or a length-1 scalar; other values") } if (base == exp(1)) { From 23b5fc74e817a2f803d49f02450629c954277ae4 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Tue, 14 Sep 2021 09:26:04 -0300 Subject: [PATCH 6/7] align treatment of `x` and `base` with handling of other binary functions (either can be scalar or Expression) --- r/R/dplyr-functions.R | 11 ++++++++++- r/tests/testthat/test-dplyr.R | 11 +++++++++-- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/r/R/dplyr-functions.R b/r/R/dplyr-functions.R index 3f796a02090..808956efe15 100644 --- a/r/R/dplyr-functions.R +++ b/r/R/dplyr-functions.R @@ -694,12 +694,21 @@ nse_funcs$wday <- function(x, label = FALSE, abbr = TRUE, week_start = getOption } nse_funcs$log <- nse_funcs$logb <- function(x, base = exp(1)) { + # like other binary functions, either `x` or `base` can be Expression or double(1) + if (is.numeric(x) && length(x) == 1) { + x <- Expression$scalar(x) + } else if (!inherits(x, "Expression")) { + arrow_not_supported("x must be a column or a length-1 numeric; other values") + } + + # handle `base` differently because we use the simpler ln, log2, and log10 + # functions for specific scalar base values if (inherits(base, "Expression")) { return(Expression$create("logb_checked", x, base)) } if (!is.numeric(base) || length(base) != 1) { - arrow_not_supported("base must be either a column in the data or a length-1 scalar; other values") + arrow_not_supported("base must be a column or a length-1 numeric; other values") } if (base == exp(1)) { diff --git a/r/tests/testthat/test-dplyr.R b/r/tests/testthat/test-dplyr.R index 7f6a7c4a4b2..999d6ec0a88 100644 --- a/r/tests/testthat/test-dplyr.R +++ b/r/tests/testthat/test-dplyr.R @@ -1045,8 +1045,15 @@ test_that("log functions", { # test log(, base = (length != 1)) expect_error( - nse_funcs$log(Expression$scalar(10), base = 5:6), - "base with length != 1 not supported by Arrow", + nse_funcs$log(10, base = 5:6), + "base must be a column or a length-1 numeric; other values not supported by Arrow", + fixed = TRUE + ) + + # test log(x = (length != 1)) + expect_error( + nse_funcs$log(10:11), + "x must be a column or a length-1 numeric; other values not supported by Arrow", fixed = TRUE ) From 350402c5cf57c30490b1927afbae152d349697ec Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Tue, 14 Sep 2021 11:59:30 -0300 Subject: [PATCH 7/7] Update r/tests/testthat/test-dplyr.R Co-authored-by: Neal Richardson --- r/tests/testthat/test-dplyr.R | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/r/tests/testthat/test-dplyr.R b/r/tests/testthat/test-dplyr.R index 999d6ec0a88..ab7296b7818 100644 --- a/r/tests/testthat/test-dplyr.R +++ b/r/tests/testthat/test-dplyr.R @@ -1062,7 +1062,10 @@ test_that("log functions", { input %>% # test cases where base = 1 below filter(x != 1) %>% - mutate(y = log(x, base = x)) %>% + mutate( + y = log(x, base = x), + z = log(2, base = x) + ) %>% collect(), df )