From 5a9cb90ae516be5896f4b6a532910aa94a166c3b Mon Sep 17 00:00:00 2001 From: Ian Cook Date: Wed, 17 Mar 2021 23:44:45 -0400 Subject: [PATCH 01/49] Implement SortIndices method for Table, RecordBatch, Array, ChunkedArray --- r/R/array.R | 8 ++++++++ r/R/arrow-tabular.R | 17 +++++++++++++++++ r/R/chunked-array.R | 8 ++++++++ r/R/record-batch.R | 7 ++++++- r/R/table.R | 7 ++++++- r/man/ChunkedArray.Rd | 2 ++ r/man/RecordBatch.Rd | 5 +++++ r/man/Table.Rd | 5 +++++ r/man/array.Rd | 2 ++ r/src/compute.cpp | 27 +++++++++++++++++++++++++++ 10 files changed, 86 insertions(+), 2 deletions(-) diff --git a/r/R/array.R b/r/R/array.R index aa164eaaf91..1d63c5735a7 100644 --- a/r/R/array.R +++ b/r/R/array.R @@ -73,6 +73,8 @@ #' (R vector or Array Array) `i`. #' - `$Filter(i, keep_na = TRUE)`: return an `Array` with values at positions where logical #' vector (or Arrow boolean Array) `i` is `TRUE`. +#' - `$SortIndices(descending = FALSE)`: return an `Array` of integer positions that can be +#' used to rearrange the `Array` in ascending or descending order #' - `$RangeEquals(other, start_idx, end_idx, other_start_idx)` : #' - `$cast(target_type, safe = TRUE, options = cast_options(safe))`: Alter the #' data in the array to change its type. @@ -131,6 +133,12 @@ Array <- R6Class("Array", assert_is(i, "Array") call_function("filter", self, i, options = list(keep_na = keep_na)) }, + SortIndices = function(descending = FALSE) { + assert_that(is.logical(descending)) + assert_that(length(descending) == 1L) + assert_that(!is.na(descending)) + call_function("array_sort_indices", self, options = list(order = descending)) + }, RangeEquals = function(other, start_idx, end_idx, other_start_idx = 0L) { assert_is(other, "Array") Array__RangeEquals(self, other, start_idx, end_idx, other_start_idx) diff --git a/r/R/arrow-tabular.R b/r/R/arrow-tabular.R index a41586f26b3..157b799f3b6 100644 --- a/r/R/arrow-tabular.R +++ b/r/R/arrow-tabular.R @@ -38,6 +38,23 @@ ArrowTabular <- R6Class("ArrowTabular", inherit = ArrowObject, } assert_that(is.Array(i, "bool")) call_function("filter", self, i, options = list(keep_na = keep_na)) + }, + SortIndices = function(names, descending = FALSE) { + assert_that(is.character(names)) + assert_that(length(names) > 0) + assert_that(!any(is.na(names))) + if (length(descending) == 1L) { + descending <- rep_len(descending, length(names)) + } + assert_that(is.logical(descending)) + assert_that(identical(length(names), length(descending))) + assert_that(!any(is.na(descending))) + call_function( + "sort_indices", + self, + # cpp11 does not support logical vectors so convert to integer + options = list(names = names, orders = as.integer(descending)) + ) } ) ) diff --git a/r/R/chunked-array.R b/r/R/chunked-array.R index d639b235f3f..64710d4743e 100644 --- a/r/R/chunked-array.R +++ b/r/R/chunked-array.R @@ -41,6 +41,8 @@ #' coerced to an R vector before taking. #' - `$Filter(i, keep_na = TRUE)`: return a `ChunkedArray` with values at positions where #' logical vector or Arrow boolean-type `(Chunked)Array` `i` is `TRUE`. +#' - `$SortIndices(descending = FALSE)`: return an `Array` of integer positions that can be +#' used to rearrange the `ChunkedArray` in ascending or descending order #' - `$cast(target_type, safe = TRUE, options = cast_options(safe))`: Alter the #' data in the array to change its type. #' - `$null_count()`: The number of null entries in the array @@ -83,6 +85,12 @@ ChunkedArray <- R6Class("ChunkedArray", inherit = ArrowDatum, } call_function("filter", self, i, options = list(keep_na = keep_na)) }, + SortIndices = function(descending = FALSE) { + assert_that(is.logical(descending)) + assert_that(length(descending) == 1L) + assert_that(!is.na(descending)) + call_function("array_sort_indices", self, options = list(order = descending)) + }, View = function(type) { ChunkedArray__View(self, as_type(type)) }, diff --git a/r/R/record-batch.R b/r/R/record-batch.R index fdb54deb881..8e7382a70b4 100644 --- a/r/R/record-batch.R +++ b/r/R/record-batch.R @@ -57,6 +57,11 @@ #' integers (R vector or Array Array) `i`. #' - `$Filter(i, keep_na = TRUE)`: return an `RecordBatch` with rows at positions where logical #' vector (or Arrow boolean Array) `i` is `TRUE`. +#' - `$SortIndices(names, descending = FALSE)`: return an `Array` of integer row +#' positions that can be used to rearrange the `RecordBatch` in ascending or +#' descending order by the first named column, breaking ties with further named +#' columns. `descending` can be a logical vector of length one or of the same +#' length as `names`. #' - `$serialize()`: Returns a raw vector suitable for interprocess communication #' - `$cast(target_schema, safe = TRUE, options = cast_options(safe))`: Alter #' the schema of the record batch. @@ -99,7 +104,7 @@ RecordBatch <- R6Class("RecordBatch", inherit = ArrowTabular, RecordBatch__Slice2(self, offset, length) } }, - # Take and Filter are methods on ArrowTabular + # Take, Filter, and SortIndices are methods on ArrowTabular serialize = function() ipc___SerializeRecordBatch__Raw(self), to_data_frame = function() { RecordBatch__to_dataframe(self, use_threads = option_use_threads()) diff --git a/r/R/table.R b/r/R/table.R index d2c9960e6d2..fdf3f5cc20d 100644 --- a/r/R/table.R +++ b/r/R/table.R @@ -65,6 +65,11 @@ #' coerced to an R vector before taking. #' - `$Filter(i, keep_na = TRUE)`: return an `Table` with rows at positions where logical #' vector or Arrow boolean-type `(Chunked)Array` `i` is `TRUE`. +#' - `$SortIndices(names, descending = FALSE)`: return an `Array` of integer row +#' positions that can be used to rearrange the `Table` in ascending or descending +#' order by the first named column, breaking ties with further named columns. +#' `descending` can be a logical vector of length one or of the same length as +#' `names`. #' - `$serialize(output_stream, ...)`: Write the table to the given #' [OutputStream] #' - `$cast(target_schema, safe = TRUE, options = cast_options(safe))`: Alter @@ -122,7 +127,7 @@ Table <- R6Class("Table", inherit = ArrowTabular, Table__Slice2(self, offset, length) } }, - # Take and Filter are methods on ArrowTabular + # Take, Filter, and SortIndices are methods on ArrowTabular Equals = function(other, check_metadata = FALSE, ...) { inherits(other, "Table") && Table__Equals(self, other, isTRUE(check_metadata)) }, diff --git a/r/man/ChunkedArray.Rd b/r/man/ChunkedArray.Rd index 533931ae972..90dd2e39e40 100644 --- a/r/man/ChunkedArray.Rd +++ b/r/man/ChunkedArray.Rd @@ -38,6 +38,8 @@ integers \code{i}. If \code{i} is an Arrow \code{Array} or \code{ChunkedArray}, coerced to an R vector before taking. \item \verb{$Filter(i, keep_na = TRUE)}: return a \code{ChunkedArray} with values at positions where logical vector or Arrow boolean-type \verb{(Chunked)Array} \code{i} is \code{TRUE}. +\item \verb{$SortIndices(descending = FALSE)}: return an \code{Array} of integer positions that can be +used to rearrange the \code{ChunkedArray} in ascending or descending order \item \verb{$cast(target_type, safe = TRUE, options = cast_options(safe))}: Alter the data in the array to change its type. \item \verb{$null_count()}: The number of null entries in the array diff --git a/r/man/RecordBatch.Rd b/r/man/RecordBatch.Rd index 06f9f67abe2..184fea99c7f 100644 --- a/r/man/RecordBatch.Rd +++ b/r/man/RecordBatch.Rd @@ -57,6 +57,11 @@ of the table if \code{NULL}, the default. integers (R vector or Array Array) \code{i}. \item \verb{$Filter(i, keep_na = TRUE)}: return an \code{RecordBatch} with rows at positions where logical vector (or Arrow boolean Array) \code{i} is \code{TRUE}. +\item \verb{$SortIndices(names, descending = FALSE)}: return an \code{Array} of integer row +positions that can be used to rearrange the \code{RecordBatch} in ascending or +descending order by the first named column, breaking ties with further named +columns. \code{descending} can be a logical vector of length one or of the same +length as \code{names}. \item \verb{$serialize()}: Returns a raw vector suitable for interprocess communication \item \verb{$cast(target_schema, safe = TRUE, options = cast_options(safe))}: Alter the schema of the record batch. diff --git a/r/man/Table.Rd b/r/man/Table.Rd index 14c0b0bf260..98a5c354ced 100644 --- a/r/man/Table.Rd +++ b/r/man/Table.Rd @@ -56,6 +56,11 @@ integers \code{i}. If \code{i} is an Arrow \code{Array} or \code{ChunkedArray}, coerced to an R vector before taking. \item \verb{$Filter(i, keep_na = TRUE)}: return an \code{Table} with rows at positions where logical vector or Arrow boolean-type \verb{(Chunked)Array} \code{i} is \code{TRUE}. +\item \verb{$SortIndices(names, descending = FALSE)}: return an \code{Array} of integer row +positions that can be used to rearrange the \code{Table} in ascending or descending +order by the first named column, breaking ties with further named columns. +\code{descending} can be a logical vector of length one or of the same length as +\code{names}. \item \verb{$serialize(output_stream, ...)}: Write the table to the given \link{OutputStream} \item \verb{$cast(target_schema, safe = TRUE, options = cast_options(safe))}: Alter diff --git a/r/man/array.Rd b/r/man/array.Rd index fbc91e4dc35..f65afe9fbc3 100644 --- a/r/man/array.Rd +++ b/r/man/array.Rd @@ -71,6 +71,8 @@ until the end of the array. (R vector or Array Array) \code{i}. \item \verb{$Filter(i, keep_na = TRUE)}: return an \code{Array} with values at positions where logical vector (or Arrow boolean Array) \code{i} is \code{TRUE}. +\item \verb{$SortIndices(descending = FALSE)}: return an \code{Array} of integer positions that can be +used to rearrange the \code{Array} in ascending or descending order \item \verb{$RangeEquals(other, start_idx, end_idx, other_start_idx)} : \item \verb{$cast(target_type, safe = TRUE, options = cast_options(safe))}: Alter the data in the array to change its type. diff --git a/r/src/compute.cpp b/r/src/compute.cpp index 07380354b12..29f0cd7691a 100644 --- a/r/src/compute.cpp +++ b/r/src/compute.cpp @@ -144,6 +144,33 @@ std::shared_ptr make_compute_options( return out; } + if (func_name == "array_sort_indices") { + using Order = arrow::compute::SortOrder; + using Options = arrow::compute::ArraySortOptions; + // false means descending, true means ascending + auto order = cpp11::as_cpp(options["order"]); + auto out = + std::make_shared(Options(order ? Order::Descending : Order::Ascending)); + return out; + } + + if (func_name == "sort_indices") { + using Key = arrow::compute::SortKey; + using Order = arrow::compute::SortOrder; + using Options = arrow::compute::SortOptions; + auto names = cpp11::as_cpp>(options["names"]); + // false means descending, true means ascending + // cpp11 does not support bool here so use int + auto orders = cpp11::as_cpp>(options["orders"]); + std::vector keys; + for (int i = 0; i < names.size(); i++) { + keys.push_back( + Key(names[i], (orders[i] > 0) ? Order::Descending : Order::Ascending)); + } + auto out = std::make_shared(Options(keys)); + return out; + } + if (func_name == "min_max") { using Options = arrow::compute::MinMaxOptions; auto out = std::make_shared(Options::Defaults()); From d47cfd0e1efc7bb5efd2c51d2156e0f4562038cc Mon Sep 17 00:00:00 2001 From: Ian Cook Date: Thu, 18 Mar 2021 18:45:42 -0400 Subject: [PATCH 02/49] Fix SortIndices for ChunkedArray --- r/R/chunked-array.R | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/r/R/chunked-array.R b/r/R/chunked-array.R index 64710d4743e..a7fd8d3143e 100644 --- a/r/R/chunked-array.R +++ b/r/R/chunked-array.R @@ -89,7 +89,11 @@ ChunkedArray <- R6Class("ChunkedArray", inherit = ArrowDatum, assert_that(is.logical(descending)) assert_that(length(descending) == 1L) assert_that(!is.na(descending)) - call_function("array_sort_indices", self, options = list(order = descending)) + call_function( + "sort_indices", + self, + options = list(names = "", orders = as.integer(descending)) + ) }, View = function(type) { ChunkedArray__View(self, as_type(type)) From 67d21990079a70752f4c4b7daa789c509ea92a2c Mon Sep 17 00:00:00 2001 From: Ian Cook Date: Thu, 18 Mar 2021 18:52:31 -0400 Subject: [PATCH 03/49] Implement sort() method for ArrowDatum types --- r/NAMESPACE | 3 +++ r/R/array.R | 5 +++++ r/R/chunked-array.R | 3 +++ r/R/scalar.R | 3 +++ 4 files changed, 14 insertions(+) diff --git a/r/NAMESPACE b/r/NAMESPACE index 96c09615896..26cf15e6cb0 100644 --- a/r/NAMESPACE +++ b/r/NAMESPACE @@ -76,6 +76,9 @@ S3method(read_message,InputStream) S3method(read_message,MessageReader) S3method(read_message,default) S3method(row.names,ArrowTabular) +S3method(sort,Array) +S3method(sort,ChunkedArray) +S3method(sort,Scalar) S3method(sum,ArrowDatum) S3method(tail,ArrowDatum) S3method(tail,ArrowTabular) diff --git a/r/R/array.R b/r/R/array.R index 1d63c5735a7..43cfffebd5e 100644 --- a/r/R/array.R +++ b/r/R/array.R @@ -299,3 +299,8 @@ is.Array <- function(x, type = NULL) { } is_it } + +#' @export +sort.Array <- function(x, decreasing = FALSE, ...) { + x$Take(x$SortIndices(descending = decreasing)) +} diff --git a/r/R/chunked-array.R b/r/R/chunked-array.R index a7fd8d3143e..e7c03993c24 100644 --- a/r/R/chunked-array.R +++ b/r/R/chunked-array.R @@ -128,3 +128,6 @@ ChunkedArray$create <- function(..., type = NULL) { #' @rdname ChunkedArray #' @export chunked_array <- ChunkedArray$create + +#' @export +sort.ChunkedArray <- sort.Array diff --git a/r/R/scalar.R b/r/R/scalar.R index d6955423b53..d2dd5db5d8e 100644 --- a/r/R/scalar.R +++ b/r/R/scalar.R @@ -68,3 +68,6 @@ length.Scalar <- function(x) 1L #' @export is.na.Scalar <- function(x) !x$is_valid + +#' @export +sort.Scalar <- function(x, decreasing = FALSE, ...) x From deef9a3048052cddae57ee272591ff499222318b Mon Sep 17 00:00:00 2001 From: Ian Cook Date: Thu, 18 Mar 2021 21:35:40 -0400 Subject: [PATCH 04/49] Disallow unsupported na.last options in Array and ChunkedArray sort() --- r/R/array.R | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/r/R/array.R b/r/R/array.R index 43cfffebd5e..bc46d9577af 100644 --- a/r/R/array.R +++ b/r/R/array.R @@ -301,6 +301,9 @@ is.Array <- function(x, type = NULL) { } #' @export -sort.Array <- function(x, decreasing = FALSE, ...) { +sort.Array <- function(x, decreasing = FALSE, na.last = TRUE, ...) { + if (!identical(na.last, TRUE)) { + stop("Arrow only supports sort() with na.last = TRUE", call. = FALSE) + } x$Take(x$SortIndices(descending = decreasing)) } From e26966967abc213477b9257bbd8fdf272a2a1829 Mon Sep 17 00:00:00 2001 From: Ian Cook Date: Thu, 18 Mar 2021 22:18:39 -0400 Subject: [PATCH 05/49] Add test data for sort tests --- r/tests/testthat/helper-data.R | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/r/tests/testthat/helper-data.R b/r/tests/testthat/helper-data.R index 1dd3a6a79d3..bc0276c031c 100644 --- a/r/tests/testthat/helper-data.R +++ b/r/tests/testthat/helper-data.R @@ -134,3 +134,17 @@ example_with_logical_factors <- tibble::tibble( "hey buddy" ) ) + +# the values in each column of this tibble are in ascending order in the C locale. +# there are some ties, but sorting by any two columns will give a deterministic order. +example_data_for_sorting <- tibble::tibble( + int = c(-.Machine$integer.max, -101L, -100L, 0L, 0L, 1L, 100L, 1000L, .Machine$integer.max, NA_integer_), + dbl = c(-Inf, -.Machine$double.xmax, -.Machine$double.xmin, 0, .Machine$double.xmin, pi, .Machine$double.xmax, Inf, NaN, NA_real_), + # R string collation varies by locale, while libarrow always uses the C locale for string collation + # (in other words: string values in libarrow are ordered lexicographically as bytestrings) + # to make R sort functions use the C locale, run Sys.setlocale("LC_COLLATE", "C") + chr = c("", "", "\u0001", "&", "ABC", "NULL", "a", "abc", "\uFFFF", NA_character_), + # bool is not supported (ARROW-12016) + lgl = c(rep(FALSE, 4L), rep(TRUE, 4L), rep(NA, 2)), + # TODO: add more types +) From 70019270a141ce89a85db280af2676ccadd7c465 Mon Sep 17 00:00:00 2001 From: Ian Cook Date: Thu, 18 Mar 2021 22:21:28 -0400 Subject: [PATCH 06/49] Add compute sorting tests --- r/tests/testthat/test-compute-sort.R | 97 ++++++++++++++++++++++++++++ 1 file changed, 97 insertions(+) create mode 100644 r/tests/testthat/test-compute-sort.R diff --git a/r/tests/testthat/test-compute-sort.R b/r/tests/testthat/test-compute-sort.R new file mode 100644 index 00000000000..b652f2e7fe9 --- /dev/null +++ b/r/tests/testthat/test-compute-sort.R @@ -0,0 +1,97 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +context("compute: sorting") + +library(dplyr) + +tbl <- example_data_for_sorting + +test_that("Array sort", { + expect_equal( + Array$create(tbl$int)$SortIndices(), + Array$create(0L:9L, type = uint64()) + ) + expect_equal( + Array$create(rev(tbl$int))$SortIndices(descending = TRUE), + Array$create(c(1L:9L, 0L), type = uint64()) + ) + expect_equal( + as.vector(sort(Array$create(tbl$int))), + sort(tbl$int, na.last = TRUE) + ) + expect_equal( + as.vector(sort(Array$create(tbl$int), decreasing = TRUE)), + sort(tbl$int, decreasing = TRUE, na.last = TRUE) + ) + expect_error( + sort(Array$create(tbl$int), decreasing = TRUE, na.last = NA), + "na.last", + fixed = TRUE + ) +}) + +test_that("ChunkedArray sort", { + expect_equal( + ChunkedArray$create(tbl$int[1:5], tbl$int[6:10])$SortIndices(), + Array$create(0L:9L, type = uint64()) + ) + expect_equal( + ChunkedArray$create(rev(tbl$int)[1:5], rev(tbl$int)[6:10])$SortIndices(descending = TRUE), + Array$create(c(1L:9L, 0L), type = uint64()) + ) + expect_equal( + as.vector(sort(ChunkedArray$create(tbl$int[1:5], tbl$int[6:10]))), + sort(tbl$int, na.last = TRUE) + ) + expect_equal( + as.vector(sort(ChunkedArray$create(tbl$int[1:5], tbl$int[6:10]), decreasing = TRUE)), + sort(tbl$int, decreasing = TRUE, na.last = TRUE) + ) + expect_error( + sort(ChunkedArray$create(tbl$int), decreasing = TRUE, na.last = NA), + "na.last", + fixed = TRUE + ) +}) + +test_that("Table/RecordBatch sort", { + expect_identical( + { + x <- tbl %>% slice_sample(prop = 1L) %>% Table$create() + x$Take(x$SortIndices("chr")) %>% pull(chr) + }, + tbl$chr + ) + expect_identical( + { + x <- tbl %>% slice_sample(prop = 1L) %>% Table$create() + x$Take(x$SortIndices(c("int", "dbl"), c(FALSE, FALSE))) %>% collect() + }, + tbl + ) + expect_identical( + { + x <- tbl %>% slice_sample(prop = 1L) %>% record_batch() + x$Take(x$SortIndices(c("chr", "int", "dbl"), TRUE)) %>% collect() + }, + rbind( + tbl %>% head(-1) %>% arrange(-row_number()), + tbl %>% tail(1) + ) + ) +}) From c5bf536716e12e4b9a299777df571fbae59f3700 Mon Sep 17 00:00:00 2001 From: Ian Cook Date: Fri, 19 Mar 2021 00:42:55 -0400 Subject: [PATCH 07/49] Fix unsigned/signed compare error --- r/src/compute.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/r/src/compute.cpp b/r/src/compute.cpp index 29f0cd7691a..5cf8c7c37d2 100644 --- a/r/src/compute.cpp +++ b/r/src/compute.cpp @@ -163,7 +163,7 @@ std::shared_ptr make_compute_options( // cpp11 does not support bool here so use int auto orders = cpp11::as_cpp>(options["orders"]); std::vector keys; - for (int i = 0; i < names.size(); i++) { + for (size_t i = 0; i < names.size(); i++) { keys.push_back( Key(names[i], (orders[i] > 0) ? Order::Descending : Order::Ascending)); } From 3fa4125b6fce13af9aef7061fe1288970c6ce3f3 Mon Sep 17 00:00:00 2001 From: Ian Cook Date: Fri, 19 Mar 2021 10:59:06 -0400 Subject: [PATCH 08/49] Implement dplyr::arrange() for ArrowTabular --- r/R/dplyr.R | 129 +++++++++++++++++++++++++++++++++++++++++++++-- r/R/expression.R | 14 +++++ 2 files changed, 138 insertions(+), 5 deletions(-) diff --git a/r/R/dplyr.R b/r/R/dplyr.R index 2745f69d90c..041d78b9f3a 100644 --- a/r/R/dplyr.R +++ b/r/R/dplyr.R @@ -46,7 +46,12 @@ arrow_dplyr_query <- function(.data) { # drop_empty_groups is a logical value indicating whether to drop # groups formed by factor levels that don't appear in the data. It # should be non-null only when the data is grouped. - drop_empty_groups = NULL + drop_empty_groups = NULL, + # arrange_vars will be a list of expressions + arrange_vars = list(), + # arrange_desc will be a logical vector indicating the sort order for each + # expression in arrange_vars (FALSE for ascending, TRUE for descending) + arrange_desc = logical() ), class = "arrow_dplyr_query" ) @@ -80,6 +85,20 @@ print.arrow_dplyr_query <- function(x, ...) { if (length(x$group_by_vars)) { cat("* Grouped by ", paste(x$group_by_vars, collapse = ", "), "\n", sep = "") } + if (length(x$arrange_vars)) { + cat( + "* Sorted by ", + paste( + paste0( + map_chr(x$arrange_vars, .format_array_expression), + " [", ifelse(x$arrange_desc, "desc", "asc"), "]" + ), + collapse = ", " + ), + "\n", + sep = "" + ) + } cat("See $.data for the source Arrow object\n") invisible(x) } @@ -404,6 +423,25 @@ collect.arrow_dplyr_query <- function(x, as_data_frame = TRUE, ...) { tab <- RecordBatch$create(!!!cols) } } + # Arrange rows + # TODO: support sorting by expressions, not just field names + # TODO: support sorting by names and expressions that are valid *before* the + # selected_columns code above but are *not* valid after it + if (length(x$arrange_vars) > 0) { + x$arrange_vars <- vapply( + x$arrange_vars, + get_field_name, + character(1), + msg = function(x) { + paste( + .format_array_expression(x), + "is not a field name. Only bare field names are supported in arrange()" + ) + }, + USE.NAMES = FALSE + ) + tab <- tab[tab$SortIndices(x$arrange_vars, x$arrange_desc), ] + } } if (as_data_frame) { df <- as.data.frame(tab) @@ -689,17 +727,98 @@ abandon_ship <- function(call, .data, msg = NULL) { eval.parent(call, 2) } -arrange.arrow_dplyr_query <- function(.data, ...) { +arrange.arrow_dplyr_query <- function(.data, ..., .by_group = FALSE) { + # TODO: handle .by_group argument + if (!isFALSE(.by_group)) { + stop(".by_group argument not supported for Arrow objects", call. = FALSE) + } + + call <- match.call() + exprs <- quos(...) + + if (length(exprs) == 0) { + # Nothing to do + return(.data) + } + .data <- arrow_dplyr_query(.data) + + # Restrict the cases we support for now if (query_on_dataset(.data)) { not_implemented_for_dataset("arrange()") } - # TODO(ARROW-11703) move this to Arrow - call <- match.call() - abandon_ship(call, .data) + is_grouped <- length(dplyr::group_vars(.data)) > 0 + if (is_grouped) { + return(abandon_ship(call, .data, 'arrange() on grouped data not supported in Arrow')) + } + is_dataset <- query_on_dataset(.data) + if (is_dataset) { + return(abandon_ship(call, .data)) + } + + # find and remove any dplyr::desc() and tidy-eval + # the arrange expressions inside an Arrow data_mask + sorts <- vector("list", length(exprs)) + descs <- logical(0) + mask <- arrow_mask(.data) + for (i in seq_along(exprs)) { + x <- find_and_remove_desc(exprs[[i]]) + exprs[[i]] <- x[["quos"]] + sorts[[i]] <- arrow_eval(exprs[[i]], mask) + descs[i] <- x[["desc"]] + } + bad_sorts <- map_lgl(sorts, ~inherits(., "try-error")) + if (any(bad_sorts)) { + bads <- oxford_paste(map_chr(exprs, as_label)[bad_sorts], quote = FALSE) + stop( + "Invalid or unsupported arrange ", + ngettext(length(bads), "expression: ", "expressions: "), + bads, + call. = FALSE + ) + } + .data$arrange_vars <- c(sorts, .data$arrange_vars) + .data$arrange_desc <- c(descs, .data$arrange_desc) + .data } arrange.Dataset <- arrange.ArrowTabular <- arrange.arrow_dplyr_query +# Helper to handle desc() in arrange() +# * Takes a quosure as input +# * Returns a list with two elements: +# 1. The quosure with any wrapping parentheses and desc() removed +# 2. A logical value indicating whether desc() was found +# * Performs some other validation +find_and_remove_desc <- function(quosure) { + expr <- quo_get_expr(quosure) + descending <- FALSE + if (length(all.vars(expr)) < 1L) { + stop( + "Expression in arrange() does not contain any field names: ", + deparse(expr), + call. = FALSE + ) + } + while (identical(typeof(expr), "language") && is.call(expr)) { + if (identical(expr[[1]], quote(`(`))) { + # remove enclosing parentheses + expr <- expr[[2]] + } else if (identical(expr[[1]], quote(desc))) { + # remove desc() and toggle descending + expr <- expr[[2]] + descending <- !descending + } else { + break + } + } + return( + list( + quos = quo_set_expr(quosure, expr), + desc = descending + ) + ) +} + query_on_dataset <- function(x) inherits(x$.data, "Dataset") not_implemented_for_dataset <- function(method) { diff --git a/r/R/expression.R b/r/R/expression.R index 3f79c92dd46..f1909ed6b39 100644 --- a/r/R/expression.R +++ b/r/R/expression.R @@ -181,6 +181,20 @@ find_array_refs <- function(x) { unlist(out) } +get_field_name <- function(x, msg = NULL) { + if (!identical(x$fun, "array_ref")) { + if (is.null(msg)) { + # Default message + stop(paste(.format_array_expression(x), "is not a bare field name"), call. = FALSE) + } else if (is.function(msg)) { + stop(msg(x), call. = FALSE) + } else { + stop(msg, call. = FALSE) + } + } + x$args$field_name +} + # Take an array_expression and replace array_refs with arrays/chunkedarrays from data bind_array_refs <- function(x, data) { if (inherits(x, "array_expression")) { From 5c8b70e26ba5cd04470b4ec2cba116e205ff38e8 Mon Sep 17 00:00:00 2001 From: Ian Cook Date: Fri, 19 Mar 2021 10:59:31 -0400 Subject: [PATCH 09/49] Add arrange() tests --- r/tests/testthat/test-dplyr-arrange.R | 61 +++++++++++++++++++++++++++ 1 file changed, 61 insertions(+) create mode 100644 r/tests/testthat/test-dplyr-arrange.R diff --git a/r/tests/testthat/test-dplyr-arrange.R b/r/tests/testthat/test-dplyr-arrange.R new file mode 100644 index 00000000000..6500b814126 --- /dev/null +++ b/r/tests/testthat/test-dplyr-arrange.R @@ -0,0 +1,61 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +library(dplyr) + +tbl <- example_data_for_sorting + +test_that("arrange", { + expect_dplyr_equal( + input %>% + arrange(int, chr) %>% + collect(), + tbl %>% + slice_sample(prop = 1L) + ) + expect_dplyr_equal( + input %>% + arrange(int, desc(dbl)) %>% + collect(), + tbl %>% + slice_sample(prop = 1L) + ) + expect_dplyr_equal( + input %>% + arrange(int) %>% + arrange(desc(dbl)) %>% + collect(), + tbl %>% + slice_sample(prop = 1L) + ) + expect_error( + tbl %>% + Table$create() %>% + arrange(int + dbl) %>% + collect(), + "Only bare field names are supported in arrange", + fixed = TRUE + ) + expect_error( + tbl %>% + Table$create() %>% + arrange(1), + "does not contain any field names", + fixed = TRUE + ) + # TODO: test the other unsupported cases +}) From ecd5bf76a0f787021afc0bf96dc8ab6ce4629806 Mon Sep 17 00:00:00 2001 From: Ian Cook Date: Mon, 22 Mar 2021 09:12:00 -0400 Subject: [PATCH 10/49] Define sort method for ArrowDatum instead of Array and ChunkedArray --- r/NAMESPACE | 3 +-- r/R/array.R | 8 -------- r/R/arrow-datum.R | 8 ++++++++ r/R/chunked-array.R | 3 --- r/tests/testthat/test-compute-sort.R | 7 +++++++ 5 files changed, 16 insertions(+), 13 deletions(-) diff --git a/r/NAMESPACE b/r/NAMESPACE index 26cf15e6cb0..e0116d2f5e4 100644 --- a/r/NAMESPACE +++ b/r/NAMESPACE @@ -76,8 +76,7 @@ S3method(read_message,InputStream) S3method(read_message,MessageReader) S3method(read_message,default) S3method(row.names,ArrowTabular) -S3method(sort,Array) -S3method(sort,ChunkedArray) +S3method(sort,ArrowDatum) S3method(sort,Scalar) S3method(sum,ArrowDatum) S3method(tail,ArrowDatum) diff --git a/r/R/array.R b/r/R/array.R index bc46d9577af..1d63c5735a7 100644 --- a/r/R/array.R +++ b/r/R/array.R @@ -299,11 +299,3 @@ is.Array <- function(x, type = NULL) { } is_it } - -#' @export -sort.Array <- function(x, decreasing = FALSE, na.last = TRUE, ...) { - if (!identical(na.last, TRUE)) { - stop("Arrow only supports sort() with na.last = TRUE", call. = FALSE) - } - x$Take(x$SortIndices(descending = decreasing)) -} diff --git a/r/R/arrow-datum.R b/r/R/arrow-datum.R index f4d9ad346aa..dc50f8f8316 100644 --- a/r/R/arrow-datum.R +++ b/r/R/arrow-datum.R @@ -138,3 +138,11 @@ as.integer.ArrowDatum <- function(x, ...) as.integer(as.vector(x), ...) #' @export as.character.ArrowDatum <- function(x, ...) as.character(as.vector(x), ...) + +#' @export +sort.ArrowDatum <- function(x, decreasing = FALSE, na.last = TRUE, ...) { + if (!identical(na.last, TRUE)) { + stop("Arrow only supports sort() with na.last = TRUE", call. = FALSE) + } + x$Take(x$SortIndices(descending = decreasing)) +} diff --git a/r/R/chunked-array.R b/r/R/chunked-array.R index e7c03993c24..a7fd8d3143e 100644 --- a/r/R/chunked-array.R +++ b/r/R/chunked-array.R @@ -128,6 +128,3 @@ ChunkedArray$create <- function(..., type = NULL) { #' @rdname ChunkedArray #' @export chunked_array <- ChunkedArray$create - -#' @export -sort.ChunkedArray <- sort.Array diff --git a/r/tests/testthat/test-compute-sort.R b/r/tests/testthat/test-compute-sort.R index b652f2e7fe9..83d246c484a 100644 --- a/r/tests/testthat/test-compute-sort.R +++ b/r/tests/testthat/test-compute-sort.R @@ -21,6 +21,13 @@ library(dplyr) tbl <- example_data_for_sorting +test_that("Scalar sort", { + expect_identical( + as.vector(sort(Scalar$create(42L))), + 42L + ) +}) + test_that("Array sort", { expect_equal( Array$create(tbl$int)$SortIndices(), From aad31cb77f4ed8bade0265bc109960fb1b31ec11 Mon Sep 17 00:00:00 2001 From: Ian Cook Date: Mon, 22 Mar 2021 10:46:01 -0400 Subject: [PATCH 11/49] Improve tests --- r/tests/testthat/helper-data.R | 29 +++++++++++++++++++-------- r/tests/testthat/test-dplyr-arrange.R | 7 +++++++ 2 files changed, 28 insertions(+), 8 deletions(-) diff --git a/r/tests/testthat/helper-data.R b/r/tests/testthat/helper-data.R index bc0276c031c..22f3b4263fd 100644 --- a/r/tests/testthat/helper-data.R +++ b/r/tests/testthat/helper-data.R @@ -135,16 +135,29 @@ example_with_logical_factors <- tibble::tibble( ) ) -# the values in each column of this tibble are in ascending order in the C locale. -# there are some ties, but sorting by any two columns will give a deterministic order. +# The values in each column of this tibble are in ascending order in the C locale. +# There are some ties, so tests should use two or more columns to ensure +# deterministic order. libarrow uses the C locale for string collation. testthat +# uses the C locale for string collation inside calls to test_that(). To run test +# code outside of test_that() calls, set the collation locale to "C" by running: +# Sys.setlocale("LC_COLLATE", "C") +# When finished, restore the default collation locale by running: +# Sys.setlocale("LC_COLLATE") example_data_for_sorting <- tibble::tibble( int = c(-.Machine$integer.max, -101L, -100L, 0L, 0L, 1L, 100L, 1000L, .Machine$integer.max, NA_integer_), dbl = c(-Inf, -.Machine$double.xmax, -.Machine$double.xmin, 0, .Machine$double.xmin, pi, .Machine$double.xmax, Inf, NaN, NA_real_), - # R string collation varies by locale, while libarrow always uses the C locale for string collation - # (in other words: string values in libarrow are ordered lexicographically as bytestrings) - # to make R sort functions use the C locale, run Sys.setlocale("LC_COLLATE", "C") chr = c("", "", "\u0001", "&", "ABC", "NULL", "a", "abc", "\uFFFF", NA_character_), - # bool is not supported (ARROW-12016) - lgl = c(rep(FALSE, 4L), rep(TRUE, 4L), rep(NA, 2)), - # TODO: add more types + lgl = c(rep(FALSE, 4L), rep(TRUE, 5L), NA), # bool is not supported (ARROW-12016) + dttm = lubridate::ymd_hms(c( + "0000-01-01 00:00:00", + "1919-05-29 13:08:55", + "1955-06-20 04:10:42", + "1973-06-30 11:38:41", + "1987-03-29 12:49:47", + "1991-06-11 19:07:01", + NA_character_, + "2017-08-21 18:26:40", + "2017-08-21 18:26:40", + "9999-12-31 23:59:59" + )) ) diff --git a/r/tests/testthat/test-dplyr-arrange.R b/r/tests/testthat/test-dplyr-arrange.R index 6500b814126..4c72969958a 100644 --- a/r/tests/testthat/test-dplyr-arrange.R +++ b/r/tests/testthat/test-dplyr-arrange.R @@ -27,6 +27,13 @@ test_that("arrange", { tbl %>% slice_sample(prop = 1L) ) + expect_dplyr_equal( + input %>% + arrange(dttm, int) %>% + collect(), + tbl %>% + slice_sample(prop = 1L) + ) expect_dplyr_equal( input %>% arrange(int, desc(dbl)) %>% From 3fe0c51fd15b5cb88f881bbd58ee1520b012f113 Mon Sep 17 00:00:00 2001 From: Ian Cook Date: Mon, 22 Mar 2021 11:56:55 -0400 Subject: [PATCH 12/49] Rename get_field_name -> get_field_name_of_array_ref --- r/R/dplyr.R | 2 +- r/R/expression.R | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/r/R/dplyr.R b/r/R/dplyr.R index 041d78b9f3a..ad63731faf4 100644 --- a/r/R/dplyr.R +++ b/r/R/dplyr.R @@ -430,7 +430,7 @@ collect.arrow_dplyr_query <- function(x, as_data_frame = TRUE, ...) { if (length(x$arrange_vars) > 0) { x$arrange_vars <- vapply( x$arrange_vars, - get_field_name, + get_field_name_of_array_ref, character(1), msg = function(x) { paste( diff --git a/r/R/expression.R b/r/R/expression.R index f1909ed6b39..f081506f0cd 100644 --- a/r/R/expression.R +++ b/r/R/expression.R @@ -181,7 +181,10 @@ find_array_refs <- function(x) { unlist(out) } -get_field_name <- function(x, msg = NULL) { +# This function takes an array reference as input and returns its field name. +# If the input is not an array reference, then it throws an error, optionally +# using the specified message function or message string. +get_field_name_of_array_ref <- function(x, msg = NULL) { if (!identical(x$fun, "array_ref")) { if (is.null(msg)) { # Default message From 1bc180620e70c4314f2ac63ab23326048dbaf079 Mon Sep 17 00:00:00 2001 From: Ian Cook Date: Tue, 23 Mar 2021 10:15:10 -0400 Subject: [PATCH 13/49] Support na.last in sort.ArrowDatum --- r/R/arrow-datum.R | 13 ++++-- r/tests/testthat/test-compute-sort.R | 62 ++++++++++++++++++++++++---- 2 files changed, 63 insertions(+), 12 deletions(-) diff --git a/r/R/arrow-datum.R b/r/R/arrow-datum.R index dc50f8f8316..12d608be9a2 100644 --- a/r/R/arrow-datum.R +++ b/r/R/arrow-datum.R @@ -140,9 +140,14 @@ as.integer.ArrowDatum <- function(x, ...) as.integer(as.vector(x), ...) as.character.ArrowDatum <- function(x, ...) as.character(as.vector(x), ...) #' @export -sort.ArrowDatum <- function(x, decreasing = FALSE, na.last = TRUE, ...) { - if (!identical(na.last, TRUE)) { - stop("Arrow only supports sort() with na.last = TRUE", call. = FALSE) +sort.ArrowDatum <- function(x, decreasing = FALSE, na.last = NA, ...) { + if (is.na(na.last)) { + x <- x$Filter(!is.na(x)) + x$Take(x$SortIndices(descending = decreasing)) + } else if (na.last) { + x$Take(x$SortIndices(descending = decreasing)) + } else { + x <- Table$create(x = x, isnax = as.integer(is.na(x))) + x$x$Take(x$SortIndices(names = c("isnax", "x"), descending = c(TRUE, decreasing))) } - x$Take(x$SortIndices(descending = decreasing)) } diff --git a/r/tests/testthat/test-compute-sort.R b/r/tests/testthat/test-compute-sort.R index 83d246c484a..9c5496e617c 100644 --- a/r/tests/testthat/test-compute-sort.R +++ b/r/tests/testthat/test-compute-sort.R @@ -39,16 +39,43 @@ test_that("Array sort", { ) expect_equal( as.vector(sort(Array$create(tbl$int))), + sort(tbl$int) + ) + expect_equal( + as.vector(sort(Array$create(tbl$int), na.last = NA)), + sort(tbl$int, na.last = NA) + ) + expect_equal( + as.vector(sort(Array$create(tbl$int), na.last = TRUE)), sort(tbl$int, na.last = TRUE) ) + expect_equal( + as.vector(sort(Array$create(tbl$int), na.last = FALSE)), + sort(tbl$int, na.last = FALSE) + ) expect_equal( as.vector(sort(Array$create(tbl$int), decreasing = TRUE)), + sort(tbl$int, decreasing = TRUE) + ) + expect_equal( + as.vector(sort(Array$create(tbl$int), decreasing = TRUE, na.last = TRUE)), sort(tbl$int, decreasing = TRUE, na.last = TRUE) ) - expect_error( - sort(Array$create(tbl$int), decreasing = TRUE, na.last = NA), - "na.last", - fixed = TRUE + expect_equal( + as.vector(sort(Array$create(tbl$int), decreasing = TRUE, na.last = FALSE)), + sort(tbl$int, decreasing = TRUE, na.last = FALSE) + ) + expect_equal( + as.vector(sort(Array$create(tbl$chr), decreasing = TRUE, na.last = FALSE)), + sort(tbl$chr, decreasing = TRUE, na.last = FALSE) + ) +}) + +test_that("Array sort treats NaN as NA", { + skip("ARROW-12055") + expect_equal( + as.vector(sort(Array$create(tbl$dbl), decreasing = TRUE, na.last = FALSE)), + sort(tbl$dbl, decreasing = TRUE, na.last = FALSE) ) }) @@ -63,16 +90,35 @@ test_that("ChunkedArray sort", { ) expect_equal( as.vector(sort(ChunkedArray$create(tbl$int[1:5], tbl$int[6:10]))), + sort(tbl$int) + ) + expect_equal( + as.vector(sort(ChunkedArray$create(tbl$int[1:5], tbl$int[6:10]), na.last = NA)), + sort(tbl$int, na.last = NA) + ) + expect_equal( + as.vector(sort(ChunkedArray$create(tbl$int[1:5], tbl$int[6:10]), na.last = TRUE)), sort(tbl$int, na.last = TRUE) ) + expect_equal( + as.vector(sort(ChunkedArray$create(tbl$int[1:5], tbl$int[6:10]), na.last = FALSE)), + sort(tbl$int, na.last = FALSE) + ) expect_equal( as.vector(sort(ChunkedArray$create(tbl$int[1:5], tbl$int[6:10]), decreasing = TRUE)), + sort(tbl$int, decreasing = TRUE) + ) + expect_equal( + as.vector(sort(ChunkedArray$create(tbl$int[1:5], tbl$int[6:10]), decreasing = TRUE, na.last = TRUE)), sort(tbl$int, decreasing = TRUE, na.last = TRUE) ) - expect_error( - sort(ChunkedArray$create(tbl$int), decreasing = TRUE, na.last = NA), - "na.last", - fixed = TRUE + expect_equal( + as.vector(sort(ChunkedArray$create(tbl$int[1:5], tbl$int[6:10]), decreasing = TRUE, na.last = FALSE)), + sort(tbl$int, decreasing = TRUE, na.last = FALSE) + ) + expect_equal( + as.vector(sort(ChunkedArray$create(tbl$chr[1:5], tbl$chr[6:10]), decreasing = TRUE, na.last = FALSE)), + sort(tbl$chr, decreasing = TRUE, na.last = FALSE) ) }) From 154cac6c0f00286b7fb08e59f6bf9d6f79715cf1 Mon Sep 17 00:00:00 2001 From: Ian Cook Date: Tue, 23 Mar 2021 14:33:04 -0400 Subject: [PATCH 14/49] Fix failing tests --- r/NAMESPACE | 2 ++ r/R/arrow-package.R | 2 +- r/tests/testthat/test-compute-sort.R | 32 +++++++++++++++++----------- r/tests/testthat/test-dplyr.R | 11 ---------- 4 files changed, 23 insertions(+), 24 deletions(-) diff --git a/r/NAMESPACE b/r/NAMESPACE index e0116d2f5e4..ce0d24445a6 100644 --- a/r/NAMESPACE +++ b/r/NAMESPACE @@ -280,6 +280,7 @@ importFrom(rlang,"%||%") importFrom(rlang,.data) importFrom(rlang,abort) importFrom(rlang,as_label) +importFrom(rlang,as_name) importFrom(rlang,dots_n) importFrom(rlang,enquo) importFrom(rlang,enquos) @@ -293,6 +294,7 @@ importFrom(rlang,is_integerish) importFrom(rlang,list2) importFrom(rlang,new_data_mask) importFrom(rlang,new_environment) +importFrom(rlang,quo_get_expr) importFrom(rlang,quo_is_null) importFrom(rlang,quos) importFrom(rlang,set_names) diff --git a/r/R/arrow-package.R b/r/R/arrow-package.R index c1d76abfd71..387bb2db5f3 100644 --- a/r/R/arrow-package.R +++ b/r/R/arrow-package.R @@ -18,7 +18,7 @@ #' @importFrom R6 R6Class #' @importFrom purrr as_mapper map map2 map_chr map_dfr map_int map_lgl keep #' @importFrom assertthat assert_that is.string -#' @importFrom rlang list2 %||% is_false abort dots_n warn enquo quo_is_null enquos is_integerish quos eval_tidy new_data_mask syms env new_environment env_bind as_label set_names exec is_bare_character +#' @importFrom rlang list2 %||% is_false abort dots_n warn enquo quo_is_null enquos is_integerish quos eval_tidy new_data_mask syms env new_environment env_bind as_label set_names exec is_bare_character quo_get_expr #' @importFrom tidyselect vars_select #' @useDynLib arrow, .registration = TRUE #' @keywords internal diff --git a/r/tests/testthat/test-compute-sort.R b/r/tests/testthat/test-compute-sort.R index 9c5496e617c..f46e671840f 100644 --- a/r/tests/testthat/test-compute-sort.R +++ b/r/tests/testthat/test-compute-sort.R @@ -65,18 +65,6 @@ test_that("Array sort", { as.vector(sort(Array$create(tbl$int), decreasing = TRUE, na.last = FALSE)), sort(tbl$int, decreasing = TRUE, na.last = FALSE) ) - expect_equal( - as.vector(sort(Array$create(tbl$chr), decreasing = TRUE, na.last = FALSE)), - sort(tbl$chr, decreasing = TRUE, na.last = FALSE) - ) -}) - -test_that("Array sort treats NaN as NA", { - skip("ARROW-12055") - expect_equal( - as.vector(sort(Array$create(tbl$dbl), decreasing = TRUE, na.last = FALSE)), - sort(tbl$dbl, decreasing = TRUE, na.last = FALSE) - ) }) test_that("ChunkedArray sort", { @@ -116,12 +104,32 @@ test_that("ChunkedArray sort", { as.vector(sort(ChunkedArray$create(tbl$int[1:5], tbl$int[6:10]), decreasing = TRUE, na.last = FALSE)), sort(tbl$int, decreasing = TRUE, na.last = FALSE) ) +}) + +test_that("Array/ChunkedArray sort on strings", { + skip_on_os("windows") # multibyte string errors + expect_equal( + as.vector(sort(Array$create(tbl$chr), decreasing = TRUE, na.last = FALSE)), + sort(tbl$chr, decreasing = TRUE, na.last = FALSE) + ) expect_equal( as.vector(sort(ChunkedArray$create(tbl$chr[1:5], tbl$chr[6:10]), decreasing = TRUE, na.last = FALSE)), sort(tbl$chr, decreasing = TRUE, na.last = FALSE) ) }) +test_that("Array/ChunkedArray sort treats NaN as NA", { + skip("ARROW-12055") + expect_equal( + as.vector(sort(Array$create(tbl$dbl), decreasing = TRUE, na.last = FALSE)), + sort(tbl$dbl, decreasing = TRUE, na.last = FALSE) + ) + expect_equal( + as.vector(sort(ChunkedArray$create(tbl$dbl[1:5], tbl$dbl[6:10]), decreasing = TRUE, na.last = FALSE)), + sort(tbl$dbl, decreasing = TRUE, na.last = FALSE) + ) +}) + test_that("Table/RecordBatch sort", { expect_identical( { diff --git a/r/tests/testthat/test-dplyr.R b/r/tests/testthat/test-dplyr.R index 460d4bbdba5..c1d9e457464 100644 --- a/r/tests/testthat/test-dplyr.R +++ b/r/tests/testthat/test-dplyr.R @@ -134,17 +134,6 @@ test_that("Empty select still includes the group_by columns", { ) }) -test_that("arrange", { - expect_dplyr_equal( - input %>% - group_by(chr) %>% - select(int, chr) %>% - arrange(desc(int)) %>% - collect(), - tbl - ) -}) - test_that("select/rename", { expect_dplyr_equal( input %>% From 4e880236ecbe418b87763ca417295ec87638984a Mon Sep 17 00:00:00 2001 From: Ian Cook Date: Tue, 23 Mar 2021 14:48:40 -0400 Subject: [PATCH 15/49] Fix failing tests --- r/NAMESPACE | 2 +- r/R/arrow-package.R | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/r/NAMESPACE b/r/NAMESPACE index ce0d24445a6..725d441b3e1 100644 --- a/r/NAMESPACE +++ b/r/NAMESPACE @@ -280,7 +280,6 @@ importFrom(rlang,"%||%") importFrom(rlang,.data) importFrom(rlang,abort) importFrom(rlang,as_label) -importFrom(rlang,as_name) importFrom(rlang,dots_n) importFrom(rlang,enquo) importFrom(rlang,enquos) @@ -296,6 +295,7 @@ importFrom(rlang,new_data_mask) importFrom(rlang,new_environment) importFrom(rlang,quo_get_expr) importFrom(rlang,quo_is_null) +importFrom(rlang,quo_set_expr) importFrom(rlang,quos) importFrom(rlang,set_names) importFrom(rlang,syms) diff --git a/r/R/arrow-package.R b/r/R/arrow-package.R index 387bb2db5f3..9ced0e88449 100644 --- a/r/R/arrow-package.R +++ b/r/R/arrow-package.R @@ -18,7 +18,7 @@ #' @importFrom R6 R6Class #' @importFrom purrr as_mapper map map2 map_chr map_dfr map_int map_lgl keep #' @importFrom assertthat assert_that is.string -#' @importFrom rlang list2 %||% is_false abort dots_n warn enquo quo_is_null enquos is_integerish quos eval_tidy new_data_mask syms env new_environment env_bind as_label set_names exec is_bare_character quo_get_expr +#' @importFrom rlang list2 %||% is_false abort dots_n warn enquo quo_is_null enquos is_integerish quos eval_tidy new_data_mask syms env new_environment env_bind as_label set_names exec is_bare_character quo_get_expr quo_set_expr #' @importFrom tidyselect vars_select #' @useDynLib arrow, .registration = TRUE #' @keywords internal From eeb1db06d7f12796535f358ff85c52ac56041a6f Mon Sep 17 00:00:00 2001 From: Ian Cook Date: Tue, 23 Mar 2021 15:51:31 -0400 Subject: [PATCH 16/49] Fix failing tests --- r/tests/testthat/helper-data.R | 2 +- r/tests/testthat/test-compute-sort.R | 7 +++---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/r/tests/testthat/helper-data.R b/r/tests/testthat/helper-data.R index 22f3b4263fd..79d91dac1ce 100644 --- a/r/tests/testthat/helper-data.R +++ b/r/tests/testthat/helper-data.R @@ -146,7 +146,7 @@ example_with_logical_factors <- tibble::tibble( example_data_for_sorting <- tibble::tibble( int = c(-.Machine$integer.max, -101L, -100L, 0L, 0L, 1L, 100L, 1000L, .Machine$integer.max, NA_integer_), dbl = c(-Inf, -.Machine$double.xmax, -.Machine$double.xmin, 0, .Machine$double.xmin, pi, .Machine$double.xmax, Inf, NaN, NA_real_), - chr = c("", "", "\u0001", "&", "ABC", "NULL", "a", "abc", "\uFFFF", NA_character_), + chr = c("", "", "\u0001", "&", "ABC", "NULL", "a", "abc", "ZZZZZ", NA_character_), lgl = c(rep(FALSE, 4L), rep(TRUE, 5L), NA), # bool is not supported (ARROW-12016) dttm = lubridate::ymd_hms(c( "0000-01-01 00:00:00", diff --git a/r/tests/testthat/test-compute-sort.R b/r/tests/testthat/test-compute-sort.R index f46e671840f..6abb11659bc 100644 --- a/r/tests/testthat/test-compute-sort.R +++ b/r/tests/testthat/test-compute-sort.R @@ -28,7 +28,7 @@ test_that("Scalar sort", { ) }) -test_that("Array sort", { +test_that("Array sort on integers", { expect_equal( Array$create(tbl$int)$SortIndices(), Array$create(0L:9L, type = uint64()) @@ -67,7 +67,7 @@ test_that("Array sort", { ) }) -test_that("ChunkedArray sort", { +test_that("ChunkedArray sort on integers", { expect_equal( ChunkedArray$create(tbl$int[1:5], tbl$int[6:10])$SortIndices(), Array$create(0L:9L, type = uint64()) @@ -107,7 +107,6 @@ test_that("ChunkedArray sort", { }) test_that("Array/ChunkedArray sort on strings", { - skip_on_os("windows") # multibyte string errors expect_equal( as.vector(sort(Array$create(tbl$chr), decreasing = TRUE, na.last = FALSE)), sort(tbl$chr, decreasing = TRUE, na.last = FALSE) @@ -118,7 +117,7 @@ test_that("Array/ChunkedArray sort on strings", { ) }) -test_that("Array/ChunkedArray sort treats NaN as NA", { +test_that("Array/ChunkedArray sort on floats", { skip("ARROW-12055") expect_equal( as.vector(sort(Array$create(tbl$dbl), decreasing = TRUE, na.last = FALSE)), From 82dc4ff4461e6d3abb6e4465ca063967b628813f Mon Sep 17 00:00:00 2001 From: Ian Cook Date: Tue, 23 Mar 2021 16:17:54 -0400 Subject: [PATCH 17/49] Support expressions in arrange() --- r/NAMESPACE | 1 + r/R/arrow-package.R | 2 +- r/R/dplyr.R | 44 +++++++++++++++++++++++++------------------- r/R/expression.R | 17 ----------------- 4 files changed, 27 insertions(+), 37 deletions(-) diff --git a/r/NAMESPACE b/r/NAMESPACE index 725d441b3e1..d7d6d6ce825 100644 --- a/r/NAMESPACE +++ b/r/NAMESPACE @@ -280,6 +280,7 @@ importFrom(rlang,"%||%") importFrom(rlang,.data) importFrom(rlang,abort) importFrom(rlang,as_label) +importFrom(rlang,as_name) importFrom(rlang,dots_n) importFrom(rlang,enquo) importFrom(rlang,enquos) diff --git a/r/R/arrow-package.R b/r/R/arrow-package.R index 9ced0e88449..dfb6392334d 100644 --- a/r/R/arrow-package.R +++ b/r/R/arrow-package.R @@ -18,7 +18,7 @@ #' @importFrom R6 R6Class #' @importFrom purrr as_mapper map map2 map_chr map_dfr map_int map_lgl keep #' @importFrom assertthat assert_that is.string -#' @importFrom rlang list2 %||% is_false abort dots_n warn enquo quo_is_null enquos is_integerish quos eval_tidy new_data_mask syms env new_environment env_bind as_label set_names exec is_bare_character quo_get_expr quo_set_expr +#' @importFrom rlang list2 %||% is_false abort dots_n warn enquo quo_is_null enquos is_integerish quos eval_tidy new_data_mask syms env new_environment env_bind as_label set_names exec is_bare_character quo_get_expr quo_set_expr as_name #' @importFrom tidyselect vars_select #' @useDynLib arrow, .registration = TRUE #' @keywords internal diff --git a/r/R/dplyr.R b/r/R/dplyr.R index ad63731faf4..a479443c832 100644 --- a/r/R/dplyr.R +++ b/r/R/dplyr.R @@ -397,6 +397,7 @@ set_filters <- function(.data, expressions) { collect.arrow_dplyr_query <- function(x, as_data_frame = TRUE, ...) { x <- ensure_group_vars(x) + x <- ensure_arrange_vars(x) # this sets x$temp_columns # Pull only the selected rows and cols into R if (query_on_dataset(x)) { # See dataset.R for Dataset and Scanner(Builder) classes @@ -410,10 +411,14 @@ collect.arrow_dplyr_query <- function(x, as_data_frame = TRUE, ...) { } else { filter <- eval_array_expression(x$filtered_rows, x$.data) } - # TODO: shortcut if identical(names(x$.data), find_array_refs(x$selected_columns))? - tab <- x$.data[filter, find_array_refs(x$selected_columns), keep_na = FALSE] + # TODO: shortcut if identical(names(x$.data), find_array_refs(c(x$selected_columns, x$temp_columns)))? + tab <- x$.data[ + filter, + find_array_refs(c(x$selected_columns, x$temp_columns)), + keep_na = FALSE + ] # Now evaluate those expressions on the filtered table - cols <- lapply(x$selected_columns, eval_array_expression, data = tab) + cols <- lapply(c(x$selected_columns, x$temp_columns), eval_array_expression, data = tab) if (length(cols) == 0) { tab <- tab[, integer(0)] } else { @@ -424,23 +429,14 @@ collect.arrow_dplyr_query <- function(x, as_data_frame = TRUE, ...) { } } # Arrange rows - # TODO: support sorting by expressions, not just field names - # TODO: support sorting by names and expressions that are valid *before* the - # selected_columns code above but are *not* valid after it if (length(x$arrange_vars) > 0) { - x$arrange_vars <- vapply( - x$arrange_vars, - get_field_name_of_array_ref, - character(1), - msg = function(x) { - paste( - .format_array_expression(x), - "is not a field name. Only bare field names are supported in arrange()" - ) - }, - USE.NAMES = FALSE - ) - tab <- tab[tab$SortIndices(x$arrange_vars, x$arrange_desc), ] + x$arrange_vars <- get_field_names(x$arrange_vars) + tab <- tab[ + tab$SortIndices(names(x$arrange_vars), x$arrange_desc), + names(x$selected_columns), # need this here to remove x$temp_columns + drop = FALSE + ] + x$temp_columns <- NULL } } if (as_data_frame) { @@ -470,6 +466,12 @@ ensure_group_vars <- function(x) { x } +ensure_arrange_vars <- function(x) { + # Make sure all arrange vars are temporarily in the projection + x$temp_columns <- x$arrange_vars[!names(x$arrange_vars) %in% names(x$selected_columns)] + x +} + restore_dplyr_features <- function(df, query) { # An arrow_dplyr_query holds some attributes that Arrow doesn't know about # After calling collect(), make sure these features are carried over @@ -765,6 +767,10 @@ arrange.arrow_dplyr_query <- function(.data, ..., .by_group = FALSE) { x <- find_and_remove_desc(exprs[[i]]) exprs[[i]] <- x[["quos"]] sorts[[i]] <- arrow_eval(exprs[[i]], mask) + names(sorts)[i] <- tryCatch( + expr = as_name(exprs[[i]]), + error = function(x) as_label(exprs[[i]]) + ) descs[i] <- x[["desc"]] } bad_sorts <- map_lgl(sorts, ~inherits(., "try-error")) diff --git a/r/R/expression.R b/r/R/expression.R index f081506f0cd..3f79c92dd46 100644 --- a/r/R/expression.R +++ b/r/R/expression.R @@ -181,23 +181,6 @@ find_array_refs <- function(x) { unlist(out) } -# This function takes an array reference as input and returns its field name. -# If the input is not an array reference, then it throws an error, optionally -# using the specified message function or message string. -get_field_name_of_array_ref <- function(x, msg = NULL) { - if (!identical(x$fun, "array_ref")) { - if (is.null(msg)) { - # Default message - stop(paste(.format_array_expression(x), "is not a bare field name"), call. = FALSE) - } else if (is.function(msg)) { - stop(msg(x), call. = FALSE) - } else { - stop(msg, call. = FALSE) - } - } - x$args$field_name -} - # Take an array_expression and replace array_refs with arrays/chunkedarrays from data bind_array_refs <- function(x, data) { if (inherits(x, "array_expression")) { From 27025e01a48d0eb614ce46cd37320d7719019e16 Mon Sep 17 00:00:00 2001 From: Ian Cook Date: Tue, 23 Mar 2021 16:19:17 -0400 Subject: [PATCH 18/49] Do arrange() locally when unsupported expression --- r/R/dplyr.R | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/r/R/dplyr.R b/r/R/dplyr.R index a479443c832..e8aa04ebccc 100644 --- a/r/R/dplyr.R +++ b/r/R/dplyr.R @@ -767,22 +767,16 @@ arrange.arrow_dplyr_query <- function(.data, ..., .by_group = FALSE) { x <- find_and_remove_desc(exprs[[i]]) exprs[[i]] <- x[["quos"]] sorts[[i]] <- arrow_eval(exprs[[i]], mask) + if (inherits(sorts[[i]], "try-error")) { + msg <- paste('Expression', as_label(exprs[[i]]), 'not supported in Arrow') + return(abandon_ship(call, .data, msg)) + } names(sorts)[i] <- tryCatch( expr = as_name(exprs[[i]]), error = function(x) as_label(exprs[[i]]) ) descs[i] <- x[["desc"]] } - bad_sorts <- map_lgl(sorts, ~inherits(., "try-error")) - if (any(bad_sorts)) { - bads <- oxford_paste(map_chr(exprs, as_label)[bad_sorts], quote = FALSE) - stop( - "Invalid or unsupported arrange ", - ngettext(length(bads), "expression: ", "expressions: "), - bads, - call. = FALSE - ) - } .data$arrange_vars <- c(sorts, .data$arrange_vars) .data$arrange_desc <- c(descs, .data$arrange_desc) .data From 791064ee9a2bda1c5efb0e34a6205b9aba421f22 Mon Sep 17 00:00:00 2001 From: Ian Cook Date: Tue, 23 Mar 2021 16:20:56 -0400 Subject: [PATCH 19/49] Update tests --- r/tests/testthat/test-dplyr-arrange.R | 47 ++++++++++++++++++++++++--- 1 file changed, 43 insertions(+), 4 deletions(-) diff --git a/r/tests/testthat/test-dplyr-arrange.R b/r/tests/testthat/test-dplyr-arrange.R index 4c72969958a..eb4cb0618c4 100644 --- a/r/tests/testthat/test-dplyr-arrange.R +++ b/r/tests/testthat/test-dplyr-arrange.R @@ -49,13 +49,40 @@ test_that("arrange", { tbl %>% slice_sample(prop = 1L) ) - expect_error( + expect_dplyr_equal( tbl %>% Table$create() %>% arrange(int + dbl) %>% collect(), - "Only bare field names are supported in arrange", - fixed = TRUE + tbl %>% + slice_sample(prop = 1L) + ) + expect_dplyr_equal( + tbl %>% + Table$create() %>% + mutate(zzz = int + dbl) %>% + arrange(zzz) %>% + collect(), + tbl %>% + slice_sample(prop = 1L) + ) + expect_dplyr_equal( + tbl %>% + Table$create() %>% + mutate(zzz = int + dbl) %>% + arrange(int + dbl) %>% + collect(), + tbl %>% + slice_sample(prop = 1L) + ) + expect_dplyr_equal( + tbl %>% + Table$create() %>% + mutate(int + dbl) %>% + arrange(int + dbl) %>% + collect(), + tbl %>% + slice_sample(prop = 1L) ) expect_error( tbl %>% @@ -64,5 +91,17 @@ test_that("arrange", { "does not contain any field names", fixed = TRUE ) - # TODO: test the other unsupported cases + expect_warning( + expect_dplyr_equal( + tbl %>% + Table$create() %>% + arrange(abs(int)) %>% + collect(), + tbl %>% + slice_sample(prop = 1L) + ), + "not supported in Arrow", + fixed = TRUE + ) + # TODO: test the other unsupported cases and error conditions }) From 9593e8941c40132a131c67cf33959ab65527e506 Mon Sep 17 00:00:00 2001 From: Ian Cook Date: Tue, 23 Mar 2021 16:21:17 -0400 Subject: [PATCH 20/49] Add TODO --- r/R/chunked-array.R | 2 ++ 1 file changed, 2 insertions(+) diff --git a/r/R/chunked-array.R b/r/R/chunked-array.R index a7fd8d3143e..a7f9c8f790c 100644 --- a/r/R/chunked-array.R +++ b/r/R/chunked-array.R @@ -89,6 +89,8 @@ ChunkedArray <- R6Class("ChunkedArray", inherit = ArrowDatum, assert_that(is.logical(descending)) assert_that(length(descending) == 1L) assert_that(!is.na(descending)) + # TODO: after ARROW-12042 is closed, review whether this and the + # Array$SortIndices definition can be consolidated call_function( "sort_indices", self, From 660414438fee11e209a27679e32591ba9b341ed6 Mon Sep 17 00:00:00 2001 From: Ian Cook Date: Tue, 23 Mar 2021 17:32:30 -0400 Subject: [PATCH 21/49] Support arrange(.by_group = TRUE) --- r/R/dplyr.R | 17 +++-------------- 1 file changed, 3 insertions(+), 14 deletions(-) diff --git a/r/R/dplyr.R b/r/R/dplyr.R index e8aa04ebccc..7f793cc8bfd 100644 --- a/r/R/dplyr.R +++ b/r/R/dplyr.R @@ -730,34 +730,23 @@ abandon_ship <- function(call, .data, msg = NULL) { } arrange.arrow_dplyr_query <- function(.data, ..., .by_group = FALSE) { - # TODO: handle .by_group argument - if (!isFALSE(.by_group)) { - stop(".by_group argument not supported for Arrow objects", call. = FALSE) - } - call <- match.call() exprs <- quos(...) - if (length(exprs) == 0) { # Nothing to do return(.data) } - .data <- arrow_dplyr_query(.data) - - # Restrict the cases we support for now if (query_on_dataset(.data)) { not_implemented_for_dataset("arrange()") } - is_grouped <- length(dplyr::group_vars(.data)) > 0 - if (is_grouped) { - return(abandon_ship(call, .data, 'arrange() on grouped data not supported in Arrow')) - } is_dataset <- query_on_dataset(.data) if (is_dataset) { return(abandon_ship(call, .data)) } - + if (.by_group) { + exprs <- c(quos(!!!dplyr::groups(.data)), exprs) + } # find and remove any dplyr::desc() and tidy-eval # the arrange expressions inside an Arrow data_mask sorts <- vector("list", length(exprs)) From 50c6992de998d853e3739e1132afbdbfd2f9260a Mon Sep 17 00:00:00 2001 From: Ian Cook Date: Tue, 23 Mar 2021 17:33:02 -0400 Subject: [PATCH 22/49] Fix and add tests --- r/tests/testthat/helper-data.R | 3 +- r/tests/testthat/test-dplyr-arrange.R | 72 +++++++++++++++++++-------- 2 files changed, 53 insertions(+), 22 deletions(-) diff --git a/r/tests/testthat/helper-data.R b/r/tests/testthat/helper-data.R index 79d91dac1ce..0762226f6b6 100644 --- a/r/tests/testthat/helper-data.R +++ b/r/tests/testthat/helper-data.R @@ -159,5 +159,6 @@ example_data_for_sorting <- tibble::tibble( "2017-08-21 18:26:40", "2017-08-21 18:26:40", "9999-12-31 23:59:59" - )) + )), + grp = c(rep("A", 5), rep("B", 5)) ) diff --git a/r/tests/testthat/test-dplyr-arrange.R b/r/tests/testthat/test-dplyr-arrange.R index eb4cb0618c4..592cb0d631c 100644 --- a/r/tests/testthat/test-dplyr-arrange.R +++ b/r/tests/testthat/test-dplyr-arrange.R @@ -41,6 +41,13 @@ test_that("arrange", { tbl %>% slice_sample(prop = 1L) ) + expect_dplyr_equal( + input %>% + arrange(int, desc(desc(dbl))) %>% + collect(), + tbl %>% + slice_sample(prop = 1L) + ) expect_dplyr_equal( input %>% arrange(int) %>% @@ -50,58 +57,81 @@ test_that("arrange", { slice_sample(prop = 1L) ) expect_dplyr_equal( - tbl %>% - Table$create() %>% - arrange(int + dbl) %>% + input %>% + arrange(int + dbl, chr) %>% collect(), tbl %>% slice_sample(prop = 1L) ) expect_dplyr_equal( + input %>% + mutate(zzz = int + dbl,) %>% + arrange(zzz, chr) %>% + collect(), tbl %>% - Table$create() %>% + slice_sample(prop = 1L) + ) + expect_dplyr_equal( + input %>% mutate(zzz = int + dbl) %>% - arrange(zzz) %>% + arrange(int + dbl, chr) %>% collect(), tbl %>% slice_sample(prop = 1L) ) expect_dplyr_equal( - tbl %>% - Table$create() %>% - mutate(zzz = int + dbl) %>% - arrange(int + dbl) %>% + input %>% + mutate(int + dbl) %>% + arrange(int + dbl, chr) %>% collect(), tbl %>% slice_sample(prop = 1L) ) expect_dplyr_equal( + input %>% + group_by(grp) %>% + arrange(int, dbl) %>% + collect(), tbl %>% - Table$create() %>% - mutate(int + dbl) %>% - arrange(int + dbl) %>% + slice_sample(prop = 1L) + ) + expect_dplyr_equal( + input %>% + group_by(grp) %>% + arrange(int, dbl, .by_group = TRUE) %>% collect(), tbl %>% slice_sample(prop = 1L) ) - expect_error( + expect_dplyr_equal( + input %>% + group_by(grp, grp2) %>% + arrange(int, dbl, .by_group = TRUE) %>% + collect(), tbl %>% - Table$create() %>% - arrange(1), - "does not contain any field names", - fixed = TRUE + mutate(grp2 = ifelse(is.na(lgl), 1L, as.integer(lgl))) %>% + slice_sample(prop = 1L) ) expect_warning( - expect_dplyr_equal( + expect_equal( tbl %>% + slice_sample(prop = 1L) %>% Table$create() %>% - arrange(abs(int)) %>% + arrange(abs(int), dbl) %>% collect(), tbl %>% - slice_sample(prop = 1L) + slice_sample(prop = 1L) %>% + arrange(abs(int), dbl) %>% + collect() ), "not supported in Arrow", fixed = TRUE ) - # TODO: test the other unsupported cases and error conditions + expect_error( + tbl %>% + Table$create() %>% + arrange(1), + "does not contain any field names", + fixed = TRUE + ) }) From 9e87b265d7ab89845239a68caff8e34635809761 Mon Sep 17 00:00:00 2001 From: Ian Cook Date: Tue, 23 Mar 2021 17:46:01 -0400 Subject: [PATCH 23/49] Handle case when empty dots and .group_by = TRUE --- r/R/dplyr.R | 6 +++--- r/tests/testthat/test-dplyr-arrange.R | 8 ++++++++ 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/r/R/dplyr.R b/r/R/dplyr.R index 7f793cc8bfd..1737724969b 100644 --- a/r/R/dplyr.R +++ b/r/R/dplyr.R @@ -732,6 +732,9 @@ abandon_ship <- function(call, .data, msg = NULL) { arrange.arrow_dplyr_query <- function(.data, ..., .by_group = FALSE) { call <- match.call() exprs <- quos(...) + if (.by_group) { + exprs <- c(quos(!!!dplyr::groups(.data)), exprs) + } if (length(exprs) == 0) { # Nothing to do return(.data) @@ -744,9 +747,6 @@ arrange.arrow_dplyr_query <- function(.data, ..., .by_group = FALSE) { if (is_dataset) { return(abandon_ship(call, .data)) } - if (.by_group) { - exprs <- c(quos(!!!dplyr::groups(.data)), exprs) - } # find and remove any dplyr::desc() and tidy-eval # the arrange expressions inside an Arrow data_mask sorts <- vector("list", length(exprs)) diff --git a/r/tests/testthat/test-dplyr-arrange.R b/r/tests/testthat/test-dplyr-arrange.R index 592cb0d631c..8d64a8505e2 100644 --- a/r/tests/testthat/test-dplyr-arrange.R +++ b/r/tests/testthat/test-dplyr-arrange.R @@ -112,6 +112,14 @@ test_that("arrange", { mutate(grp2 = ifelse(is.na(lgl), 1L, as.integer(lgl))) %>% slice_sample(prop = 1L) ) + expect_dplyr_equal( + input %>% + group_by(grp) %>% + arrange(.by_group = TRUE) %>% + pull(grp), + tbl %>% + slice_sample(prop = 1L) + ) expect_warning( expect_equal( tbl %>% From d7dba8e6476b5090987fd6e8df524f1171378788 Mon Sep 17 00:00:00 2001 From: Ian Cook Date: Tue, 23 Mar 2021 17:53:36 -0400 Subject: [PATCH 24/49] More tests for edge cases --- r/tests/testthat/test-dplyr-arrange.R | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/r/tests/testthat/test-dplyr-arrange.R b/r/tests/testthat/test-dplyr-arrange.R index 8d64a8505e2..9e7c0fb762c 100644 --- a/r/tests/testthat/test-dplyr-arrange.R +++ b/r/tests/testthat/test-dplyr-arrange.R @@ -120,6 +120,19 @@ test_that("arrange", { tbl %>% slice_sample(prop = 1L) ) + expect_dplyr_equal( + input %>% + arrange() %>% + collect(), + tbl %>% + group_by(grp) + ) + expect_dplyr_equal( + input %>% + arrange() %>% + collect(), + tbl + ) expect_warning( expect_equal( tbl %>% From 07324dc00c4702fcb2623615d00e3a40efe174dc Mon Sep 17 00:00:00 2001 From: Ian Cook Date: Tue, 23 Mar 2021 21:24:59 -0400 Subject: [PATCH 25/49] Fix and add tests --- r/tests/testthat/helper-data.R | 2 +- r/tests/testthat/test-dplyr-arrange.R | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/r/tests/testthat/helper-data.R b/r/tests/testthat/helper-data.R index 0762226f6b6..1d922212e95 100644 --- a/r/tests/testthat/helper-data.R +++ b/r/tests/testthat/helper-data.R @@ -146,7 +146,7 @@ example_with_logical_factors <- tibble::tibble( example_data_for_sorting <- tibble::tibble( int = c(-.Machine$integer.max, -101L, -100L, 0L, 0L, 1L, 100L, 1000L, .Machine$integer.max, NA_integer_), dbl = c(-Inf, -.Machine$double.xmax, -.Machine$double.xmin, 0, .Machine$double.xmin, pi, .Machine$double.xmax, Inf, NaN, NA_real_), - chr = c("", "", "\u0001", "&", "ABC", "NULL", "a", "abc", "ZZZZZ", NA_character_), + chr = c("", "", "\u0001", "&", "ABC", "NULL", "a", "abc", "zzz", NA_character_), lgl = c(rep(FALSE, 4L), rep(TRUE, 5L), NA), # bool is not supported (ARROW-12016) dttm = lubridate::ymd_hms(c( "0000-01-01 00:00:00", diff --git a/r/tests/testthat/test-dplyr-arrange.R b/r/tests/testthat/test-dplyr-arrange.R index 9e7c0fb762c..5f6cd70c63d 100644 --- a/r/tests/testthat/test-dplyr-arrange.R +++ b/r/tests/testthat/test-dplyr-arrange.R @@ -127,6 +127,13 @@ test_that("arrange", { tbl %>% group_by(grp) ) + expect_dplyr_equal( + input %>% + group_by(grp) %>% + arrange() %>% + collect(), + tbl + ) expect_dplyr_equal( input %>% arrange() %>% From d84a65a038eed4c2844d162bb5e03e4718c7fb15 Mon Sep 17 00:00:00 2001 From: Ian Cook Date: Tue, 23 Mar 2021 22:11:48 -0400 Subject: [PATCH 26/49] Add comments --- r/R/dplyr.R | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/r/R/dplyr.R b/r/R/dplyr.R index 1737724969b..d861381b0b9 100644 --- a/r/R/dplyr.R +++ b/r/R/dplyr.R @@ -47,7 +47,8 @@ arrow_dplyr_query <- function(.data) { # groups formed by factor levels that don't appear in the data. It # should be non-null only when the data is grouped. drop_empty_groups = NULL, - # arrange_vars will be a list of expressions + # arrange_vars will be a list of expressions named by their associated + # column names arrange_vars = list(), # arrange_desc will be a logical vector indicating the sort order for each # expression in arrange_vars (FALSE for ascending, TRUE for descending) @@ -433,7 +434,7 @@ collect.arrow_dplyr_query <- function(x, as_data_frame = TRUE, ...) { x$arrange_vars <- get_field_names(x$arrange_vars) tab <- tab[ tab$SortIndices(names(x$arrange_vars), x$arrange_desc), - names(x$selected_columns), # need this here to remove x$temp_columns + names(x$selected_columns), # this omits x$temp_columns from the result drop = FALSE ] x$temp_columns <- NULL @@ -467,7 +468,15 @@ ensure_group_vars <- function(x) { } ensure_arrange_vars <- function(x) { - # Make sure all arrange vars are temporarily in the projection + # The arrange() operation is not performed until later, because: + # - It must be performed after mutate(), to enable sorting by new columns. + # - It should be performed after filter() and select(), for efficiency. + # However, we need users to be able to arrange() by columns and expressions + # that are *not* returned in the query result. To enable this, we must + # *temporarily* include these columns and expressions in the projection. We + # use x$temp_columns to store these. Later, after the arrange() operation has + # been performed, these are omitted from the result. This differs from the + # columns in x$group_by_vars which *are* returned in the result. x$temp_columns <- x$arrange_vars[!names(x$arrange_vars) %in% names(x$selected_columns)] x } From 6f8e8bf9d41ff135a94a75d5f6483eb6ac6ebb68 Mon Sep 17 00:00:00 2001 From: Ian Cook Date: Tue, 23 Mar 2021 22:14:40 -0400 Subject: [PATCH 27/49] Remove usage of rlang::as_name --- r/NAMESPACE | 1 - r/R/arrow-package.R | 2 +- r/R/dplyr.R | 5 +---- 3 files changed, 2 insertions(+), 6 deletions(-) diff --git a/r/NAMESPACE b/r/NAMESPACE index d7d6d6ce825..725d441b3e1 100644 --- a/r/NAMESPACE +++ b/r/NAMESPACE @@ -280,7 +280,6 @@ importFrom(rlang,"%||%") importFrom(rlang,.data) importFrom(rlang,abort) importFrom(rlang,as_label) -importFrom(rlang,as_name) importFrom(rlang,dots_n) importFrom(rlang,enquo) importFrom(rlang,enquos) diff --git a/r/R/arrow-package.R b/r/R/arrow-package.R index dfb6392334d..9ced0e88449 100644 --- a/r/R/arrow-package.R +++ b/r/R/arrow-package.R @@ -18,7 +18,7 @@ #' @importFrom R6 R6Class #' @importFrom purrr as_mapper map map2 map_chr map_dfr map_int map_lgl keep #' @importFrom assertthat assert_that is.string -#' @importFrom rlang list2 %||% is_false abort dots_n warn enquo quo_is_null enquos is_integerish quos eval_tidy new_data_mask syms env new_environment env_bind as_label set_names exec is_bare_character quo_get_expr quo_set_expr as_name +#' @importFrom rlang list2 %||% is_false abort dots_n warn enquo quo_is_null enquos is_integerish quos eval_tidy new_data_mask syms env new_environment env_bind as_label set_names exec is_bare_character quo_get_expr quo_set_expr #' @importFrom tidyselect vars_select #' @useDynLib arrow, .registration = TRUE #' @keywords internal diff --git a/r/R/dplyr.R b/r/R/dplyr.R index d861381b0b9..763d6373d8b 100644 --- a/r/R/dplyr.R +++ b/r/R/dplyr.R @@ -769,10 +769,7 @@ arrange.arrow_dplyr_query <- function(.data, ..., .by_group = FALSE) { msg <- paste('Expression', as_label(exprs[[i]]), 'not supported in Arrow') return(abandon_ship(call, .data, msg)) } - names(sorts)[i] <- tryCatch( - expr = as_name(exprs[[i]]), - error = function(x) as_label(exprs[[i]]) - ) + names(sorts)[i] <- as_label(exprs[[i]]) descs[i] <- x[["desc"]] } .data$arrange_vars <- c(sorts, .data$arrange_vars) From 948fa5af3bcf15b167155877b72e0164489be1b5 Mon Sep 17 00:00:00 2001 From: Ian Cook Date: Tue, 23 Mar 2021 22:14:53 -0400 Subject: [PATCH 28/49] Add comment --- r/R/dplyr.R | 2 ++ 1 file changed, 2 insertions(+) diff --git a/r/R/dplyr.R b/r/R/dplyr.R index 763d6373d8b..f8d2e1b2465 100644 --- a/r/R/dplyr.R +++ b/r/R/dplyr.R @@ -742,6 +742,8 @@ arrange.arrow_dplyr_query <- function(.data, ..., .by_group = FALSE) { call <- match.call() exprs <- quos(...) if (.by_group) { + # when the data is is grouped and .by_group is TRUE, order the result by + # the grouping columns first exprs <- c(quos(!!!dplyr::groups(.data)), exprs) } if (length(exprs) == 0) { From f94d630fbb3bc61c81e4a6864aabd5adab7a7dc9 Mon Sep 17 00:00:00 2001 From: Ian Cook Date: Tue, 23 Mar 2021 22:15:27 -0400 Subject: [PATCH 29/49] Remove unnecessary code --- r/R/dplyr.R | 2 -- 1 file changed, 2 deletions(-) diff --git a/r/R/dplyr.R b/r/R/dplyr.R index f8d2e1b2465..33a34e6bc14 100644 --- a/r/R/dplyr.R +++ b/r/R/dplyr.R @@ -431,13 +431,11 @@ collect.arrow_dplyr_query <- function(x, as_data_frame = TRUE, ...) { } # Arrange rows if (length(x$arrange_vars) > 0) { - x$arrange_vars <- get_field_names(x$arrange_vars) tab <- tab[ tab$SortIndices(names(x$arrange_vars), x$arrange_desc), names(x$selected_columns), # this omits x$temp_columns from the result drop = FALSE ] - x$temp_columns <- NULL } } if (as_data_frame) { From cb675058c6e6f466ec2617dd33e1d6a7ae3abf57 Mon Sep 17 00:00:00 2001 From: Ian Cook Date: Tue, 23 Mar 2021 22:16:00 -0400 Subject: [PATCH 30/49] Add test of arrange(!!!syms()) --- r/tests/testthat/test-dplyr-arrange.R | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/r/tests/testthat/test-dplyr-arrange.R b/r/tests/testthat/test-dplyr-arrange.R index 5f6cd70c63d..86387a74980 100644 --- a/r/tests/testthat/test-dplyr-arrange.R +++ b/r/tests/testthat/test-dplyr-arrange.R @@ -140,6 +140,13 @@ test_that("arrange", { collect(), tbl ) + test_sort_cols <- c("int", "dbl") + expect_dplyr_equal( + input %>% + arrange(!!!syms(test_sort_cols)) %>% + collect(), + tbl + ) expect_warning( expect_equal( tbl %>% From b21fcf985e9fdb71e6e62b33ba4d8d5046627ea7 Mon Sep 17 00:00:00 2001 From: Ian Cook Date: Tue, 23 Mar 2021 22:21:52 -0400 Subject: [PATCH 31/49] Remove dup dataset check --- r/R/dplyr.R | 4 ---- 1 file changed, 4 deletions(-) diff --git a/r/R/dplyr.R b/r/R/dplyr.R index 33a34e6bc14..692a31803f7 100644 --- a/r/R/dplyr.R +++ b/r/R/dplyr.R @@ -752,10 +752,6 @@ arrange.arrow_dplyr_query <- function(.data, ..., .by_group = FALSE) { if (query_on_dataset(.data)) { not_implemented_for_dataset("arrange()") } - is_dataset <- query_on_dataset(.data) - if (is_dataset) { - return(abandon_ship(call, .data)) - } # find and remove any dplyr::desc() and tidy-eval # the arrange expressions inside an Arrow data_mask sorts <- vector("list", length(exprs)) From d2ab74b007016ac51ba43010671f80710bf23a64 Mon Sep 17 00:00:00 2001 From: Ian Cook Date: Tue, 23 Mar 2021 22:42:56 -0400 Subject: [PATCH 32/49] Add comments and rename vars for clarity in sort.ArrowDatum --- r/R/arrow-datum.R | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/r/R/arrow-datum.R b/r/R/arrow-datum.R index 12d608be9a2..99940e74cbd 100644 --- a/r/R/arrow-datum.R +++ b/r/R/arrow-datum.R @@ -141,13 +141,20 @@ as.character.ArrowDatum <- function(x, ...) as.character(as.vector(x), ...) #' @export sort.ArrowDatum <- function(x, decreasing = FALSE, na.last = NA, ...) { + # Arrow always sorts nulls at the end of the array. This corresponds to + # sort(na.last = TRUE). For the other two cases (na.last = NA and + # na.last = FALSE) we need to use workarounds. + # TODO: Implement this more cleanly after ARROW-12063 if (is.na(na.last)) { + # Filter out NAs before sorting x <- x$Filter(!is.na(x)) x$Take(x$SortIndices(descending = decreasing)) } else if (na.last) { x$Take(x$SortIndices(descending = decreasing)) } else { - x <- Table$create(x = x, isnax = as.integer(is.na(x))) - x$x$Take(x$SortIndices(names = c("isnax", "x"), descending = c(TRUE, decreasing))) + # Create a new array that encodes missing values as 1 and non-missing values + # as 0. Sort descending by that array first to get the NAs at the beginning + tbl <- Table$create(x = x, `is_na` = as.integer(is.na(x))) + tbl$x$Take(tbl$SortIndices(names = c("is_na", "x"), descending = c(TRUE, decreasing))) } } From 33bbc0d2fb907bedbae64f784b13a3feb258c19e Mon Sep 17 00:00:00 2001 From: Ian Cook Date: Wed, 24 Mar 2021 00:06:32 -0400 Subject: [PATCH 33/49] Implement arrange() for Datasets --- r/R/dataset-scan.R | 5 +++-- r/R/dplyr.R | 26 ++++++++++++++------------ 2 files changed, 17 insertions(+), 14 deletions(-) diff --git a/r/R/dataset-scan.R b/r/R/dataset-scan.R index 7e148863226..f7ede663c7f 100644 --- a/r/R/dataset-scan.R +++ b/r/R/dataset-scan.R @@ -76,7 +76,7 @@ Scanner$create <- function(dataset, } return(Scanner$create( dataset$.data, - dataset$selected_columns, + c(dataset$selected_columns, dataset$temp_columns), dataset$filtered_rows, use_threads, batch_size, @@ -148,7 +148,8 @@ map_batches <- function(X, FUN, ..., .data.frame = TRUE) { lapply(scan_task$Execute(), function(batch) { # message("Processing Batch") # This inner lapply cannot be parallelized - # TODO: wrap batch in arrow_dplyr_query with X$selected_columns and X$group_by_vars + # TODO: wrap batch in arrow_dplyr_query with X$selected_columns, + # X$temp_columns, and X$group_by_vars # if X is arrow_dplyr_query, if some other arg (.dplyr?) == TRUE FUN(batch, ...) }) diff --git a/r/R/dplyr.R b/r/R/dplyr.R index 692a31803f7..1b507822f08 100644 --- a/r/R/dplyr.R +++ b/r/R/dplyr.R @@ -87,11 +87,16 @@ print.arrow_dplyr_query <- function(x, ...) { cat("* Grouped by ", paste(x$group_by_vars, collapse = ", "), "\n", sep = "") } if (length(x$arrange_vars)) { + if (query_on_dataset(x)) { + arrange_strings <- map_chr(x$arrange_vars, function(x) x$ToString()) + } else { + arrange_strings <- map_chr(x$arrange_vars, .format_array_expression) + } cat( "* Sorted by ", paste( paste0( - map_chr(x$arrange_vars, .format_array_expression), + arrange_strings, " [", ifelse(x$arrange_desc, "desc", "asc"), "]" ), collapse = ", " @@ -429,14 +434,14 @@ collect.arrow_dplyr_query <- function(x, as_data_frame = TRUE, ...) { tab <- RecordBatch$create(!!!cols) } } - # Arrange rows - if (length(x$arrange_vars) > 0) { - tab <- tab[ - tab$SortIndices(names(x$arrange_vars), x$arrange_desc), - names(x$selected_columns), # this omits x$temp_columns from the result - drop = FALSE - ] - } + } + # Arrange rows + if (length(x$arrange_vars) > 0) { + tab <- tab[ + tab$SortIndices(names(x$arrange_vars), x$arrange_desc), + names(x$selected_columns), # this omits x$temp_columns from the result + drop = FALSE + ] } if (as_data_frame) { df <- as.data.frame(tab) @@ -749,9 +754,6 @@ arrange.arrow_dplyr_query <- function(.data, ..., .by_group = FALSE) { return(.data) } .data <- arrow_dplyr_query(.data) - if (query_on_dataset(.data)) { - not_implemented_for_dataset("arrange()") - } # find and remove any dplyr::desc() and tidy-eval # the arrange expressions inside an Arrow data_mask sorts <- vector("list", length(exprs)) From 0646affdfcdb0b839a61191375c61ec5f1d86e66 Mon Sep 17 00:00:00 2001 From: Ian Cook Date: Wed, 24 Mar 2021 00:07:11 -0400 Subject: [PATCH 34/49] Add tests for arrange() on Dataset --- r/tests/testthat/test-dataset.R | 36 +++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/r/tests/testthat/test-dataset.R b/r/tests/testthat/test-dataset.R index e6db0bcea17..1dba1b2da90 100644 --- a/r/tests/testthat/test-dataset.R +++ b/r/tests/testthat/test-dataset.R @@ -1030,6 +1030,42 @@ test_that("count()", { ) }) +test_that("arrange()", { + ds <- open_dataset(dataset_dir, partitioning = schema(part = uint8())) + arranged <- ds %>% + select(chr, dbl, int) %>% + filter(dbl * 2 > 14 & dbl - 50 < 3L) %>% + mutate(twice = int * 2) %>% + arrange(chr, desc(twice), dbl + int) + expect_output( + print(arranged), + "FileSystemDataset (query) +chr: string +dbl: double +int: int32 +twice: expr + +* Filter: ((multiply_checked(dbl, 2) > 14) and (subtract_checked(dbl, 50) < 3)) +* Sorted by chr [asc], multiply_checked(int, 2) [desc], add_checked(dbl, int) [asc] +See $.data for the source Arrow object", + fixed = TRUE + ) + expect_equivalent( + arranged %>% + collect(), + rbind( + df1[8, c("chr", "dbl", "int")], + df2[2, c("chr", "dbl", "int")], + df1[9, c("chr", "dbl", "int")], + df2[1, c("chr", "dbl", "int")], + df1[10, c("chr", "dbl", "int")] + ) %>% + mutate( + twice = int * 2 + ) + ) +}) + test_that("head/tail", { skip_if_not_available("parquet") ds <- open_dataset(dataset_dir) From 95b310f122a444828baa670ff87afaeb05792a65 Mon Sep 17 00:00:00 2001 From: Ian Cook Date: Wed, 24 Mar 2021 00:07:28 -0400 Subject: [PATCH 35/49] Improve skip message --- r/tests/testthat/test-compute-sort.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/r/tests/testthat/test-compute-sort.R b/r/tests/testthat/test-compute-sort.R index 6abb11659bc..17fc0458ec1 100644 --- a/r/tests/testthat/test-compute-sort.R +++ b/r/tests/testthat/test-compute-sort.R @@ -118,7 +118,7 @@ test_that("Array/ChunkedArray sort on strings", { }) test_that("Array/ChunkedArray sort on floats", { - skip("ARROW-12055") + skip("is.na() evaluates to FALSE on Arrow NaN values (ARROW-12055)") expect_equal( as.vector(sort(Array$create(tbl$dbl), decreasing = TRUE, na.last = FALSE)), sort(tbl$dbl, decreasing = TRUE, na.last = FALSE) From deba9881a964afd52593f9efede8fb38affacc38 Mon Sep 17 00:00:00 2001 From: Ian Cook Date: Wed, 24 Mar 2021 09:07:04 -0400 Subject: [PATCH 36/49] Fix test failures --- r/tests/testthat/helper-data.R | 2 +- r/tests/testthat/test-dataset.R | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/r/tests/testthat/helper-data.R b/r/tests/testthat/helper-data.R index 1d922212e95..eaa2556d2a0 100644 --- a/r/tests/testthat/helper-data.R +++ b/r/tests/testthat/helper-data.R @@ -146,7 +146,7 @@ example_with_logical_factors <- tibble::tibble( example_data_for_sorting <- tibble::tibble( int = c(-.Machine$integer.max, -101L, -100L, 0L, 0L, 1L, 100L, 1000L, .Machine$integer.max, NA_integer_), dbl = c(-Inf, -.Machine$double.xmax, -.Machine$double.xmin, 0, .Machine$double.xmin, pi, .Machine$double.xmax, Inf, NaN, NA_real_), - chr = c("", "", "\u0001", "&", "ABC", "NULL", "a", "abc", "zzz", NA_character_), + chr = c("", "", "\"", "&", "ABC", "NULL", "a", "abc", "zzz", NA_character_), lgl = c(rep(FALSE, 4L), rep(TRUE, 5L), NA), # bool is not supported (ARROW-12016) dttm = lubridate::ymd_hms(c( "0000-01-01 00:00:00", diff --git a/r/tests/testthat/test-dataset.R b/r/tests/testthat/test-dataset.R index 1dba1b2da90..73057221a24 100644 --- a/r/tests/testthat/test-dataset.R +++ b/r/tests/testthat/test-dataset.R @@ -1153,7 +1153,6 @@ test_that("dplyr method not implemented messages", { expect_not_implemented <- function(x) { expect_error(x, "is not currently implemented for Arrow Datasets") } - expect_not_implemented(ds %>% arrange(int)) expect_not_implemented(ds %>% filter(int == 1) %>% summarize(n())) }) From 38a8aec12b3c414c46878e5ce101a23482737d22 Mon Sep 17 00:00:00 2001 From: Ian Cook Date: Wed, 24 Mar 2021 10:12:39 -0400 Subject: [PATCH 37/49] Add expect_vector_equal() helper --- r/tests/testthat/helper-expectation.R | 41 ++++++++++++++++++++++++++- 1 file changed, 40 insertions(+), 1 deletion(-) diff --git a/r/tests/testthat/helper-expectation.R b/r/tests/testthat/helper-expectation.R index 76edea61f57..39cc9e0597a 100644 --- a/r/tests/testthat/helper-expectation.R +++ b/r/tests/testthat/helper-expectation.R @@ -121,4 +121,43 @@ expect_dplyr_error <- function(expr, # A dplyr pipeline with `input` as its star msg, ... ) -} \ No newline at end of file +} + +expect_vector_equal <- function(expr, # A vectorized R expression containing `input` as its input + vec, # A vector as reference, will make Array/ChunkedArray with + skip_array = NULL, # Msg, if should skip Array test + skip_chunked_array = NULL, # Msg, if should skip ChunkedArray test + ...) { + expr <- rlang::enquo(expr) + expected <- rlang::eval_tidy(expr, rlang::new_data_mask(rlang::env(input = vec))) + + skip_msg <- NULL + + if (is.null(skip_array)) { + via_array <- rlang::eval_tidy( + expr, + rlang::new_data_mask(rlang::env(input = Array$create(vec))) + ) + expect_vector(via_array, expected, ...) + } else { + skip_msg <- c(skip_msg, skip_array) + } + + if (is.null(skip_chunked_array)) { + # split input vector into two to exercise ChunkedArray with >1 chunk + vec_split <- length(vec) %/% 2 + vec1 <- vec[seq(from = min(1, length(vec) - 1), to = min(length(vec) - 1, vec_split), by = 1)] + vec2 <- vec[seq(from = min(length(vec), vec_split + 1), to = length(vec), by = 1)] + via_chunked <- rlang::eval_tidy( + expr, + rlang::new_data_mask(rlang::env(input = ChunkedArray$create(vec1, vec2))) + ) + expect_vector(via_chunked, expected, ...) + } else { + skip_msg <- c(skip_msg, skip_chunked_array) + } + + if (!is.null(skip_msg)) { + skip(paste(skip_msg, collpase = "\n")) + } +} From f98b92e751c089802d79d9a7207acb77da5f85ab Mon Sep 17 00:00:00 2001 From: Ian Cook Date: Wed, 24 Mar 2021 10:13:27 -0400 Subject: [PATCH 38/49] Clean up tests --- r/tests/testthat/test-compute-sort.R | 116 +++++++++++---------------- 1 file changed, 49 insertions(+), 67 deletions(-) diff --git a/r/tests/testthat/test-compute-sort.R b/r/tests/testthat/test-compute-sort.R index 17fc0458ec1..a1d4f4ea682 100644 --- a/r/tests/testthat/test-compute-sort.R +++ b/r/tests/testthat/test-compute-sort.R @@ -21,14 +21,18 @@ library(dplyr) tbl <- example_data_for_sorting -test_that("Scalar sort", { +test_that("sort(Scalar) is identity function", { expect_identical( as.vector(sort(Scalar$create(42L))), 42L ) + expect_identical( + as.vector(sort(Scalar$create("foo"))), + "foo" + ) }) -test_that("Array sort on integers", { +test_that("Array$SortIndices()", { expect_equal( Array$create(tbl$int)$SortIndices(), Array$create(0L:9L, type = uint64()) @@ -37,37 +41,9 @@ test_that("Array sort on integers", { Array$create(rev(tbl$int))$SortIndices(descending = TRUE), Array$create(c(1L:9L, 0L), type = uint64()) ) - expect_equal( - as.vector(sort(Array$create(tbl$int))), - sort(tbl$int) - ) - expect_equal( - as.vector(sort(Array$create(tbl$int), na.last = NA)), - sort(tbl$int, na.last = NA) - ) - expect_equal( - as.vector(sort(Array$create(tbl$int), na.last = TRUE)), - sort(tbl$int, na.last = TRUE) - ) - expect_equal( - as.vector(sort(Array$create(tbl$int), na.last = FALSE)), - sort(tbl$int, na.last = FALSE) - ) - expect_equal( - as.vector(sort(Array$create(tbl$int), decreasing = TRUE)), - sort(tbl$int, decreasing = TRUE) - ) - expect_equal( - as.vector(sort(Array$create(tbl$int), decreasing = TRUE, na.last = TRUE)), - sort(tbl$int, decreasing = TRUE, na.last = TRUE) - ) - expect_equal( - as.vector(sort(Array$create(tbl$int), decreasing = TRUE, na.last = FALSE)), - sort(tbl$int, decreasing = TRUE, na.last = FALSE) - ) }) -test_that("ChunkedArray sort on integers", { +test_that("ChunkedArray$SortIndices()", { expect_equal( ChunkedArray$create(tbl$int[1:5], tbl$int[6:10])$SortIndices(), Array$create(0L:9L, type = uint64()) @@ -76,60 +52,63 @@ test_that("ChunkedArray sort on integers", { ChunkedArray$create(rev(tbl$int)[1:5], rev(tbl$int)[6:10])$SortIndices(descending = TRUE), Array$create(c(1L:9L, 0L), type = uint64()) ) - expect_equal( - as.vector(sort(ChunkedArray$create(tbl$int[1:5], tbl$int[6:10]))), - sort(tbl$int) +}) + +test_that("sort(vector), sort(Array), sort(ChunkedArray) give equivalent results on integers", { + expect_vector_equal( + sort(input), + tbl$int ) - expect_equal( - as.vector(sort(ChunkedArray$create(tbl$int[1:5], tbl$int[6:10]), na.last = NA)), - sort(tbl$int, na.last = NA) + expect_vector_equal( + sort(input, na.last = NA), + tbl$int ) - expect_equal( - as.vector(sort(ChunkedArray$create(tbl$int[1:5], tbl$int[6:10]), na.last = TRUE)), - sort(tbl$int, na.last = TRUE) + expect_vector_equal( + sort(input, na.last = TRUE), + tbl$int ) - expect_equal( - as.vector(sort(ChunkedArray$create(tbl$int[1:5], tbl$int[6:10]), na.last = FALSE)), - sort(tbl$int, na.last = FALSE) + expect_vector_equal( + sort(input, na.last = FALSE), + tbl$int ) - expect_equal( - as.vector(sort(ChunkedArray$create(tbl$int[1:5], tbl$int[6:10]), decreasing = TRUE)), - sort(tbl$int, decreasing = TRUE) + expect_vector_equal( + sort(input, decreasing = TRUE), + tbl$int, ) - expect_equal( - as.vector(sort(ChunkedArray$create(tbl$int[1:5], tbl$int[6:10]), decreasing = TRUE, na.last = TRUE)), - sort(tbl$int, decreasing = TRUE, na.last = TRUE) + expect_vector_equal( + sort(input, decreasing = TRUE, na.last = TRUE), + tbl$int, ) - expect_equal( - as.vector(sort(ChunkedArray$create(tbl$int[1:5], tbl$int[6:10]), decreasing = TRUE, na.last = FALSE)), - sort(tbl$int, decreasing = TRUE, na.last = FALSE) + expect_vector_equal( + sort(input, decreasing = TRUE, na.last = FALSE), + tbl$int, ) }) -test_that("Array/ChunkedArray sort on strings", { - expect_equal( - as.vector(sort(Array$create(tbl$chr), decreasing = TRUE, na.last = FALSE)), - sort(tbl$chr, decreasing = TRUE, na.last = FALSE) +test_that("sort(vector), sort(Array), sort(ChunkedArray) give equivalent results on strings", { + expect_vector_equal( + sort(input, decreasing = TRUE, na.last = FALSE), + tbl$chr ) - expect_equal( - as.vector(sort(ChunkedArray$create(tbl$chr[1:5], tbl$chr[6:10]), decreasing = TRUE, na.last = FALSE)), - sort(tbl$chr, decreasing = TRUE, na.last = FALSE) + expect_vector_equal( + sort(input, decreasing = TRUE, na.last = FALSE), + tbl$chr ) }) -test_that("Array/ChunkedArray sort on floats", { +test_that("sort(vector), sort(Array), sort(ChunkedArray) give equivalent results on floats", { skip("is.na() evaluates to FALSE on Arrow NaN values (ARROW-12055)") - expect_equal( - as.vector(sort(Array$create(tbl$dbl), decreasing = TRUE, na.last = FALSE)), - sort(tbl$dbl, decreasing = TRUE, na.last = FALSE) + expect_vector_equal( + sort(input, decreasing = TRUE, na.last = FALSE), + tbl$dbl ) - expect_equal( - as.vector(sort(ChunkedArray$create(tbl$dbl[1:5], tbl$dbl[6:10]), decreasing = TRUE, na.last = FALSE)), - sort(tbl$dbl, decreasing = TRUE, na.last = FALSE) + expect_vector_equal( + sort(input, decreasing = TRUE, na.last = FALSE), + tbl$dbl, ) }) -test_that("Table/RecordBatch sort", { +test_that("Table$SortIndices()", { expect_identical( { x <- tbl %>% slice_sample(prop = 1L) %>% Table$create() @@ -144,6 +123,9 @@ test_that("Table/RecordBatch sort", { }, tbl ) +}) + +test_that("RecordBatch$SortIndices()", { expect_identical( { x <- tbl %>% slice_sample(prop = 1L) %>% record_batch() From 470904f7e5bd554dcac0c780f6d49b31d50135cf Mon Sep 17 00:00:00 2001 From: Ian Cook Date: Wed, 24 Mar 2021 10:27:12 -0400 Subject: [PATCH 39/49] Fix failing tests --- r/tests/testthat/test-compute-sort.R | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/r/tests/testthat/test-compute-sort.R b/r/tests/testthat/test-compute-sort.R index a1d4f4ea682..2e050ee47c0 100644 --- a/r/tests/testthat/test-compute-sort.R +++ b/r/tests/testthat/test-compute-sort.R @@ -86,6 +86,10 @@ test_that("sort(vector), sort(Array), sort(ChunkedArray) give equivalent results }) test_that("sort(vector), sort(Array), sort(ChunkedArray) give equivalent results on strings", { + skip_if_not( + identical(Sys.getlocale("LC_COLLATE"), "C"), + "Unexpected LC_COLLATE" + ) expect_vector_equal( sort(input, decreasing = TRUE, na.last = FALSE), tbl$chr From 6b4a58683aecb70f84183f76a8ed05811e889991 Mon Sep 17 00:00:00 2001 From: Ian Cook Date: Wed, 24 Mar 2021 11:08:06 -0400 Subject: [PATCH 40/49] Trigger CI From e44b7a1a56473c126892d5c0d6afe2b6980e3e33 Mon Sep 17 00:00:00 2001 From: Ian Cook Date: Thu, 25 Mar 2021 10:06:46 -0400 Subject: [PATCH 41/49] Implement Scalar$Equals(), $ApproxEquals() and simplify sort(Scalar) test --- r/R/arrowExports.R | 8 ++++++++ r/R/scalar.R | 8 +++++++- r/src/arrowExports.cpp | 20 ++++++++++++++++++++ r/src/scalar.cpp | 12 ++++++++++++ r/tests/testthat/test-compute-sort.R | 14 ++++++-------- r/tests/testthat/test-scalar.R | 24 +++++++++++++++++++++++- 6 files changed, 76 insertions(+), 10 deletions(-) diff --git a/r/R/arrowExports.R b/r/R/arrowExports.R index 86ee1303d1c..2c7bf5c19f6 100644 --- a/r/R/arrowExports.R +++ b/r/R/arrowExports.R @@ -1476,6 +1476,14 @@ Scalar__type <- function(s){ .Call(`_arrow_Scalar__type`, s) } +Scalar__Equals <- function(lhs, rhs){ + .Call(`_arrow_Scalar__Equals`, lhs, rhs) +} + +Scalar__ApproxEquals <- function(lhs, rhs){ + .Call(`_arrow_Scalar__ApproxEquals`, lhs, rhs) +} + schema_ <- function(fields){ .Call(`_arrow_schema_`, fields) } diff --git a/r/R/scalar.R b/r/R/scalar.R index d2dd5db5d8e..cbda5964a2c 100644 --- a/r/R/scalar.R +++ b/r/R/scalar.R @@ -33,7 +33,13 @@ Scalar <- R6Class("Scalar", public = list( ToString = function() Scalar__ToString(self), as_vector = function() Scalar__as_vector(self), - as_array = function() MakeArrayFromScalar(self) + as_array = function() MakeArrayFromScalar(self), + Equals = function(other, ...) { + inherits(other, "Scalar") && Scalar__Equals(self, other) + }, + ApproxEquals = function(other, ...) { + inherits(other, "Scalar") && Scalar__ApproxEquals(self, other) + } ), active = list( is_valid = function() Scalar__is_valid(self), diff --git a/r/src/arrowExports.cpp b/r/src/arrowExports.cpp index 35dab553fbd..b06a2696e50 100644 --- a/r/src/arrowExports.cpp +++ b/r/src/arrowExports.cpp @@ -3793,6 +3793,24 @@ BEGIN_CPP11 return cpp11::as_sexp(Scalar__type(s)); END_CPP11 } +// scalar.cpp +bool Scalar__Equals(const std::shared_ptr& lhs, const std::shared_ptr& rhs); +extern "C" SEXP _arrow_Scalar__Equals(SEXP lhs_sexp, SEXP rhs_sexp){ +BEGIN_CPP11 + arrow::r::Input&>::type lhs(lhs_sexp); + arrow::r::Input&>::type rhs(rhs_sexp); + return cpp11::as_sexp(Scalar__Equals(lhs, rhs)); +END_CPP11 +} +// scalar.cpp +bool Scalar__ApproxEquals(const std::shared_ptr& lhs, const std::shared_ptr& rhs); +extern "C" SEXP _arrow_Scalar__ApproxEquals(SEXP lhs_sexp, SEXP rhs_sexp){ +BEGIN_CPP11 + arrow::r::Input&>::type lhs(lhs_sexp); + arrow::r::Input&>::type rhs(rhs_sexp); + return cpp11::as_sexp(Scalar__ApproxEquals(lhs, rhs)); +END_CPP11 +} // schema.cpp std::shared_ptr schema_(const std::vector>& fields); extern "C" SEXP _arrow_schema_(SEXP fields_sexp){ @@ -4544,6 +4562,8 @@ static const R_CallMethodDef CallEntries[] = { { "_arrow_MakeArrayFromScalar", (DL_FUNC) &_arrow_MakeArrayFromScalar, 1}, { "_arrow_Scalar__is_valid", (DL_FUNC) &_arrow_Scalar__is_valid, 1}, { "_arrow_Scalar__type", (DL_FUNC) &_arrow_Scalar__type, 1}, + { "_arrow_Scalar__Equals", (DL_FUNC) &_arrow_Scalar__Equals, 2}, + { "_arrow_Scalar__ApproxEquals", (DL_FUNC) &_arrow_Scalar__ApproxEquals, 2}, { "_arrow_schema_", (DL_FUNC) &_arrow_schema_, 1}, { "_arrow_Schema__ToString", (DL_FUNC) &_arrow_Schema__ToString, 1}, { "_arrow_Schema__num_fields", (DL_FUNC) &_arrow_Schema__num_fields, 1}, diff --git a/r/src/scalar.cpp b/r/src/scalar.cpp index c7e2a716bde..e7c35064514 100644 --- a/r/src/scalar.cpp +++ b/r/src/scalar.cpp @@ -82,4 +82,16 @@ std::shared_ptr Scalar__type(const std::shared_ptrtype; } +// [[arrow::export]] +bool Scalar__Equals(const std::shared_ptr& lhs, + const std::shared_ptr& rhs) { + return lhs->Equals(rhs); +} + +// [[arrow::export]] +bool Scalar__ApproxEquals(const std::shared_ptr& lhs, + const std::shared_ptr& rhs) { + return lhs->ApproxEquals(*rhs); +} + #endif diff --git a/r/tests/testthat/test-compute-sort.R b/r/tests/testthat/test-compute-sort.R index 2e050ee47c0..03c06136187 100644 --- a/r/tests/testthat/test-compute-sort.R +++ b/r/tests/testthat/test-compute-sort.R @@ -22,14 +22,12 @@ library(dplyr) tbl <- example_data_for_sorting test_that("sort(Scalar) is identity function", { - expect_identical( - as.vector(sort(Scalar$create(42L))), - 42L - ) - expect_identical( - as.vector(sort(Scalar$create("foo"))), - "foo" - ) + int <- Scalar$create(42L) + expect_equal(sort(int), int) + dbl <- Scalar$create(3.14) + expect_equal(sort(dbl), dbl) + chr <- Scalar$create("foo") + expect_equal(sort(chr), chr) }) test_that("Array$SortIndices()", { diff --git a/r/tests/testthat/test-scalar.R b/r/tests/testthat/test-scalar.R index ce0f53d2be3..e9ef893bbd9 100644 --- a/r/tests/testthat/test-scalar.R +++ b/r/tests/testthat/test-scalar.R @@ -53,4 +53,26 @@ test_that("Scalar to Array", { a <- Scalar$create(42) expect_equal(a$as_array(), Array$create(42)) expect_equal(Array$create(a), Array$create(42)) -}) \ No newline at end of file +}) + +test_that("Scalar$Equals", { + a <- Scalar$create(42) + aa <- Array$create(42) + b <- Scalar$create(42) + d <- Scalar$create(43) + expect_equal(a, b) + expect_true(a$Equals(b)) + expect_false(a$Equals(d)) + expect_false(a$Equals(aa)) +}) + +test_that("Scalar$ApproxEquals", { + a <- Scalar$create(1.0000000000001) + aa <- Array$create(1.0000000000001) + b <- Scalar$create(1.0) + d <- 2.400000000000001 + expect_false(a$Equals(b)) + expect_true(a$ApproxEquals(b)) + expect_false(a$ApproxEquals(d)) + expect_false(a$ApproxEquals(aa)) +}) From 2d0777eba0628b8d4286d3125ce4ef5a2504141b Mon Sep 17 00:00:00 2001 From: Ian Cook Date: Thu, 25 Mar 2021 10:22:45 -0400 Subject: [PATCH 42/49] Add some non-skipped float sort tests --- r/tests/testthat/test-compute-sort.R | 26 +++++++++++++++++++++----- 1 file changed, 21 insertions(+), 5 deletions(-) diff --git a/r/tests/testthat/test-compute-sort.R b/r/tests/testthat/test-compute-sort.R index 03c06136187..c958e44c28c 100644 --- a/r/tests/testthat/test-compute-sort.R +++ b/r/tests/testthat/test-compute-sort.R @@ -33,22 +33,22 @@ test_that("sort(Scalar) is identity function", { test_that("Array$SortIndices()", { expect_equal( Array$create(tbl$int)$SortIndices(), - Array$create(0L:9L, type = uint64()) + Array$create(0:9, type = uint64()) ) expect_equal( Array$create(rev(tbl$int))$SortIndices(descending = TRUE), - Array$create(c(1L:9L, 0L), type = uint64()) + Array$create(c(1:9, 0L), type = uint64()) ) }) test_that("ChunkedArray$SortIndices()", { expect_equal( ChunkedArray$create(tbl$int[1:5], tbl$int[6:10])$SortIndices(), - Array$create(0L:9L, type = uint64()) + Array$create(0:9, type = uint64()) ) expect_equal( ChunkedArray$create(rev(tbl$int)[1:5], rev(tbl$int)[6:10])$SortIndices(descending = TRUE), - Array$create(c(1L:9L, 0L), type = uint64()) + Array$create(c(1:9, 0L), type = uint64()) ) }) @@ -99,15 +99,31 @@ test_that("sort(vector), sort(Array), sort(ChunkedArray) give equivalent results }) test_that("sort(vector), sort(Array), sort(ChunkedArray) give equivalent results on floats", { + expect_vector_equal( + sort(input, decreasing = TRUE, na.last = TRUE), + tbl$dbl + ) + expect_vector_equal( + sort(input, decreasing = FALSE, na.last = TRUE), + tbl$dbl + ) skip("is.na() evaluates to FALSE on Arrow NaN values (ARROW-12055)") expect_vector_equal( - sort(input, decreasing = TRUE, na.last = FALSE), + sort(input, decreasing = TRUE, na.last = NA), tbl$dbl ) expect_vector_equal( sort(input, decreasing = TRUE, na.last = FALSE), tbl$dbl, ) + expect_vector_equal( + sort(input, decreasing = FALSE, na.last = NA), + tbl$dbl + ) + expect_vector_equal( + sort(input, decreasing = FALSE, na.last = FALSE), + tbl$dbl, + ) }) test_that("Table$SortIndices()", { From 20c9233efbf1b65dbeb785135c44f824a6b3eb49 Mon Sep 17 00:00:00 2001 From: Ian Cook Date: Thu, 25 Mar 2021 10:58:41 -0400 Subject: [PATCH 43/49] Lint --- r/src/scalar.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/r/src/scalar.cpp b/r/src/scalar.cpp index e7c35064514..057e587e7eb 100644 --- a/r/src/scalar.cpp +++ b/r/src/scalar.cpp @@ -84,13 +84,13 @@ std::shared_ptr Scalar__type(const std::shared_ptr& lhs, - const std::shared_ptr& rhs) { + const std::shared_ptr& rhs) { return lhs->Equals(rhs); } // [[arrow::export]] bool Scalar__ApproxEquals(const std::shared_ptr& lhs, - const std::shared_ptr& rhs) { + const std::shared_ptr& rhs) { return lhs->ApproxEquals(*rhs); } From d9a5e09fb2b4507745e9c1033d5b9ff093869dfe Mon Sep 17 00:00:00 2001 From: Ian Cook Date: Thu, 25 Mar 2021 10:59:19 -0400 Subject: [PATCH 44/49] Improve tests --- r/tests/testthat/test-compute-sort.R | 50 +++++++++++++++------------ r/tests/testthat/test-dplyr-arrange.R | 50 +++++++++++---------------- 2 files changed, 49 insertions(+), 51 deletions(-) diff --git a/r/tests/testthat/test-compute-sort.R b/r/tests/testthat/test-compute-sort.R index c958e44c28c..f3f4979a3fb 100644 --- a/r/tests/testthat/test-compute-sort.R +++ b/r/tests/testthat/test-compute-sort.R @@ -19,7 +19,11 @@ context("compute: sorting") library(dplyr) -tbl <- example_data_for_sorting +# randomize order of rows in test data +tbl <- slice_sample(example_data_for_sorting, prop = 1L) + +# use the C locale for string collation in R (ARROW-12046) +Sys.setlocale("LC_COLLATE", "C") test_that("sort(Scalar) is identity function", { int <- Scalar$create(42L) @@ -31,24 +35,30 @@ test_that("sort(Scalar) is identity function", { }) test_that("Array$SortIndices()", { + int <- tbl$int + int <- int[!duplicated(int)] # needed because ties in int expect_equal( - Array$create(tbl$int)$SortIndices(), - Array$create(0:9, type = uint64()) + Array$create(int)$SortIndices(), + Array$create(order(int) - 1L, type = uint64()) ) + int <- na.omit(int) # needed because ARROW-12063 expect_equal( - Array$create(rev(tbl$int))$SortIndices(descending = TRUE), - Array$create(c(1:9, 0L), type = uint64()) + Array$create(int)$SortIndices(descending = TRUE), + Array$create(rev(order(int)) - 1, type = uint64()) ) }) test_that("ChunkedArray$SortIndices()", { + int <- tbl$int + int <- int[!duplicated(int)] # needed because ties in int expect_equal( - ChunkedArray$create(tbl$int[1:5], tbl$int[6:10])$SortIndices(), - Array$create(0:9, type = uint64()) + ChunkedArray$create(int[1:4], int[5:length(int)])$SortIndices(), + Array$create(order(int) - 1L, type = uint64()) ) + int <- na.omit(int) # needed because ARROW-12063 expect_equal( - ChunkedArray$create(rev(tbl$int)[1:5], rev(tbl$int)[6:10])$SortIndices(descending = TRUE), - Array$create(c(1:9, 0L), type = uint64()) + ChunkedArray$create(int[1:4], int[5:length(int)])$SortIndices(descending = TRUE), + Array$create(rev(order(int)) - 1, type = uint64()) ) }) @@ -84,10 +94,6 @@ test_that("sort(vector), sort(Array), sort(ChunkedArray) give equivalent results }) test_that("sort(vector), sort(Array), sort(ChunkedArray) give equivalent results on strings", { - skip_if_not( - identical(Sys.getlocale("LC_COLLATE"), "C"), - "Unexpected LC_COLLATE" - ) expect_vector_equal( sort(input, decreasing = TRUE, na.last = FALSE), tbl$chr @@ -129,29 +135,29 @@ test_that("sort(vector), sort(Array), sort(ChunkedArray) give equivalent results test_that("Table$SortIndices()", { expect_identical( { - x <- tbl %>% slice_sample(prop = 1L) %>% Table$create() + x <- tbl %>% Table$create() x$Take(x$SortIndices("chr")) %>% pull(chr) }, - tbl$chr + sort(tbl$chr, na.last = TRUE) ) expect_identical( { - x <- tbl %>% slice_sample(prop = 1L) %>% Table$create() + x <- tbl %>% Table$create() x$Take(x$SortIndices(c("int", "dbl"), c(FALSE, FALSE))) %>% collect() }, - tbl + tbl %>% arrange(int, dbl) ) }) test_that("RecordBatch$SortIndices()", { expect_identical( { - x <- tbl %>% slice_sample(prop = 1L) %>% record_batch() + x <- tbl %>% record_batch() x$Take(x$SortIndices(c("chr", "int", "dbl"), TRUE)) %>% collect() }, - rbind( - tbl %>% head(-1) %>% arrange(-row_number()), - tbl %>% tail(1) - ) + tbl %>% arrange(desc(chr), desc(int), desc(dbl)) ) }) + +# restore previous collation locale setting +Sys.setlocale("LC_COLLATE") diff --git a/r/tests/testthat/test-dplyr-arrange.R b/r/tests/testthat/test-dplyr-arrange.R index 86387a74980..b83dc739bbe 100644 --- a/r/tests/testthat/test-dplyr-arrange.R +++ b/r/tests/testthat/test-dplyr-arrange.R @@ -17,91 +17,84 @@ library(dplyr) -tbl <- example_data_for_sorting +# randomize order of rows in test data +tbl <- slice_sample(example_data_for_sorting, prop = 1L) + +# use the C locale for string collation in R (ARROW-12046) +Sys.setlocale("LC_COLLATE", "C") test_that("arrange", { expect_dplyr_equal( input %>% arrange(int, chr) %>% collect(), - tbl %>% - slice_sample(prop = 1L) + tbl ) expect_dplyr_equal( input %>% arrange(dttm, int) %>% collect(), - tbl %>% - slice_sample(prop = 1L) + tbl ) expect_dplyr_equal( input %>% arrange(int, desc(dbl)) %>% collect(), - tbl %>% - slice_sample(prop = 1L) + tbl ) expect_dplyr_equal( input %>% arrange(int, desc(desc(dbl))) %>% collect(), - tbl %>% - slice_sample(prop = 1L) + tbl ) expect_dplyr_equal( input %>% arrange(int) %>% arrange(desc(dbl)) %>% collect(), - tbl %>% - slice_sample(prop = 1L) + tbl ) expect_dplyr_equal( input %>% arrange(int + dbl, chr) %>% collect(), - tbl %>% - slice_sample(prop = 1L) + tbl ) expect_dplyr_equal( input %>% mutate(zzz = int + dbl,) %>% arrange(zzz, chr) %>% collect(), - tbl %>% - slice_sample(prop = 1L) + tbl ) expect_dplyr_equal( input %>% mutate(zzz = int + dbl) %>% arrange(int + dbl, chr) %>% collect(), - tbl %>% - slice_sample(prop = 1L) + tbl ) expect_dplyr_equal( input %>% mutate(int + dbl) %>% arrange(int + dbl, chr) %>% collect(), - tbl %>% - slice_sample(prop = 1L) + tbl ) expect_dplyr_equal( input %>% group_by(grp) %>% arrange(int, dbl) %>% collect(), - tbl %>% - slice_sample(prop = 1L) + tbl ) expect_dplyr_equal( input %>% group_by(grp) %>% arrange(int, dbl, .by_group = TRUE) %>% collect(), - tbl %>% - slice_sample(prop = 1L) + tbl ) expect_dplyr_equal( input %>% @@ -109,16 +102,14 @@ test_that("arrange", { arrange(int, dbl, .by_group = TRUE) %>% collect(), tbl %>% - mutate(grp2 = ifelse(is.na(lgl), 1L, as.integer(lgl))) %>% - slice_sample(prop = 1L) + mutate(grp2 = ifelse(is.na(lgl), 1L, as.integer(lgl))) ) expect_dplyr_equal( input %>% group_by(grp) %>% arrange(.by_group = TRUE) %>% pull(grp), - tbl %>% - slice_sample(prop = 1L) + tbl ) expect_dplyr_equal( input %>% @@ -150,12 +141,10 @@ test_that("arrange", { expect_warning( expect_equal( tbl %>% - slice_sample(prop = 1L) %>% Table$create() %>% arrange(abs(int), dbl) %>% collect(), tbl %>% - slice_sample(prop = 1L) %>% arrange(abs(int), dbl) %>% collect() ), @@ -170,3 +159,6 @@ test_that("arrange", { fixed = TRUE ) }) + +# restore previous collation locale setting +Sys.setlocale("LC_COLLATE") From da1b7464ae0092e95f77829e0dabc48529074b93 Mon Sep 17 00:00:00 2001 From: Ian Cook Date: Thu, 25 Mar 2021 11:46:17 -0400 Subject: [PATCH 45/49] Improve comments --- r/R/dplyr.R | 3 +++ r/tests/testthat/test-compute-sort.R | 16 ++++++++++++---- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/r/R/dplyr.R b/r/R/dplyr.R index 1b507822f08..e26d0bf2fb1 100644 --- a/r/R/dplyr.R +++ b/r/R/dplyr.R @@ -792,6 +792,9 @@ find_and_remove_desc <- function(quosure) { call. = FALSE ) } + # Use a while loop to remove any number of nested pairs of enclosing + # parentheses and any number of nested desc() calls. In the case of multiple + # nested desc() calls, each one toggles the sort order. while (identical(typeof(expr), "language") && is.call(expr)) { if (identical(expr[[1]], quote(`(`))) { # remove enclosing parentheses diff --git a/r/tests/testthat/test-compute-sort.R b/r/tests/testthat/test-compute-sort.R index f3f4979a3fb..08777385fd7 100644 --- a/r/tests/testthat/test-compute-sort.R +++ b/r/tests/testthat/test-compute-sort.R @@ -36,12 +36,16 @@ test_that("sort(Scalar) is identity function", { test_that("Array$SortIndices()", { int <- tbl$int - int <- int[!duplicated(int)] # needed because ties in int + # Remove ties because they could give non-deterministic sort indices, and this + # test compares sort indices. Other tests compare sorted values, which are + # deterministic in the case of ties. + int <- int[!duplicated(int)] expect_equal( Array$create(int)$SortIndices(), Array$create(order(int) - 1L, type = uint64()) ) - int <- na.omit(int) # needed because ARROW-12063 + # Need to remove NAs because ARROW-12063 + int <- na.omit(int) expect_equal( Array$create(int)$SortIndices(descending = TRUE), Array$create(rev(order(int)) - 1, type = uint64()) @@ -50,12 +54,16 @@ test_that("Array$SortIndices()", { test_that("ChunkedArray$SortIndices()", { int <- tbl$int - int <- int[!duplicated(int)] # needed because ties in int + # Remove ties because they could give non-deterministic sort indices, and this + # test compares sort indices. Other tests compare sorted values, which are + # deterministic in the case of ties. + int <- int[!duplicated(int)] expect_equal( ChunkedArray$create(int[1:4], int[5:length(int)])$SortIndices(), Array$create(order(int) - 1L, type = uint64()) ) - int <- na.omit(int) # needed because ARROW-12063 + # Need to remove NAs because ARROW-12063 + int <- na.omit(int) expect_equal( ChunkedArray$create(int[1:4], int[5:length(int)])$SortIndices(descending = TRUE), Array$create(rev(order(int)) - 1, type = uint64()) From 49aaf1ef8547519fbe780f579a3eab9061a9eef8 Mon Sep 17 00:00:00 2001 From: Ian Cook Date: Thu, 25 Mar 2021 16:56:00 -0400 Subject: [PATCH 46/49] Add and reorganize tests --- r/tests/testthat/test-dplyr-arrange.R | 53 +++++++++++++++++++++++---- 1 file changed, 46 insertions(+), 7 deletions(-) diff --git a/r/tests/testthat/test-dplyr-arrange.R b/r/tests/testthat/test-dplyr-arrange.R index b83dc739bbe..ffe8e86fdbc 100644 --- a/r/tests/testthat/test-dplyr-arrange.R +++ b/r/tests/testthat/test-dplyr-arrange.R @@ -23,19 +23,13 @@ tbl <- slice_sample(example_data_for_sorting, prop = 1L) # use the C locale for string collation in R (ARROW-12046) Sys.setlocale("LC_COLLATE", "C") -test_that("arrange", { +test_that("arrange() on integer, double, and character columns", { expect_dplyr_equal( input %>% arrange(int, chr) %>% collect(), tbl ) - expect_dplyr_equal( - input %>% - arrange(dttm, int) %>% - collect(), - tbl - ) expect_dplyr_equal( input %>% arrange(int, desc(dbl)) %>% @@ -131,6 +125,14 @@ test_that("arrange", { collect(), tbl ) + test_sort_col <- "chr" + expect_dplyr_equal( + input %>% + arrange(!!sym(test_sort_col)) %>% + collect(), + tbl %>% + select(chr, lgl) + ) test_sort_cols <- c("int", "dbl") expect_dplyr_equal( input %>% @@ -151,6 +153,36 @@ test_that("arrange", { "not supported in Arrow", fixed = TRUE ) +}) + +test_that("arrange() on datetime columns", { + expect_dplyr_equal( + input %>% + arrange(dttm, int) %>% + collect(), + tbl + ) + skip("Sorting by only a single timestamp column fails (ARROW-12087)") + expect_dplyr_equal( + input %>% + arrange(dttm) %>% + collect(), + tbl %>% + select(dttm, grp) + ) +}) + +test_that("arrange() on logical columns", { + skip("Sorting by bool columns is not supported (ARROW-12016)") + expect_dplyr_equal( + input %>% + arrange(lgl, int) %>% + collect(), + tbl + ) +}) + +test_that("arrange() with bad inputs", { expect_error( tbl %>% Table$create() %>% @@ -158,6 +190,13 @@ test_that("arrange", { "does not contain any field names", fixed = TRUE ) + expect_error( + tbl %>% + Table$create() %>% + arrange(aertidjfgjksertyj), + "not found", + fixed = TRUE + ) }) # restore previous collation locale setting From a4e69c2266465b59a60647776b532cadddd11103 Mon Sep 17 00:00:00 2001 From: Ian Cook Date: Thu, 25 Mar 2021 17:08:13 -0400 Subject: [PATCH 47/49] More bad input tests --- r/tests/testthat/test-dplyr-arrange.R | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/r/tests/testthat/test-dplyr-arrange.R b/r/tests/testthat/test-dplyr-arrange.R index ffe8e86fdbc..62e8119f601 100644 --- a/r/tests/testthat/test-dplyr-arrange.R +++ b/r/tests/testthat/test-dplyr-arrange.R @@ -190,6 +190,13 @@ test_that("arrange() with bad inputs", { "does not contain any field names", fixed = TRUE ) + expect_error( + tbl %>% + Table$create() %>% + arrange(2 + 2), + "does not contain any field names", + fixed = TRUE + ) expect_error( tbl %>% Table$create() %>% @@ -197,6 +204,13 @@ test_that("arrange() with bad inputs", { "not found", fixed = TRUE ) + expect_error( + tbl %>% + Table$create() %>% + arrange(desc(aertidjfgjksertyj + iaermxiwerksxsdqq)), + "not found", + fixed = TRUE + ) }) # restore previous collation locale setting From dd2e0efd05ff46fae184ebb18965a39cdbfcce96 Mon Sep 17 00:00:00 2001 From: Ian Cook Date: Thu, 25 Mar 2021 20:08:03 -0400 Subject: [PATCH 48/49] Improve tests --- r/tests/testthat/helper-data.R | 17 ++++++++----- r/tests/testthat/test-compute-sort.R | 6 ----- r/tests/testthat/test-dplyr-arrange.R | 36 ++++++++++++--------------- 3 files changed, 27 insertions(+), 32 deletions(-) diff --git a/r/tests/testthat/helper-data.R b/r/tests/testthat/helper-data.R index eaa2556d2a0..43b5bf0354f 100644 --- a/r/tests/testthat/helper-data.R +++ b/r/tests/testthat/helper-data.R @@ -135,14 +135,19 @@ example_with_logical_factors <- tibble::tibble( ) ) -# The values in each column of this tibble are in ascending order in the C locale. -# There are some ties, so tests should use two or more columns to ensure -# deterministic order. libarrow uses the C locale for string collation. testthat -# uses the C locale for string collation inside calls to test_that(). To run test -# code outside of test_that() calls, set the collation locale to "C" by running: +# The values in each column of this tibble are in ascending order. There are +# some ties, so tests should use two or more columns to ensure deterministic +# sort order. The Arrow C++ library orders strings lexicographically as byte +# strings. The order of a string array sorted by Arrow will not match the order +# of an equivalent character vector sorted by R unless you set the R collation +# locale to "C" by running: # Sys.setlocale("LC_COLLATE", "C") -# When finished, restore the default collation locale by running: +# These test scripts set that, but if you are running individual tests you might +# need to set it manually. When finished, you can restore the default +# collation locale by running: # Sys.setlocale("LC_COLLATE") +# In the future, the string collation locale used by the Arrow C++ library might +# be configurable (ARROW-12046). example_data_for_sorting <- tibble::tibble( int = c(-.Machine$integer.max, -101L, -100L, 0L, 0L, 1L, 100L, 1000L, .Machine$integer.max, NA_integer_), dbl = c(-Inf, -.Machine$double.xmax, -.Machine$double.xmin, 0, .Machine$double.xmin, pi, .Machine$double.xmax, Inf, NaN, NA_real_), diff --git a/r/tests/testthat/test-compute-sort.R b/r/tests/testthat/test-compute-sort.R index 08777385fd7..ba38d4ce37e 100644 --- a/r/tests/testthat/test-compute-sort.R +++ b/r/tests/testthat/test-compute-sort.R @@ -22,9 +22,6 @@ library(dplyr) # randomize order of rows in test data tbl <- slice_sample(example_data_for_sorting, prop = 1L) -# use the C locale for string collation in R (ARROW-12046) -Sys.setlocale("LC_COLLATE", "C") - test_that("sort(Scalar) is identity function", { int <- Scalar$create(42L) expect_equal(sort(int), int) @@ -166,6 +163,3 @@ test_that("RecordBatch$SortIndices()", { tbl %>% arrange(desc(chr), desc(int), desc(dbl)) ) }) - -# restore previous collation locale setting -Sys.setlocale("LC_COLLATE") diff --git a/r/tests/testthat/test-dplyr-arrange.R b/r/tests/testthat/test-dplyr-arrange.R index 62e8119f601..be1b18c0fff 100644 --- a/r/tests/testthat/test-dplyr-arrange.R +++ b/r/tests/testthat/test-dplyr-arrange.R @@ -20,9 +20,6 @@ library(dplyr) # randomize order of rows in test data tbl <- slice_sample(example_data_for_sorting, prop = 1L) -# use the C locale for string collation in R (ARROW-12046) -Sys.setlocale("LC_COLLATE", "C") - test_that("arrange() on integer, double, and character columns", { expect_dplyr_equal( input %>% @@ -197,21 +194,20 @@ test_that("arrange() with bad inputs", { "does not contain any field names", fixed = TRUE ) - expect_error( - tbl %>% - Table$create() %>% - arrange(aertidjfgjksertyj), - "not found", - fixed = TRUE - ) - expect_error( - tbl %>% - Table$create() %>% - arrange(desc(aertidjfgjksertyj + iaermxiwerksxsdqq)), - "not found", - fixed = TRUE - ) + with_language("en", { + expect_error( + tbl %>% + Table$create() %>% + arrange(aertidjfgjksertyj), + "not found", + fixed = TRUE + ) + expect_error( + tbl %>% + Table$create() %>% + arrange(desc(aertidjfgjksertyj + iaermxiwerksxsdqq)), + "not found", + fixed = TRUE + ) + }) }) - -# restore previous collation locale setting -Sys.setlocale("LC_COLLATE") From 434dcb56892d5a1367d83c6d6bf377f4ab16b62b Mon Sep 17 00:00:00 2001 From: Ian Cook Date: Thu, 25 Mar 2021 20:12:58 -0400 Subject: [PATCH 49/49] Tests fixes --- r/tests/testthat/helper-arrow.R | 4 ++++ r/tests/testthat/test-dplyr-arrange.R | 30 +++++++++++++-------------- 2 files changed, 18 insertions(+), 16 deletions(-) diff --git a/r/tests/testthat/helper-arrow.R b/r/tests/testthat/helper-arrow.R index 64a4991b827..a20ec23961e 100644 --- a/r/tests/testthat/helper-arrow.R +++ b/r/tests/testthat/helper-arrow.R @@ -30,6 +30,10 @@ MAX_INT <- 2147483647L # Make sure this is unset Sys.setenv(ARROW_PRE_0_15_IPC_FORMAT = "") + +# use the C locale for string collation (ARROW-12046) +Sys.setlocale("LC_COLLATE", "C") + # Set English language so that error messages aren't internationalized # (R CMD check does this, but in case you're running outside of check) Sys.setenv(LANGUAGE = "en") diff --git a/r/tests/testthat/test-dplyr-arrange.R b/r/tests/testthat/test-dplyr-arrange.R index be1b18c0fff..b476c032945 100644 --- a/r/tests/testthat/test-dplyr-arrange.R +++ b/r/tests/testthat/test-dplyr-arrange.R @@ -194,20 +194,18 @@ test_that("arrange() with bad inputs", { "does not contain any field names", fixed = TRUE ) - with_language("en", { - expect_error( - tbl %>% - Table$create() %>% - arrange(aertidjfgjksertyj), - "not found", - fixed = TRUE - ) - expect_error( - tbl %>% - Table$create() %>% - arrange(desc(aertidjfgjksertyj + iaermxiwerksxsdqq)), - "not found", - fixed = TRUE - ) - }) + expect_error( + tbl %>% + Table$create() %>% + arrange(aertidjfgjksertyj), + "not found", + fixed = TRUE + ) + expect_error( + tbl %>% + Table$create() %>% + arrange(desc(aertidjfgjksertyj + iaermxiwerksxsdqq)), + "not found", + fixed = TRUE + ) })