diff --git a/r/NAMESPACE b/r/NAMESPACE index 96c09615896..725d441b3e1 100644 --- a/r/NAMESPACE +++ b/r/NAMESPACE @@ -76,6 +76,8 @@ S3method(read_message,InputStream) S3method(read_message,MessageReader) S3method(read_message,default) S3method(row.names,ArrowTabular) +S3method(sort,ArrowDatum) +S3method(sort,Scalar) S3method(sum,ArrowDatum) S3method(tail,ArrowDatum) S3method(tail,ArrowTabular) @@ -291,7 +293,9 @@ importFrom(rlang,is_integerish) importFrom(rlang,list2) importFrom(rlang,new_data_mask) importFrom(rlang,new_environment) +importFrom(rlang,quo_get_expr) importFrom(rlang,quo_is_null) +importFrom(rlang,quo_set_expr) importFrom(rlang,quos) importFrom(rlang,set_names) importFrom(rlang,syms) diff --git a/r/R/array.R b/r/R/array.R index aa164eaaf91..1d63c5735a7 100644 --- a/r/R/array.R +++ b/r/R/array.R @@ -73,6 +73,8 @@ #' (R vector or Array Array) `i`. #' - `$Filter(i, keep_na = TRUE)`: return an `Array` with values at positions where logical #' vector (or Arrow boolean Array) `i` is `TRUE`. +#' - `$SortIndices(descending = FALSE)`: return an `Array` of integer positions that can be +#' used to rearrange the `Array` in ascending or descending order #' - `$RangeEquals(other, start_idx, end_idx, other_start_idx)` : #' - `$cast(target_type, safe = TRUE, options = cast_options(safe))`: Alter the #' data in the array to change its type. @@ -131,6 +133,12 @@ Array <- R6Class("Array", assert_is(i, "Array") call_function("filter", self, i, options = list(keep_na = keep_na)) }, + SortIndices = function(descending = FALSE) { + assert_that(is.logical(descending)) + assert_that(length(descending) == 1L) + assert_that(!is.na(descending)) + call_function("array_sort_indices", self, options = list(order = descending)) + }, RangeEquals = function(other, start_idx, end_idx, other_start_idx = 0L) { assert_is(other, "Array") Array__RangeEquals(self, other, start_idx, end_idx, other_start_idx) diff --git a/r/R/arrow-datum.R b/r/R/arrow-datum.R index f4d9ad346aa..99940e74cbd 100644 --- a/r/R/arrow-datum.R +++ b/r/R/arrow-datum.R @@ -138,3 +138,23 @@ as.integer.ArrowDatum <- function(x, ...) as.integer(as.vector(x), ...) #' @export as.character.ArrowDatum <- function(x, ...) as.character(as.vector(x), ...) + +#' @export +sort.ArrowDatum <- function(x, decreasing = FALSE, na.last = NA, ...) { + # Arrow always sorts nulls at the end of the array. This corresponds to + # sort(na.last = TRUE). For the other two cases (na.last = NA and + # na.last = FALSE) we need to use workarounds. + # TODO: Implement this more cleanly after ARROW-12063 + if (is.na(na.last)) { + # Filter out NAs before sorting + x <- x$Filter(!is.na(x)) + x$Take(x$SortIndices(descending = decreasing)) + } else if (na.last) { + x$Take(x$SortIndices(descending = decreasing)) + } else { + # Create a new array that encodes missing values as 1 and non-missing values + # as 0. Sort descending by that array first to get the NAs at the beginning + tbl <- Table$create(x = x, `is_na` = as.integer(is.na(x))) + tbl$x$Take(tbl$SortIndices(names = c("is_na", "x"), descending = c(TRUE, decreasing))) + } +} diff --git a/r/R/arrow-package.R b/r/R/arrow-package.R index c1d76abfd71..9ced0e88449 100644 --- a/r/R/arrow-package.R +++ b/r/R/arrow-package.R @@ -18,7 +18,7 @@ #' @importFrom R6 R6Class #' @importFrom purrr as_mapper map map2 map_chr map_dfr map_int map_lgl keep #' @importFrom assertthat assert_that is.string -#' @importFrom rlang list2 %||% is_false abort dots_n warn enquo quo_is_null enquos is_integerish quos eval_tidy new_data_mask syms env new_environment env_bind as_label set_names exec is_bare_character +#' @importFrom rlang list2 %||% is_false abort dots_n warn enquo quo_is_null enquos is_integerish quos eval_tidy new_data_mask syms env new_environment env_bind as_label set_names exec is_bare_character quo_get_expr quo_set_expr #' @importFrom tidyselect vars_select #' @useDynLib arrow, .registration = TRUE #' @keywords internal diff --git a/r/R/arrow-tabular.R b/r/R/arrow-tabular.R index a41586f26b3..157b799f3b6 100644 --- a/r/R/arrow-tabular.R +++ b/r/R/arrow-tabular.R @@ -38,6 +38,23 @@ ArrowTabular <- R6Class("ArrowTabular", inherit = ArrowObject, } assert_that(is.Array(i, "bool")) call_function("filter", self, i, options = list(keep_na = keep_na)) + }, + SortIndices = function(names, descending = FALSE) { + assert_that(is.character(names)) + assert_that(length(names) > 0) + assert_that(!any(is.na(names))) + if (length(descending) == 1L) { + descending <- rep_len(descending, length(names)) + } + assert_that(is.logical(descending)) + assert_that(identical(length(names), length(descending))) + assert_that(!any(is.na(descending))) + call_function( + "sort_indices", + self, + # cpp11 does not support logical vectors so convert to integer + options = list(names = names, orders = as.integer(descending)) + ) } ) ) diff --git a/r/R/arrowExports.R b/r/R/arrowExports.R index 86ee1303d1c..2c7bf5c19f6 100644 --- a/r/R/arrowExports.R +++ b/r/R/arrowExports.R @@ -1476,6 +1476,14 @@ Scalar__type <- function(s){ .Call(`_arrow_Scalar__type`, s) } +Scalar__Equals <- function(lhs, rhs){ + .Call(`_arrow_Scalar__Equals`, lhs, rhs) +} + +Scalar__ApproxEquals <- function(lhs, rhs){ + .Call(`_arrow_Scalar__ApproxEquals`, lhs, rhs) +} + schema_ <- function(fields){ .Call(`_arrow_schema_`, fields) } diff --git a/r/R/chunked-array.R b/r/R/chunked-array.R index d639b235f3f..a7f9c8f790c 100644 --- a/r/R/chunked-array.R +++ b/r/R/chunked-array.R @@ -41,6 +41,8 @@ #' coerced to an R vector before taking. #' - `$Filter(i, keep_na = TRUE)`: return a `ChunkedArray` with values at positions where #' logical vector or Arrow boolean-type `(Chunked)Array` `i` is `TRUE`. +#' - `$SortIndices(descending = FALSE)`: return an `Array` of integer positions that can be +#' used to rearrange the `ChunkedArray` in ascending or descending order #' - `$cast(target_type, safe = TRUE, options = cast_options(safe))`: Alter the #' data in the array to change its type. #' - `$null_count()`: The number of null entries in the array @@ -83,6 +85,18 @@ ChunkedArray <- R6Class("ChunkedArray", inherit = ArrowDatum, } call_function("filter", self, i, options = list(keep_na = keep_na)) }, + SortIndices = function(descending = FALSE) { + assert_that(is.logical(descending)) + assert_that(length(descending) == 1L) + assert_that(!is.na(descending)) + # TODO: after ARROW-12042 is closed, review whether this and the + # Array$SortIndices definition can be consolidated + call_function( + "sort_indices", + self, + options = list(names = "", orders = as.integer(descending)) + ) + }, View = function(type) { ChunkedArray__View(self, as_type(type)) }, diff --git a/r/R/dataset-scan.R b/r/R/dataset-scan.R index 7e148863226..f7ede663c7f 100644 --- a/r/R/dataset-scan.R +++ b/r/R/dataset-scan.R @@ -76,7 +76,7 @@ Scanner$create <- function(dataset, } return(Scanner$create( dataset$.data, - dataset$selected_columns, + c(dataset$selected_columns, dataset$temp_columns), dataset$filtered_rows, use_threads, batch_size, @@ -148,7 +148,8 @@ map_batches <- function(X, FUN, ..., .data.frame = TRUE) { lapply(scan_task$Execute(), function(batch) { # message("Processing Batch") # This inner lapply cannot be parallelized - # TODO: wrap batch in arrow_dplyr_query with X$selected_columns and X$group_by_vars + # TODO: wrap batch in arrow_dplyr_query with X$selected_columns, + # X$temp_columns, and X$group_by_vars # if X is arrow_dplyr_query, if some other arg (.dplyr?) == TRUE FUN(batch, ...) }) diff --git a/r/R/dplyr.R b/r/R/dplyr.R index 2745f69d90c..e26d0bf2fb1 100644 --- a/r/R/dplyr.R +++ b/r/R/dplyr.R @@ -46,7 +46,13 @@ arrow_dplyr_query <- function(.data) { # drop_empty_groups is a logical value indicating whether to drop # groups formed by factor levels that don't appear in the data. It # should be non-null only when the data is grouped. - drop_empty_groups = NULL + drop_empty_groups = NULL, + # arrange_vars will be a list of expressions named by their associated + # column names + arrange_vars = list(), + # arrange_desc will be a logical vector indicating the sort order for each + # expression in arrange_vars (FALSE for ascending, TRUE for descending) + arrange_desc = logical() ), class = "arrow_dplyr_query" ) @@ -80,6 +86,25 @@ print.arrow_dplyr_query <- function(x, ...) { if (length(x$group_by_vars)) { cat("* Grouped by ", paste(x$group_by_vars, collapse = ", "), "\n", sep = "") } + if (length(x$arrange_vars)) { + if (query_on_dataset(x)) { + arrange_strings <- map_chr(x$arrange_vars, function(x) x$ToString()) + } else { + arrange_strings <- map_chr(x$arrange_vars, .format_array_expression) + } + cat( + "* Sorted by ", + paste( + paste0( + arrange_strings, + " [", ifelse(x$arrange_desc, "desc", "asc"), "]" + ), + collapse = ", " + ), + "\n", + sep = "" + ) + } cat("See $.data for the source Arrow object\n") invisible(x) } @@ -378,6 +403,7 @@ set_filters <- function(.data, expressions) { collect.arrow_dplyr_query <- function(x, as_data_frame = TRUE, ...) { x <- ensure_group_vars(x) + x <- ensure_arrange_vars(x) # this sets x$temp_columns # Pull only the selected rows and cols into R if (query_on_dataset(x)) { # See dataset.R for Dataset and Scanner(Builder) classes @@ -391,10 +417,14 @@ collect.arrow_dplyr_query <- function(x, as_data_frame = TRUE, ...) { } else { filter <- eval_array_expression(x$filtered_rows, x$.data) } - # TODO: shortcut if identical(names(x$.data), find_array_refs(x$selected_columns))? - tab <- x$.data[filter, find_array_refs(x$selected_columns), keep_na = FALSE] + # TODO: shortcut if identical(names(x$.data), find_array_refs(c(x$selected_columns, x$temp_columns)))? + tab <- x$.data[ + filter, + find_array_refs(c(x$selected_columns, x$temp_columns)), + keep_na = FALSE + ] # Now evaluate those expressions on the filtered table - cols <- lapply(x$selected_columns, eval_array_expression, data = tab) + cols <- lapply(c(x$selected_columns, x$temp_columns), eval_array_expression, data = tab) if (length(cols) == 0) { tab <- tab[, integer(0)] } else { @@ -405,6 +435,14 @@ collect.arrow_dplyr_query <- function(x, as_data_frame = TRUE, ...) { } } } + # Arrange rows + if (length(x$arrange_vars) > 0) { + tab <- tab[ + tab$SortIndices(names(x$arrange_vars), x$arrange_desc), + names(x$selected_columns), # this omits x$temp_columns from the result + drop = FALSE + ] + } if (as_data_frame) { df <- as.data.frame(tab) tab$invalidate() @@ -432,6 +470,20 @@ ensure_group_vars <- function(x) { x } +ensure_arrange_vars <- function(x) { + # The arrange() operation is not performed until later, because: + # - It must be performed after mutate(), to enable sorting by new columns. + # - It should be performed after filter() and select(), for efficiency. + # However, we need users to be able to arrange() by columns and expressions + # that are *not* returned in the query result. To enable this, we must + # *temporarily* include these columns and expressions in the projection. We + # use x$temp_columns to store these. Later, after the arrange() operation has + # been performed, these are omitted from the result. This differs from the + # columns in x$group_by_vars which *are* returned in the result. + x$temp_columns <- x$arrange_vars[!names(x$arrange_vars) %in% names(x$selected_columns)] + x +} + restore_dplyr_features <- function(df, query) { # An arrow_dplyr_query holds some attributes that Arrow doesn't know about # After calling collect(), make sure these features are carried over @@ -689,17 +741,80 @@ abandon_ship <- function(call, .data, msg = NULL) { eval.parent(call, 2) } -arrange.arrow_dplyr_query <- function(.data, ...) { +arrange.arrow_dplyr_query <- function(.data, ..., .by_group = FALSE) { + call <- match.call() + exprs <- quos(...) + if (.by_group) { + # when the data is is grouped and .by_group is TRUE, order the result by + # the grouping columns first + exprs <- c(quos(!!!dplyr::groups(.data)), exprs) + } + if (length(exprs) == 0) { + # Nothing to do + return(.data) + } .data <- arrow_dplyr_query(.data) - if (query_on_dataset(.data)) { - not_implemented_for_dataset("arrange()") + # find and remove any dplyr::desc() and tidy-eval + # the arrange expressions inside an Arrow data_mask + sorts <- vector("list", length(exprs)) + descs <- logical(0) + mask <- arrow_mask(.data) + for (i in seq_along(exprs)) { + x <- find_and_remove_desc(exprs[[i]]) + exprs[[i]] <- x[["quos"]] + sorts[[i]] <- arrow_eval(exprs[[i]], mask) + if (inherits(sorts[[i]], "try-error")) { + msg <- paste('Expression', as_label(exprs[[i]]), 'not supported in Arrow') + return(abandon_ship(call, .data, msg)) + } + names(sorts)[i] <- as_label(exprs[[i]]) + descs[i] <- x[["desc"]] } - # TODO(ARROW-11703) move this to Arrow - call <- match.call() - abandon_ship(call, .data) + .data$arrange_vars <- c(sorts, .data$arrange_vars) + .data$arrange_desc <- c(descs, .data$arrange_desc) + .data } arrange.Dataset <- arrange.ArrowTabular <- arrange.arrow_dplyr_query +# Helper to handle desc() in arrange() +# * Takes a quosure as input +# * Returns a list with two elements: +# 1. The quosure with any wrapping parentheses and desc() removed +# 2. A logical value indicating whether desc() was found +# * Performs some other validation +find_and_remove_desc <- function(quosure) { + expr <- quo_get_expr(quosure) + descending <- FALSE + if (length(all.vars(expr)) < 1L) { + stop( + "Expression in arrange() does not contain any field names: ", + deparse(expr), + call. = FALSE + ) + } + # Use a while loop to remove any number of nested pairs of enclosing + # parentheses and any number of nested desc() calls. In the case of multiple + # nested desc() calls, each one toggles the sort order. + while (identical(typeof(expr), "language") && is.call(expr)) { + if (identical(expr[[1]], quote(`(`))) { + # remove enclosing parentheses + expr <- expr[[2]] + } else if (identical(expr[[1]], quote(desc))) { + # remove desc() and toggle descending + expr <- expr[[2]] + descending <- !descending + } else { + break + } + } + return( + list( + quos = quo_set_expr(quosure, expr), + desc = descending + ) + ) +} + query_on_dataset <- function(x) inherits(x$.data, "Dataset") not_implemented_for_dataset <- function(method) { diff --git a/r/R/record-batch.R b/r/R/record-batch.R index fdb54deb881..8e7382a70b4 100644 --- a/r/R/record-batch.R +++ b/r/R/record-batch.R @@ -57,6 +57,11 @@ #' integers (R vector or Array Array) `i`. #' - `$Filter(i, keep_na = TRUE)`: return an `RecordBatch` with rows at positions where logical #' vector (or Arrow boolean Array) `i` is `TRUE`. +#' - `$SortIndices(names, descending = FALSE)`: return an `Array` of integer row +#' positions that can be used to rearrange the `RecordBatch` in ascending or +#' descending order by the first named column, breaking ties with further named +#' columns. `descending` can be a logical vector of length one or of the same +#' length as `names`. #' - `$serialize()`: Returns a raw vector suitable for interprocess communication #' - `$cast(target_schema, safe = TRUE, options = cast_options(safe))`: Alter #' the schema of the record batch. @@ -99,7 +104,7 @@ RecordBatch <- R6Class("RecordBatch", inherit = ArrowTabular, RecordBatch__Slice2(self, offset, length) } }, - # Take and Filter are methods on ArrowTabular + # Take, Filter, and SortIndices are methods on ArrowTabular serialize = function() ipc___SerializeRecordBatch__Raw(self), to_data_frame = function() { RecordBatch__to_dataframe(self, use_threads = option_use_threads()) diff --git a/r/R/scalar.R b/r/R/scalar.R index d6955423b53..cbda5964a2c 100644 --- a/r/R/scalar.R +++ b/r/R/scalar.R @@ -33,7 +33,13 @@ Scalar <- R6Class("Scalar", public = list( ToString = function() Scalar__ToString(self), as_vector = function() Scalar__as_vector(self), - as_array = function() MakeArrayFromScalar(self) + as_array = function() MakeArrayFromScalar(self), + Equals = function(other, ...) { + inherits(other, "Scalar") && Scalar__Equals(self, other) + }, + ApproxEquals = function(other, ...) { + inherits(other, "Scalar") && Scalar__ApproxEquals(self, other) + } ), active = list( is_valid = function() Scalar__is_valid(self), @@ -68,3 +74,6 @@ length.Scalar <- function(x) 1L #' @export is.na.Scalar <- function(x) !x$is_valid + +#' @export +sort.Scalar <- function(x, decreasing = FALSE, ...) x diff --git a/r/R/table.R b/r/R/table.R index d2c9960e6d2..fdf3f5cc20d 100644 --- a/r/R/table.R +++ b/r/R/table.R @@ -65,6 +65,11 @@ #' coerced to an R vector before taking. #' - `$Filter(i, keep_na = TRUE)`: return an `Table` with rows at positions where logical #' vector or Arrow boolean-type `(Chunked)Array` `i` is `TRUE`. +#' - `$SortIndices(names, descending = FALSE)`: return an `Array` of integer row +#' positions that can be used to rearrange the `Table` in ascending or descending +#' order by the first named column, breaking ties with further named columns. +#' `descending` can be a logical vector of length one or of the same length as +#' `names`. #' - `$serialize(output_stream, ...)`: Write the table to the given #' [OutputStream] #' - `$cast(target_schema, safe = TRUE, options = cast_options(safe))`: Alter @@ -122,7 +127,7 @@ Table <- R6Class("Table", inherit = ArrowTabular, Table__Slice2(self, offset, length) } }, - # Take and Filter are methods on ArrowTabular + # Take, Filter, and SortIndices are methods on ArrowTabular Equals = function(other, check_metadata = FALSE, ...) { inherits(other, "Table") && Table__Equals(self, other, isTRUE(check_metadata)) }, diff --git a/r/man/ChunkedArray.Rd b/r/man/ChunkedArray.Rd index 533931ae972..90dd2e39e40 100644 --- a/r/man/ChunkedArray.Rd +++ b/r/man/ChunkedArray.Rd @@ -38,6 +38,8 @@ integers \code{i}. If \code{i} is an Arrow \code{Array} or \code{ChunkedArray}, coerced to an R vector before taking. \item \verb{$Filter(i, keep_na = TRUE)}: return a \code{ChunkedArray} with values at positions where logical vector or Arrow boolean-type \verb{(Chunked)Array} \code{i} is \code{TRUE}. +\item \verb{$SortIndices(descending = FALSE)}: return an \code{Array} of integer positions that can be +used to rearrange the \code{ChunkedArray} in ascending or descending order \item \verb{$cast(target_type, safe = TRUE, options = cast_options(safe))}: Alter the data in the array to change its type. \item \verb{$null_count()}: The number of null entries in the array diff --git a/r/man/RecordBatch.Rd b/r/man/RecordBatch.Rd index 06f9f67abe2..184fea99c7f 100644 --- a/r/man/RecordBatch.Rd +++ b/r/man/RecordBatch.Rd @@ -57,6 +57,11 @@ of the table if \code{NULL}, the default. integers (R vector or Array Array) \code{i}. \item \verb{$Filter(i, keep_na = TRUE)}: return an \code{RecordBatch} with rows at positions where logical vector (or Arrow boolean Array) \code{i} is \code{TRUE}. +\item \verb{$SortIndices(names, descending = FALSE)}: return an \code{Array} of integer row +positions that can be used to rearrange the \code{RecordBatch} in ascending or +descending order by the first named column, breaking ties with further named +columns. \code{descending} can be a logical vector of length one or of the same +length as \code{names}. \item \verb{$serialize()}: Returns a raw vector suitable for interprocess communication \item \verb{$cast(target_schema, safe = TRUE, options = cast_options(safe))}: Alter the schema of the record batch. diff --git a/r/man/Table.Rd b/r/man/Table.Rd index 14c0b0bf260..98a5c354ced 100644 --- a/r/man/Table.Rd +++ b/r/man/Table.Rd @@ -56,6 +56,11 @@ integers \code{i}. If \code{i} is an Arrow \code{Array} or \code{ChunkedArray}, coerced to an R vector before taking. \item \verb{$Filter(i, keep_na = TRUE)}: return an \code{Table} with rows at positions where logical vector or Arrow boolean-type \verb{(Chunked)Array} \code{i} is \code{TRUE}. +\item \verb{$SortIndices(names, descending = FALSE)}: return an \code{Array} of integer row +positions that can be used to rearrange the \code{Table} in ascending or descending +order by the first named column, breaking ties with further named columns. +\code{descending} can be a logical vector of length one or of the same length as +\code{names}. \item \verb{$serialize(output_stream, ...)}: Write the table to the given \link{OutputStream} \item \verb{$cast(target_schema, safe = TRUE, options = cast_options(safe))}: Alter diff --git a/r/man/array.Rd b/r/man/array.Rd index fbc91e4dc35..f65afe9fbc3 100644 --- a/r/man/array.Rd +++ b/r/man/array.Rd @@ -71,6 +71,8 @@ until the end of the array. (R vector or Array Array) \code{i}. \item \verb{$Filter(i, keep_na = TRUE)}: return an \code{Array} with values at positions where logical vector (or Arrow boolean Array) \code{i} is \code{TRUE}. +\item \verb{$SortIndices(descending = FALSE)}: return an \code{Array} of integer positions that can be +used to rearrange the \code{Array} in ascending or descending order \item \verb{$RangeEquals(other, start_idx, end_idx, other_start_idx)} : \item \verb{$cast(target_type, safe = TRUE, options = cast_options(safe))}: Alter the data in the array to change its type. diff --git a/r/src/arrowExports.cpp b/r/src/arrowExports.cpp index 35dab553fbd..b06a2696e50 100644 --- a/r/src/arrowExports.cpp +++ b/r/src/arrowExports.cpp @@ -3793,6 +3793,24 @@ BEGIN_CPP11 return cpp11::as_sexp(Scalar__type(s)); END_CPP11 } +// scalar.cpp +bool Scalar__Equals(const std::shared_ptr& lhs, const std::shared_ptr& rhs); +extern "C" SEXP _arrow_Scalar__Equals(SEXP lhs_sexp, SEXP rhs_sexp){ +BEGIN_CPP11 + arrow::r::Input&>::type lhs(lhs_sexp); + arrow::r::Input&>::type rhs(rhs_sexp); + return cpp11::as_sexp(Scalar__Equals(lhs, rhs)); +END_CPP11 +} +// scalar.cpp +bool Scalar__ApproxEquals(const std::shared_ptr& lhs, const std::shared_ptr& rhs); +extern "C" SEXP _arrow_Scalar__ApproxEquals(SEXP lhs_sexp, SEXP rhs_sexp){ +BEGIN_CPP11 + arrow::r::Input&>::type lhs(lhs_sexp); + arrow::r::Input&>::type rhs(rhs_sexp); + return cpp11::as_sexp(Scalar__ApproxEquals(lhs, rhs)); +END_CPP11 +} // schema.cpp std::shared_ptr schema_(const std::vector>& fields); extern "C" SEXP _arrow_schema_(SEXP fields_sexp){ @@ -4544,6 +4562,8 @@ static const R_CallMethodDef CallEntries[] = { { "_arrow_MakeArrayFromScalar", (DL_FUNC) &_arrow_MakeArrayFromScalar, 1}, { "_arrow_Scalar__is_valid", (DL_FUNC) &_arrow_Scalar__is_valid, 1}, { "_arrow_Scalar__type", (DL_FUNC) &_arrow_Scalar__type, 1}, + { "_arrow_Scalar__Equals", (DL_FUNC) &_arrow_Scalar__Equals, 2}, + { "_arrow_Scalar__ApproxEquals", (DL_FUNC) &_arrow_Scalar__ApproxEquals, 2}, { "_arrow_schema_", (DL_FUNC) &_arrow_schema_, 1}, { "_arrow_Schema__ToString", (DL_FUNC) &_arrow_Schema__ToString, 1}, { "_arrow_Schema__num_fields", (DL_FUNC) &_arrow_Schema__num_fields, 1}, diff --git a/r/src/compute.cpp b/r/src/compute.cpp index 07380354b12..5cf8c7c37d2 100644 --- a/r/src/compute.cpp +++ b/r/src/compute.cpp @@ -144,6 +144,33 @@ std::shared_ptr make_compute_options( return out; } + if (func_name == "array_sort_indices") { + using Order = arrow::compute::SortOrder; + using Options = arrow::compute::ArraySortOptions; + // false means descending, true means ascending + auto order = cpp11::as_cpp(options["order"]); + auto out = + std::make_shared(Options(order ? Order::Descending : Order::Ascending)); + return out; + } + + if (func_name == "sort_indices") { + using Key = arrow::compute::SortKey; + using Order = arrow::compute::SortOrder; + using Options = arrow::compute::SortOptions; + auto names = cpp11::as_cpp>(options["names"]); + // false means descending, true means ascending + // cpp11 does not support bool here so use int + auto orders = cpp11::as_cpp>(options["orders"]); + std::vector keys; + for (size_t i = 0; i < names.size(); i++) { + keys.push_back( + Key(names[i], (orders[i] > 0) ? Order::Descending : Order::Ascending)); + } + auto out = std::make_shared(Options(keys)); + return out; + } + if (func_name == "min_max") { using Options = arrow::compute::MinMaxOptions; auto out = std::make_shared(Options::Defaults()); diff --git a/r/src/scalar.cpp b/r/src/scalar.cpp index c7e2a716bde..057e587e7eb 100644 --- a/r/src/scalar.cpp +++ b/r/src/scalar.cpp @@ -82,4 +82,16 @@ std::shared_ptr Scalar__type(const std::shared_ptrtype; } +// [[arrow::export]] +bool Scalar__Equals(const std::shared_ptr& lhs, + const std::shared_ptr& rhs) { + return lhs->Equals(rhs); +} + +// [[arrow::export]] +bool Scalar__ApproxEquals(const std::shared_ptr& lhs, + const std::shared_ptr& rhs) { + return lhs->ApproxEquals(*rhs); +} + #endif diff --git a/r/tests/testthat/helper-arrow.R b/r/tests/testthat/helper-arrow.R index 64a4991b827..a20ec23961e 100644 --- a/r/tests/testthat/helper-arrow.R +++ b/r/tests/testthat/helper-arrow.R @@ -30,6 +30,10 @@ MAX_INT <- 2147483647L # Make sure this is unset Sys.setenv(ARROW_PRE_0_15_IPC_FORMAT = "") + +# use the C locale for string collation (ARROW-12046) +Sys.setlocale("LC_COLLATE", "C") + # Set English language so that error messages aren't internationalized # (R CMD check does this, but in case you're running outside of check) Sys.setenv(LANGUAGE = "en") diff --git a/r/tests/testthat/helper-data.R b/r/tests/testthat/helper-data.R index 1dd3a6a79d3..43b5bf0354f 100644 --- a/r/tests/testthat/helper-data.R +++ b/r/tests/testthat/helper-data.R @@ -134,3 +134,36 @@ example_with_logical_factors <- tibble::tibble( "hey buddy" ) ) + +# The values in each column of this tibble are in ascending order. There are +# some ties, so tests should use two or more columns to ensure deterministic +# sort order. The Arrow C++ library orders strings lexicographically as byte +# strings. The order of a string array sorted by Arrow will not match the order +# of an equivalent character vector sorted by R unless you set the R collation +# locale to "C" by running: +# Sys.setlocale("LC_COLLATE", "C") +# These test scripts set that, but if you are running individual tests you might +# need to set it manually. When finished, you can restore the default +# collation locale by running: +# Sys.setlocale("LC_COLLATE") +# In the future, the string collation locale used by the Arrow C++ library might +# be configurable (ARROW-12046). +example_data_for_sorting <- tibble::tibble( + int = c(-.Machine$integer.max, -101L, -100L, 0L, 0L, 1L, 100L, 1000L, .Machine$integer.max, NA_integer_), + dbl = c(-Inf, -.Machine$double.xmax, -.Machine$double.xmin, 0, .Machine$double.xmin, pi, .Machine$double.xmax, Inf, NaN, NA_real_), + chr = c("", "", "\"", "&", "ABC", "NULL", "a", "abc", "zzz", NA_character_), + lgl = c(rep(FALSE, 4L), rep(TRUE, 5L), NA), # bool is not supported (ARROW-12016) + dttm = lubridate::ymd_hms(c( + "0000-01-01 00:00:00", + "1919-05-29 13:08:55", + "1955-06-20 04:10:42", + "1973-06-30 11:38:41", + "1987-03-29 12:49:47", + "1991-06-11 19:07:01", + NA_character_, + "2017-08-21 18:26:40", + "2017-08-21 18:26:40", + "9999-12-31 23:59:59" + )), + grp = c(rep("A", 5), rep("B", 5)) +) diff --git a/r/tests/testthat/helper-expectation.R b/r/tests/testthat/helper-expectation.R index 76edea61f57..39cc9e0597a 100644 --- a/r/tests/testthat/helper-expectation.R +++ b/r/tests/testthat/helper-expectation.R @@ -121,4 +121,43 @@ expect_dplyr_error <- function(expr, # A dplyr pipeline with `input` as its star msg, ... ) -} \ No newline at end of file +} + +expect_vector_equal <- function(expr, # A vectorized R expression containing `input` as its input + vec, # A vector as reference, will make Array/ChunkedArray with + skip_array = NULL, # Msg, if should skip Array test + skip_chunked_array = NULL, # Msg, if should skip ChunkedArray test + ...) { + expr <- rlang::enquo(expr) + expected <- rlang::eval_tidy(expr, rlang::new_data_mask(rlang::env(input = vec))) + + skip_msg <- NULL + + if (is.null(skip_array)) { + via_array <- rlang::eval_tidy( + expr, + rlang::new_data_mask(rlang::env(input = Array$create(vec))) + ) + expect_vector(via_array, expected, ...) + } else { + skip_msg <- c(skip_msg, skip_array) + } + + if (is.null(skip_chunked_array)) { + # split input vector into two to exercise ChunkedArray with >1 chunk + vec_split <- length(vec) %/% 2 + vec1 <- vec[seq(from = min(1, length(vec) - 1), to = min(length(vec) - 1, vec_split), by = 1)] + vec2 <- vec[seq(from = min(length(vec), vec_split + 1), to = length(vec), by = 1)] + via_chunked <- rlang::eval_tidy( + expr, + rlang::new_data_mask(rlang::env(input = ChunkedArray$create(vec1, vec2))) + ) + expect_vector(via_chunked, expected, ...) + } else { + skip_msg <- c(skip_msg, skip_chunked_array) + } + + if (!is.null(skip_msg)) { + skip(paste(skip_msg, collpase = "\n")) + } +} diff --git a/r/tests/testthat/test-compute-sort.R b/r/tests/testthat/test-compute-sort.R new file mode 100644 index 00000000000..ba38d4ce37e --- /dev/null +++ b/r/tests/testthat/test-compute-sort.R @@ -0,0 +1,165 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +context("compute: sorting") + +library(dplyr) + +# randomize order of rows in test data +tbl <- slice_sample(example_data_for_sorting, prop = 1L) + +test_that("sort(Scalar) is identity function", { + int <- Scalar$create(42L) + expect_equal(sort(int), int) + dbl <- Scalar$create(3.14) + expect_equal(sort(dbl), dbl) + chr <- Scalar$create("foo") + expect_equal(sort(chr), chr) +}) + +test_that("Array$SortIndices()", { + int <- tbl$int + # Remove ties because they could give non-deterministic sort indices, and this + # test compares sort indices. Other tests compare sorted values, which are + # deterministic in the case of ties. + int <- int[!duplicated(int)] + expect_equal( + Array$create(int)$SortIndices(), + Array$create(order(int) - 1L, type = uint64()) + ) + # Need to remove NAs because ARROW-12063 + int <- na.omit(int) + expect_equal( + Array$create(int)$SortIndices(descending = TRUE), + Array$create(rev(order(int)) - 1, type = uint64()) + ) +}) + +test_that("ChunkedArray$SortIndices()", { + int <- tbl$int + # Remove ties because they could give non-deterministic sort indices, and this + # test compares sort indices. Other tests compare sorted values, which are + # deterministic in the case of ties. + int <- int[!duplicated(int)] + expect_equal( + ChunkedArray$create(int[1:4], int[5:length(int)])$SortIndices(), + Array$create(order(int) - 1L, type = uint64()) + ) + # Need to remove NAs because ARROW-12063 + int <- na.omit(int) + expect_equal( + ChunkedArray$create(int[1:4], int[5:length(int)])$SortIndices(descending = TRUE), + Array$create(rev(order(int)) - 1, type = uint64()) + ) +}) + +test_that("sort(vector), sort(Array), sort(ChunkedArray) give equivalent results on integers", { + expect_vector_equal( + sort(input), + tbl$int + ) + expect_vector_equal( + sort(input, na.last = NA), + tbl$int + ) + expect_vector_equal( + sort(input, na.last = TRUE), + tbl$int + ) + expect_vector_equal( + sort(input, na.last = FALSE), + tbl$int + ) + expect_vector_equal( + sort(input, decreasing = TRUE), + tbl$int, + ) + expect_vector_equal( + sort(input, decreasing = TRUE, na.last = TRUE), + tbl$int, + ) + expect_vector_equal( + sort(input, decreasing = TRUE, na.last = FALSE), + tbl$int, + ) +}) + +test_that("sort(vector), sort(Array), sort(ChunkedArray) give equivalent results on strings", { + expect_vector_equal( + sort(input, decreasing = TRUE, na.last = FALSE), + tbl$chr + ) + expect_vector_equal( + sort(input, decreasing = TRUE, na.last = FALSE), + tbl$chr + ) +}) + +test_that("sort(vector), sort(Array), sort(ChunkedArray) give equivalent results on floats", { + expect_vector_equal( + sort(input, decreasing = TRUE, na.last = TRUE), + tbl$dbl + ) + expect_vector_equal( + sort(input, decreasing = FALSE, na.last = TRUE), + tbl$dbl + ) + skip("is.na() evaluates to FALSE on Arrow NaN values (ARROW-12055)") + expect_vector_equal( + sort(input, decreasing = TRUE, na.last = NA), + tbl$dbl + ) + expect_vector_equal( + sort(input, decreasing = TRUE, na.last = FALSE), + tbl$dbl, + ) + expect_vector_equal( + sort(input, decreasing = FALSE, na.last = NA), + tbl$dbl + ) + expect_vector_equal( + sort(input, decreasing = FALSE, na.last = FALSE), + tbl$dbl, + ) +}) + +test_that("Table$SortIndices()", { + expect_identical( + { + x <- tbl %>% Table$create() + x$Take(x$SortIndices("chr")) %>% pull(chr) + }, + sort(tbl$chr, na.last = TRUE) + ) + expect_identical( + { + x <- tbl %>% Table$create() + x$Take(x$SortIndices(c("int", "dbl"), c(FALSE, FALSE))) %>% collect() + }, + tbl %>% arrange(int, dbl) + ) +}) + +test_that("RecordBatch$SortIndices()", { + expect_identical( + { + x <- tbl %>% record_batch() + x$Take(x$SortIndices(c("chr", "int", "dbl"), TRUE)) %>% collect() + }, + tbl %>% arrange(desc(chr), desc(int), desc(dbl)) + ) +}) diff --git a/r/tests/testthat/test-dataset.R b/r/tests/testthat/test-dataset.R index e6db0bcea17..73057221a24 100644 --- a/r/tests/testthat/test-dataset.R +++ b/r/tests/testthat/test-dataset.R @@ -1030,6 +1030,42 @@ test_that("count()", { ) }) +test_that("arrange()", { + ds <- open_dataset(dataset_dir, partitioning = schema(part = uint8())) + arranged <- ds %>% + select(chr, dbl, int) %>% + filter(dbl * 2 > 14 & dbl - 50 < 3L) %>% + mutate(twice = int * 2) %>% + arrange(chr, desc(twice), dbl + int) + expect_output( + print(arranged), + "FileSystemDataset (query) +chr: string +dbl: double +int: int32 +twice: expr + +* Filter: ((multiply_checked(dbl, 2) > 14) and (subtract_checked(dbl, 50) < 3)) +* Sorted by chr [asc], multiply_checked(int, 2) [desc], add_checked(dbl, int) [asc] +See $.data for the source Arrow object", + fixed = TRUE + ) + expect_equivalent( + arranged %>% + collect(), + rbind( + df1[8, c("chr", "dbl", "int")], + df2[2, c("chr", "dbl", "int")], + df1[9, c("chr", "dbl", "int")], + df2[1, c("chr", "dbl", "int")], + df1[10, c("chr", "dbl", "int")] + ) %>% + mutate( + twice = int * 2 + ) + ) +}) + test_that("head/tail", { skip_if_not_available("parquet") ds <- open_dataset(dataset_dir) @@ -1117,7 +1153,6 @@ test_that("dplyr method not implemented messages", { expect_not_implemented <- function(x) { expect_error(x, "is not currently implemented for Arrow Datasets") } - expect_not_implemented(ds %>% arrange(int)) expect_not_implemented(ds %>% filter(int == 1) %>% summarize(n())) }) diff --git a/r/tests/testthat/test-dplyr-arrange.R b/r/tests/testthat/test-dplyr-arrange.R new file mode 100644 index 00000000000..b476c032945 --- /dev/null +++ b/r/tests/testthat/test-dplyr-arrange.R @@ -0,0 +1,211 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +library(dplyr) + +# randomize order of rows in test data +tbl <- slice_sample(example_data_for_sorting, prop = 1L) + +test_that("arrange() on integer, double, and character columns", { + expect_dplyr_equal( + input %>% + arrange(int, chr) %>% + collect(), + tbl + ) + expect_dplyr_equal( + input %>% + arrange(int, desc(dbl)) %>% + collect(), + tbl + ) + expect_dplyr_equal( + input %>% + arrange(int, desc(desc(dbl))) %>% + collect(), + tbl + ) + expect_dplyr_equal( + input %>% + arrange(int) %>% + arrange(desc(dbl)) %>% + collect(), + tbl + ) + expect_dplyr_equal( + input %>% + arrange(int + dbl, chr) %>% + collect(), + tbl + ) + expect_dplyr_equal( + input %>% + mutate(zzz = int + dbl,) %>% + arrange(zzz, chr) %>% + collect(), + tbl + ) + expect_dplyr_equal( + input %>% + mutate(zzz = int + dbl) %>% + arrange(int + dbl, chr) %>% + collect(), + tbl + ) + expect_dplyr_equal( + input %>% + mutate(int + dbl) %>% + arrange(int + dbl, chr) %>% + collect(), + tbl + ) + expect_dplyr_equal( + input %>% + group_by(grp) %>% + arrange(int, dbl) %>% + collect(), + tbl + ) + expect_dplyr_equal( + input %>% + group_by(grp) %>% + arrange(int, dbl, .by_group = TRUE) %>% + collect(), + tbl + ) + expect_dplyr_equal( + input %>% + group_by(grp, grp2) %>% + arrange(int, dbl, .by_group = TRUE) %>% + collect(), + tbl %>% + mutate(grp2 = ifelse(is.na(lgl), 1L, as.integer(lgl))) + ) + expect_dplyr_equal( + input %>% + group_by(grp) %>% + arrange(.by_group = TRUE) %>% + pull(grp), + tbl + ) + expect_dplyr_equal( + input %>% + arrange() %>% + collect(), + tbl %>% + group_by(grp) + ) + expect_dplyr_equal( + input %>% + group_by(grp) %>% + arrange() %>% + collect(), + tbl + ) + expect_dplyr_equal( + input %>% + arrange() %>% + collect(), + tbl + ) + test_sort_col <- "chr" + expect_dplyr_equal( + input %>% + arrange(!!sym(test_sort_col)) %>% + collect(), + tbl %>% + select(chr, lgl) + ) + test_sort_cols <- c("int", "dbl") + expect_dplyr_equal( + input %>% + arrange(!!!syms(test_sort_cols)) %>% + collect(), + tbl + ) + expect_warning( + expect_equal( + tbl %>% + Table$create() %>% + arrange(abs(int), dbl) %>% + collect(), + tbl %>% + arrange(abs(int), dbl) %>% + collect() + ), + "not supported in Arrow", + fixed = TRUE + ) +}) + +test_that("arrange() on datetime columns", { + expect_dplyr_equal( + input %>% + arrange(dttm, int) %>% + collect(), + tbl + ) + skip("Sorting by only a single timestamp column fails (ARROW-12087)") + expect_dplyr_equal( + input %>% + arrange(dttm) %>% + collect(), + tbl %>% + select(dttm, grp) + ) +}) + +test_that("arrange() on logical columns", { + skip("Sorting by bool columns is not supported (ARROW-12016)") + expect_dplyr_equal( + input %>% + arrange(lgl, int) %>% + collect(), + tbl + ) +}) + +test_that("arrange() with bad inputs", { + expect_error( + tbl %>% + Table$create() %>% + arrange(1), + "does not contain any field names", + fixed = TRUE + ) + expect_error( + tbl %>% + Table$create() %>% + arrange(2 + 2), + "does not contain any field names", + fixed = TRUE + ) + expect_error( + tbl %>% + Table$create() %>% + arrange(aertidjfgjksertyj), + "not found", + fixed = TRUE + ) + expect_error( + tbl %>% + Table$create() %>% + arrange(desc(aertidjfgjksertyj + iaermxiwerksxsdqq)), + "not found", + fixed = TRUE + ) +}) diff --git a/r/tests/testthat/test-dplyr.R b/r/tests/testthat/test-dplyr.R index 460d4bbdba5..c1d9e457464 100644 --- a/r/tests/testthat/test-dplyr.R +++ b/r/tests/testthat/test-dplyr.R @@ -134,17 +134,6 @@ test_that("Empty select still includes the group_by columns", { ) }) -test_that("arrange", { - expect_dplyr_equal( - input %>% - group_by(chr) %>% - select(int, chr) %>% - arrange(desc(int)) %>% - collect(), - tbl - ) -}) - test_that("select/rename", { expect_dplyr_equal( input %>% diff --git a/r/tests/testthat/test-scalar.R b/r/tests/testthat/test-scalar.R index ce0f53d2be3..e9ef893bbd9 100644 --- a/r/tests/testthat/test-scalar.R +++ b/r/tests/testthat/test-scalar.R @@ -53,4 +53,26 @@ test_that("Scalar to Array", { a <- Scalar$create(42) expect_equal(a$as_array(), Array$create(42)) expect_equal(Array$create(a), Array$create(42)) -}) \ No newline at end of file +}) + +test_that("Scalar$Equals", { + a <- Scalar$create(42) + aa <- Array$create(42) + b <- Scalar$create(42) + d <- Scalar$create(43) + expect_equal(a, b) + expect_true(a$Equals(b)) + expect_false(a$Equals(d)) + expect_false(a$Equals(aa)) +}) + +test_that("Scalar$ApproxEquals", { + a <- Scalar$create(1.0000000000001) + aa <- Array$create(1.0000000000001) + b <- Scalar$create(1.0) + d <- 2.400000000000001 + expect_false(a$Equals(b)) + expect_true(a$ApproxEquals(b)) + expect_false(a$ApproxEquals(d)) + expect_false(a$ApproxEquals(aa)) +})