From 374402eeae185a6d280737b8c06893713c46ffd5 Mon Sep 17 00:00:00 2001 From: Neal Richardson Date: Sat, 27 Aug 2022 17:43:03 -0400 Subject: [PATCH 1/8] R scalars should be cast to match the expression type, where appropriate --- r/R/expression.R | 53 +++++++++++++++++++++++--- r/tests/testthat/test-dataset-dplyr.R | 6 +-- r/tests/testthat/test-dplyr-collapse.R | 4 +- r/tests/testthat/test-dplyr-filter.R | 30 ++++++++------- r/tests/testthat/test-dplyr-mutate.R | 2 +- r/tests/testthat/test-dplyr-query.R | 15 ++++++++ 6 files changed, 85 insertions(+), 25 deletions(-) diff --git a/r/R/expression.R b/r/R/expression.R index 99296aab399..1925dc9327d 100644 --- a/r/R/expression.R +++ b/r/R/expression.R @@ -217,12 +217,7 @@ build_expr <- function(FUN, ) ) } else { - args <- lapply(args, function(x) { - if (!inherits(x, "Expression")) { - x <- Expression$scalar(x) - } - x - }) + args <- wrap_scalars(args, FUN) # In Arrow, "divide" is one function, which does integer division on # integer inputs and floating-point division on floats @@ -258,6 +253,52 @@ build_expr <- function(FUN, expr } +wrap_scalars <- function(args, FUN) { + arrow_fun <- .array_function_map[[FUN]] %||% FUN + if (arrow_fun == "if_else") { + # For if_else, the first arg should be a bool Expression, and we don't + # want to consider that when casting the other args to the same type + args[-1] <- wrap_scalars(args[-1], FUN = "") + return(args) + } + + is_expr <- map_lgl(args, ~ inherits(., "Expression")) + if (all(is_expr)) { + # No wrapping is required + return(args) + } + + args[!is_expr] <- lapply(args[!is_expr], Scalar$create) + + # Some special casing by function + # * %/%: we switch behavior based on int vs. dbl in R (see build_expr) so skip + # * binary_repeat, list_element: 2nd arg must be integer, Acero will handle it + if (any(is_expr) && !(arrow_fun %in% c("binary_repeat", "list_element")) && !(FUN %in% "%/%")) { + if (sum(is_expr) == 1) { + # Simple case: just one expr so take its type + try( + { + # If the Expression has no Schema embedded, we cannot resolve its + # type here, so this will error, hence the try() wrapping it + to_type <- args[[which(is_expr)]]$type() + # Try casting to this type, but if the cast fails, + # we'll just keep the original + args[!is_expr] <- lapply(args[!is_expr], function(x) x$cast(to_type)) + }, + silent = TRUE + ) + } else { + # TODO: check if all expression types are the same, and if so, cast to that + # Functions that exercise code that go through here (in our tests): + # * case_when + # * pmin/pmax + } + } + + args[!is_expr] <- lapply(args[!is_expr], Expression$scalar) + args +} + #' @export Ops.Expression <- function(e1, e2) { if (.Generic == "!") { diff --git a/r/tests/testthat/test-dataset-dplyr.R b/r/tests/testthat/test-dataset-dplyr.R index 04b1a460dc4..44e038586ea 100644 --- a/r/tests/testthat/test-dataset-dplyr.R +++ b/r/tests/testthat/test-dataset-dplyr.R @@ -143,7 +143,7 @@ test_that("mutate()", { chr: string dbl: double int: int32 -twice: double (multiply_checked(int, 2)) +twice: int32 (multiply_checked(int, 2)) * Filter: ((multiply_checked(dbl, 2) > 14) and (subtract_checked(dbl, 50) < 3)) See $.data for the source Arrow object", @@ -219,7 +219,7 @@ test_that("arrange()", { chr: string dbl: double int: int32 -twice: double (multiply_checked(int, 2)) +twice: int32 (multiply_checked(int, 2)) * Filter: ((multiply_checked(dbl, 2) > 14) and (subtract_checked(dbl, 50) < 3)) * Sorted by chr [asc], multiply_checked(int, 2) [desc], add_checked(dbl, int) [asc] @@ -368,7 +368,7 @@ test_that("show_exec_plan(), show_query() and explain() with datasets", { "ExecPlan with .* nodes:.*", # boiler plate for ExecPlan "ProjectNode.*", # output columns "FilterNode.*", # filter node - "int > 6.*cast.*", # filtering expressions + auto-casting of part + "int > 6.*", # filtering expressions "SourceNode" # entry point ) ) diff --git a/r/tests/testthat/test-dplyr-collapse.R b/r/tests/testthat/test-dplyr-collapse.R index 1809cb6e388..6c5f4c19911 100644 --- a/r/tests/testthat/test-dplyr-collapse.R +++ b/r/tests/testthat/test-dplyr-collapse.R @@ -57,7 +57,7 @@ test_that("implicit_schema with mutate", { words = as.character(int) ) %>% implicit_schema(), - schema(numbers = float64(), words = utf8()) + schema(numbers = int32(), words = utf8()) ) }) @@ -163,7 +163,7 @@ test_that("Properties of collapsed query", { "Table (query) lgl: bool total: int32 -extra: double (multiply_checked(total, 5)) +extra: int32 (multiply_checked(total, 5)) See $.data for the source Arrow object", fixed = TRUE diff --git a/r/tests/testthat/test-dplyr-filter.R b/r/tests/testthat/test-dplyr-filter.R index 21a78ee06e4..281ae2abf02 100644 --- a/r/tests/testthat/test-dplyr-filter.R +++ b/r/tests/testthat/test-dplyr-filter.R @@ -217,25 +217,29 @@ test_that("filter() with between()", { filter(dbl >= int, dbl <= dbl2) ) - expect_error( - tbl %>% - record_batch() %>% + compare_dplyr_binding( + .input %>% filter(between(dbl, 1, "2")) %>% - collect() + collect(), + tbl ) - expect_error( - tbl %>% - record_batch() %>% + compare_dplyr_binding( + .input %>% filter(between(dbl, 1, NA)) %>% - collect() + collect(), + tbl ) - expect_error( - tbl %>% - record_batch() %>% - filter(between(chr, 1, 2)) %>% - collect() + expect_warning( + compare_dplyr_binding( + .input %>% + filter(between(chr, 1, 2)) %>% + collect(), + tbl + ), + # the dplyr version warns: + "NAs introduced by coercion" ) }) diff --git a/r/tests/testthat/test-dplyr-mutate.R b/r/tests/testthat/test-dplyr-mutate.R index ee13c8be2e3..5d431089ce7 100644 --- a/r/tests/testthat/test-dplyr-mutate.R +++ b/r/tests/testthat/test-dplyr-mutate.R @@ -458,7 +458,7 @@ test_that("print a mutated table", { print(), "Table (query) int: int32 -twice: double (multiply_checked(int, 2)) +twice: int32 (multiply_checked(int, 2)) See $.data for the source Arrow object", fixed = TRUE diff --git a/r/tests/testthat/test-dplyr-query.R b/r/tests/testthat/test-dplyr-query.R index ef9a9bcdc14..13b113dc019 100644 --- a/r/tests/testthat/test-dplyr-query.R +++ b/r/tests/testthat/test-dplyr-query.R @@ -631,3 +631,18 @@ test_that("collect() is identical to compute() %>% collect()", { collect() ) }) + +test_that("Scalars in expressions match the type of the field, if possible", { + tab <- Table$create(tbl) + expect_output( + tab %>% + filter(int == 4) %>% + show_exec_plan(), + "int == 4" + ) + # TODO: + # * test int == 4.2 (cast to int errors so send as float and let int cast) + # * what about int == "4"? + # * special handling for date == "string", this should work (has separate jira) + # * what about functions where types should not be the same? (timestamp - duration?) +}) From eec7ab1b964ace32e9b5b2af04635e596282515f Mon Sep 17 00:00:00 2001 From: Neal Richardson Date: Mon, 29 Aug 2022 15:41:57 -0400 Subject: [PATCH 2/8] Add a few tests and clean up notes --- r/R/expression.R | 2 +- r/tests/testthat/test-dplyr-query.R | 52 ++++++++++++++++++++++++----- 2 files changed, 45 insertions(+), 9 deletions(-) diff --git a/r/R/expression.R b/r/R/expression.R index 1925dc9327d..8fbda324d0a 100644 --- a/r/R/expression.R +++ b/r/R/expression.R @@ -288,7 +288,7 @@ wrap_scalars <- function(args, FUN) { silent = TRUE ) } else { - # TODO: check if all expression types are the same, and if so, cast to that + # TODO: check if all expr types are the same, and if so, cast to that # Functions that exercise code that go through here (in our tests): # * case_when # * pmin/pmax diff --git a/r/tests/testthat/test-dplyr-query.R b/r/tests/testthat/test-dplyr-query.R index 13b113dc019..f24b49de85f 100644 --- a/r/tests/testthat/test-dplyr-query.R +++ b/r/tests/testthat/test-dplyr-query.R @@ -633,16 +633,52 @@ test_that("collect() is identical to compute() %>% collect()", { }) test_that("Scalars in expressions match the type of the field, if possible", { - tab <- Table$create(tbl) + tbl_with_datetime <- tbl + tbl_with_datetime$dates <- as.Date("2022-08-28") + 1:10 + tbl_with_datetime$times <- lubridate::ymd_hms("2018-10-07 19:04:05") + 1:10 + tab <- Table$create(tbl_with_datetime) + + # 5 is double in R but is properly interpreted as int, no cast is added expect_output( tab %>% - filter(int == 4) %>% + filter(int == 5) %>% show_exec_plan(), - "int == 4" + "int == 5" + ) + + # Because 5.2 can't cast to int32 without truncation, we pass as is + # and Acero will cast int to float64 + expect_output( + tab %>% + filter(int == 5.2) %>% + show_exec_plan(), + "filter=(cast(int, {to_type=double", + fixed = TRUE + ) + expect_equal( + tab %>% + filter(int == 5.2) %>% + nrow(), + 0 ) - # TODO: - # * test int == 4.2 (cast to int errors so send as float and let int cast) - # * what about int == "4"? - # * special handling for date == "string", this should work (has separate jira) - # * what about functions where types should not be the same? (timestamp - duration?) + + # int == string, this works in dplyr and here too + expect_output( + tab %>% + filter(int == "5") %>% + show_exec_plan(), + "int == 5", + fixed = TRUE + ) + expect_equal( + tab %>% + filter(int == "5") %>% + nrow(), + 1 + ) + + skip("Auto casting string to date/timestamp not implemented") + tab %>% + filter(dates > "2022-09-01") %>% + show_exec_plan() }) From 2339e27fe290298b38a93fd944cbad04eab08043 Mon Sep 17 00:00:00 2001 From: Neal Richardson Date: Mon, 29 Aug 2022 16:48:57 -0400 Subject: [PATCH 3/8] Progress on string-datetime parsing --- r/R/expression.R | 20 ++++++++++++++++++- r/tests/testthat/test-dplyr-query.R | 31 +++++++++++++++++++++++++---- 2 files changed, 46 insertions(+), 5 deletions(-) diff --git a/r/R/expression.R b/r/R/expression.R index 8fbda324d0a..79b4d76e5e8 100644 --- a/r/R/expression.R +++ b/r/R/expression.R @@ -283,7 +283,25 @@ wrap_scalars <- function(args, FUN) { to_type <- args[[which(is_expr)]]$type() # Try casting to this type, but if the cast fails, # we'll just keep the original - args[!is_expr] <- lapply(args[!is_expr], function(x) x$cast(to_type)) + args[!is_expr] <- lapply(args[!is_expr], function(x) { + if (x$type == string()) { + if (to_type == date32()) { + x <- call_function( + "strptime", + x, + options = list(format = "%Y-%m-%d", unit = 0L) + ) + } else if (to_type$id == Type[["TIMESTAMP"]]) { + x <- call_function( + "strptime", + x, + options = list(format = "%Y-%m-%d %H:%M:%S", unit = 1L) + ) + # TODO: assume_timezone? + } + } + x$cast(to_type) + }) }, silent = TRUE ) diff --git a/r/tests/testthat/test-dplyr-query.R b/r/tests/testthat/test-dplyr-query.R index f24b49de85f..856fd7bdaf3 100644 --- a/r/tests/testthat/test-dplyr-query.R +++ b/r/tests/testthat/test-dplyr-query.R @@ -677,8 +677,31 @@ test_that("Scalars in expressions match the type of the field, if possible", { 1 ) - skip("Auto casting string to date/timestamp not implemented") - tab %>% - filter(dates > "2022-09-01") %>% - show_exec_plan() + # Strings automatically parsed to date/timestamp + expect_output( + tab %>% + filter(dates > "2022-09-01") %>% + show_exec_plan(), + "dates > 2022-09-01" + ) + compare_dplyr_binding( + .input %>% + filter(dates > "2022-09-01") %>% + collect(), + tbl_with_datetime + ) + + expect_output( + tab %>% + filter(times > "2018-10-07 19:04:05") %>% + show_exec_plan(), + "times > 2018-10-07 19:04:05" + ) + skip("Timezones?") + compare_dplyr_binding( + .input %>% + filter(times > "2018-10-07 19:04:05") %>% + collect(), + tbl_with_datetime + ) }) From 6393bb59b315d3d6dba7848ea9c2f1f7e7d89411 Mon Sep 17 00:00:00 2001 From: Neal Richardson Date: Tue, 30 Aug 2022 09:31:51 -0400 Subject: [PATCH 4/8] Assume timezone --- r/R/expression.R | 9 ++++++++- r/tests/testthat/test-dplyr-query.R | 3 +-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/r/R/expression.R b/r/R/expression.R index 79b4d76e5e8..7c651a46d4c 100644 --- a/r/R/expression.R +++ b/r/R/expression.R @@ -297,7 +297,14 @@ wrap_scalars <- function(args, FUN) { x, options = list(format = "%Y-%m-%d %H:%M:%S", unit = 1L) ) - # TODO: assume_timezone? + # R assumes timestamps without timezone specified are + # local timezone while Arrow assumes UTC. For consistency + # with R behavior, specify local timezone here. + x <- call_function( + "assume_timezone", + x, + options = list(timezone = Sys.timezone()) + ) } } x$cast(to_type) diff --git a/r/tests/testthat/test-dplyr-query.R b/r/tests/testthat/test-dplyr-query.R index 856fd7bdaf3..a79bbea7771 100644 --- a/r/tests/testthat/test-dplyr-query.R +++ b/r/tests/testthat/test-dplyr-query.R @@ -695,9 +695,8 @@ test_that("Scalars in expressions match the type of the field, if possible", { tab %>% filter(times > "2018-10-07 19:04:05") %>% show_exec_plan(), - "times > 2018-10-07 19:04:05" + "times > 2018-10-0. ..:..:05" ) - skip("Timezones?") compare_dplyr_binding( .input %>% filter(times > "2018-10-07 19:04:05") %>% From b554f10b866485bc6864dca235d0977e0323fe2f Mon Sep 17 00:00:00 2001 From: Neal Richardson Date: Wed, 31 Aug 2022 10:24:52 -0400 Subject: [PATCH 5/8] Refactor and resolve some TODOs --- r/R/expression.R | 108 ++++++++++++++++++++++++++--------------------- 1 file changed, 61 insertions(+), 47 deletions(-) diff --git a/r/R/expression.R b/r/R/expression.R index 7c651a46d4c..63744bf918a 100644 --- a/r/R/expression.R +++ b/r/R/expression.R @@ -208,11 +208,15 @@ build_expr <- function(FUN, } if (FUN == "%in%") { # Special-case %in%, which is different from the Array function name + value_set <- Array$create(args[[2]]) + try( + value_set <- cast_or_parse(value_set, args[[1]]$type()), + silent = TRUE + ) + expr <- Expression$create("is_in", args[[1]], options = list( - # If args[[2]] is already an Arrow object (like a scalar), - # this wouldn't work - value_set = Array$create(args[[2]]), + value_set = value_set, skip_nulls = TRUE ) ) @@ -274,56 +278,66 @@ wrap_scalars <- function(args, FUN) { # * %/%: we switch behavior based on int vs. dbl in R (see build_expr) so skip # * binary_repeat, list_element: 2nd arg must be integer, Acero will handle it if (any(is_expr) && !(arrow_fun %in% c("binary_repeat", "list_element")) && !(FUN %in% "%/%")) { - if (sum(is_expr) == 1) { - # Simple case: just one expr so take its type - try( - { - # If the Expression has no Schema embedded, we cannot resolve its - # type here, so this will error, hence the try() wrapping it - to_type <- args[[which(is_expr)]]$type() - # Try casting to this type, but if the cast fails, - # we'll just keep the original - args[!is_expr] <- lapply(args[!is_expr], function(x) { - if (x$type == string()) { - if (to_type == date32()) { - x <- call_function( - "strptime", - x, - options = list(format = "%Y-%m-%d", unit = 0L) - ) - } else if (to_type$id == Type[["TIMESTAMP"]]) { - x <- call_function( - "strptime", - x, - options = list(format = "%Y-%m-%d %H:%M:%S", unit = 1L) - ) - # R assumes timestamps without timezone specified are - # local timezone while Arrow assumes UTC. For consistency - # with R behavior, specify local timezone here. - x <- call_function( - "assume_timezone", - x, - options = list(timezone = Sys.timezone()) - ) - } - } - x$cast(to_type) - }) - }, - silent = TRUE - ) - } else { - # TODO: check if all expr types are the same, and if so, cast to that - # Functions that exercise code that go through here (in our tests): - # * case_when - # * pmin/pmax - } + try( + { + # If the Expression has no Schema embedded, we cannot resolve its + # type here, so this will error, hence the try() wrapping it + # This will also error if length(args[is_expr]) == 0, or + # if there are multiple exprs that do not share a common type. + to_type <- common_type(args[is_expr]) + # Try casting to this type, but if the cast fails, + # we'll just keep the original + args[!is_expr] <- lapply(args[!is_expr], cast_or_parse, type = to_type) + }, + silent = TRUE + ) } args[!is_expr] <- lapply(args[!is_expr], Expression$scalar) args } +common_type <- function(exprs) { + types <- map(exprs, ~ .$type()) + first_type <- types[[1]] + if (length(types) == 1 || all(map_lgl(types, ~ .$Equals(first_type)))) { + # Functions (in our tests) that have multiple exprs to check: + # * case_when + # * pmin/pmax + return(first_type) + } + stop("There is no common type in these expressions") +} + +cast_or_parse <- function(x, type) { + # For most types, just cast. + # But for string -> date/time, we need to call a parsing function + if (x$type == string()) { + if (type == date32()) { + x <- call_function( + "strptime", + x, + options = list(format = "%Y-%m-%d", unit = 0L) + ) + } else if (type$id == Type[["TIMESTAMP"]]) { + x <- call_function( + "strptime", + x, + options = list(format = "%Y-%m-%d %H:%M:%S", unit = 1L) + ) + # R assumes timestamps without timezone specified are + # local timezone while Arrow assumes UTC. For consistency + # with R behavior, specify local timezone here. + x <- call_function( + "assume_timezone", + x, + options = list(timezone = Sys.timezone()) + ) + } + } + x$cast(type) +} + #' @export Ops.Expression <- function(e1, e2) { if (.Generic == "!") { From 97bfc136def387ef28372cfae8315717e92556f3 Mon Sep 17 00:00:00 2001 From: Neal Richardson Date: Sat, 3 Sep 2022 11:05:51 -0400 Subject: [PATCH 6/8] Don't auto-cast to Decimal types to avoid compute bug --- r/R/expression.R | 18 ++++++++++++++++-- r/tests/testthat/test-dplyr-query.R | 14 ++++++++++++++ 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/r/R/expression.R b/r/R/expression.R index 63744bf918a..a172c5cae45 100644 --- a/r/R/expression.R +++ b/r/R/expression.R @@ -310,16 +310,30 @@ common_type <- function(exprs) { } cast_or_parse <- function(x, type) { + type_id <- type$id + if (type_id %in% c(Type[["DECIMAL128"]], Type[["DECIMAL256"]])) { + # TODO: determine the minimum size of decimal (or integer) required to + # accommodate x + # We would like to keep calculations on decimal if that's what the data has + # so that we don't lose precision. However, there are some limitations + # today, so it makes sense to keep x as double (which is probably is from R) + # and let Acero cast the decimal to double to compute. + # You can specify in your query that x should be decimal or integer if you + # know it to be safe. + # * ARROW-17601: multiply(decimal, decimal) can fail to make output type + return(x) + } + # For most types, just cast. # But for string -> date/time, we need to call a parsing function if (x$type == string()) { - if (type == date32()) { + if (type_id %in% c(Type[["DATE32"]], Type[["DATE64"]])) { x <- call_function( "strptime", x, options = list(format = "%Y-%m-%d", unit = 0L) ) - } else if (type$id == Type[["TIMESTAMP"]]) { + } else if (type_id == Type[["TIMESTAMP"]]) { x <- call_function( "strptime", x, diff --git a/r/tests/testthat/test-dplyr-query.R b/r/tests/testthat/test-dplyr-query.R index a79bbea7771..891e4db54ee 100644 --- a/r/tests/testthat/test-dplyr-query.R +++ b/r/tests/testthat/test-dplyr-query.R @@ -703,4 +703,18 @@ test_that("Scalars in expressions match the type of the field, if possible", { collect(), tbl_with_datetime ) + + tab_with_decimal <- tab %>% + mutate(dec = cast(dbl, decimal(15, 2))) %>% + compute() + + # This reproduces the issue on ARROW-17601, found in the TPC-H query 1 + # In ARROW-17462, we chose not to auto-cast to decimal to avoid that issue + result <- tab_with_decimal %>% + summarize( + tpc_h_1 = sum(dec * (1 - dec) * (1 + dec), na.rm = TRUE), + as_dbl = sum(dbl * (1 - dbl) * (1 + dbl), na.rm = TRUE) + ) %>% + collect() + expect_equal(result$tpc_h_1, result$as_dbl) }) From 2b2e15bb947578454803f931fc75d73cead8366b Mon Sep 17 00:00:00 2001 From: Neal Richardson Date: Sat, 3 Sep 2022 17:18:23 -0400 Subject: [PATCH 7/8] Avoid some cpp calls --- r/R/expression.R | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/r/R/expression.R b/r/R/expression.R index a172c5cae45..8171f06d979 100644 --- a/r/R/expression.R +++ b/r/R/expression.R @@ -187,7 +187,10 @@ Expression$field_ref <- function(name) { compute___expr__field_ref(name) } Expression$scalar <- function(x) { - expr <- compute___expr__scalar(Scalar$create(x)) + if (!inherits(x, "Scalar")) { + x <- Scalar$create(x) + } + expr <- compute___expr__scalar(x) expr$schema <- schema() expr } @@ -310,8 +313,8 @@ common_type <- function(exprs) { } cast_or_parse <- function(x, type) { - type_id <- type$id - if (type_id %in% c(Type[["DECIMAL128"]], Type[["DECIMAL256"]])) { + to_type_id <- type$id + if (to_type_id %in% c(Type[["DECIMAL128"]], Type[["DECIMAL256"]])) { # TODO: determine the minimum size of decimal (or integer) required to # accommodate x # We would like to keep calculations on decimal if that's what the data has @@ -326,14 +329,14 @@ cast_or_parse <- function(x, type) { # For most types, just cast. # But for string -> date/time, we need to call a parsing function - if (x$type == string()) { - if (type_id %in% c(Type[["DATE32"]], Type[["DATE64"]])) { + if (x$type_id() %in% c(Type[["STRING"]], Type[["LARGE_STRING"]])) { + if (to_type_id %in% c(Type[["DATE32"]], Type[["DATE64"]])) { x <- call_function( "strptime", x, options = list(format = "%Y-%m-%d", unit = 0L) ) - } else if (type_id == Type[["TIMESTAMP"]]) { + } else if (to_type_id == Type[["TIMESTAMP"]]) { x <- call_function( "strptime", x, From 30410fae5f39839741de9d1e577d2ac65cb874a4 Mon Sep 17 00:00:00 2001 From: Neal Richardson Date: Tue, 4 Oct 2022 08:32:57 -0400 Subject: [PATCH 8/8] Take UDFs out of build_expr and have Expression ensure Expression inputs --- r/R/compute.R | 2 +- r/R/expression.R | 8 +++++++- r/tests/testthat/test-expression.R | 5 +++-- 3 files changed, 11 insertions(+), 4 deletions(-) diff --git a/r/R/compute.R b/r/R/compute.R index 1386728ac90..6fb62b10a87 100644 --- a/r/R/compute.R +++ b/r/R/compute.R @@ -379,7 +379,7 @@ register_scalar_function <- function(name, fun, in_type, out_type, RegisterScalarUDF(name, scalar_function) # register with dplyr binding (enables its use in mutate(), filter(), etc.) - binding_fun <- function(...) build_expr(name, ...) + binding_fun <- function(...) Expression$create(name, ...) # inject the value of `name` into the expression to avoid saving this # execution environment in the binding, which eliminates a warning when the diff --git a/r/R/expression.R b/r/R/expression.R index 8171f06d979..cdd2ee2596c 100644 --- a/r/R/expression.R +++ b/r/R/expression.R @@ -171,7 +171,13 @@ Expression$create <- function(function_name, args = list(...), options = empty_named_list()) { assert_that(is.string(function_name)) - assert_that(is_list_of(args, "Expression"), msg = "Expression arguments must be Expression objects") + # Make sure all inputs are Expressions + args <- lapply(args, function(x) { + if (!inherits(x, "Expression")) { + x <- Expression$scalar(x) + } + x + }) expr <- compute___expr__call(function_name, args, options) if (length(args)) { expr$schema <- unify_schemas(schemas = lapply(args, function(x) x$schema)) diff --git a/r/tests/testthat/test-expression.R b/r/tests/testthat/test-expression.R index c4aab718d90..2b6039b04ce 100644 --- a/r/tests/testthat/test-expression.R +++ b/r/tests/testthat/test-expression.R @@ -58,9 +58,10 @@ test_that("C++ expressions", { # Interprets that as a list type expect_r6_class(f == c(1L, 2L), "Expression") - expect_error( + # Non-Expression inputs are wrapped in Expression$scalar() + expect_equal( Expression$create("add", 1, 2), - "Expression arguments must be Expression objects" + Expression$create("add", Expression$scalar(1), Expression$scalar(2)) ) })