diff --git a/r/R/compute.R b/r/R/compute.R index 1386728ac90..6fb62b10a87 100644 --- a/r/R/compute.R +++ b/r/R/compute.R @@ -379,7 +379,7 @@ register_scalar_function <- function(name, fun, in_type, out_type, RegisterScalarUDF(name, scalar_function) # register with dplyr binding (enables its use in mutate(), filter(), etc.) - binding_fun <- function(...) build_expr(name, ...) + binding_fun <- function(...) Expression$create(name, ...) # inject the value of `name` into the expression to avoid saving this # execution environment in the binding, which eliminates a warning when the diff --git a/r/R/expression.R b/r/R/expression.R index 99296aab399..cdd2ee2596c 100644 --- a/r/R/expression.R +++ b/r/R/expression.R @@ -171,7 +171,13 @@ Expression$create <- function(function_name, args = list(...), options = empty_named_list()) { assert_that(is.string(function_name)) - assert_that(is_list_of(args, "Expression"), msg = "Expression arguments must be Expression objects") + # Make sure all inputs are Expressions + args <- lapply(args, function(x) { + if (!inherits(x, "Expression")) { + x <- Expression$scalar(x) + } + x + }) expr <- compute___expr__call(function_name, args, options) if (length(args)) { expr$schema <- unify_schemas(schemas = lapply(args, function(x) x$schema)) @@ -187,7 +193,10 @@ Expression$field_ref <- function(name) { compute___expr__field_ref(name) } Expression$scalar <- function(x) { - expr <- compute___expr__scalar(Scalar$create(x)) + if (!inherits(x, "Scalar")) { + x <- Scalar$create(x) + } + expr <- compute___expr__scalar(x) expr$schema <- schema() expr } @@ -208,21 +217,20 @@ build_expr <- function(FUN, } if (FUN == "%in%") { # Special-case %in%, which is different from the Array function name + value_set <- Array$create(args[[2]]) + try( + value_set <- cast_or_parse(value_set, args[[1]]$type()), + silent = TRUE + ) + expr <- Expression$create("is_in", args[[1]], options = list( - # If args[[2]] is already an Arrow object (like a scalar), - # this wouldn't work - value_set = Array$create(args[[2]]), + value_set = value_set, skip_nulls = TRUE ) ) } else { - args <- lapply(args, function(x) { - if (!inherits(x, "Expression")) { - x <- Expression$scalar(x) - } - x - }) + args <- wrap_scalars(args, FUN) # In Arrow, "divide" is one function, which does integer division on # integer inputs and floating-point division on floats @@ -258,6 +266,101 @@ build_expr <- function(FUN, expr } +wrap_scalars <- function(args, FUN) { + arrow_fun <- .array_function_map[[FUN]] %||% FUN + if (arrow_fun == "if_else") { + # For if_else, the first arg should be a bool Expression, and we don't + # want to consider that when casting the other args to the same type + args[-1] <- wrap_scalars(args[-1], FUN = "") + return(args) + } + + is_expr <- map_lgl(args, ~ inherits(., "Expression")) + if (all(is_expr)) { + # No wrapping is required + return(args) + } + + args[!is_expr] <- lapply(args[!is_expr], Scalar$create) + + # Some special casing by function + # * %/%: we switch behavior based on int vs. dbl in R (see build_expr) so skip + # * binary_repeat, list_element: 2nd arg must be integer, Acero will handle it + if (any(is_expr) && !(arrow_fun %in% c("binary_repeat", "list_element")) && !(FUN %in% "%/%")) { + try( + { + # If the Expression has no Schema embedded, we cannot resolve its + # type here, so this will error, hence the try() wrapping it + # This will also error if length(args[is_expr]) == 0, or + # if there are multiple exprs that do not share a common type. + to_type <- common_type(args[is_expr]) + # Try casting to this type, but if the cast fails, + # we'll just keep the original + args[!is_expr] <- lapply(args[!is_expr], cast_or_parse, type = to_type) + }, + silent = TRUE + ) + } + + args[!is_expr] <- lapply(args[!is_expr], Expression$scalar) + args +} + +common_type <- function(exprs) { + types <- map(exprs, ~ .$type()) + first_type <- types[[1]] + if (length(types) == 1 || all(map_lgl(types, ~ .$Equals(first_type)))) { + # Functions (in our tests) that have multiple exprs to check: + # * case_when + # * pmin/pmax + return(first_type) + } + stop("There is no common type in these expressions") +} + +cast_or_parse <- function(x, type) { + to_type_id <- type$id + if (to_type_id %in% c(Type[["DECIMAL128"]], Type[["DECIMAL256"]])) { + # TODO: determine the minimum size of decimal (or integer) required to + # accommodate x + # We would like to keep calculations on decimal if that's what the data has + # so that we don't lose precision. However, there are some limitations + # today, so it makes sense to keep x as double (which is probably is from R) + # and let Acero cast the decimal to double to compute. + # You can specify in your query that x should be decimal or integer if you + # know it to be safe. + # * ARROW-17601: multiply(decimal, decimal) can fail to make output type + return(x) + } + + # For most types, just cast. + # But for string -> date/time, we need to call a parsing function + if (x$type_id() %in% c(Type[["STRING"]], Type[["LARGE_STRING"]])) { + if (to_type_id %in% c(Type[["DATE32"]], Type[["DATE64"]])) { + x <- call_function( + "strptime", + x, + options = list(format = "%Y-%m-%d", unit = 0L) + ) + } else if (to_type_id == Type[["TIMESTAMP"]]) { + x <- call_function( + "strptime", + x, + options = list(format = "%Y-%m-%d %H:%M:%S", unit = 1L) + ) + # R assumes timestamps without timezone specified are + # local timezone while Arrow assumes UTC. For consistency + # with R behavior, specify local timezone here. + x <- call_function( + "assume_timezone", + x, + options = list(timezone = Sys.timezone()) + ) + } + } + x$cast(type) +} + #' @export Ops.Expression <- function(e1, e2) { if (.Generic == "!") { diff --git a/r/tests/testthat/test-dataset-dplyr.R b/r/tests/testthat/test-dataset-dplyr.R index 04b1a460dc4..44e038586ea 100644 --- a/r/tests/testthat/test-dataset-dplyr.R +++ b/r/tests/testthat/test-dataset-dplyr.R @@ -143,7 +143,7 @@ test_that("mutate()", { chr: string dbl: double int: int32 -twice: double (multiply_checked(int, 2)) +twice: int32 (multiply_checked(int, 2)) * Filter: ((multiply_checked(dbl, 2) > 14) and (subtract_checked(dbl, 50) < 3)) See $.data for the source Arrow object", @@ -219,7 +219,7 @@ test_that("arrange()", { chr: string dbl: double int: int32 -twice: double (multiply_checked(int, 2)) +twice: int32 (multiply_checked(int, 2)) * Filter: ((multiply_checked(dbl, 2) > 14) and (subtract_checked(dbl, 50) < 3)) * Sorted by chr [asc], multiply_checked(int, 2) [desc], add_checked(dbl, int) [asc] @@ -368,7 +368,7 @@ test_that("show_exec_plan(), show_query() and explain() with datasets", { "ExecPlan with .* nodes:.*", # boiler plate for ExecPlan "ProjectNode.*", # output columns "FilterNode.*", # filter node - "int > 6.*cast.*", # filtering expressions + auto-casting of part + "int > 6.*", # filtering expressions "SourceNode" # entry point ) ) diff --git a/r/tests/testthat/test-dplyr-collapse.R b/r/tests/testthat/test-dplyr-collapse.R index 1809cb6e388..6c5f4c19911 100644 --- a/r/tests/testthat/test-dplyr-collapse.R +++ b/r/tests/testthat/test-dplyr-collapse.R @@ -57,7 +57,7 @@ test_that("implicit_schema with mutate", { words = as.character(int) ) %>% implicit_schema(), - schema(numbers = float64(), words = utf8()) + schema(numbers = int32(), words = utf8()) ) }) @@ -163,7 +163,7 @@ test_that("Properties of collapsed query", { "Table (query) lgl: bool total: int32 -extra: double (multiply_checked(total, 5)) +extra: int32 (multiply_checked(total, 5)) See $.data for the source Arrow object", fixed = TRUE diff --git a/r/tests/testthat/test-dplyr-filter.R b/r/tests/testthat/test-dplyr-filter.R index 21a78ee06e4..281ae2abf02 100644 --- a/r/tests/testthat/test-dplyr-filter.R +++ b/r/tests/testthat/test-dplyr-filter.R @@ -217,25 +217,29 @@ test_that("filter() with between()", { filter(dbl >= int, dbl <= dbl2) ) - expect_error( - tbl %>% - record_batch() %>% + compare_dplyr_binding( + .input %>% filter(between(dbl, 1, "2")) %>% - collect() + collect(), + tbl ) - expect_error( - tbl %>% - record_batch() %>% + compare_dplyr_binding( + .input %>% filter(between(dbl, 1, NA)) %>% - collect() + collect(), + tbl ) - expect_error( - tbl %>% - record_batch() %>% - filter(between(chr, 1, 2)) %>% - collect() + expect_warning( + compare_dplyr_binding( + .input %>% + filter(between(chr, 1, 2)) %>% + collect(), + tbl + ), + # the dplyr version warns: + "NAs introduced by coercion" ) }) diff --git a/r/tests/testthat/test-dplyr-mutate.R b/r/tests/testthat/test-dplyr-mutate.R index ee13c8be2e3..5d431089ce7 100644 --- a/r/tests/testthat/test-dplyr-mutate.R +++ b/r/tests/testthat/test-dplyr-mutate.R @@ -458,7 +458,7 @@ test_that("print a mutated table", { print(), "Table (query) int: int32 -twice: double (multiply_checked(int, 2)) +twice: int32 (multiply_checked(int, 2)) See $.data for the source Arrow object", fixed = TRUE diff --git a/r/tests/testthat/test-dplyr-query.R b/r/tests/testthat/test-dplyr-query.R index ef9a9bcdc14..891e4db54ee 100644 --- a/r/tests/testthat/test-dplyr-query.R +++ b/r/tests/testthat/test-dplyr-query.R @@ -631,3 +631,90 @@ test_that("collect() is identical to compute() %>% collect()", { collect() ) }) + +test_that("Scalars in expressions match the type of the field, if possible", { + tbl_with_datetime <- tbl + tbl_with_datetime$dates <- as.Date("2022-08-28") + 1:10 + tbl_with_datetime$times <- lubridate::ymd_hms("2018-10-07 19:04:05") + 1:10 + tab <- Table$create(tbl_with_datetime) + + # 5 is double in R but is properly interpreted as int, no cast is added + expect_output( + tab %>% + filter(int == 5) %>% + show_exec_plan(), + "int == 5" + ) + + # Because 5.2 can't cast to int32 without truncation, we pass as is + # and Acero will cast int to float64 + expect_output( + tab %>% + filter(int == 5.2) %>% + show_exec_plan(), + "filter=(cast(int, {to_type=double", + fixed = TRUE + ) + expect_equal( + tab %>% + filter(int == 5.2) %>% + nrow(), + 0 + ) + + # int == string, this works in dplyr and here too + expect_output( + tab %>% + filter(int == "5") %>% + show_exec_plan(), + "int == 5", + fixed = TRUE + ) + expect_equal( + tab %>% + filter(int == "5") %>% + nrow(), + 1 + ) + + # Strings automatically parsed to date/timestamp + expect_output( + tab %>% + filter(dates > "2022-09-01") %>% + show_exec_plan(), + "dates > 2022-09-01" + ) + compare_dplyr_binding( + .input %>% + filter(dates > "2022-09-01") %>% + collect(), + tbl_with_datetime + ) + + expect_output( + tab %>% + filter(times > "2018-10-07 19:04:05") %>% + show_exec_plan(), + "times > 2018-10-0. ..:..:05" + ) + compare_dplyr_binding( + .input %>% + filter(times > "2018-10-07 19:04:05") %>% + collect(), + tbl_with_datetime + ) + + tab_with_decimal <- tab %>% + mutate(dec = cast(dbl, decimal(15, 2))) %>% + compute() + + # This reproduces the issue on ARROW-17601, found in the TPC-H query 1 + # In ARROW-17462, we chose not to auto-cast to decimal to avoid that issue + result <- tab_with_decimal %>% + summarize( + tpc_h_1 = sum(dec * (1 - dec) * (1 + dec), na.rm = TRUE), + as_dbl = sum(dbl * (1 - dbl) * (1 + dbl), na.rm = TRUE) + ) %>% + collect() + expect_equal(result$tpc_h_1, result$as_dbl) +}) diff --git a/r/tests/testthat/test-expression.R b/r/tests/testthat/test-expression.R index c4aab718d90..2b6039b04ce 100644 --- a/r/tests/testthat/test-expression.R +++ b/r/tests/testthat/test-expression.R @@ -58,9 +58,10 @@ test_that("C++ expressions", { # Interprets that as a list type expect_r6_class(f == c(1L, 2L), "Expression") - expect_error( + # Non-Expression inputs are wrapped in Expression$scalar() + expect_equal( Expression$create("add", 1, 2), - "Expression arguments must be Expression objects" + Expression$create("add", Expression$scalar(1), Expression$scalar(2)) ) })