diff --git a/r/R/array.R b/r/R/array.R index 061c42189b5..8c9a29b6680 100644 --- a/r/R/array.R +++ b/r/R/array.R @@ -270,13 +270,7 @@ FixedSizeListArray <- R6Class("FixedSizeListArray", inherit = Array, length.Array <- function(x) x$length() #' @export -is.na.Array <- function(x) { - if (x$type == null()) { - rep(TRUE, length(x)) - } else { - !Array__Mask(x) - } -} +is.na.Array <- function(x) shared_ptr(Array, call_function("is_null", x)) #' @export as.vector.Array <- function(x, mode) x$as_vector() @@ -287,7 +281,7 @@ filter_rows <- function(x, i, keep_na = TRUE, ...) { nrows <- x$num_rows %||% x$length() # Depends on whether Array or Table-like if (inherits(i, "array_expression")) { # Evaluate it - i <- as.vector(i) + i <- eval_array_expression(i) } if (is.logical(i)) { if (isTRUE(i)) { diff --git a/r/R/arrowExports.R b/r/R/arrowExports.R index 5a2c952e77b..a98a6cba1f6 100644 --- a/r/R/arrowExports.R +++ b/r/R/arrowExports.R @@ -60,10 +60,6 @@ Array__View <- function(array, type){ .Call(`_arrow_Array__View` , array, type) } -Array__Mask <- function(array){ - .Call(`_arrow_Array__Mask` , array) -} - Array__Validate <- function(array){ invisible(.Call(`_arrow_Array__Validate` , array)) } diff --git a/r/R/chunked-array.R b/r/R/chunked-array.R index 5592d5437a7..d2475eb9a76 100644 --- a/r/R/chunked-array.R +++ b/r/R/chunked-array.R @@ -128,7 +128,7 @@ length.ChunkedArray <- function(x) x$length() as.vector.ChunkedArray <- function(x, mode) x$as_vector() #' @export -is.na.ChunkedArray <- function(x) unlist(lapply(x$chunks, is.na)) +is.na.ChunkedArray <- function(x) shared_ptr(ChunkedArray, call_function("is_null", x)) #' @export `[.ChunkedArray` <- filter_rows diff --git a/r/R/compute.R b/r/R/compute.R index f242a58e854..60a1a46976c 100644 --- a/r/R/compute.R +++ b/r/R/compute.R @@ -19,9 +19,19 @@ #' @include chunked-array.R #' @include scalar.R -call_function <- function(function_name, ..., options = list()) { +call_function <- function(function_name, ..., args = list(...), options = empty_named_list()) { assert_that(is.string(function_name)) - compute__CallFunction(function_name, list(...), options) + assert_that(is.list(options), !is.null(names(options))) + + datum_classes <- c("Array", "ChunkedArray", "RecordBatch", "Table", "Scalar") + valid_args <- map_lgl(args, ~inherits(., datum_classes)) + if (!all(valid_args)) { + # Lame, just pick one to report + first_bad <- min(which(!valid_args)) + stop("Argument ", first_bad, " is of class ", head(class(args[[first_bad]]), 1), " but it must be one of ", oxford_paste(datum_classes, "or"), call. = FALSE) + } + + compute__CallFunction(function_name, args, options) } #' @export diff --git a/r/R/dplyr.R b/r/R/dplyr.R index bf5d3c688c0..4d0cd5f58f6 100644 --- a/r/R/dplyr.R +++ b/r/R/dplyr.R @@ -59,7 +59,12 @@ print.arrow_dplyr_query <- function(x, ...) { cat(fields, "\n", sep = "") cat("\n") if (!isTRUE(x$filtered_rows)) { - cat("* Filter: ", x$filtered_rows$ToString(), "\n", sep = "") + if (query_on_dataset(x)) { + filter_string <- x$filtered_rows$ToString() + } else { + filter_string <- .format_array_expression(x$filtered_rows) + } + cat("* Filter: ", filter_string, "\n", sep = "") } if (length(x$group_by_vars)) { cat("* Grouped by ", paste(x$group_by_vars, collapse = ", "), "\n", sep = "") @@ -202,13 +207,13 @@ filter_mask <- function(.data) { } else { comp_func <- function(operator) { force(operator) - function(e1, e2) array_expression(operator, e1, e2) + function(e1, e2) build_array_expression(operator, e1, e2) } var_binder <- function(x) .data$.data[[x]] } # First add the functions - func_names <- set_names(c(names(comparison_function_map), "&", "|", "%in%")) + func_names <- set_names(names(.array_function_map)) env_bind(f_env, !!!lapply(func_names, comp_func)) # Then add the column references # Renaming is handled automatically by the named list diff --git a/r/R/expression.R b/r/R/expression.R index 338e15260a2..092c7caed80 100644 --- a/r/R/expression.R +++ b/r/R/expression.R @@ -17,30 +17,148 @@ #' @include arrowExports.R -array_expression <- function(FUN, ...) { - structure(list(fun = FUN, args = list(...)), class = "array_expression") +array_expression <- function(FUN, + ..., + args = list(...), + options = empty_named_list(), + result_class = .guess_result_class(args[[1]])) { + structure( + list( + fun = FUN, + args = args, + options = options, + result_class = result_class + ), + class = "array_expression" + ) } #' @export -Ops.Array <- function(e1, e2) array_expression(.Generic, e1, e2) +Ops.Array <- function(e1, e2) { + if (.Generic %in% names(.array_function_map)) { + expr <- build_array_expression(.Generic, e1, e2, result_class = "Array") + eval_array_expression(expr) + } else { + stop("Unsupported operation on Array: ", .Generic, call. = FALSE) + } +} #' @export -Ops.ChunkedArray <- Ops.Array +Ops.ChunkedArray <- function(e1, e2) { + if (.Generic %in% names(.array_function_map)) { + expr <- build_array_expression(.Generic, e1, e2, result_class = "ChunkedArray") + eval_array_expression(expr) + } else { + stop("Unsupported operation on ChunkedArray: ", .Generic, call. = FALSE) + } +} #' @export -Ops.array_expression <- Ops.Array +Ops.array_expression <- function(e1, e2) { + if (.Generic == "!") { + build_array_expression(.Generic, e1, result_class = e1$result_class) + } else { + build_array_expression(.Generic, e1, e2, result_class = e1$result_class) + } +} + +build_array_expression <- function(.Generic, e1, e2, ...) { + if (.Generic %in% names(.unary_function_map)) { + expr <- array_expression(.unary_function_map[[.Generic]], e1) + } else { + e1 <- .wrap_arrow(e1, .Generic, e2$type) + e2 <- .wrap_arrow(e2, .Generic, e1$type) + expr <- array_expression(.binary_function_map[[.Generic]], e1, e2, ...) + } + expr +} + +.wrap_arrow <- function(arg, fun, type) { + if (!inherits(arg, c("ArrowObject", "array_expression"))) { + # TODO: Array$create if lengths are equal? + # TODO: these kernels should autocast like the dataset ones do (e.g. int vs. float) + if (fun == "%in%") { + arg <- Array$create(arg, type = type) + } else { + arg <- Scalar$create(arg, type = type) + } + } + arg +} + +.unary_function_map <- list( + "!" = "invert", + "is.na" = "is_null" +) + +.binary_function_map <- list( + "==" = "equal", + "!=" = "not_equal", + ">" = "greater", + ">=" = "greater_equal", + "<" = "less", + "<=" = "less_equal", + "&" = "and_kleene", + "|" = "or_kleene", + "%in%" = "is_in_meta_binary" +) + +.array_function_map <- c(.unary_function_map, .binary_function_map) + +.guess_result_class <- function(arg) { + # HACK HACK HACK delete this when call_function returns an ArrowObject itself + if (inherits(arg, "ArrowObject")) { + return(class(arg)[1]) + } else if (inherits(arg, "array_expression")) { + return(arg$result_class) + } else { + stop("Not implemented") + } +} + +eval_array_expression <- function(x) { + x$args <- lapply(x$args, function (a) { + if (inherits(a, "array_expression")) { + eval_array_expression(a) + } else { + a + } + }) + ptr <- call_function(x$fun, args = x$args, options = x$options %||% empty_named_list()) + shared_ptr(get(x$result_class), ptr) +} #' @export is.na.array_expression <- function(x) array_expression("is.na", x) #' @export as.vector.array_expression <- function(x, ...) { - x$args <- lapply(x$args, as.vector) - do.call(x$fun, x$args) + as.vector(eval_array_expression(x)) } #' @export -print.array_expression <- function(x, ...) print(as.vector(x)) +print.array_expression <- function(x, ...) { + cat(.format_array_expression(x), "\n", sep = "") + invisible(x) +} + +.format_array_expression <- function(x) { + printed_args <- map_chr(x$args, function(arg) { + if (inherits(arg, "Scalar")) { + deparse(as.vector(arg)) + } else if (inherits(arg, "ArrowObject")) { + paste0("<", class(arg)[1], ">") + } else if (inherits(arg, "array_expression")) { + .format_array_expression(arg) + } else { + # Should not happen + deparse(arg) + } + }) + # Prune this for readability + function_name <- sub("_kleene", "", x$fun) + paste0(function_name, "(", paste(printed_args, collapse = ", "), ")") +} ########### @@ -130,6 +248,15 @@ make_expression <- function(operator, e1, e2) { # In doesn't take Scalar, it takes Array return(Expression$in_(e1, e2)) } + + # Handle unary functions before touching e2 + if (operator == "is.na") { + return(is.na(e1)) + } + if (operator == "!") { + return(Expression$not(e1)) + } + # Check for non-expressions and convert to Expressions if (!inherits(e1, "Expression")) { e1 <- Expression$scalar(e1) diff --git a/r/R/record-batch.R b/r/R/record-batch.R index cc683480b1e..712000a97ab 100644 --- a/r/R/record-batch.R +++ b/r/R/record-batch.R @@ -120,6 +120,7 @@ RecordBatch <- R6Class("RecordBatch", inherit = ArrowObject, if (is.logical(i)) { i <- Array$create(i) } + assert_that(is.Array(i, "bool")) shared_ptr(RecordBatch, call_function("filter", self, i, options = list(keep_na = keep_na))) }, serialize = function() ipc___SerializeRecordBatch__Raw(self), diff --git a/r/src/array.cpp b/r/src/array.cpp index 1ebb1578355..5879dc91675 100644 --- a/r/src/array.cpp +++ b/r/src/array.cpp @@ -149,22 +149,6 @@ std::shared_ptr Array__View(const std::shared_ptr& a return ValueOrStop(array->View(type)); } -// [[arrow::export]] -LogicalVector Array__Mask(const std::shared_ptr& array) { - if (array->null_count() == 0) { - return LogicalVector(array->length(), true); - } - - auto n = array->length(); - LogicalVector res(no_init(n)); - arrow::internal::BitmapReader bitmap_reader(array->null_bitmap()->data(), - array->offset(), n); - for (int64_t i = 0; i < n; i++, bitmap_reader.Next()) { - res[i] = bitmap_reader.IsSet(); - } - return res; -} - // [[arrow::export]] void Array__Validate(const std::shared_ptr& array) { StopIfNotOk(array->Validate()); diff --git a/r/src/arrowExports.cpp b/r/src/arrowExports.cpp index 570f1268642..9d0058bb9a3 100644 --- a/r/src/arrowExports.cpp +++ b/r/src/arrowExports.cpp @@ -241,21 +241,6 @@ RcppExport SEXP _arrow_Array__View(SEXP array_sexp, SEXP type_sexp){ } #endif -// array.cpp -#if defined(ARROW_R_WITH_ARROW) -LogicalVector Array__Mask(const std::shared_ptr& array); -RcppExport SEXP _arrow_Array__Mask(SEXP array_sexp){ -BEGIN_RCPP - Rcpp::traits::input_parameter&>::type array(array_sexp); - return Rcpp::wrap(Array__Mask(array)); -END_RCPP -} -#else -RcppExport SEXP _arrow_Array__Mask(SEXP array_sexp){ - Rf_error("Cannot call Array__Mask(). Please use arrow::install_arrow() to install required runtime libraries. "); -} -#endif - // array.cpp #if defined(ARROW_R_WITH_ARROW) void Array__Validate(const std::shared_ptr& array); @@ -5940,7 +5925,6 @@ static const R_CallMethodDef CallEntries[] = { { "_arrow_Array__data", (DL_FUNC) &_arrow_Array__data, 1}, { "_arrow_Array__RangeEquals", (DL_FUNC) &_arrow_Array__RangeEquals, 5}, { "_arrow_Array__View", (DL_FUNC) &_arrow_Array__View, 2}, - { "_arrow_Array__Mask", (DL_FUNC) &_arrow_Array__Mask, 1}, { "_arrow_Array__Validate", (DL_FUNC) &_arrow_Array__Validate, 1}, { "_arrow_DictionaryArray__indices", (DL_FUNC) &_arrow_DictionaryArray__indices, 1}, { "_arrow_DictionaryArray__dictionary", (DL_FUNC) &_arrow_DictionaryArray__dictionary, 1}, diff --git a/r/tests/testthat/test-Array.R b/r/tests/testthat/test-Array.R index d60fea4abe6..ce1b5bef176 100644 --- a/r/tests/testthat/test-Array.R +++ b/r/tests/testthat/test-Array.R @@ -25,7 +25,7 @@ expect_array_roundtrip <- function(x, type, as = NULL) { # TODO: revisit how missingness works with ListArrays # R list objects don't handle missingness the same way as other vectors. # Is there some vctrs thing we should do on the roundtrip back to R? - expect_identical(is.na(a), is.na(x)) + expect_equal(as.vector(is.na(a)), is.na(x)) } expect_equivalent(as.vector(a), x) # Make sure the storage mode is the same on roundtrip (esp. integer vs. numeric) @@ -37,7 +37,7 @@ expect_array_roundtrip <- function(x, type, as = NULL) { expect_type_equal(a_sliced$type, type) expect_identical(length(a_sliced), length(x_sliced)) if (!inherits(type, c("ListType", "LargeListType"))) { - expect_identical(is.na(a_sliced), is.na(x_sliced)) + expect_equal(as.vector(is.na(a_sliced)), is.na(x_sliced)) } expect_equivalent(as.vector(a_sliced), x_sliced) } @@ -182,8 +182,8 @@ test_that("Array supports NA", { expect_true(x_int$IsNull(10L)) expect_true(x_dbl$IsNull(10)) - expect_equal(is.na(x_int), c(rep(FALSE, 10), TRUE)) - expect_equal(is.na(x_dbl), c(rep(FALSE, 10), TRUE)) + expect_equal(as.vector(is.na(x_int)), c(rep(FALSE, 10), TRUE)) + expect_equal(as.vector(is.na(x_dbl)), c(rep(FALSE, 10), TRUE)) # Input validation expect_error(x_int$IsValid("ten"), class = "Rcpp::not_compatible") @@ -354,7 +354,7 @@ test_that("integer types casts (ARROW-3741)", { for (type in c(int_types, uint_types)) { casted <- a$cast(type) expect_equal(casted$type, type) - expect_identical(is.na(casted), c(rep(FALSE, 10), TRUE)) + expect_identical(as.vector(is.na(casted)), c(rep(FALSE, 10), TRUE)) } }) @@ -372,7 +372,7 @@ test_that("float types casts (ARROW-3741)", { for (type in float_types) { casted <- a$cast(type) expect_equal(casted$type, type) - expect_identical(is.na(casted), c(rep(FALSE, 3), TRUE)) + expect_identical(as.vector(is.na(casted)), c(rep(FALSE, 3), TRUE)) expect_identical(as.vector(casted), x) } }) diff --git a/r/tests/testthat/test-chunked-array.R b/r/tests/testthat/test-chunked-array.R index 75f27aa93b4..b4695e28eed 100644 --- a/r/tests/testthat/test-chunked-array.R +++ b/r/tests/testthat/test-chunked-array.R @@ -28,7 +28,7 @@ expect_chunked_roundtrip <- function(x, type) { # TODO: revisit how missingness works with ListArrays # R list objects don't handle missingness the same way as other vectors. # Is there some vctrs thing we should do on the roundtrip back to R? - expect_identical(is.na(a), is.na(flat_x)) + expect_identical(as.vector(is.na(a)), is.na(flat_x)) } expect_equal(as.vector(a), flat_x) expect_equal(as.vector(a$chunk(0)), x[[1]]) @@ -39,7 +39,7 @@ expect_chunked_roundtrip <- function(x, type) { expect_type_equal(a_sliced$type, type) expect_identical(length(a_sliced), length(x_sliced)) if (!inherits(type, "ListType")) { - expect_identical(is.na(a_sliced), is.na(x_sliced)) + expect_identical(as.vector(is.na(a_sliced)), is.na(x_sliced)) } expect_equal(as.vector(a_sliced), x_sliced) } @@ -117,10 +117,8 @@ test_that("ChunkedArray handles NA", { expect_equal(as.vector(x), c(1:10, c(NA, 2:10), c(1:3, NA, 5))) chunks <- x$chunks - expect_equal(is.na(chunks[[1]]), is.na(data[[1]])) - expect_equal(is.na(chunks[[2]]), is.na(data[[2]])) - expect_equal(is.na(chunks[[3]]), is.na(data[[3]])) - expect_equal(is.na(x), c(is.na(data[[1]]), is.na(data[[2]]), is.na(data[[3]]))) + expect_equal(as.vector(is.na(chunks[[2]])), is.na(data[[2]])) + expect_equal(as.vector(is.na(x)), c(is.na(data[[1]]), is.na(data[[2]]), is.na(data[[3]]))) }) test_that("ChunkedArray supports logical vectors (ARROW-3341)", { diff --git a/r/tests/testthat/test-compute.R b/r/tests/testthat/test-compute-aggregate.R similarity index 94% rename from r/tests/testthat/test-compute.R rename to r/tests/testthat/test-compute-aggregate.R index 811c27d05e5..1e5f9a46b33 100644 --- a/r/tests/testthat/test-compute.R +++ b/r/tests/testthat/test-compute-aggregate.R @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -context("compute") +context("compute: aggregation") test_that("sum.Array", { ints <- 1:5 @@ -94,7 +94,10 @@ test_that("mean.Scalar", { }) test_that("Bad input handling of call_function", { - expect_error(call_function("sum", 2, 3), "to_datum: Not implemented for type double") + expect_error( + call_function("sum", 2, 3), + 'Argument 1 is of class numeric but it must be one of "Array", "ChunkedArray", "RecordBatch", "Table", or "Scalar"' + ) }) test_that("min/max.Array", { diff --git a/r/tests/testthat/test-compute-vector.R b/r/tests/testthat/test-compute-vector.R new file mode 100644 index 00000000000..b9097b6e1b6 --- /dev/null +++ b/r/tests/testthat/test-compute-vector.R @@ -0,0 +1,121 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +context("compute: vector operations") + +expect_bool_function_equal <- function(array_exp, r_exp, class = "Array") { + # Assert that the Array operation returns a boolean array + # and that its contents are equal to expected + expect_is(array_exp, class) + expect_type_equal(array_exp, bool()) + expect_identical(as.vector(array_exp), r_exp) +} + +expect_array_compares <- function(r_values, compared_to, Class = Array) { + a <- Class$create(r_values) + # Iterate over all comparison functions + expect_bool_function_equal(a == compared_to, r_values == compared_to, class(a)) + expect_bool_function_equal(a != compared_to, r_values != compared_to, class(a)) + expect_bool_function_equal(a > compared_to, r_values > compared_to, class(a)) + expect_bool_function_equal(a >= compared_to, r_values >= compared_to, class(a)) + expect_bool_function_equal(a < compared_to, r_values < compared_to, class(a)) + expect_bool_function_equal(a <= compared_to, r_values <= compared_to, class(a)) +} + +expect_chunked_array_compares <- function(...) expect_array_compares(..., Class = ChunkedArray) + +test_that("compare ops with Array", { + expect_array_compares(1:5, 4L) + expect_array_compares(1:5, 4) # implicit casting + expect_array_compares(c(NA, 1:5), 4) + expect_array_compares(as.numeric(c(NA, 1:5)), 4) +}) + +test_that("compare ops with ChunkedArray", { + expect_chunked_array_compares(1:5, 4L) + expect_chunked_array_compares(1:5, 4) # implicit casting + expect_chunked_array_compares(c(NA, 1:5), 4) + expect_chunked_array_compares(as.numeric(c(NA, 1:5)), 4) +}) + +test_that("logic ops with Array", { + truth <- expand.grid(left = c(TRUE, FALSE, NA), right = c(TRUE, FALSE, NA)) + a_left <- Array$create(truth$left) + a_right <- Array$create(truth$right) + expect_bool_function_equal(a_left & a_right, truth$left & truth$right) + expect_bool_function_equal(a_left | a_right, truth$left | truth$right) + expect_bool_function_equal(a_left == a_right, truth$left == truth$right) + expect_bool_function_equal(a_left != a_right, truth$left != truth$right) + expect_bool_function_equal(!a_left, !truth$left) + + # More complexity + isEqualTo <- function(x, y) x == y & !is.na(x) + expect_bool_function_equal( + isEqualTo(a_left, a_right), + isEqualTo(truth$left, truth$right) + ) +}) + +test_that("logic ops with ChunkedArray", { + truth <- expand.grid(left = c(TRUE, FALSE, NA), right = c(TRUE, FALSE, NA)) + a_left <- ChunkedArray$create(truth$left) + a_right <- ChunkedArray$create(truth$right) + expect_bool_function_equal(a_left & a_right, truth$left & truth$right, "ChunkedArray") + expect_bool_function_equal(a_left | a_right, truth$left | truth$right, "ChunkedArray") + expect_bool_function_equal(a_left == a_right, truth$left == truth$right, "ChunkedArray") + expect_bool_function_equal(a_left != a_right, truth$left != truth$right, "ChunkedArray") + expect_bool_function_equal(!a_left, !truth$left, "ChunkedArray") + + # More complexity + isEqualTo <- function(x, y) x == y & !is.na(x) + expect_bool_function_equal( + isEqualTo(a_left, a_right), + isEqualTo(truth$left, truth$right), + "ChunkedArray" + ) +}) + +test_that("call_function validation", { + expect_error( + call_function("filter", 4), + 'Argument 1 is of class numeric but it must be one of "Array", "ChunkedArray", "RecordBatch", "Table", or "Scalar"' + ) + expect_error( + call_function("filter", Array$create(1:4), 3), + 'Argument 2 is of class numeric' + ) + expect_error( + call_function("filter", + Array$create(1:4), + Array$create(c(TRUE, FALSE, TRUE)), + options = list(keep_na = TRUE) + ), + "Array arguments must all be the same length" + ) + expect_error( + call_function("filter", + record_batch(a = 1:3), + Array$create(c(TRUE, FALSE, TRUE)), + options = list(keep_na = TRUE) + ), + NA + ) + expect_error( + call_function("filter", options = list(keep_na = TRUE)), + "accepts 2 arguments" + ) +}) diff --git a/r/tests/testthat/test-dplyr.R b/r/tests/testthat/test-dplyr.R index 7b4afda335a..995ba8adabb 100644 --- a/r/tests/testthat/test-dplyr.R +++ b/r/tests/testthat/test-dplyr.R @@ -145,6 +145,24 @@ test_that("More complex select/filter", { ) }) +test_that("Print method", { + expect_output( + record_batch(tbl) %>% + filter(dbl > 2, chr == "d" | chr == "f") %>% + select(chr, int, lgl) %>% + filter(int < 5) %>% + select(int, chr) %>% + print(), +'RecordBatch (query) +int: int32 +chr: string + +* Filter: and(and(greater(, 2), or(equal(, "d"), equal(, "f"))), less(, 5L)) +See $.data for the source Arrow object', + fixed = TRUE + ) +}) + test_that("filter() with %in%", { expect_dplyr_equal( input %>% @@ -169,6 +187,10 @@ test_that("filter environment scope", { # 'could not find function "isEqualTo"' expect_dplyr_error(filter(batch, isEqualTo(int, 4))) + # TODO: fix this: this isEqualTo function is eagerly evaluating; it should + # instead yield array_expressions. Probably bc the parent env of the function + # has the Ops.Array methods defined; we need to move it so that the parent + # env is the data mask we use in the dplyr eval isEqualTo <- function(x, y) x == y & !is.na(x) expect_dplyr_equal( input %>% diff --git a/r/tests/testthat/test-expression.R b/r/tests/testthat/test-expression.R index f75926eb8d3..1bf08595758 100644 --- a/r/tests/testthat/test-expression.R +++ b/r/tests/testthat/test-expression.R @@ -18,25 +18,18 @@ context("Expressions") test_that("Can create an expression", { - expect_is(Array$create(1:5) + 4, "array_expression") -}) - -test_that("Recursive expression generation", { - a <- Array$create(1:5) - expect_is(a == 4 | a == 3, "array_expression") + expect_is(build_array_expression(">", Array$create(1:5), 4), "array_expression") }) test_that("as.vector(array_expression)", { - a <- Array$create(1:5) - expect_equal(as.vector(a + 4), 5:9) - expect_equal(as.vector(a == 4 | a == 3), c(FALSE, FALSE, TRUE, TRUE, FALSE)) + expect_equal(as.vector(build_array_expression(">", Array$create(1:5), 4)), c(FALSE, FALSE, FALSE, FALSE, TRUE)) }) test_that("array_expression print method", { - a <- Array$create(1:5) expect_output( - print(a == 4 | a == 3), - capture.output(print(c(FALSE, FALSE, TRUE, TRUE, FALSE))), + print(build_array_expression(">", Array$create(1:5), 4)), + # Not ideal but it is informative + "greater(, 4L)", fixed = TRUE ) })