Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion r/NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

## dplyr methods

* `dplyr::mutate()` on Arrow `Table` and `RecordBatch` is now supported in Arrow for many applications. Where not yet supported, the implementation falls back to pulling data into an R `data.frame` first.
* `dplyr::mutate()` is now supported in Arrow for many applications. For queries on `Table` and `RecordBatch` that are not yet supported in Arrow, the implementation falls back to pulling data into an R `data.frame` first, as in the previous release. For queries on `Dataset`, it raises an error if the feature is not implemented.
* String functions `nchar()`, `tolower()`, and `toupper()`, along with their `stringr` spellings `str_length()`, `str_to_lower()`, and `str_to_upper()`, are supported in Arrow `dplyr` calls. `str_trim()` is also supported.

## Other improvements
Expand Down
8 changes: 6 additions & 2 deletions r/R/arrowExports.R

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

18 changes: 12 additions & 6 deletions r/R/dataset-scan.R
Original file line number Diff line number Diff line change
Expand Up @@ -157,13 +157,19 @@ ScannerBuilder <- R6Class("ScannerBuilder", inherit = ArrowObject,
public = list(
Project = function(cols) {
# cols is either a character vector or a named list of Expressions
if (!is.character(cols)) {
# We don't yet support mutate() on datasets, so this is just a list
# of FieldRefs, and we need to back out the field names
cols <- get_field_names(cols)
if (is.character(cols)) {
dataset___ScannerBuilder__ProjectNames(self, cols)
} else {
# If we have expressions, but they all turn out to be field_refs,
# we can still call the simple method
field_names <- get_field_names(cols)
if (all(nzchar(field_names))) {
dataset___ScannerBuilder__ProjectNames(self, field_names)
} else {
# Else, we are projecting/mutating
dataset___ScannerBuilder__ProjectExprs(self, cols, names(cols))
}
}
assert_is(cols, "character")
dataset___ScannerBuilder__Project(self, cols)
self
},
Filter = function(expr) {
Expand Down
15 changes: 11 additions & 4 deletions r/R/dplyr.R
Original file line number Diff line number Diff line change
Expand Up @@ -506,9 +506,6 @@ mutate.arrow_dplyr_query <- function(.data,
}

.data <- arrow_dplyr_query(.data)
if (query_on_dataset(.data)) {
not_implemented_for_dataset("mutate()")
}

.keep <- match.arg(.keep)
.before <- enquo(.before)
Expand All @@ -529,6 +526,7 @@ mutate.arrow_dplyr_query <- function(.data,
# Deparse and take the first element in case they're long expressions
names(exprs)[unnamed] <- map_chr(exprs[unnamed], as_label)

is_dataset <- query_on_dataset(.data)
mask <- arrow_mask(.data)
results <- list()
for (i in seq_along(exprs)) {
Expand All @@ -539,6 +537,15 @@ mutate.arrow_dplyr_query <- function(.data,
if (inherits(results[[new_var]], "try-error")) {
msg <- paste('Expression', as_label(exprs[[i]]), 'not supported in Arrow')
return(abandon_ship(call, .data, msg))
} else if (is_dataset &&
!inherits(results[[new_var]], "Expression") &&
!is.null(results[[new_var]])) {
# We need some wrapping to handle literal values
if (length(results[[new_var]]) != 1) {
msg <- paste0('In ', new_var, " = ", as_label(exprs[[i]]), ", only values of size one are recycled")
return(abandon_ship(call, .data, msg))
}
results[[new_var]] <- Expression$scalar(results[[new_var]])
}
# Put it in the data mask too
mask[[new_var]] <- mask$.data[[new_var]] <- results[[new_var]]
Expand Down Expand Up @@ -583,7 +590,7 @@ abandon_ship <- function(call, .data, msg = NULL) {
# Default message: function not implemented
not_implemented_for_dataset(paste0(dplyr_fun_name, "()"))
} else {
stop(msg, call. = FALSE)
stop(msg, "\nCall collect() first to pull data into R.", call. = FALSE)
}
}

Expand Down
20 changes: 16 additions & 4 deletions r/src/arrowExports.cpp

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

17 changes: 15 additions & 2 deletions r/src/dataset.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -308,11 +308,24 @@ std::shared_ptr<ds::PartitioningFactory> dataset___HivePartitioning__MakeFactory
// ScannerBuilder, Scanner

// [[arrow::export]]
void dataset___ScannerBuilder__Project(const std::shared_ptr<ds::ScannerBuilder>& sb,
const std::vector<std::string>& cols) {
void dataset___ScannerBuilder__ProjectNames(const std::shared_ptr<ds::ScannerBuilder>& sb,
const std::vector<std::string>& cols) {
StopIfNotOk(sb->Project(cols));
}

// [[arrow::export]]
void dataset___ScannerBuilder__ProjectExprs(
const std::shared_ptr<ds::ScannerBuilder>& sb,
const std::vector<std::shared_ptr<ds::Expression>>& exprs,
const std::vector<std::string>& names) {
// We have shared_ptrs of expressions but need the Expressions
std::vector<ds::Expression> expressions;
for (auto expr : exprs) {
expressions.push_back(*expr);
}
StopIfNotOk(sb->Project(expressions, names));
}

// [[arrow::export]]
void dataset___ScannerBuilder__Filter(const std::shared_ptr<ds::ScannerBuilder>& sb,
const std::shared_ptr<ds::Expression>& expr) {
Expand Down
7 changes: 5 additions & 2 deletions r/src/expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,11 @@ std::shared_ptr<ds::Expression> dataset___expr__field_ref(std::string name) {
// [[arrow::export]]
std::string dataset___expr__get_field_ref_name(
const std::shared_ptr<ds::Expression>& ref) {
auto refname = ref->field_ref()->name();
return *refname;
auto field_ref = ref->field_ref();
if (field_ref == nullptr) {
return "";
}
return *field_ref->name();
}

// [[arrow::export]]
Expand Down
93 changes: 91 additions & 2 deletions r/tests/testthat/test-dataset.R
Original file line number Diff line number Diff line change
Expand Up @@ -717,6 +717,96 @@ test_that("filter() with expressions", {
)
})

test_that("mutate()", {
ds <- open_dataset(dataset_dir, partitioning = schema(part = uint8()))
mutated <- ds %>%
select(chr, dbl, int) %>%
filter(dbl * 2 > 14 & dbl - 50 < 3L) %>%
mutate(twice = int * 2)
expect_output(
print(mutated),
"FileSystemDataset (query)
chr: string
dbl: double
int: int32
twice: expr

* Filter: ((multiply_checked(dbl, 2) > 14) and (subtract_checked(dbl, 50) < 3))
See $.data for the source Arrow object",
fixed = TRUE
)
expect_equivalent(
mutated %>%
collect() %>%
arrange(dbl),
rbind(
df1[8:10, c("chr", "dbl", "int")],
df2[1:2, c("chr", "dbl", "int")]
) %>%
mutate(
twice = int * 2
)
)
})

test_that("transmute()", {
ds <- open_dataset(dataset_dir, partitioning = schema(part = uint8()))
mutated <-
expect_equivalent(
ds %>%
select(chr, dbl, int) %>%
filter(dbl * 2 > 14 & dbl - 50 < 3L) %>%
transmute(twice = int * 2) %>%
collect() %>%
arrange(twice),
rbind(
df1[8:10, "int", drop = FALSE],
df2[1:2, "int", drop = FALSE]
) %>%
transmute(
twice = int * 2
)
)
})

test_that("mutate() features not yet implemented", {
expect_error(
ds %>%
group_by(int) %>%
mutate(avg = mean(int)),
"mutate() on grouped data not supported in Arrow\nCall collect() first to pull data into R.",
fixed = TRUE
)
})


test_that("mutate() with scalar (length 1) literal inputs", {
expect_equal(
ds %>%
mutate(the_answer = 42) %>%
collect() %>%
pull(the_answer),
rep(42, nrow(ds))
)

expect_error(
ds %>% mutate(the_answer = c(42, 42)),
"In the_answer = c(42, 42), only values of size one are recycled\nCall collect() first to pull data into R.",
fixed = TRUE
)
})

test_that("mutate() with NULL inputs", {
expect_equal(
ds %>%
mutate(int = NULL) %>%
collect(),
ds %>%
select(-int) %>%
collect()
)
})

test_that("filter scalar validation doesn't crash (ARROW-7772)", {
expect_error(
ds %>%
Expand Down Expand Up @@ -832,7 +922,6 @@ test_that("dplyr method not implemented messages", {
expect_error(x, "is not currently implemented for Arrow Datasets")
}
expect_not_implemented(ds %>% arrange(int))
expect_not_implemented(ds %>% mutate(int = int + 2))
expect_not_implemented(ds %>% filter(int == 1) %>% summarize(n()))
})

Expand Down Expand Up @@ -1137,7 +1226,7 @@ test_that("Dataset writing: no partitioning", {
test_that("Dataset writing: partition on null", {
skip_on_os("windows") # https://issues.apache.org/jira/browse/ARROW-9651
ds <- open_dataset(hive_dir)

dst_dir <- tempfile()
partitioning = hive_partition(lgl = boolean())
write_dataset(ds, dst_dir, partitioning = partitioning)
Expand Down