Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion r/R/compute.R
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ register_scalar_function <- function(name, fun, in_type, out_type,
RegisterScalarUDF(name, scalar_function)

# register with dplyr binding (enables its use in mutate(), filter(), etc.)
binding_fun <- function(...) build_expr(name, ...)
binding_fun <- function(...) Expression$create(name, ...)

# inject the value of `name` into the expression to avoid saving this
# execution environment in the binding, which eliminates a warning when the
Expand Down
125 changes: 114 additions & 11 deletions r/R/expression.R
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,13 @@ Expression$create <- function(function_name,
args = list(...),
options = empty_named_list()) {
assert_that(is.string(function_name))
assert_that(is_list_of(args, "Expression"), msg = "Expression arguments must be Expression objects")
# Make sure all inputs are Expressions
args <- lapply(args, function(x) {
if (!inherits(x, "Expression")) {
x <- Expression$scalar(x)
}
x
})
expr <- compute___expr__call(function_name, args, options)
if (length(args)) {
expr$schema <- unify_schemas(schemas = lapply(args, function(x) x$schema))
Expand All @@ -187,7 +193,10 @@ Expression$field_ref <- function(name) {
compute___expr__field_ref(name)
}
Expression$scalar <- function(x) {
expr <- compute___expr__scalar(Scalar$create(x))
if (!inherits(x, "Scalar")) {
x <- Scalar$create(x)
}
expr <- compute___expr__scalar(x)
expr$schema <- schema()
expr
}
Expand All @@ -208,21 +217,20 @@ build_expr <- function(FUN,
}
if (FUN == "%in%") {
# Special-case %in%, which is different from the Array function name
value_set <- Array$create(args[[2]])
try(
value_set <- cast_or_parse(value_set, args[[1]]$type()),
silent = TRUE
)

expr <- Expression$create("is_in", args[[1]],
options = list(
# If args[[2]] is already an Arrow object (like a scalar),
# this wouldn't work
value_set = Array$create(args[[2]]),
value_set = value_set,
skip_nulls = TRUE
)
)
} else {
args <- lapply(args, function(x) {
if (!inherits(x, "Expression")) {
x <- Expression$scalar(x)
}
x
})
args <- wrap_scalars(args, FUN)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you want to whitelist the functions this applies to? (Or maybe you already do this and I'm not reading this correctly?) This logic is awesome and very appropriate for most math functions but I wonder if there are some compute functions (maybe binary_repeat) that will stop working when used with build_expr(). I think that user-defined functions also generate their bindings through build_expr() (although they don't have to).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a blocklist rather than an allowlist, and binary_repeat is on it (L285, below). If there are compute functions that don't work with this change, we don't test them.

Do you think we should exclude UDFs from the type matching too?

For functions that do go through build_expr(), the way to skip the type-matching logic is to wrap the value in Expression$scalar(). Only non-Expressions are cast.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really think you should whitelist here...in theory one can use build_expr() for any compute function, although many bindings choose to go directly through Expression$create() instead. Using a blocklist would mean you can only use build_expr() safely for specific functions (in which case you should probably compute what those functions are so that can be documented).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I went through the function list on https://arrow.apache.org/docs/cpp/compute.html and evaluated whether you should try to cast scalar inputs to the type of the corresponding column (and remember, if you can't cast the scalar without loss of precision, it doesn't add the cast, so for int + 4.2, 4.2 won't get cast to int so that will go to cast(int, float64) + 4.2 in Acero). For the non-unary scalar functions, all but 4 make sense to try to convert scalars like this. The 4 functions that don't are binary_repeat, list_element, binary_join (kind of an odd case, which we don't use, we use binary_join_element_wise instead), and make_struct. It's around 40-50 functions on the allow side, so it seems that the "don't cast" functions are the exception.

Does that persuade you in favor of blocklist instead of allowlist?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still think a whitelist is safer, although feel free to make the change. The build_expr() in the user-defined function code ( https://github.com/apache/arrow/blob/master/r/R/compute.R#L384 ) would have to change to something approaching the previous behaviour since we have no guarantees about those functions.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just pulled UDFs out of build_expr in d54de48, and in a followup I'll go further to reduce the usage of build_expr to places where the type matching matters (more of an allowlist, in that sense), pull out the special cases inside of it, and rename it to something like build_simple_expr to make clear that it is a special case and not the default path you should choose.

Sound ok to you?


# In Arrow, "divide" is one function, which does integer division on
# integer inputs and floating-point division on floats
Expand Down Expand Up @@ -258,6 +266,101 @@ build_expr <- function(FUN,
expr
}

wrap_scalars <- function(args, FUN) {
arrow_fun <- .array_function_map[[FUN]] %||% FUN
if (arrow_fun == "if_else") {
# For if_else, the first arg should be a bool Expression, and we don't
# want to consider that when casting the other args to the same type
args[-1] <- wrap_scalars(args[-1], FUN = "")
return(args)
}

is_expr <- map_lgl(args, ~ inherits(., "Expression"))
if (all(is_expr)) {
# No wrapping is required
return(args)
}

args[!is_expr] <- lapply(args[!is_expr], Scalar$create)

# Some special casing by function
# * %/%: we switch behavior based on int vs. dbl in R (see build_expr) so skip
# * binary_repeat, list_element: 2nd arg must be integer, Acero will handle it
if (any(is_expr) && !(arrow_fun %in% c("binary_repeat", "list_element")) && !(FUN %in% "%/%")) {
try(
{
# If the Expression has no Schema embedded, we cannot resolve its
# type here, so this will error, hence the try() wrapping it
# This will also error if length(args[is_expr]) == 0, or
# if there are multiple exprs that do not share a common type.
to_type <- common_type(args[is_expr])
# Try casting to this type, but if the cast fails,
# we'll just keep the original
args[!is_expr] <- lapply(args[!is_expr], cast_or_parse, type = to_type)
},
silent = TRUE
)
}

args[!is_expr] <- lapply(args[!is_expr], Expression$scalar)
args
}

common_type <- function(exprs) {
types <- map(exprs, ~ .$type())
first_type <- types[[1]]
if (length(types) == 1 || all(map_lgl(types, ~ .$Equals(first_type)))) {
# Functions (in our tests) that have multiple exprs to check:
# * case_when
# * pmin/pmax
return(first_type)
}
stop("There is no common type in these expressions")
}

cast_or_parse <- function(x, type) {
to_type_id <- type$id
if (to_type_id %in% c(Type[["DECIMAL128"]], Type[["DECIMAL256"]])) {
# TODO: determine the minimum size of decimal (or integer) required to
# accommodate x
# We would like to keep calculations on decimal if that's what the data has
# so that we don't lose precision. However, there are some limitations
# today, so it makes sense to keep x as double (which is probably is from R)
# and let Acero cast the decimal to double to compute.
# You can specify in your query that x should be decimal or integer if you
# know it to be safe.
# * ARROW-17601: multiply(decimal, decimal) can fail to make output type
return(x)
}

# For most types, just cast.
# But for string -> date/time, we need to call a parsing function
if (x$type_id() %in% c(Type[["STRING"]], Type[["LARGE_STRING"]])) {
if (to_type_id %in% c(Type[["DATE32"]], Type[["DATE64"]])) {
x <- call_function(
"strptime",
x,
options = list(format = "%Y-%m-%d", unit = 0L)
)
} else if (to_type_id == Type[["TIMESTAMP"]]) {
x <- call_function(
"strptime",
x,
options = list(format = "%Y-%m-%d %H:%M:%S", unit = 1L)
)
# R assumes timestamps without timezone specified are
# local timezone while Arrow assumes UTC. For consistency
# with R behavior, specify local timezone here.
x <- call_function(
"assume_timezone",
x,
options = list(timezone = Sys.timezone())
)
}
}
x$cast(type)
}

#' @export
Ops.Expression <- function(e1, e2) {
if (.Generic == "!") {
Expand Down
6 changes: 3 additions & 3 deletions r/tests/testthat/test-dataset-dplyr.R
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ test_that("mutate()", {
chr: string
dbl: double
int: int32
twice: double (multiply_checked(int, 2))
twice: int32 (multiply_checked(int, 2))

* Filter: ((multiply_checked(dbl, 2) > 14) and (subtract_checked(dbl, 50) < 3))
See $.data for the source Arrow object",
Expand Down Expand Up @@ -219,7 +219,7 @@ test_that("arrange()", {
chr: string
dbl: double
int: int32
twice: double (multiply_checked(int, 2))
twice: int32 (multiply_checked(int, 2))

* Filter: ((multiply_checked(dbl, 2) > 14) and (subtract_checked(dbl, 50) < 3))
* Sorted by chr [asc], multiply_checked(int, 2) [desc], add_checked(dbl, int) [asc]
Expand Down Expand Up @@ -368,7 +368,7 @@ test_that("show_exec_plan(), show_query() and explain() with datasets", {
"ExecPlan with .* nodes:.*", # boiler plate for ExecPlan
"ProjectNode.*", # output columns
"FilterNode.*", # filter node
"int > 6.*cast.*", # filtering expressions + auto-casting of part
"int > 6.*", # filtering expressions
"SourceNode" # entry point
)
)
Expand Down
4 changes: 2 additions & 2 deletions r/tests/testthat/test-dplyr-collapse.R
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ test_that("implicit_schema with mutate", {
words = as.character(int)
) %>%
implicit_schema(),
schema(numbers = float64(), words = utf8())
schema(numbers = int32(), words = utf8())
)
})

Expand Down Expand Up @@ -163,7 +163,7 @@ test_that("Properties of collapsed query", {
"Table (query)
lgl: bool
total: int32
extra: double (multiply_checked(total, 5))
extra: int32 (multiply_checked(total, 5))

See $.data for the source Arrow object",
fixed = TRUE
Expand Down
30 changes: 17 additions & 13 deletions r/tests/testthat/test-dplyr-filter.R
Original file line number Diff line number Diff line change
Expand Up @@ -217,25 +217,29 @@ test_that("filter() with between()", {
filter(dbl >= int, dbl <= dbl2)
)

expect_error(
tbl %>%
record_batch() %>%
compare_dplyr_binding(
.input %>%
filter(between(dbl, 1, "2")) %>%
collect()
collect(),
tbl
)

expect_error(
tbl %>%
record_batch() %>%
compare_dplyr_binding(
.input %>%
filter(between(dbl, 1, NA)) %>%
collect()
collect(),
tbl
)

expect_error(
tbl %>%
record_batch() %>%
filter(between(chr, 1, 2)) %>%
collect()
expect_warning(
compare_dplyr_binding(
.input %>%
filter(between(chr, 1, 2)) %>%
collect(),
tbl
),
# the dplyr version warns:
"NAs introduced by coercion"
)
})

Expand Down
2 changes: 1 addition & 1 deletion r/tests/testthat/test-dplyr-mutate.R
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ test_that("print a mutated table", {
print(),
"Table (query)
int: int32
twice: double (multiply_checked(int, 2))
twice: int32 (multiply_checked(int, 2))

See $.data for the source Arrow object",
fixed = TRUE
Expand Down
87 changes: 87 additions & 0 deletions r/tests/testthat/test-dplyr-query.R
Original file line number Diff line number Diff line change
Expand Up @@ -631,3 +631,90 @@ test_that("collect() is identical to compute() %>% collect()", {
collect()
)
})

test_that("Scalars in expressions match the type of the field, if possible", {
tbl_with_datetime <- tbl
tbl_with_datetime$dates <- as.Date("2022-08-28") + 1:10
tbl_with_datetime$times <- lubridate::ymd_hms("2018-10-07 19:04:05") + 1:10
tab <- Table$create(tbl_with_datetime)

# 5 is double in R but is properly interpreted as int, no cast is added
expect_output(
tab %>%
filter(int == 5) %>%
show_exec_plan(),
"int == 5"
)

# Because 5.2 can't cast to int32 without truncation, we pass as is
# and Acero will cast int to float64
expect_output(
tab %>%
filter(int == 5.2) %>%
show_exec_plan(),
"filter=(cast(int, {to_type=double",
fixed = TRUE
)
expect_equal(
tab %>%
filter(int == 5.2) %>%
nrow(),
0
)

# int == string, this works in dplyr and here too
expect_output(
tab %>%
filter(int == "5") %>%
show_exec_plan(),
"int == 5",
fixed = TRUE
)
expect_equal(
tab %>%
filter(int == "5") %>%
nrow(),
1
)

# Strings automatically parsed to date/timestamp
expect_output(
tab %>%
filter(dates > "2022-09-01") %>%
show_exec_plan(),
"dates > 2022-09-01"
)
compare_dplyr_binding(
.input %>%
filter(dates > "2022-09-01") %>%
collect(),
tbl_with_datetime
)

expect_output(
tab %>%
filter(times > "2018-10-07 19:04:05") %>%
show_exec_plan(),
"times > 2018-10-0. ..:..:05"
)
compare_dplyr_binding(
.input %>%
filter(times > "2018-10-07 19:04:05") %>%
collect(),
tbl_with_datetime
)

tab_with_decimal <- tab %>%
mutate(dec = cast(dbl, decimal(15, 2))) %>%
compute()

# This reproduces the issue on ARROW-17601, found in the TPC-H query 1
# In ARROW-17462, we chose not to auto-cast to decimal to avoid that issue
result <- tab_with_decimal %>%
summarize(
tpc_h_1 = sum(dec * (1 - dec) * (1 + dec), na.rm = TRUE),
as_dbl = sum(dbl * (1 - dbl) * (1 + dbl), na.rm = TRUE)
) %>%
collect()
expect_equal(result$tpc_h_1, result$as_dbl)
})
5 changes: 3 additions & 2 deletions r/tests/testthat/test-expression.R
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,10 @@ test_that("C++ expressions", {
# Interprets that as a list type
expect_r6_class(f == c(1L, 2L), "Expression")

expect_error(
# Non-Expression inputs are wrapped in Expression$scalar()
expect_equal(
Expression$create("add", 1, 2),
"Expression arguments must be Expression objects"
Expression$create("add", Expression$scalar(1), Expression$scalar(2))
)
})

Expand Down