diff --git a/r/NEWS.md b/r/NEWS.md index ebf80ee1d81..a606c03b9cf 100644 --- a/r/NEWS.md +++ b/r/NEWS.md @@ -29,6 +29,7 @@ ## Enhancements +* Arithmetic operations (`+`, `*`, etc.) are supported on Arrays and ChunkedArrays and can be used in filter expressions in Arrow `dplyr` pipelines * Table columns can now be added, replaced, or removed by assigning (`<-`) with either `$` or `[[` * Column names of Tables and RecordBatches can be renamed by assigning `names()` * Large string types can now be written to Parquet files diff --git a/r/R/arrowExports.R b/r/R/arrowExports.R index 11fd99b2321..c41ef33cb69 100644 --- a/r/R/arrowExports.R +++ b/r/R/arrowExports.R @@ -1412,10 +1412,6 @@ Scalar__ToString <- function(s){ .Call(`_arrow_Scalar__ToString` , s) } -Scalar__CastTo <- function(s, t){ - .Call(`_arrow_Scalar__CastTo` , s, t) -} - StructScalar__field <- function(s, i){ .Call(`_arrow_StructScalar__field` , s, i) } diff --git a/r/R/expression.R b/r/R/expression.R index f9e09c2fadd..9a5e575183d 100644 --- a/r/R/expression.R +++ b/r/R/expression.R @@ -57,21 +57,53 @@ build_array_expression <- function(.Generic, e1, e2, ...) { if (.Generic %in% names(.unary_function_map)) { expr <- array_expression(.unary_function_map[[.Generic]], e1) } else { - e1 <- .wrap_arrow(e1, .Generic, e2$type) - e2 <- .wrap_arrow(e2, .Generic, e1$type) + e1 <- .wrap_arrow(e1, .Generic) + e2 <- .wrap_arrow(e2, .Generic) + + # In Arrow, "divide" is one function, which does integer division on + # integer inputs and floating-point division on floats + if (.Generic == "/") { + # TODO: omg so many ways it's wrong to assume these types + e1 <- cast_array_expression(e1, float64()) + e2 <- cast_array_expression(e2, float64()) + } else if (.Generic == "%/%") { + # In R, integer division works like floor(float division) + out <- build_array_expression("/", e1, e2) + return(cast_array_expression(out, int32(), allow_float_truncate = TRUE)) + } else if (.Generic == "%%") { + # {e1 - e2 * ( e1 %/% e2 )} + # ^^^ form doesn't work because Ops.Array evaluates eagerly, + # but we can build that up + quotient <- build_array_expression("%/%", e1, e2) + # this cast is to ensure that the result of this and e1 are the same + # (autocasting only applies to scalars) + base <- cast_array_expression(quotient * e2, e1$type) + return(build_array_expression("-", e1, base)) + } + expr <- array_expression(.binary_function_map[[.Generic]], e1, e2, ...) } expr } -.wrap_arrow <- function(arg, fun, type) { +cast_array_expression <- function(x, to_type, safe = TRUE, ...) { + opts <- list( + to_type = to_type, + allow_int_overflow = !safe, + allow_time_truncate = !safe, + allow_float_truncate = !safe + ) + array_expression("cast", x, options = modifyList(opts, list(...))) +} + +.wrap_arrow <- function(arg, fun) { if (!inherits(arg, c("ArrowObject", "array_expression"))) { # TODO: Array$create if lengths are equal? # TODO: these kernels should autocast like the dataset ones do (e.g. int vs. float) if (fun == "%in%") { - arg <- Array$create(arg, type = type) + arg <- Array$create(arg) } else { - arg <- Scalar$create(arg, type = type) + arg <- Scalar$create(arg) } } arg @@ -91,6 +123,15 @@ build_array_expression <- function(.Generic, e1, e2, ...) { "<=" = "less_equal", "&" = "and_kleene", "|" = "or_kleene", + "+" = "add_checked", + "-" = "subtract_checked", + "*" = "multiply_checked", + "/" = "divide_checked", + "%/%" = "divide_checked", + # we don't actually use divide_checked with `%%`, rather it is rewritten to + # use %/% above. + "%%" = "divide_checked", + # TODO: "^" (ARROW-11070) "%in%" = "is_in_meta_binary" ) @@ -104,6 +145,16 @@ eval_array_expression <- function(x) { a } }) + if (length(x$args) == 2L) { + # Insert implicit casts + if (inherits(x$args[[1]], "Scalar")) { + x$args[[1]] <- x$args[[1]]$cast(x$args[[2]]$type) + } else if (inherits(x$args[[2]], "Scalar")) { + x$args[[2]] <- x$args[[2]]$cast(x$args[[1]]$type) + } else if (x$fun == "is_in_meta_binary" && inherits(x$args[[2]], "Array")) { + x$args[[2]] <- x$args[[2]]$cast(x$args[[1]]$type) + } + } call_function(x$fun, args = x$args, options = x$options %||% empty_named_list()) } @@ -160,7 +211,16 @@ print.array_expression <- function(x, ...) { #' @export Expression <- R6Class("Expression", inherit = ArrowObject, public = list( - ToString = function() dataset___expr__ToString(self) + ToString = function() dataset___expr__ToString(self), + cast = function(to_type, safe = TRUE, ...) { + opts <- list( + to_type = to_type, + allow_int_overflow = !safe, + allow_time_truncate = !safe, + allow_float_truncate = !safe + ) + Expression$create("cast", self, options = modifyList(opts, list(...))) + } ) ) Expression$create <- function(function_name, @@ -196,6 +256,21 @@ build_dataset_expression <- function(.Generic, e1, e2, ...) { if (!inherits(e2, "Expression")) { e2 <- Expression$scalar(e2) } + + # In Arrow, "divide" is one function, which does integer division on + # integer inputs and floating-point division on floats + if (.Generic == "/") { + # TODO: omg so many ways it's wrong to assume these types + e1 <- e1$cast(float64()) + e2 <- e2$cast(float64()) + } else if (.Generic == "%/%") { + # In R, integer division works like floor(float division) + out <- build_dataset_expression("/", e1, e2) + return(out$cast(int32(), allow_float_truncate = TRUE)) + } else if (.Generic == "%%") { + return(e1 - e2 * ( e1 %/% e2 )) + } + expr <- Expression$create(.binary_function_map[[.Generic]], e1, e2, ...) } expr diff --git a/r/R/scalar.R b/r/R/scalar.R index 12f29990e0a..774fe571145 100644 --- a/r/R/scalar.R +++ b/r/R/scalar.R @@ -32,8 +32,14 @@ Scalar <- R6Class("Scalar", # TODO: document the methods public = list( ToString = function() Scalar__ToString(self), - cast = function(target_type) { - Scalar__CastTo(self, as_type(target_type)) + cast = function(target_type, safe = TRUE, ...) { + opts <- list( + to_type = as_type(target_type), + allow_int_overflow = !safe, + allow_time_truncate = !safe, + allow_float_truncate = !safe + ) + call_function("cast", self, options = modifyList(opts, list(...))) }, as_vector = function() Scalar__as_vector(self) ), diff --git a/r/src/arrowExports.cpp b/r/src/arrowExports.cpp index 670bee20665..a66b69d0bab 100644 --- a/r/src/arrowExports.cpp +++ b/r/src/arrowExports.cpp @@ -5543,22 +5543,6 @@ extern "C" SEXP _arrow_Scalar__ToString(SEXP s_sexp){ } #endif -// scalar.cpp -#if defined(ARROW_R_WITH_ARROW) -std::shared_ptr Scalar__CastTo(const std::shared_ptr& s, const std::shared_ptr& t); -extern "C" SEXP _arrow_Scalar__CastTo(SEXP s_sexp, SEXP t_sexp){ -BEGIN_CPP11 - arrow::r::Input&>::type s(s_sexp); - arrow::r::Input&>::type t(t_sexp); - return cpp11::as_sexp(Scalar__CastTo(s, t)); -END_CPP11 -} -#else -extern "C" SEXP _arrow_Scalar__CastTo(SEXP s_sexp, SEXP t_sexp){ - Rf_error("Cannot call Scalar__CastTo(). See https://arrow.apache.org/docs/r/articles/install.html for help installing Arrow C++ libraries. "); -} -#endif - // scalar.cpp #if defined(ARROW_R_WITH_ARROW) std::shared_ptr StructScalar__field(const std::shared_ptr& s, int i); @@ -6617,7 +6601,6 @@ static const R_CallMethodDef CallEntries[] = { { "_arrow_ipc___RecordBatchStreamWriter__Open", (DL_FUNC) &_arrow_ipc___RecordBatchStreamWriter__Open, 4}, { "_arrow_Array__GetScalar", (DL_FUNC) &_arrow_Array__GetScalar, 2}, { "_arrow_Scalar__ToString", (DL_FUNC) &_arrow_Scalar__ToString, 1}, - { "_arrow_Scalar__CastTo", (DL_FUNC) &_arrow_Scalar__CastTo, 2}, { "_arrow_StructScalar__field", (DL_FUNC) &_arrow_StructScalar__field, 2}, { "_arrow_StructScalar__GetFieldByName", (DL_FUNC) &_arrow_StructScalar__GetFieldByName, 2}, { "_arrow_Scalar__as_vector", (DL_FUNC) &_arrow_Scalar__as_vector, 1}, diff --git a/r/src/compute.cpp b/r/src/compute.cpp index 8a75a251d8e..4497f5b59a3 100644 --- a/r/src/compute.cpp +++ b/r/src/compute.cpp @@ -185,6 +185,33 @@ std::shared_ptr make_compute_options( cpp11::as_cpp(options["skip_nulls"])); } + // hacky attempt to pass through to_type and other options + if (func_name == "cast") { + using Options = arrow::compute::CastOptions; + auto out = std::make_shared(true); + SEXP to_type = options["to_type"]; + if (!Rf_isNull(to_type) && cpp11::as_cpp>(to_type)) { + out->to_type = cpp11::as_cpp>(to_type); + } + + SEXP allow_float_truncate = options["allow_float_truncate"]; + if (!Rf_isNull(allow_float_truncate) && cpp11::as_cpp(allow_float_truncate)) { + out->allow_float_truncate = cpp11::as_cpp(allow_float_truncate); + } + + SEXP allow_time_truncate = options["allow_time_truncate"]; + if (!Rf_isNull(allow_time_truncate) && cpp11::as_cpp(allow_time_truncate)) { + out->allow_time_truncate = cpp11::as_cpp(allow_time_truncate); + } + + SEXP allow_int_overflow = options["allow_int_overflow"]; + if (!Rf_isNull(allow_int_overflow) && cpp11::as_cpp(allow_int_overflow)) { + out->allow_int_overflow = cpp11::as_cpp(allow_int_overflow); + } + + return out; + } + return nullptr; } diff --git a/r/src/scalar.cpp b/r/src/scalar.cpp index 2c2d291b5bf..c0cc396b02d 100644 --- a/r/src/scalar.cpp +++ b/r/src/scalar.cpp @@ -47,12 +47,6 @@ std::string Scalar__ToString(const std::shared_ptr& s) { return s->ToString(); } -// [[arrow::export]] -std::shared_ptr Scalar__CastTo(const std::shared_ptr& s, - const std::shared_ptr& t) { - return ValueOrStop(s->CastTo(t)); -} - // [[arrow::export]] std::shared_ptr StructScalar__field( const std::shared_ptr& s, int i) { diff --git a/r/tests/testthat/test-compute-arith.R b/r/tests/testthat/test-compute-arith.R new file mode 100644 index 00000000000..d37367d47c8 --- /dev/null +++ b/r/tests/testthat/test-compute-arith.R @@ -0,0 +1,78 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +test_that("Addition", { + a <- Array$create(c(1:4, NA_integer_)) + expect_type_equal(a, int32()) + expect_type_equal(a + 4, int32()) + expect_equal(a + 4, Array$create(c(5:8, NA_integer_))) + expect_identical(as.vector(a + 4), c(5:8, NA_integer_)) + expect_equal(a + 4L, Array$create(c(5:8, NA_integer_))) + expect_vector(a + 4L, c(5:8, NA_integer_)) + expect_equal(a + NA_integer_, Array$create(rep(NA_integer_, 5))) + + # overflow errors — this is slightly different from R's `NA` coercion when + # overflowing, but better than the alternative of silently restarting + casted <- a$cast(int8()) + expect_error(casted + 127) + expect_error(casted + 200) + + skip("autocasting should happen in compute kernels; R workaround fails on this ARROW-8919") + expect_type_equal(a + 4.1, float64()) + expect_equal(a + 4.1, Array$create(c(5.1, 6.1, 7.1, 8.1, NA_real_))) +}) + +test_that("Subtraction", { + a <- Array$create(c(1:4, NA_integer_)) + expect_equal(a - 3, Array$create(c(-2:1, NA_integer_))) +}) + +test_that("Multiplication", { + a <- Array$create(c(1:4, NA_integer_)) + expect_equal(a * 2, Array$create(c(1:4 * 2L, NA_integer_))) +}) + +test_that("Division", { + a <- Array$create(c(1:4, NA_integer_)) + expect_equal(a / 2, Array$create(c(1:4 / 2, NA_real_))) + expect_equal(a %/% 2, Array$create(c(0L, 1L, 1L, 2L, NA_integer_))) + expect_equal(a / 2 / 2, Array$create(c(1:4 / 2 / 2, NA_real_))) + expect_equal(a %/% 2 %/% 2, Array$create(c(0L, 0L, 0L, 1L, NA_integer_))) + + b <- a$cast(float64()) + expect_equal(b / 2, Array$create(c(1:4 / 2, NA_real_))) + expect_equal(b %/% 2, Array$create(c(0L, 1L, 1L, 2L, NA_integer_))) + + # the behavior of %/% matches R's (i.e. the integer of the quotient, not + # simply dividing two integers) + expect_equal(b / 2.2, Array$create(c(1:4 / 2.2, NA_real_))) + # c(1:4) %/% 2.2 != c(1:4) %/% as.integer(2.2) + # c(1:4) %/% 2.2 == c(0L, 0L, 1L, 1L) + # c(1:4) %/% as.integer(2.2) == c(0L, 1L, 1L, 2L) + expect_equal(b %/% 2.2, Array$create(c(0L, 0L, 1L, 1L, NA_integer_))) + + expect_equal(a %% 2, Array$create(c(1L, 0L, 1L, 0L, NA_integer_))) + + expect_equal(b %% 2, Array$create(c(1:4 %% 2, NA_real_))) +}) + +test_that("Dates casting", { + a <- Array$create(c(Sys.Date() + 1:4, NA_integer_)) + + skip("autocasting should happen in compute kernels; R workaround fails on this ARROW-8919") + expect_equal(a + 2, Array$create(c((Sys.Date() + 1:4 ) + 2), NA_integer_)) +}) diff --git a/r/tests/testthat/test-dataset.R b/r/tests/testthat/test-dataset.R index 4c8db531411..5bdbc42410e 100644 --- a/r/tests/testthat/test-dataset.R +++ b/r/tests/testthat/test-dataset.R @@ -523,6 +523,113 @@ test_that("filter() on date32 columns", { ) }) +test_that("filter() with expressions", { + ds <- open_dataset(dataset_dir, partitioning = schema(part = uint8())) + expect_is(ds$format, "ParquetFileFormat") + expect_is(ds$filesystem, "LocalFileSystem") + expect_is(ds, "Dataset") + expect_equivalent( + ds %>% + select(chr, dbl) %>% + filter(dbl * 2 > 14 & dbl - 50 < 3L) %>% + collect() %>% + arrange(dbl), + rbind( + df1[8:10, c("chr", "dbl")], + df2[1:2, c("chr", "dbl")] + ) + ) + + # check division's special casing. + expect_equivalent( + ds %>% + select(chr, dbl) %>% + filter(dbl / 2 > 3.5 & dbl < 53) %>% + collect() %>% + arrange(dbl), + rbind( + df1[8:10, c("chr", "dbl")], + df2[1:2, c("chr", "dbl")] + ) + ) + + expect_equivalent( + ds %>% + select(chr, dbl, int) %>% + filter(int %/% 2L > 3 & dbl < 53) %>% + collect() %>% + arrange(dbl), + rbind( + df1[8:10, c("chr", "dbl", "int")], + df2[1:2, c("chr", "dbl", "int")] + ) + ) + + expect_equivalent( + ds %>% + select(chr, dbl, int) %>% + filter(int %/% 2 > 3 & dbl < 53) %>% + collect() %>% + arrange(dbl), + rbind( + df1[8:10, c("chr", "dbl", "int")], + df2[1:2, c("chr", "dbl", "int")] + ) + ) + + expect_equivalent( + ds %>% + select(chr, dbl, int) %>% + filter(int %% 2L > 0 & dbl < 53) %>% + collect() %>% + arrange(dbl), + rbind( + df1[c(1, 3, 5, 7, 9), c("chr", "dbl", "int")], + df2[1, c("chr", "dbl", "int")] + ) + ) + + expect_equivalent( + ds %>% + select(chr, dbl, int) %>% + filter(int %% 2L > 0 & dbl < 53) %>% + collect() %>% + arrange(dbl), + rbind( + df1[c(1, 3, 5, 7, 9), c("chr", "dbl", "int")], + df2[1, c("chr", "dbl", "int")] + ) + ) + + skip("Implicit casts aren't being inserted everywhere they need to be (ARROW-8919)") + # Error: NotImplemented: Function multiply_checked has no kernel matching input types (scalar[double], array[int32]) + expect_equivalent( + ds %>% + select(chr, dbl, int) %>% + filter(int %% 2 > 0 & dbl < 53) %>% + collect() %>% + arrange(dbl), + rbind( + df1[c(1, 3, 5, 7, 9), c("chr", "dbl", "int")], + df2[1, c("chr", "dbl", "int")] + ) + ) + + skip("Implicit casts are only inserted for scalars (ARROW-8919)") + # Error: NotImplemented: Function add_checked has no kernel matching input types (array[double], array[int32]) + expect_equivalent( + ds %>% + select(chr, dbl, int) %>% + filter(dbl + int > 15 & dbl < 53L) %>% + collect() %>% + arrange(dbl), + rbind( + df1[8:10, c("chr", "dbl", "int")], + df2[1:2, c("chr", "dbl", "int")] + ) + ) +}) + test_that("filter scalar validation doesn't crash (ARROW-7772)", { expect_error( ds %>% diff --git a/r/tests/testthat/test-dplyr.R b/r/tests/testthat/test-dplyr.R index f9a01d8ceb6..a80e17c6f3e 100644 --- a/r/tests/testthat/test-dplyr.R +++ b/r/tests/testthat/test-dplyr.R @@ -27,6 +27,8 @@ expect_dplyr_equal <- function(expr, # A dplyr pipeline with `input` as its star expr <- rlang::enquo(expr) expected <- rlang::eval_tidy(expr, rlang::new_data_mask(rlang::env(input = tbl))) + skip_msg <- NULL + if (is.null(skip_record_batch)) { via_batch <- rlang::eval_tidy( expr, @@ -34,7 +36,7 @@ expect_dplyr_equal <- function(expr, # A dplyr pipeline with `input` as its star ) expect_equivalent(via_batch, expected, ...) } else { - skip(skip_record_batch) + skip_msg <- c(skip_msg, skip_record_batch) } if (is.null(skip_table)) { @@ -44,7 +46,11 @@ expect_dplyr_equal <- function(expr, # A dplyr pipeline with `input` as its star ) expect_equivalent(via_table, expected, ...) } else { - skip(skip_table) + skip_msg <- c(skip_msg, skip_table) + } + + if (!is.null(skip_msg)) { + skip(paste(skip_msg, collpase = "\n")) } } @@ -133,6 +139,74 @@ test_that("filtering with expression", { ) }) +test_that("filtering with arithmetic", { + expect_dplyr_equal( + input %>% + filter(dbl + 1 > 3) %>% + select(string = chr, int, dbl) %>% + collect(), + tbl + ) + + expect_dplyr_equal( + input %>% + filter(dbl / 2 > 3) %>% + select(string = chr, int, dbl) %>% + collect(), + tbl + ) + + expect_dplyr_equal( + input %>% + filter(dbl / 2L > 3) %>% + select(string = chr, int, dbl) %>% + collect(), + tbl + ) + + expect_dplyr_equal( + input %>% + filter(int / 2 > 3) %>% + select(string = chr, int, dbl) %>% + collect(), + tbl + ) + + expect_dplyr_equal( + input %>% + filter(int / 2L > 3) %>% + select(string = chr, int, dbl) %>% + collect(), + tbl + ) + + expect_dplyr_equal( + input %>% + filter(dbl %/% 2 > 3) %>% + select(string = chr, int, dbl) %>% + collect(), + tbl + ) +}) + +test_that("filtering with expression + autocasting", { + expect_dplyr_equal( + input %>% + filter(dbl + 1 > 3L) %>% # test autocasting with comparison to 3L + select(string = chr, int, dbl) %>% + collect(), + tbl + ) + + expect_dplyr_equal( + input %>% + filter(int + 1 > 3) %>% + select(string = chr, int, dbl) %>% + collect(), + tbl + ) +}) + test_that("More complex select/filter", { expect_dplyr_equal( input %>% @@ -167,7 +241,7 @@ test_that("Print method", { int: int32 chr: string -* Filter: and(and(greater(, 2), or(equal(, "d"), equal(, "f"))), less(, 5L)) +* Filter: and(and(greater(, 2), or(equal(, "d"), equal(, "f"))), less(, 5)) See $.data for the source Arrow object', fixed = TRUE ) @@ -276,6 +350,14 @@ test_that("summarize", { summarize(min_int = min(int)), tbl ) + + expect_dplyr_equal( + input %>% + select(int, chr) %>% + filter(int > 5) %>% + summarize(min_int = min(int) / 2), + tbl + ) }) test_that("mutate", { diff --git a/r/tests/testthat/test-expression.R b/r/tests/testthat/test-expression.R index 0c5ef4c12da..3c100812ff1 100644 --- a/r/tests/testthat/test-expression.R +++ b/r/tests/testthat/test-expression.R @@ -29,7 +29,7 @@ test_that("array_expression print method", { expect_output( print(build_array_expression(">", Array$create(1:5), 4)), # Not ideal but it is informative - "greater(, 4L)", + "greater(, 4)", fixed = TRUE ) }) @@ -66,3 +66,21 @@ test_that("C++ expressions", { # Interprets that as a list type expect_is(f == c(1L, 2L), "Expression") }) + +test_that("Can create an expression", { + a <- Array$create(as.numeric(1:5)) + expr <- array_expression("cast", a, options = list(to_type = int32())) + expect_is(expr, "array_expression") + expect_equal(eval_array_expression(expr), Array$create(1:5)) + + b <- Array$create(0.5:4.5) + bad_expr <- array_expression("cast", b, options = list(to_type = int32())) + expect_is(bad_expr, "array_expression") + expect_error( + eval_array_expression(bad_expr), + "Invalid: Float value .* was truncated converting" + ) + expr <- array_expression("cast", b, options = list(to_type = int32(), allow_float_truncate = TRUE)) + expect_is(expr, "array_expression") + expect_equal(eval_array_expression(expr), Array$create(0:4)) +})