Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
5a9cb90
Implement SortIndices method for Table, RecordBatch, Array, ChunkedArray
ianmcook Mar 18, 2021
d47cfd0
Fix SortIndices for ChunkedArray
ianmcook Mar 18, 2021
67d2199
Implement sort() method for ArrowDatum types
ianmcook Mar 18, 2021
deef9a3
Disallow unsupported na.last options in Array and ChunkedArray sort()
ianmcook Mar 19, 2021
e269669
Add test data for sort tests
ianmcook Mar 19, 2021
7001927
Add compute sorting tests
ianmcook Mar 19, 2021
c5bf536
Fix unsigned/signed compare error
ianmcook Mar 19, 2021
3fa4125
Implement dplyr::arrange() for ArrowTabular
ianmcook Mar 19, 2021
5c8b70e
Add arrange() tests
ianmcook Mar 19, 2021
ecd5bf7
Define sort method for ArrowDatum instead of Array and ChunkedArray
ianmcook Mar 22, 2021
aad31cb
Improve tests
ianmcook Mar 22, 2021
3fe0c51
Rename get_field_name -> get_field_name_of_array_ref
ianmcook Mar 22, 2021
1bc1806
Support na.last in sort.ArrowDatum
ianmcook Mar 23, 2021
154cac6
Fix failing tests
ianmcook Mar 23, 2021
4e88023
Fix failing tests
ianmcook Mar 23, 2021
eeb1db0
Fix failing tests
ianmcook Mar 23, 2021
82dc4ff
Support expressions in arrange()
ianmcook Mar 23, 2021
27025e0
Do arrange() locally when unsupported expression
ianmcook Mar 23, 2021
791064e
Update tests
ianmcook Mar 23, 2021
9593e89
Add TODO
ianmcook Mar 23, 2021
6604144
Support arrange(.by_group = TRUE)
ianmcook Mar 23, 2021
50c6992
Fix and add tests
ianmcook Mar 23, 2021
9e87b26
Handle case when empty dots and .group_by = TRUE
ianmcook Mar 23, 2021
d7dba8e
More tests for edge cases
ianmcook Mar 23, 2021
07324dc
Fix and add tests
ianmcook Mar 24, 2021
d84a65a
Add comments
ianmcook Mar 24, 2021
6f8e8bf
Remove usage of rlang::as_name
ianmcook Mar 24, 2021
948fa5a
Add comment
ianmcook Mar 24, 2021
f94d630
Remove unnecessary code
ianmcook Mar 24, 2021
cb67505
Add test of arrange(!!!syms())
ianmcook Mar 24, 2021
b21fcf9
Remove dup dataset check
ianmcook Mar 24, 2021
d2ab74b
Add comments and rename vars for clarity in sort.ArrowDatum
ianmcook Mar 24, 2021
33bbc0d
Implement arrange() for Datasets
ianmcook Mar 24, 2021
0646aff
Add tests for arrange() on Dataset
ianmcook Mar 24, 2021
95b310f
Improve skip message
ianmcook Mar 24, 2021
deba988
Fix test failures
ianmcook Mar 24, 2021
38a8aec
Add expect_vector_equal() helper
ianmcook Mar 24, 2021
f98b92e
Clean up tests
ianmcook Mar 24, 2021
470904f
Fix failing tests
ianmcook Mar 24, 2021
6b4a586
Trigger CI
ianmcook Mar 24, 2021
e44b7a1
Implement Scalar$Equals(), $ApproxEquals() and simplify sort(Scalar) …
ianmcook Mar 25, 2021
2d0777e
Add some non-skipped float sort tests
ianmcook Mar 25, 2021
20c9233
Lint
ianmcook Mar 25, 2021
d9a5e09
Improve tests
ianmcook Mar 25, 2021
da1b746
Improve comments
ianmcook Mar 25, 2021
49aaf1e
Add and reorganize tests
ianmcook Mar 25, 2021
a4e69c2
More bad input tests
ianmcook Mar 25, 2021
dd2e0ef
Improve tests
ianmcook Mar 26, 2021
434dcb5
Tests fixes
ianmcook Mar 26, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions r/NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ S3method(read_message,InputStream)
S3method(read_message,MessageReader)
S3method(read_message,default)
S3method(row.names,ArrowTabular)
S3method(sort,ArrowDatum)
S3method(sort,Scalar)
S3method(sum,ArrowDatum)
S3method(tail,ArrowDatum)
S3method(tail,ArrowTabular)
Expand Down Expand Up @@ -291,7 +293,9 @@ importFrom(rlang,is_integerish)
importFrom(rlang,list2)
importFrom(rlang,new_data_mask)
importFrom(rlang,new_environment)
importFrom(rlang,quo_get_expr)
importFrom(rlang,quo_is_null)
importFrom(rlang,quo_set_expr)
importFrom(rlang,quos)
importFrom(rlang,set_names)
importFrom(rlang,syms)
Expand Down
8 changes: 8 additions & 0 deletions r/R/array.R
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@
#' (R vector or Array Array) `i`.
#' - `$Filter(i, keep_na = TRUE)`: return an `Array` with values at positions where logical
#' vector (or Arrow boolean Array) `i` is `TRUE`.
#' - `$SortIndices(descending = FALSE)`: return an `Array` of integer positions that can be
#' used to rearrange the `Array` in ascending or descending order
#' - `$RangeEquals(other, start_idx, end_idx, other_start_idx)` :
#' - `$cast(target_type, safe = TRUE, options = cast_options(safe))`: Alter the
#' data in the array to change its type.
Expand Down Expand Up @@ -131,6 +133,12 @@ Array <- R6Class("Array",
assert_is(i, "Array")
call_function("filter", self, i, options = list(keep_na = keep_na))
},
SortIndices = function(descending = FALSE) {
assert_that(is.logical(descending))
assert_that(length(descending) == 1L)
assert_that(!is.na(descending))
call_function("array_sort_indices", self, options = list(order = descending))
},
RangeEquals = function(other, start_idx, end_idx, other_start_idx = 0L) {
assert_is(other, "Array")
Array__RangeEquals(self, other, start_idx, end_idx, other_start_idx)
Expand Down
20 changes: 20 additions & 0 deletions r/R/arrow-datum.R
Original file line number Diff line number Diff line change
Expand Up @@ -138,3 +138,23 @@ as.integer.ArrowDatum <- function(x, ...) as.integer(as.vector(x), ...)

#' @export
as.character.ArrowDatum <- function(x, ...) as.character(as.vector(x), ...)

#' @export
sort.ArrowDatum <- function(x, decreasing = FALSE, na.last = NA, ...) {
# Arrow always sorts nulls at the end of the array. This corresponds to
# sort(na.last = TRUE). For the other two cases (na.last = NA and
# na.last = FALSE) we need to use workarounds.
# TODO: Implement this more cleanly after ARROW-12063
if (is.na(na.last)) {
# Filter out NAs before sorting
x <- x$Filter(!is.na(x))
x$Take(x$SortIndices(descending = decreasing))
} else if (na.last) {
x$Take(x$SortIndices(descending = decreasing))
} else {
# Create a new array that encodes missing values as 1 and non-missing values
# as 0. Sort descending by that array first to get the NAs at the beginning
tbl <- Table$create(x = x, `is_na` = as.integer(is.na(x)))
tbl$x$Take(tbl$SortIndices(names = c("is_na", "x"), descending = c(TRUE, decreasing)))
}
}
2 changes: 1 addition & 1 deletion r/R/arrow-package.R
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
#' @importFrom R6 R6Class
#' @importFrom purrr as_mapper map map2 map_chr map_dfr map_int map_lgl keep
#' @importFrom assertthat assert_that is.string
#' @importFrom rlang list2 %||% is_false abort dots_n warn enquo quo_is_null enquos is_integerish quos eval_tidy new_data_mask syms env new_environment env_bind as_label set_names exec is_bare_character
#' @importFrom rlang list2 %||% is_false abort dots_n warn enquo quo_is_null enquos is_integerish quos eval_tidy new_data_mask syms env new_environment env_bind as_label set_names exec is_bare_character quo_get_expr quo_set_expr
#' @importFrom tidyselect vars_select
#' @useDynLib arrow, .registration = TRUE
#' @keywords internal
Expand Down
17 changes: 17 additions & 0 deletions r/R/arrow-tabular.R
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,23 @@ ArrowTabular <- R6Class("ArrowTabular", inherit = ArrowObject,
}
assert_that(is.Array(i, "bool"))
call_function("filter", self, i, options = list(keep_na = keep_na))
},
SortIndices = function(names, descending = FALSE) {
assert_that(is.character(names))
assert_that(length(names) > 0)
assert_that(!any(is.na(names)))
if (length(descending) == 1L) {
descending <- rep_len(descending, length(names))
}
assert_that(is.logical(descending))
assert_that(identical(length(names), length(descending)))
assert_that(!any(is.na(descending)))
call_function(
"sort_indices",
self,
# cpp11 does not support logical vectors so convert to integer
options = list(names = names, orders = as.integer(descending))
)
}
)
)
Expand Down
8 changes: 8 additions & 0 deletions r/R/arrowExports.R

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

14 changes: 14 additions & 0 deletions r/R/chunked-array.R
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@
#' coerced to an R vector before taking.
#' - `$Filter(i, keep_na = TRUE)`: return a `ChunkedArray` with values at positions where
#' logical vector or Arrow boolean-type `(Chunked)Array` `i` is `TRUE`.
#' - `$SortIndices(descending = FALSE)`: return an `Array` of integer positions that can be
#' used to rearrange the `ChunkedArray` in ascending or descending order
#' - `$cast(target_type, safe = TRUE, options = cast_options(safe))`: Alter the
#' data in the array to change its type.
#' - `$null_count()`: The number of null entries in the array
Expand Down Expand Up @@ -83,6 +85,18 @@ ChunkedArray <- R6Class("ChunkedArray", inherit = ArrowDatum,
}
call_function("filter", self, i, options = list(keep_na = keep_na))
},
SortIndices = function(descending = FALSE) {
assert_that(is.logical(descending))
assert_that(length(descending) == 1L)
assert_that(!is.na(descending))
# TODO: after ARROW-12042 is closed, review whether this and the
# Array$SortIndices definition can be consolidated
call_function(
"sort_indices",
self,
options = list(names = "", orders = as.integer(descending))
)
},
View = function(type) {
ChunkedArray__View(self, as_type(type))
},
Expand Down
5 changes: 3 additions & 2 deletions r/R/dataset-scan.R
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ Scanner$create <- function(dataset,
}
return(Scanner$create(
dataset$.data,
dataset$selected_columns,
c(dataset$selected_columns, dataset$temp_columns),
dataset$filtered_rows,
use_threads,
batch_size,
Expand Down Expand Up @@ -148,7 +148,8 @@ map_batches <- function(X, FUN, ..., .data.frame = TRUE) {
lapply(scan_task$Execute(), function(batch) {
# message("Processing Batch")
# This inner lapply cannot be parallelized
# TODO: wrap batch in arrow_dplyr_query with X$selected_columns and X$group_by_vars
# TODO: wrap batch in arrow_dplyr_query with X$selected_columns,
# X$temp_columns, and X$group_by_vars
# if X is arrow_dplyr_query, if some other arg (.dplyr?) == TRUE
FUN(batch, ...)
})
Expand Down
135 changes: 125 additions & 10 deletions r/R/dplyr.R
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,13 @@ arrow_dplyr_query <- function(.data) {
# drop_empty_groups is a logical value indicating whether to drop
# groups formed by factor levels that don't appear in the data. It
# should be non-null only when the data is grouped.
drop_empty_groups = NULL
drop_empty_groups = NULL,
# arrange_vars will be a list of expressions named by their associated
# column names
arrange_vars = list(),
# arrange_desc will be a logical vector indicating the sort order for each
# expression in arrange_vars (FALSE for ascending, TRUE for descending)
arrange_desc = logical()
),
class = "arrow_dplyr_query"
)
Expand Down Expand Up @@ -80,6 +86,25 @@ print.arrow_dplyr_query <- function(x, ...) {
if (length(x$group_by_vars)) {
cat("* Grouped by ", paste(x$group_by_vars, collapse = ", "), "\n", sep = "")
}
if (length(x$arrange_vars)) {
if (query_on_dataset(x)) {
arrange_strings <- map_chr(x$arrange_vars, function(x) x$ToString())
} else {
arrange_strings <- map_chr(x$arrange_vars, .format_array_expression)
}
cat(
"* Sorted by ",
paste(
paste0(
arrange_strings,
" [", ifelse(x$arrange_desc, "desc", "asc"), "]"
),
collapse = ", "
),
"\n",
sep = ""
)
}
cat("See $.data for the source Arrow object\n")
invisible(x)
}
Expand Down Expand Up @@ -378,6 +403,7 @@ set_filters <- function(.data, expressions) {

collect.arrow_dplyr_query <- function(x, as_data_frame = TRUE, ...) {
x <- ensure_group_vars(x)
x <- ensure_arrange_vars(x) # this sets x$temp_columns
# Pull only the selected rows and cols into R
if (query_on_dataset(x)) {
# See dataset.R for Dataset and Scanner(Builder) classes
Expand All @@ -391,10 +417,14 @@ collect.arrow_dplyr_query <- function(x, as_data_frame = TRUE, ...) {
} else {
filter <- eval_array_expression(x$filtered_rows, x$.data)
}
# TODO: shortcut if identical(names(x$.data), find_array_refs(x$selected_columns))?
tab <- x$.data[filter, find_array_refs(x$selected_columns), keep_na = FALSE]
# TODO: shortcut if identical(names(x$.data), find_array_refs(c(x$selected_columns, x$temp_columns)))?
tab <- x$.data[
filter,
find_array_refs(c(x$selected_columns, x$temp_columns)),
keep_na = FALSE
]
# Now evaluate those expressions on the filtered table
cols <- lapply(x$selected_columns, eval_array_expression, data = tab)
cols <- lapply(c(x$selected_columns, x$temp_columns), eval_array_expression, data = tab)
if (length(cols) == 0) {
tab <- tab[, integer(0)]
} else {
Expand All @@ -405,6 +435,14 @@ collect.arrow_dplyr_query <- function(x, as_data_frame = TRUE, ...) {
}
}
}
# Arrange rows
if (length(x$arrange_vars) > 0) {
tab <- tab[
tab$SortIndices(names(x$arrange_vars), x$arrange_desc),
names(x$selected_columns), # this omits x$temp_columns from the result
drop = FALSE
]
}
if (as_data_frame) {
df <- as.data.frame(tab)
tab$invalidate()
Expand Down Expand Up @@ -432,6 +470,20 @@ ensure_group_vars <- function(x) {
x
}

ensure_arrange_vars <- function(x) {
# The arrange() operation is not performed until later, because:
# - It must be performed after mutate(), to enable sorting by new columns.
# - It should be performed after filter() and select(), for efficiency.
# However, we need users to be able to arrange() by columns and expressions
# that are *not* returned in the query result. To enable this, we must
# *temporarily* include these columns and expressions in the projection. We
# use x$temp_columns to store these. Later, after the arrange() operation has
# been performed, these are omitted from the result. This differs from the
# columns in x$group_by_vars which *are* returned in the result.
x$temp_columns <- x$arrange_vars[!names(x$arrange_vars) %in% names(x$selected_columns)]
x
}

restore_dplyr_features <- function(df, query) {
# An arrow_dplyr_query holds some attributes that Arrow doesn't know about
# After calling collect(), make sure these features are carried over
Expand Down Expand Up @@ -689,17 +741,80 @@ abandon_ship <- function(call, .data, msg = NULL) {
eval.parent(call, 2)
}

arrange.arrow_dplyr_query <- function(.data, ...) {
arrange.arrow_dplyr_query <- function(.data, ..., .by_group = FALSE) {
call <- match.call()
exprs <- quos(...)
if (.by_group) {
# when the data is is grouped and .by_group is TRUE, order the result by
# the grouping columns first
exprs <- c(quos(!!!dplyr::groups(.data)), exprs)
}
if (length(exprs) == 0) {
# Nothing to do
return(.data)
}
.data <- arrow_dplyr_query(.data)
if (query_on_dataset(.data)) {
not_implemented_for_dataset("arrange()")
# find and remove any dplyr::desc() and tidy-eval
# the arrange expressions inside an Arrow data_mask
sorts <- vector("list", length(exprs))
descs <- logical(0)
mask <- arrow_mask(.data)
for (i in seq_along(exprs)) {
x <- find_and_remove_desc(exprs[[i]])
exprs[[i]] <- x[["quos"]]
sorts[[i]] <- arrow_eval(exprs[[i]], mask)
if (inherits(sorts[[i]], "try-error")) {
msg <- paste('Expression', as_label(exprs[[i]]), 'not supported in Arrow')
return(abandon_ship(call, .data, msg))
}
names(sorts)[i] <- as_label(exprs[[i]])
descs[i] <- x[["desc"]]
}
# TODO(ARROW-11703) move this to Arrow
call <- match.call()
abandon_ship(call, .data)
.data$arrange_vars <- c(sorts, .data$arrange_vars)
.data$arrange_desc <- c(descs, .data$arrange_desc)
.data
}
arrange.Dataset <- arrange.ArrowTabular <- arrange.arrow_dplyr_query

# Helper to handle desc() in arrange()
# * Takes a quosure as input
# * Returns a list with two elements:
# 1. The quosure with any wrapping parentheses and desc() removed
# 2. A logical value indicating whether desc() was found
# * Performs some other validation
find_and_remove_desc <- function(quosure) {
expr <- quo_get_expr(quosure)
descending <- FALSE
if (length(all.vars(expr)) < 1L) {
stop(
"Expression in arrange() does not contain any field names: ",
deparse(expr),
call. = FALSE
)
}
# Use a while loop to remove any number of nested pairs of enclosing
# parentheses and any number of nested desc() calls. In the case of multiple
# nested desc() calls, each one toggles the sort order.
while (identical(typeof(expr), "language") && is.call(expr)) {
if (identical(expr[[1]], quote(`(`))) {
# remove enclosing parentheses
expr <- expr[[2]]
} else if (identical(expr[[1]], quote(desc))) {
# remove desc() and toggle descending
expr <- expr[[2]]
descending <- !descending
} else {
break
}
}
return(
list(
quos = quo_set_expr(quosure, expr),
desc = descending
)
)
}

query_on_dataset <- function(x) inherits(x$.data, "Dataset")

not_implemented_for_dataset <- function(method) {
Expand Down
7 changes: 6 additions & 1 deletion r/R/record-batch.R
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@
#' integers (R vector or Array Array) `i`.
#' - `$Filter(i, keep_na = TRUE)`: return an `RecordBatch` with rows at positions where logical
#' vector (or Arrow boolean Array) `i` is `TRUE`.
#' - `$SortIndices(names, descending = FALSE)`: return an `Array` of integer row
#' positions that can be used to rearrange the `RecordBatch` in ascending or
#' descending order by the first named column, breaking ties with further named
#' columns. `descending` can be a logical vector of length one or of the same
#' length as `names`.
#' - `$serialize()`: Returns a raw vector suitable for interprocess communication
#' - `$cast(target_schema, safe = TRUE, options = cast_options(safe))`: Alter
#' the schema of the record batch.
Expand Down Expand Up @@ -99,7 +104,7 @@ RecordBatch <- R6Class("RecordBatch", inherit = ArrowTabular,
RecordBatch__Slice2(self, offset, length)
}
},
# Take and Filter are methods on ArrowTabular
# Take, Filter, and SortIndices are methods on ArrowTabular
serialize = function() ipc___SerializeRecordBatch__Raw(self),
to_data_frame = function() {
RecordBatch__to_dataframe(self, use_threads = option_use_threads())
Expand Down
11 changes: 10 additions & 1 deletion r/R/scalar.R
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,13 @@ Scalar <- R6Class("Scalar",
public = list(
ToString = function() Scalar__ToString(self),
as_vector = function() Scalar__as_vector(self),
as_array = function() MakeArrayFromScalar(self)
as_array = function() MakeArrayFromScalar(self),
Equals = function(other, ...) {
inherits(other, "Scalar") && Scalar__Equals(self, other)
},
ApproxEquals = function(other, ...) {
inherits(other, "Scalar") && Scalar__ApproxEquals(self, other)
}
),
active = list(
is_valid = function() Scalar__is_valid(self),
Expand Down Expand Up @@ -68,3 +74,6 @@ length.Scalar <- function(x) 1L

#' @export
is.na.Scalar <- function(x) !x$is_valid

#' @export
sort.Scalar <- function(x, decreasing = FALSE, ...) x
Loading