diff --git a/r/R/arrow-package.R b/r/R/arrow-package.R index e452460ffd4..fece77d9cc9 100644 --- a/r/R/arrow-package.R +++ b/r/R/arrow-package.R @@ -41,7 +41,7 @@ "semi_join", "anti_join", "count", "tally" ) ) - for (cl in c("Dataset", "ArrowTabular", "arrow_dplyr_query")) { + for (cl in c("Dataset", "ArrowTabular", "RecordBatchReader", "arrow_dplyr_query")) { for (m in dplyr_methods) { s3_register(m, cl) } diff --git a/r/R/dplyr-arrange.R b/r/R/dplyr-arrange.R index 4c8c687a3cb..247a539f527 100644 --- a/r/R/dplyr-arrange.R +++ b/r/R/dplyr-arrange.R @@ -51,7 +51,7 @@ arrange.arrow_dplyr_query <- function(.data, ..., .by_group = FALSE) { .data$arrange_desc <- c(descs, .data$arrange_desc) .data } -arrange.Dataset <- arrange.ArrowTabular <- arrange.arrow_dplyr_query +arrange.Dataset <- arrange.ArrowTabular <- arrange.RecordBatchReader <- arrange.arrow_dplyr_query # Helper to handle desc() in arrange() # * Takes a quosure as input diff --git a/r/R/dplyr-collect.R b/r/R/dplyr-collect.R index c62f2559310..732c3325f75 100644 --- a/r/R/dplyr-collect.R +++ b/r/R/dplyr-collect.R @@ -43,11 +43,11 @@ collect.ArrowTabular <- function(x, as_data_frame = TRUE, ...) { x } } -collect.Dataset <- function(x, ...) dplyr::collect(as_adq(x), ...) +collect.Dataset <- collect.RecordBatchReader <- function(x, ...) dplyr::collect(as_adq(x), ...) compute.arrow_dplyr_query <- function(x, ...) dplyr::collect(x, as_data_frame = FALSE) compute.ArrowTabular <- function(x, ...) x -compute.Dataset <- compute.arrow_dplyr_query +compute.Dataset <- compute.RecordBatchReader <- compute.arrow_dplyr_query pull.arrow_dplyr_query <- function(.data, var = -1) { .data <- as_adq(.data) @@ -55,7 +55,7 @@ pull.arrow_dplyr_query <- function(.data, var = -1) { .data$selected_columns <- set_names(.data$selected_columns[var], var) dplyr::collect(.data)[[1]] } -pull.Dataset <- pull.ArrowTabular <- pull.arrow_dplyr_query +pull.Dataset <- pull.ArrowTabular <- pull.RecordBatchReader <- pull.arrow_dplyr_query restore_dplyr_features <- function(df, query) { # An arrow_dplyr_query holds some attributes that Arrow doesn't know about @@ -85,7 +85,7 @@ collapse.arrow_dplyr_query <- function(x, ...) { # Nest inside a new arrow_dplyr_query (and keep groups) restore_dplyr_features(arrow_dplyr_query(x), x) } -collapse.Dataset <- collapse.ArrowTabular <- function(x, ...) { +collapse.Dataset <- collapse.ArrowTabular <- collapse.RecordBatchReader <- function(x, ...) { arrow_dplyr_query(x) } diff --git a/r/R/dplyr-count.R b/r/R/dplyr-count.R index c567c285f54..c194d400fba 100644 --- a/r/R/dplyr-count.R +++ b/r/R/dplyr-count.R @@ -34,7 +34,7 @@ count.arrow_dplyr_query <- function(x, ..., wt = NULL, sort = FALSE, name = NULL out } -count.Dataset <- count.ArrowTabular <- count.arrow_dplyr_query +count.Dataset <- count.ArrowTabular <- count.RecordBatchReader <- count.arrow_dplyr_query #' @importFrom rlang sym := tally.arrow_dplyr_query <- function(x, wt = NULL, sort = FALSE, name = NULL) { @@ -54,7 +54,7 @@ tally.arrow_dplyr_query <- function(x, wt = NULL, sort = FALSE, name = NULL) { } } -tally.Dataset <- tally.ArrowTabular <- tally.arrow_dplyr_query +tally.Dataset <- tally.ArrowTabular <- tally.RecordBatchReader <- tally.arrow_dplyr_query # we don't want to depend on dplyr, but we refrence these above utils::globalVariables(c("n", "desc")) diff --git a/r/R/dplyr-distinct.R b/r/R/dplyr-distinct.R index 5dfcb641f87..d5a8c81e6b0 100644 --- a/r/R/dplyr-distinct.R +++ b/r/R/dplyr-distinct.R @@ -43,4 +43,4 @@ distinct.arrow_dplyr_query <- function(.data, ..., .keep_all = FALSE) { out } -distinct.Dataset <- distinct.ArrowTabular <- distinct.arrow_dplyr_query +distinct.Dataset <- distinct.ArrowTabular <- distinct.RecordBatchReader <- distinct.arrow_dplyr_query diff --git a/r/R/dplyr-filter.R b/r/R/dplyr-filter.R index 3c8c08ea5c0..7db68b43e93 100644 --- a/r/R/dplyr-filter.R +++ b/r/R/dplyr-filter.R @@ -67,7 +67,7 @@ filter.arrow_dplyr_query <- function(.data, ..., .preserve = FALSE) { set_filters(.data, filters) } -filter.Dataset <- filter.ArrowTabular <- filter.arrow_dplyr_query +filter.Dataset <- filter.ArrowTabular <- filter.RecordBatchReader <- filter.arrow_dplyr_query set_filters <- function(.data, expressions) { if (length(expressions)) { diff --git a/r/R/dplyr-group-by.R b/r/R/dplyr-group-by.R index 66b867210a8..1014c20f4da 100644 --- a/r/R/dplyr-group-by.R +++ b/r/R/dplyr-group-by.R @@ -55,10 +55,10 @@ group_by.arrow_dplyr_query <- function(.data, .data$drop_empty_groups <- ifelse(length(gv), .drop, dplyr::group_by_drop_default(.data)) .data } -group_by.Dataset <- group_by.ArrowTabular <- group_by.arrow_dplyr_query +group_by.Dataset <- group_by.ArrowTabular <- group_by.RecordBatchReader <- group_by.arrow_dplyr_query groups.arrow_dplyr_query <- function(x) syms(dplyr::group_vars(x)) -groups.Dataset <- groups.ArrowTabular <- function(x) NULL +groups.Dataset <- groups.ArrowTabular <- groups.RecordBatchReader <- function(x) NULL group_vars.arrow_dplyr_query <- function(x) x$group_by_vars group_vars.Dataset <- function(x) NULL @@ -71,7 +71,7 @@ group_vars.ArrowTabular <- function(x) { # the .drop argument to group_by() group_by_drop_default.arrow_dplyr_query <- function(.tbl) .tbl$drop_empty_groups %||% TRUE -group_by_drop_default.Dataset <- group_by_drop_default.ArrowTabular <- +group_by_drop_default.Dataset <- group_by_drop_default.ArrowTabular <- group_by_drop_default.RecordBatchReader <- function(.tbl) TRUE ungroup.arrow_dplyr_query <- function(x, ...) { @@ -79,7 +79,7 @@ ungroup.arrow_dplyr_query <- function(x, ...) { x$drop_empty_groups <- NULL x } -ungroup.Dataset <- force +ungroup.Dataset <- ungroup.RecordBatchReader <- force ungroup.ArrowTabular <- function(x) { x$r_metadata$attributes$.group_vars <- NULL x diff --git a/r/R/dplyr-join.R b/r/R/dplyr-join.R index c14b1a8f3dd..42048137abd 100644 --- a/r/R/dplyr-join.R +++ b/r/R/dplyr-join.R @@ -52,7 +52,7 @@ left_join.arrow_dplyr_query <- function(x, keep = FALSE) { do_join(x, y, by, copy, suffix, ..., keep = keep, join_type = "LEFT_OUTER") } -left_join.Dataset <- left_join.ArrowTabular <- left_join.arrow_dplyr_query +left_join.Dataset <- left_join.ArrowTabular <- left_join.RecordBatchReader <- left_join.arrow_dplyr_query right_join.arrow_dplyr_query <- function(x, y, @@ -63,7 +63,7 @@ right_join.arrow_dplyr_query <- function(x, keep = FALSE) { do_join(x, y, by, copy, suffix, ..., keep = keep, join_type = "RIGHT_OUTER") } -right_join.Dataset <- right_join.ArrowTabular <- right_join.arrow_dplyr_query +right_join.Dataset <- right_join.ArrowTabular <- right_join.RecordBatchReader <- right_join.arrow_dplyr_query inner_join.arrow_dplyr_query <- function(x, y, @@ -74,7 +74,7 @@ inner_join.arrow_dplyr_query <- function(x, keep = FALSE) { do_join(x, y, by, copy, suffix, ..., keep = keep, join_type = "INNER") } -inner_join.Dataset <- inner_join.ArrowTabular <- inner_join.arrow_dplyr_query +inner_join.Dataset <- inner_join.ArrowTabular <- inner_join.RecordBatchReader <- inner_join.arrow_dplyr_query full_join.arrow_dplyr_query <- function(x, y, @@ -85,7 +85,7 @@ full_join.arrow_dplyr_query <- function(x, keep = FALSE) { do_join(x, y, by, copy, suffix, ..., keep = keep, join_type = "FULL_OUTER") } -full_join.Dataset <- full_join.ArrowTabular <- full_join.arrow_dplyr_query +full_join.Dataset <- full_join.ArrowTabular <- full_join.RecordBatchReader <- full_join.arrow_dplyr_query semi_join.arrow_dplyr_query <- function(x, y, @@ -96,7 +96,7 @@ semi_join.arrow_dplyr_query <- function(x, keep = FALSE) { do_join(x, y, by, copy, suffix, ..., keep = keep, join_type = "LEFT_SEMI") } -semi_join.Dataset <- semi_join.ArrowTabular <- semi_join.arrow_dplyr_query +semi_join.Dataset <- semi_join.ArrowTabular <- semi_join.RecordBatchReader <- semi_join.arrow_dplyr_query anti_join.arrow_dplyr_query <- function(x, y, @@ -107,7 +107,7 @@ anti_join.arrow_dplyr_query <- function(x, keep = FALSE) { do_join(x, y, by, copy, suffix, ..., keep = keep, join_type = "LEFT_ANTI") } -anti_join.Dataset <- anti_join.ArrowTabular <- anti_join.arrow_dplyr_query +anti_join.Dataset <- anti_join.ArrowTabular <- anti_join.RecordBatchReader <- anti_join.arrow_dplyr_query handle_join_by <- function(by, x, y) { if (is.null(by)) { diff --git a/r/R/dplyr-mutate.R b/r/R/dplyr-mutate.R index 2e5239484f6..986f29cc1d5 100644 --- a/r/R/dplyr-mutate.R +++ b/r/R/dplyr-mutate.R @@ -108,13 +108,13 @@ mutate.arrow_dplyr_query <- function(.data, # Even if "none", we still keep group vars ensure_group_vars(.data) } -mutate.Dataset <- mutate.ArrowTabular <- mutate.arrow_dplyr_query +mutate.Dataset <- mutate.ArrowTabular <- mutate.RecordBatchReader <- mutate.arrow_dplyr_query transmute.arrow_dplyr_query <- function(.data, ...) { dots <- check_transmute_args(...) dplyr::mutate(.data, !!!dots, .keep = "none") } -transmute.Dataset <- transmute.ArrowTabular <- transmute.arrow_dplyr_query +transmute.Dataset <- transmute.ArrowTabular <- transmute.RecordBatchReader <- transmute.arrow_dplyr_query # This function is a copy of dplyr:::check_transmute_args at # https://github.com/tidyverse/dplyr/blob/master/R/mutate.R diff --git a/r/R/dplyr-select.R b/r/R/dplyr-select.R index 9a867ced964..c9624d090ac 100644 --- a/r/R/dplyr-select.R +++ b/r/R/dplyr-select.R @@ -24,13 +24,13 @@ select.arrow_dplyr_query <- function(.data, ...) { check_select_helpers(enexprs(...)) column_select(as_adq(.data), !!!enquos(...)) } -select.Dataset <- select.ArrowTabular <- select.arrow_dplyr_query +select.Dataset <- select.ArrowTabular <- select.RecordBatchReader <- select.arrow_dplyr_query rename.arrow_dplyr_query <- function(.data, ...) { check_select_helpers(enexprs(...)) column_select(as_adq(.data), !!!enquos(...), .FUN = vars_rename) } -rename.Dataset <- rename.ArrowTabular <- rename.arrow_dplyr_query +rename.Dataset <- rename.ArrowTabular <- rename.RecordBatchReader <- rename.arrow_dplyr_query column_select <- function(.data, ..., .FUN = vars_select) { # .FUN is either tidyselect::vars_select or tidyselect::vars_rename @@ -106,7 +106,7 @@ relocate.arrow_dplyr_query <- function(.data, ..., .before = NULL, .after = NULL } .data } -relocate.Dataset <- relocate.ArrowTabular <- relocate.arrow_dplyr_query +relocate.Dataset <- relocate.ArrowTabular <- relocate.RecordBatchReader <- relocate.arrow_dplyr_query check_select_helpers <- function(exprs) { # Throw an error if unsupported tidyselect selection helpers in `exprs` diff --git a/r/R/dplyr-summarize.R b/r/R/dplyr-summarize.R index 7cb9a3483d5..d8e6c46d921 100644 --- a/r/R/dplyr-summarize.R +++ b/r/R/dplyr-summarize.R @@ -187,7 +187,7 @@ summarise.arrow_dplyr_query <- function(.data, ...) { return(out) } } -summarise.Dataset <- summarise.ArrowTabular <- summarise.arrow_dplyr_query +summarise.Dataset <- summarise.ArrowTabular <- summarise.RecordBatchReader <- summarise.arrow_dplyr_query # This is the Arrow summarize implementation do_arrow_summarize <- function(.data, ..., .groups = NULL) { diff --git a/r/R/duckdb.R b/r/R/duckdb.R index 960892fbf88..d5845ab5c13 100644 --- a/r/R/duckdb.R +++ b/r/R/duckdb.R @@ -104,12 +104,11 @@ unique_arrow_tablename <- function() { # Creates an environment that disconnects the database when it's GC'd duckdb_disconnector <- function(con, tbl_name) { + force(tbl_name) reg.finalizer(environment(), function(...) { # remote the table we ephemerally created (though only if the connection is # still valid) - if (DBI::dbIsValid(con)) { - duckdb::duckdb_unregister_arrow(con, tbl_name) - } + duckdb::duckdb_unregister_arrow(con, tbl_name) }) environment() } @@ -120,8 +119,11 @@ duckdb_disconnector <- function(con, tbl_name) { #' other processes (like DuckDB). #' #' @param .data the object to be converted +#' @param as_arrow_query should the returned object be wrapped as an +#' `arrow_dplyr_query`? (logical, default: `TRUE`) #' -#' @return an `arrow_dplyr_query` object, to be used in dplyr pipelines. +#' @return a `RecordBatchReader` object, wrapped as an arrow dplyr query which +#' can be used in dplyr pipelines. #' @export #' #' @examplesIf getFromNamespace("run_duckdb_examples", "arrow")() @@ -136,7 +138,7 @@ duckdb_disconnector <- function(con, tbl_name) { #' summarize(mean_mpg = mean(mpg, na.rm = TRUE)) %>% #' to_arrow() %>% #' collect() -to_arrow <- function(.data) { +to_arrow <- function(.data, as_arrow_query = TRUE) { # If this is an Arrow object already, return quickly since we're already Arrow if (inherits(.data, c("arrow_dplyr_query", "ArrowObject"))) { return(.data) @@ -155,6 +157,9 @@ to_arrow <- function(.data) { # Run the query res <- DBI::dbSendQuery(dbplyr::remote_con(.data), dbplyr::remote_query(.data), arrow = TRUE) - # TODO: we shouldn't need $read_table(), but we get segfaults when we do. - arrow_dplyr_query(duckdb::duckdb_fetch_record_batch(res)$read_table()) + if (as_arrow_query) { + arrow_dplyr_query(duckdb::duckdb_fetch_record_batch(res)) + } else { + duckdb::duckdb_fetch_record_batch(res) + } } diff --git a/r/man/to_arrow.Rd b/r/man/to_arrow.Rd index e0c31b8dc35..4d300521011 100644 --- a/r/man/to_arrow.Rd +++ b/r/man/to_arrow.Rd @@ -4,13 +4,17 @@ \alias{to_arrow} \title{Create an Arrow object from others} \usage{ -to_arrow(.data) +to_arrow(.data, as_arrow_query = TRUE) } \arguments{ \item{.data}{the object to be converted} + +\item{as_arrow_query}{should the returned object be wrapped as an +\code{arrow_dplyr_query}? (logical, default: \code{TRUE})} } \value{ -an \code{arrow_dplyr_query} object, to be used in dplyr pipelines. +a \code{RecordBatchReader} object, wrapped as an arrow dplyr query which +can be used in dplyr pipelines. } \description{ This can be used in pipelines that pass data back and forth between Arrow and diff --git a/r/tests/testthat/test-dataset.R b/r/tests/testthat/test-dataset.R index d2c5615bc59..4d2b1b7acac 100644 --- a/r/tests/testthat/test-dataset.R +++ b/r/tests/testthat/test-dataset.R @@ -841,7 +841,7 @@ test_that("Collecting zero columns from a dataset doesn't return entire dataset" test_that("dataset RecordBatchReader to C-interface to arrow_dplyr_query", { - ds <- open_dataset(ipc_dir, partitioning = "part", format = "feather") + ds <- open_dataset(hive_dir) # export the RecordBatchReader via the C-interface stream_ptr <- allocate_arrow_array_stream() @@ -849,25 +849,55 @@ test_that("dataset RecordBatchReader to C-interface to arrow_dplyr_query", { reader <- scan$ToRecordBatchReader() reader$export_to_c(stream_ptr) + expect_equal( + RecordBatchStreamReader$import_from_c(stream_ptr) %>% + filter(int < 8 | int > 55) %>% + mutate(part_plus = group + 6) %>% + arrange(dbl) %>% + collect(), + ds %>% + filter(int < 8 | int > 55) %>% + mutate(part_plus = group + 6) %>% + arrange(dbl) %>% + collect() + ) + + # must clean up the pointer or we leak + delete_arrow_array_stream(stream_ptr) +}) + +test_that("dataset to C-interface to arrow_dplyr_query with proj/filter", { + ds <- open_dataset(hive_dir) + + # filter the dataset + ds <- ds %>% + filter(int > 2) + + # export the RecordBatchReader via the C-interface + stream_ptr <- allocate_arrow_array_stream() + scan <- Scanner$create( + ds, + projection = names(ds), + filter = Expression$create("less", Expression$field_ref("int"), Expression$scalar(8L))) + reader <- scan$ToRecordBatchReader() + reader$export_to_c(stream_ptr) + # then import it and check that the roundtripped value is the same circle <- RecordBatchStreamReader$import_from_c(stream_ptr) # create an arrow_dplyr_query() from the recordbatch reader reader_adq <- arrow_dplyr_query(circle) - # TODO: ARROW-14321 should be able to arrange then collect - tab_from_c_new <- reader_adq %>% - filter(int < 8, int > 55) %>% - mutate(part_plus = part + 6) %>% - collect() expect_equal( - tab_from_c_new %>% - arrange(dbl), + reader_adq %>% + mutate(part_plus = group + 6) %>% + arrange(dbl) %>% + collect(), ds %>% - filter(int < 8, int > 55) %>% - mutate(part_plus = part + 6) %>% - collect() %>% - arrange(dbl) + filter(int < 8, int > 2) %>% + mutate(part_plus = group + 6) %>% + arrange(dbl) %>% + collect() ) # must clean up the pointer or we leak diff --git a/r/tests/testthat/test-duckdb.R b/r/tests/testthat/test-duckdb.R index b31fffda95a..48803575dc5 100644 --- a/r/tests/testthat/test-duckdb.R +++ b/r/tests/testthat/test-duckdb.R @@ -66,7 +66,6 @@ test_that("to_duckdb", { }) test_that("to_duckdb then to_arrow", { - skip("Flaky, unskip when ARROW-14745 is merged") ds <- InMemoryDataset$create(example_data) ds_rt <- ds %>% @@ -113,6 +112,64 @@ test_that("to_duckdb then to_arrow", { ) }) +test_that("to_arrow roundtrip, with dataset", { + # these will continue to error until 0.3.2 is released + # https://github.com/duckdb/duckdb/pull/2957 + skip_if_not_installed("duckdb", minimum_version = "0.3.2") + # With a multi-part dataset + tf <- tempfile() + new_ds <- rbind( + cbind(example_data, part = 1), + cbind(example_data, part = 2), + cbind(mutate(example_data, dbl = dbl * 3, dbl2 = dbl2 * 3), part = 3), + cbind(mutate(example_data, dbl = dbl * 4, dbl2 = dbl2 * 4), part = 4) + ) + write_dataset(new_ds, tf, partitioning = "part") + + ds <- open_dataset(tf) + + expect_identical( + ds %>% + to_duckdb() %>% + select(-fct) %>% + mutate(dbl_plus = dbl + 1) %>% + to_arrow() %>% + filter(int > 5 & part > 1) %>% + collect() %>% + arrange(part, int) %>% + as.data.frame(), + ds %>% + select(-fct) %>% + filter(int > 5 & part > 1) %>% + mutate(dbl_plus = dbl + 1) %>% + collect() %>% + arrange(part, int) + ) +}) + +test_that("to_arrow roundtrip, with dataset (without wrapping", { + # these will continue to error until 0.3.2 is released + # https://github.com/duckdb/duckdb/pull/2957 + skip_if_not_installed("duckdb", minimum_version = "0.3.2") + # With a multi-part dataset + tf <- tempfile() + new_ds <- rbind( + cbind(example_data, part = 1), + cbind(example_data, part = 2), + cbind(mutate(example_data, dbl = dbl * 3, dbl2 = dbl2 * 3), part = 3), + cbind(mutate(example_data, dbl = dbl * 4, dbl2 = dbl2 * 4), part = 4) + ) + write_dataset(new_ds, tf, partitioning = "part") + + out <- ds %>% + to_duckdb() %>% + select(-fct) %>% + mutate(dbl_plus = dbl + 1) %>% + to_arrow(as_arrow_query = FALSE) + + expect_r6_class(out, "RecordBatchReader") +}) + # The next set of tests use an already-extant connection to test features of # persistence and querying against the table without using the `tbl` itself, so # we need to create a connection separate from the ephemeral one that is made