Skip to content
24 changes: 22 additions & 2 deletions r/R/dplyr-functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -839,22 +839,42 @@ agg_funcs$var <- function(x, na.rm = FALSE, ddof = 1) {
options = list(skip_nulls = na.rm, min_count = 0L, ddof = ddof)
)
}

agg_funcs$n_distinct <- function(x, na.rm = FALSE) {
list(
fun = "count_distinct",
data = x,
options = list(na.rm = na.rm)
)
}

agg_funcs$n <- function() {
list(
fun = "sum",
data = Expression$scalar(1L),
options = list()
)
}
agg_funcs$min <- function(..., na.rm = FALSE) {
args <- list2(...)
if (length(args) > 1) {
arrow_not_supported("Multiple arguments to min()")
}
list(
fun = "min",
data = args[[1]],
options = list(skip_nulls = na.rm, min_count = 0L)
)
}
agg_funcs$max <- function(..., na.rm = FALSE) {
args <- list2(...)
if (length(args) > 1) {
arrow_not_supported("Multiple arguments to max()")
}
list(
fun = "max",
data = args[[1]],
options = list(skip_nulls = na.rm, min_count = 0L)
)
}

output_type <- function(fun, input_type) {
# These are quick and dirty heuristics.
Expand Down
9 changes: 5 additions & 4 deletions r/src/compute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -171,10 +171,11 @@ std::shared_ptr<arrow::compute::FunctionOptions> make_compute_options(
return out;
}

if (func_name == "min_max" || func_name == "sum" || func_name == "mean" ||
func_name == "any" || func_name == "all" || func_name == "hash_min_max" ||
func_name == "hash_sum" || func_name == "hash_mean" || func_name == "hash_any" ||
func_name == "hash_all") {
if (func_name == "all" || func_name == "hash_all" || func_name == "any" ||
func_name == "hash_any" || func_name == "mean" || func_name == "hash_mean" ||
func_name == "min_max" || func_name == "hash_min_max" || func_name == "min" ||
func_name == "hash_min" || func_name == "max" || func_name == "hash_max" ||
func_name == "sum" || func_name == "hash_sum") {
using Options = arrow::compute::ScalarAggregateOptions;
auto out = std::make_shared<Options>(Options::Defaults());
if (!Rf_isNull(options["min_count"])) {
Expand Down
16 changes: 7 additions & 9 deletions r/tests/testthat/test-dataset.R
Original file line number Diff line number Diff line change
Expand Up @@ -621,17 +621,15 @@ test_that("Creating UnionDataset", {
})

test_that("map_batches", {
skip("map_batches() is broken (ARROW-14029)")
skip_if_not_available("parquet")
ds <- open_dataset(dataset_dir, partitioning = "part")
expect_warning(
expect_equivalent(
ds %>%
filter(int > 5) %>%
select(int, lgl) %>%
map_batches(~ summarize(., min_int = min(int))),
tibble(min_int = c(6L, 101L))
),
"pulling data into R" # ARROW-13502
expect_equivalent(
ds %>%
filter(int > 5) %>%
select(int, lgl) %>%
map_batches(~ summarize(., min_int = min(int))),
tibble(min_int = c(6L, 101L))
)
})

Expand Down
25 changes: 19 additions & 6 deletions r/tests/testthat/test-dplyr-group-by.R
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,8 @@ test_that("group_by groupings are recorded", {
group_by(chr) %>%
select(int, chr) %>%
filter(int > 5) %>%
summarize(min_int = min(int)),
tbl,
warning = TRUE
collect(),
tbl
)
})

Expand Down Expand Up @@ -62,9 +61,23 @@ test_that("ungroup", {
select(int, chr) %>%
ungroup() %>%
filter(int > 5) %>%
summarize(min_int = min(int)),
tbl,
warning = TRUE
collect(),
tbl
)

# to confirm that the above expectation is actually testing what we think it's
# testing, verify that expect_dplyr_equal() distinguishes between grouped and
# ungrouped tibbles
expect_error(
expect_dplyr_equal(
input %>%
group_by(chr) %>%
select(int, chr) %>%
(function(x) if (inherits(x, "tbl_df")) ungroup(x) else x) %>%
filter(int > 5) %>%
collect(),
tbl
)
)
})

Expand Down
147 changes: 125 additions & 22 deletions r/tests/testthat/test-dplyr-summarize.R
Original file line number Diff line number Diff line change
Expand Up @@ -30,28 +30,6 @@ tbl$verses <- verses[[1]]
tbl$padded_strings <- stringr::str_pad(letters[1:10], width = 2 * (1:10) + 1, side = "both")
tbl$some_grouping <- rep(c(1, 2), 5)

test_that("summarize", {
expect_dplyr_equal(
input %>%
select(int, chr) %>%
filter(int > 5) %>%
summarize(min_int = min(int)) %>%
collect(),
tbl,
warning = TRUE
)

expect_dplyr_equal(
input %>%
select(int, chr) %>%
filter(int > 5) %>%
summarize(min_int = min(int) / 2) %>%
collect(),
tbl,
warning = TRUE
)
})

test_that("summarize() doesn't evaluate eagerly", {
expect_s3_class(
Table$create(tbl) %>%
Expand Down Expand Up @@ -251,6 +229,131 @@ test_that("Group by n_distinct() on dataset", {
)
})

test_that("summarize() with min() and max()", {
expect_dplyr_equal(
input %>%
select(int, chr) %>%
filter(int > 5) %>% # this filters out the NAs in `int`
summarize(min_int = min(int), max_int = max(int)) %>%
collect(),
tbl,
)
expect_dplyr_equal(
input %>%
select(int, chr) %>%
filter(int > 5) %>% # this filters out the NAs in `int`
summarize(
min_int = min(int + 4) / 2,
max_int = 3 / max(42 - int)
) %>%
collect(),
tbl,
)
expect_dplyr_equal(
input %>%
select(int, chr) %>%
summarize(min_int = min(int), max_int = max(int)) %>%
collect(),
tbl,
)
expect_dplyr_equal(
input %>%
select(int) %>%
summarize(
min_int = min(int, na.rm = TRUE),
max_int = max(int, na.rm = TRUE)
) %>%
collect(),
tbl,
)
expect_dplyr_equal(
input %>%
select(dbl, int) %>%
summarize(
min_int = -min(log(ceiling(dbl)), na.rm = TRUE),
max_int = log(max(as.double(int), na.rm = TRUE))
) %>%
collect(),
tbl,
)

# multiple dots arguments to min(), max() not supported
expect_dplyr_equal(
input %>%
summarize(min_mult = min(dbl, int)) %>%
collect(),
tbl,
warning = "Multiple arguments to min\\(\\) not supported by Arrow"
)
expect_dplyr_equal(
input %>%
select(int, dbl, dbl2) %>%
summarize(max_mult = max(int, dbl, dbl2)) %>%
collect(),
tbl,
warning = "Multiple arguments to max\\(\\) not supported by Arrow"
)

# min(logical) or max(logical) yields integer in R
# min(Boolean) or max(Boolean) yields Boolean in Arrow
expect_dplyr_equal(
input %>%
select(lgl) %>%
summarize(
max_lgl = as.logical(max(lgl, na.rm = TRUE)),
min_lgl = as.logical(min(lgl, na.rm = TRUE))
) %>%
collect(),
tbl,
)
})

test_that("min() and max() on character strings", {
expect_dplyr_equal(
input %>%
summarize(
min_chr = min(chr, na.rm = TRUE),
max_chr = max(chr, na.rm = TRUE)
) %>%
collect(),
tbl,
)
skip("Strings not supported by hash_min_max (ARROW-13988)")
expect_dplyr_equal(
input %>%
group_by(fct) %>%
summarize(
min_chr = min(chr, na.rm = TRUE),
max_chr = max(chr, na.rm = TRUE)
) %>%
collect(),
tbl,
)
})

test_that("summarise() with !!sym()", {
test_chr_col <- "int"
test_dbl_col <- "dbl"
test_lgl_col <- "lgl"
expect_dplyr_equal(
input %>%
group_by(false) %>%
summarise(
sum = sum(!!sym(test_dbl_col)),
any = any(!!sym(test_lgl_col)),
all = all(!!sym(test_lgl_col)),
mean = mean(!!sym(test_dbl_col)),
sd = sd(!!sym(test_dbl_col)),
var = var(!!sym(test_dbl_col)),
n_distinct = n_distinct(!!sym(test_chr_col)),
min = min(!!sym(test_dbl_col)),
max = max(!!sym(test_dbl_col))
) %>%
collect(),
tbl
)
})

test_that("Filter and aggregate", {
expect_dplyr_equal(
input %>%
Expand Down