Skip to content
24 changes: 20 additions & 4 deletions r/R/dplyr-functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,22 @@ nse_funcs$str_split <- function(string, pattern, n = Inf, simplify = FALSE) {
)
}

nse_funcs$pmin <- function(..., na.rm = FALSE) {
build_expr(
"min_element_wise",
...,
options = list(skip_nulls = na.rm)
)
}

nse_funcs$pmax <- function(..., na.rm = FALSE) {
build_expr(
"max_element_wise",
...,
options = list(skip_nulls = na.rm)
)
}

# String function helpers

# format `pattern` as needed for case insensitivity and literal matching by RE2
Expand Down Expand Up @@ -511,24 +527,24 @@ nse_funcs$second <- function(x) {
}

# After ARROW-13054 is completed, we can refactor this for simplicity
#
#
# Arrow's `day_of_week` kernel counts from 0 (Monday) to 6 (Sunday), whereas
# `lubridate::wday` counts from 1 to 7, and allows users to specify which day
# of the week is first (Sunday by default). This Expression converts the returned
# day of the week back to the value that would be returned by lubridate by
# providing offset values based on the specified week_start day, and adding 1
# so the returned value is 1-indexed instead of 0-indexed.
nse_funcs$wday <- function(x, label = FALSE, abbr = TRUE, week_start = getOption("lubridate.week.start", 7)) {

# The "day_of_week" compute function returns numeric days of week and not locale-aware strftime
# When the ticket below is resolved, we should be able to support the label argument
# https://issues.apache.org/jira/browse/ARROW-13133
if (label) {
arrow_not_supported("Label argument")
}

# overall formula to convert from arrow::wday to lubridate::wday is:
# ((wday(day) - start + 8) %% 7) + 1
((Expression$create("day_of_week", x) - Expression$scalar(week_start) + 8) %% 7) + 1

}
9 changes: 9 additions & 0 deletions r/src/compute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,15 @@ std::shared_ptr<arrow::compute::FunctionOptions> make_compute_options(
return out;
}

if (func_name == "min_element_wise" || func_name == "max_element_wise") {
using Options = arrow::compute::ElementWiseAggregateOptions;
bool skip_nulls = true;
if (!Rf_isNull(options["skip_nulls"])) {
skip_nulls = cpp11::as_cpp<bool>(options["skip_nulls"]);
}
return std::make_shared<Options>(skip_nulls);
}

if (func_name == "quantile") {
using Options = arrow::compute::QuantileOptions;
auto out = std::make_shared<Options>(Options::Defaults());
Expand Down
31 changes: 31 additions & 0 deletions r/tests/testthat/test-dplyr-mutate.R
Original file line number Diff line number Diff line change
Expand Up @@ -418,3 +418,34 @@ test_that("mutate and write_dataset", {
summarize(mean = mean(integer))
)
})

test_that("mutate and pmin/pmax", {
df <- tibble(
city = c("Chillan", "Valdivia", "Osorno"),
val1 = c(200, 300, NA),
val2 = c(100, NA, NA),
val3 = c(0, NA, NA)
)

expect_dplyr_equal(
input %>%
mutate(
max_val_1 = pmax(val1, val2, val3),
max_val_2 = pmax(val1, val2, val3, na.rm = T),
min_val_1 = pmin(val1, val2, val3),
min_val_2 = pmin(val1, val2, val3, na.rm = T)
) %>%
collect(),
df
)

expect_dplyr_equal(
input %>%
mutate(
max_val_1 = pmax(val1 - 100, 200, val1 * 100, na.rm = T),
min_val_1 = pmin(val1 - 100, 100, val1 * 100, na.rm = T),
) %>%
collect(),
df
)
})