diff --git a/r/NAMESPACE b/r/NAMESPACE index 59055ff2b77..f1f4bd80570 100644 --- a/r/NAMESPACE +++ b/r/NAMESPACE @@ -30,6 +30,7 @@ S3method(as.character,FileFormat) S3method(as.character,FragmentScanOptions) S3method(as.data.frame,ArrowTabular) S3method(as.data.frame,RecordBatchReader) +S3method(as.data.frame,Schema) S3method(as.data.frame,StructArray) S3method(as.data.frame,arrow_dplyr_query) S3method(as.double,ArrowDatum) @@ -48,6 +49,7 @@ S3method(as_arrow_array,pyarrow.lib.Array) S3method(as_arrow_array,vctrs_list_of) S3method(as_arrow_table,RecordBatch) S3method(as_arrow_table,RecordBatchReader) +S3method(as_arrow_table,Schema) S3method(as_arrow_table,Table) S3method(as_arrow_table,arrow_dplyr_query) S3method(as_arrow_table,data.frame) @@ -478,6 +480,7 @@ importFrom(stats,runif) importFrom(tidyselect,all_of) importFrom(tidyselect,contains) importFrom(tidyselect,ends_with) +importFrom(tidyselect,eval_rename) importFrom(tidyselect,eval_select) importFrom(tidyselect,everything) importFrom(tidyselect,last_col) @@ -486,8 +489,6 @@ importFrom(tidyselect,num_range) importFrom(tidyselect,one_of) importFrom(tidyselect,starts_with) importFrom(tidyselect,vars_pull) -importFrom(tidyselect,vars_rename) -importFrom(tidyselect,vars_select) importFrom(utils,capture.output) importFrom(utils,getFromNamespace) importFrom(utils,head) diff --git a/r/R/arrow-package.R b/r/R/arrow-package.R index 477fa67e7c6..da005efac21 100644 --- a/r/R/arrow-package.R +++ b/r/R/arrow-package.R @@ -27,7 +27,7 @@ #' @importFrom rlang is_list call2 is_empty as_function as_label arg_match is_symbol is_call call_args #' @importFrom rlang quo_set_env quo_get_env is_formula quo_is_call f_rhs parse_expr f_env new_quosure #' @importFrom rlang new_quosures expr_text caller_env check_dots_empty dots_list -#' @importFrom tidyselect vars_pull vars_rename vars_select eval_select +#' @importFrom tidyselect vars_pull eval_select eval_rename #' @importFrom glue glue #' @useDynLib arrow, .registration = TRUE #' @keywords internal diff --git a/r/R/arrowExports.R b/r/R/arrowExports.R index b73bef71023..dc49278236c 100644 --- a/r/R/arrowExports.R +++ b/r/R/arrowExports.R @@ -2024,6 +2024,10 @@ Table__from_record_batches <- function(batches, schema_sxp) { .Call(`_arrow_Table__from_record_batches`, batches, schema_sxp) } +Table__from_schema <- function(schema_sxp) { + .Call(`_arrow_Table__from_schema`, schema_sxp) +} + Table__ReferencedBufferSize <- function(table) { .Call(`_arrow_Table__ReferencedBufferSize`, table) } @@ -2051,3 +2055,4 @@ SetIOThreadPoolCapacity <- function(threads) { Array__infer_type <- function(x) { .Call(`_arrow_Array__infer_type`, x) } + diff --git a/r/R/csv.R b/r/R/csv.R index 7b474c137eb..b814a50f2f8 100644 --- a/r/R/csv.R +++ b/r/R/csv.R @@ -102,7 +102,7 @@ #' an Arrow [Schema], or `NULL` (the default) to infer types from the data. #' @param col_select A character vector of column names to keep, as in the #' "select" argument to `data.table::fread()`, or a -#' [tidy selection specification][tidyselect::vars_select()] +#' [tidy selection specification][tidyselect::eval_select()] #' of columns, as used in `dplyr::select()`. #' @param na A character vector of strings to interpret as missing values. #' @param quoted_na Should missing values inside quotes be treated as missing @@ -226,7 +226,8 @@ read_delim_arrow <- function(file, # TODO: move this into convert_options using include_columns col_select <- enquo(col_select) if (!quo_is_null(col_select)) { - tab <- tab[vars_select(names(tab), !!col_select)] + sim_df <- as.data.frame(tab$schema) + tab <- tab[eval_select(col_select, sim_df)] } if (isTRUE(as_data_frame)) { diff --git a/r/R/dplyr-funcs-doc.R b/r/R/dplyr-funcs-doc.R index d399e37e101..bce7010c17f 100644 --- a/r/R/dplyr-funcs-doc.R +++ b/r/R/dplyr-funcs-doc.R @@ -190,7 +190,7 @@ #' #' ## dplyr #' -#' * [`across()`][dplyr::across()]: Use of `where()` selection helper not yet supported +#' * [`across()`][dplyr::across()] #' * [`between()`][dplyr::between()] #' * [`case_when()`][dplyr::case_when()] #' * [`coalesce()`][dplyr::coalesce()] diff --git a/r/R/dplyr-select.R b/r/R/dplyr-select.R index 6e7dc7a1aa2..3a9d82f9752 100644 --- a/r/R/dplyr-select.R +++ b/r/R/dplyr-select.R @@ -21,14 +21,12 @@ tbl_vars.arrow_dplyr_query <- function(x) names(x$selected_columns) select.arrow_dplyr_query <- function(.data, ...) { - check_select_helpers(enexprs(...)) - column_select(as_adq(.data), !!!enquos(...)) + column_select(.data, enquos(...), op = "select") } select.Dataset <- select.ArrowTabular <- select.RecordBatchReader <- select.arrow_dplyr_query rename.arrow_dplyr_query <- function(.data, ...) { - check_select_helpers(enexprs(...)) - column_select(as_adq(.data), !!!enquos(...), .FUN = vars_rename) + column_select(.data, enquos(...), op = "rename") } rename.Dataset <- rename.ArrowTabular <- rename.RecordBatchReader <- rename.arrow_dplyr_query @@ -39,29 +37,6 @@ rename_with.arrow_dplyr_query <- function(.data, .fn, .cols = everything(), ...) } rename_with.Dataset <- rename_with.ArrowTabular <- rename_with.RecordBatchReader <- rename_with.arrow_dplyr_query -column_select <- function(.data, ..., .FUN = vars_select) { - # .FUN is either tidyselect::vars_select or tidyselect::vars_rename - # It operates on the names() of selected_columns, i.e. the column names - # factoring in any renaming that may already have happened - out <- .FUN(names(.data), !!!enquos(...)) - # Make sure that the resulting selected columns map back to the original data, - # as in when there are multiple renaming steps - .data$selected_columns <- set_names(.data$selected_columns[out], names(out)) - - # If we've renamed columns, we need to project that renaming into other - # query parameters we've collected - renamed <- out[names(out) != out] - if (length(renamed)) { - # Massage group_by - gbv <- .data$group_by_vars - renamed_groups <- gbv %in% renamed - gbv[renamed_groups] <- names(renamed)[match(gbv[renamed_groups], renamed)] - .data$group_by_vars <- gbv - # No need to massage filters because those contain references to Arrow objects - } - .data -} - relocate.arrow_dplyr_query <- function(.data, ..., .before = NULL, .after = NULL) { # The code in this function is adapted from the code in dplyr::relocate.data.frame # at https://github.com/tidyverse/dplyr/blob/master/R/relocate.R @@ -115,18 +90,39 @@ relocate.arrow_dplyr_query <- function(.data, ..., .before = NULL, .after = NULL } relocate.Dataset <- relocate.ArrowTabular <- relocate.RecordBatchReader <- relocate.arrow_dplyr_query -check_select_helpers <- function(exprs) { - # Throw an error if unsupported tidyselect selection helpers in `exprs` - exprs <- lapply(exprs, function(x) if (is_quosure(x)) quo_get_expr(x) else x) - unsup_select_helpers <- "where" - funs_in_exprs <- unlist(lapply(exprs, all_funs)) - unsup_funs <- funs_in_exprs[funs_in_exprs %in% unsup_select_helpers] - if (length(unsup_funs)) { - stop( - "Unsupported selection ", - ngettext(length(unsup_funs), "helper: ", "helpers: "), - oxford_paste(paste0(unsup_funs, "()"), quote = FALSE), - call. = FALSE - ) +column_select <- function(.data, select_expression, op = c("select", "rename")) { + op <- match.arg(op) + + .data <- as_adq(.data) + sim_df <- as.data.frame(implicit_schema(.data)) + old_names <- names(sim_df) + + if (op == "select") { + out <- eval_select(expr(c(!!!select_expression)), sim_df) + # select only columns from `out` + subset <- out + } else if (op == "rename") { + out <- eval_rename(expr(c(!!!select_expression)), sim_df) + # select all columns as only renaming + subset <- set_names(seq_along(old_names), old_names) + names(subset)[out] <- names(out) + } + + .data$selected_columns <- set_names(.data$selected_columns[subset], names(subset)) + + # check if names have updated + new_names <- old_names + new_names[out] <- names(out) + names_compared <- set_names(old_names, new_names) + renamed <- names_compared[old_names != new_names] + + # Update names in group_by if changed in select() or rename() + if (length(renamed)) { + gbv <- .data$group_by_vars + renamed_groups <- gbv %in% renamed + gbv[renamed_groups] <- names(renamed)[match(gbv[renamed_groups], renamed)] + .data$group_by_vars <- gbv } + + .data } diff --git a/r/R/feather.R b/r/R/feather.R index 4e2e9947cb9..7791b9e8aa8 100644 --- a/r/R/feather.R +++ b/r/R/feather.R @@ -178,8 +178,11 @@ read_feather <- function(file, col_select = NULL, as_data_frame = TRUE, mmap = T reader <- FeatherReader$create(file) col_select <- enquo(col_select) + columns <- if (!quo_is_null(col_select)) { - vars_select(names(reader), !!col_select) + sim_df <- as.data.frame(reader$schema) + indices <- eval_select(col_select, sim_df) + names(reader)[indices] } out <- tryCatch( diff --git a/r/R/json.R b/r/R/json.R index c4061f066b1..1d03045947b 100644 --- a/r/R/json.R +++ b/r/R/json.R @@ -68,7 +68,8 @@ read_json_arrow <- function(file, col_select <- enquo(col_select) if (!quo_is_null(col_select)) { - tab <- tab[vars_select(names(tab), !!col_select)] + sim_df <- as.data.frame(tab$schema) + tab <- tab[eval_select(col_select, sim_df)] } if (isTRUE(as_data_frame)) { diff --git a/r/R/parquet.R b/r/R/parquet.R index 0b3f93b20e1..ac3ca616741 100644 --- a/r/R/parquet.R +++ b/r/R/parquet.R @@ -55,9 +55,8 @@ read_parquet <- function(file, col_select <- enquo(col_select) if (!quo_is_null(col_select)) { # infer which columns to keep from schema - schema <- reader$GetSchema() - names <- names(schema) - indices <- match(vars_select(names, !!col_select), names) - 1L + sim_df <- as.data.frame(reader$GetSchema()) + indices <- eval_select(col_select, sim_df) - 1L tab <- tryCatch( reader$ReadTable(indices), error = read_compressed_error diff --git a/r/R/schema.R b/r/R/schema.R index 86a968b5003..c7e26652c90 100644 --- a/r/R/schema.R +++ b/r/R/schema.R @@ -383,3 +383,8 @@ as_schema.Schema <- function(x, ...) { as_schema.StructType <- function(x, ...) { schema(!!!x$fields()) } + +#' @export +as.data.frame.Schema <- function(x, row.names = NULL, optional = FALSE, ...) { + as.data.frame(Table__from_schema(x)) +} diff --git a/r/R/table.R b/r/R/table.R index c5291257792..2007a3887bc 100644 --- a/r/R/table.R +++ b/r/R/table.R @@ -134,6 +134,11 @@ Table$create <- function(..., schema = NULL) { if (is.null(names(dots))) { names(dots) <- rep_len("", length(dots)) } + + if (length(dots) == 0 && inherits(schema, "Schema")) { + return(Table__from_schema(schema)) + } + stopifnot(length(dots) > 0) if (all_record_batches(dots)) { @@ -330,3 +335,9 @@ as_arrow_table.RecordBatchReader <- function(x, ...) { as_arrow_table.arrow_dplyr_query <- function(x, ...) { as_arrow_table(as_record_batch_reader(x)) } + +#' @rdname as_arrow_table +#' @export +as_arrow_table.Schema <- function(x, ...) { + Table__from_schema(x) +} diff --git a/r/data-raw/docgen.R b/r/data-raw/docgen.R index 8db3bb7e804..33a1fe49668 100644 --- a/r/data-raw/docgen.R +++ b/r/data-raw/docgen.R @@ -127,10 +127,7 @@ docs <- arrow:::.cache$docs # Add some functions # across() is handled by manipulating the quosures, not by nse_funcs -docs[["dplyr::across"]] <- c( - # TODO(ARROW-17384): implement where - "Use of `where()` selection helper not yet supported" -) +docs[["dplyr::across"]] <- character(0) # if_any() and if_all() are used instead of across() in filter() # they are both handled by manipulating the quosures, not by nse_funcs diff --git a/r/man/acero.Rd b/r/man/acero.Rd index 5cbe211d00d..472e79e82e8 100644 --- a/r/man/acero.Rd +++ b/r/man/acero.Rd @@ -180,7 +180,7 @@ as \code{arrow_ascii_is_decimal}. \subsection{dplyr}{ \itemize{ -\item \code{\link[dplyr:across]{across()}}: Use of \code{where()} selection helper not yet supported +\item \code{\link[dplyr:across]{across()}} \item \code{\link[dplyr:between]{between()}} \item \code{\link[dplyr:case_when]{case_when()}} \item \code{\link[dplyr:coalesce]{coalesce()}} diff --git a/r/man/as_arrow_table.Rd b/r/man/as_arrow_table.Rd index aac4495e7c6..22d4ea1c191 100644 --- a/r/man/as_arrow_table.Rd +++ b/r/man/as_arrow_table.Rd @@ -8,6 +8,7 @@ \alias{as_arrow_table.data.frame} \alias{as_arrow_table.RecordBatchReader} \alias{as_arrow_table.arrow_dplyr_query} +\alias{as_arrow_table.Schema} \title{Convert an object to an Arrow Table} \usage{ as_arrow_table(x, ..., schema = NULL) @@ -23,6 +24,8 @@ as_arrow_table(x, ..., schema = NULL) \method{as_arrow_table}{RecordBatchReader}(x, ...) \method{as_arrow_table}{arrow_dplyr_query}(x, ...) + +\method{as_arrow_table}{Schema}(x, ...) } \arguments{ \item{x}{An object to convert to an Arrow Table} diff --git a/r/man/read_delim_arrow.Rd b/r/man/read_delim_arrow.Rd index 5b91fc0ec9c..740aff3711b 100644 --- a/r/man/read_delim_arrow.Rd +++ b/r/man/read_delim_arrow.Rd @@ -101,7 +101,7 @@ an Arrow \link{Schema}, or \code{NULL} (the default) to infer types from the dat \item{col_select}{A character vector of column names to keep, as in the "select" argument to \code{data.table::fread()}, or a -\link[tidyselect:vars_select]{tidy selection specification} +\link[tidyselect:eval_select]{tidy selection specification} of columns, as used in \code{dplyr::select()}.} \item{na}{A character vector of strings to interpret as missing values.} diff --git a/r/man/read_feather.Rd b/r/man/read_feather.Rd index 218a163b990..000aa541aac 100644 --- a/r/man/read_feather.Rd +++ b/r/man/read_feather.Rd @@ -18,7 +18,7 @@ open.} \item{col_select}{A character vector of column names to keep, as in the "select" argument to \code{data.table::fread()}, or a -\link[tidyselect:vars_select]{tidy selection specification} +\link[tidyselect:eval_select]{tidy selection specification} of columns, as used in \code{dplyr::select()}.} \item{as_data_frame}{Should the function return a \code{data.frame} (default) or diff --git a/r/man/read_json_arrow.Rd b/r/man/read_json_arrow.Rd index cc821c33014..926b1454c13 100644 --- a/r/man/read_json_arrow.Rd +++ b/r/man/read_json_arrow.Rd @@ -22,7 +22,7 @@ open.} \item{col_select}{A character vector of column names to keep, as in the "select" argument to \code{data.table::fread()}, or a -\link[tidyselect:vars_select]{tidy selection specification} +\link[tidyselect:eval_select]{tidy selection specification} of columns, as used in \code{dplyr::select()}.} \item{as_data_frame}{Should the function return a \code{data.frame} (default) or diff --git a/r/man/read_parquet.Rd b/r/man/read_parquet.Rd index d509f8068e7..68e56903d14 100644 --- a/r/man/read_parquet.Rd +++ b/r/man/read_parquet.Rd @@ -21,7 +21,7 @@ open.} \item{col_select}{A character vector of column names to keep, as in the "select" argument to \code{data.table::fread()}, or a -\link[tidyselect:vars_select]{tidy selection specification} +\link[tidyselect:eval_select]{tidy selection specification} of columns, as used in \code{dplyr::select()}.} \item{as_data_frame}{Should the function return a \code{data.frame} (default) or diff --git a/r/src/arrowExports.cpp b/r/src/arrowExports.cpp index aa4fd01af49..336b12cb482 100644 --- a/r/src/arrowExports.cpp +++ b/r/src/arrowExports.cpp @@ -5101,6 +5101,14 @@ BEGIN_CPP11 END_CPP11 } // table.cpp +std::shared_ptr Table__from_schema(SEXP schema_sxp); +extern "C" SEXP _arrow_Table__from_schema(SEXP schema_sxp_sexp){ +BEGIN_CPP11 + arrow::r::Input::type schema_sxp(schema_sxp_sexp); + return cpp11::as_sexp(Table__from_schema(schema_sxp)); +END_CPP11 +} +// table.cpp r_vec_size Table__ReferencedBufferSize(const std::shared_ptr& table); extern "C" SEXP _arrow_Table__ReferencedBufferSize(SEXP table_sexp){ BEGIN_CPP11 @@ -5724,6 +5732,7 @@ static const R_CallMethodDef CallEntries[] = { { "_arrow_Table__SelectColumns", (DL_FUNC) &_arrow_Table__SelectColumns, 2}, { "_arrow_all_record_batches", (DL_FUNC) &_arrow_all_record_batches, 1}, { "_arrow_Table__from_record_batches", (DL_FUNC) &_arrow_Table__from_record_batches, 2}, + { "_arrow_Table__from_schema", (DL_FUNC) &_arrow_Table__from_schema, 1}, { "_arrow_Table__ReferencedBufferSize", (DL_FUNC) &_arrow_Table__ReferencedBufferSize, 1}, { "_arrow_Table__ConcatenateTables", (DL_FUNC) &_arrow_Table__ConcatenateTables, 2}, { "_arrow_GetCpuThreadPoolCapacity", (DL_FUNC) &_arrow_GetCpuThreadPoolCapacity, 0}, diff --git a/r/src/table.cpp b/r/src/table.cpp index f31aac33eff..062f85b719c 100644 --- a/r/src/table.cpp +++ b/r/src/table.cpp @@ -18,6 +18,7 @@ #include "./arrow_types.h" #include +#include #include #include #include @@ -302,6 +303,37 @@ std::shared_ptr Table__from_record_batches( return tab; } +// [[arrow::export]] +std::shared_ptr Table__from_schema(SEXP schema_sxp) { + auto schema = cpp11::as_cpp>(schema_sxp); + + int num_fields = schema->num_fields(); + + std::vector> columns; + + for (int i = 0; i < num_fields; i++) { + bool is_extension_type = schema->field(i)->type()->name() == "extension"; + std::shared_ptr type; + + // need to handle extension types a bit differently + if (is_extension_type) { + // TODO: ARROW-18043 - update this to properly construct extension types instead of + // converting to null + type = arrow::null(); + } else { + type = schema->field(i)->type(); + } + + std::shared_ptr array; + std::unique_ptr type_builder; + StopIfNotOk(arrow::MakeBuilder(gc_memory_pool(), type, &type_builder)); + StopIfNotOk(type_builder->Finish(&array)); + columns.push_back(array); + } + + return (arrow::Table::Make(schema, std::move(columns))); +} + // [[arrow::export]] r_vec_size Table__ReferencedBufferSize(const std::shared_ptr& table) { return r_vec_size(ValueOrStop(arrow::util::ReferencedBufferSize(*table))); diff --git a/r/tests/testthat/test-Table.R b/r/tests/testthat/test-Table.R index d2818943823..409df85f06d 100644 --- a/r/tests/testthat/test-Table.R +++ b/r/tests/testthat/test-Table.R @@ -693,3 +693,11 @@ test_that("num_rows method not susceptible to integer overflow", { expect_identical(big_string_array$data()$buffers[[3]]$size, 2148007936) }) + +test_that("can create empty table from schema", { + schema <- schema(col1 = float64(), col2 = string()) + out <- Table$create(schema = schema) + expect_r6_class(out, "Table") + expect_equal(nrow(out), 0) + expect_equal(out$schema, schema) +}) diff --git a/r/tests/testthat/test-dplyr-group-by.R b/r/tests/testthat/test-dplyr-group-by.R index 9bb6aa9600d..0c93f530faa 100644 --- a/r/tests/testthat/test-dplyr-group-by.R +++ b/r/tests/testthat/test-dplyr-group-by.R @@ -265,14 +265,10 @@ test_that("Can use across() within group_by()", { tbl ) - # ARROW-12778 - `where()` is not yet supported - expect_error( - compare_dplyr_binding( - .input %>% - group_by(across(where(is.numeric))) %>% - collect(), - tbl - ), - "Unsupported selection helper" + compare_dplyr_binding( + .input %>% + group_by(across(where(is.numeric))) %>% + collect(), + tbl ) }) diff --git a/r/tests/testthat/test-dplyr-mutate.R b/r/tests/testthat/test-dplyr-mutate.R index 86c243e5490..ee13c8be2e3 100644 --- a/r/tests/testthat/test-dplyr-mutate.R +++ b/r/tests/testthat/test-dplyr-mutate.R @@ -616,15 +616,11 @@ test_that("Can use across() within mutate()", { ) ) - # ARROW-12778 - `where()` is not yet supported - expect_error( - compare_dplyr_binding( - .input %>% - mutate(across(where(is.double))) %>% - collect(), - example_data - ), - "Unsupported selection helper" + compare_dplyr_binding( + .input %>% + mutate(across(where(is.double))) %>% + collect(), + example_data ) # gives the right error with window functions @@ -642,7 +638,6 @@ test_that("Can use across() within mutate()", { }) test_that("Can use across() within transmute()", { - compare_dplyr_binding( .input %>% transmute( @@ -654,5 +649,4 @@ test_that("Can use across() within transmute()", { collect(), example_data ) - }) diff --git a/r/tests/testthat/test-dplyr-select.R b/r/tests/testthat/test-dplyr-select.R index 98dcd6396d9..f71c4000442 100644 --- a/r/tests/testthat/test-dplyr-select.R +++ b/r/tests/testthat/test-dplyr-select.R @@ -87,15 +87,14 @@ test_that("select/rename/rename_with using selection helpers", { collect(), tbl ) - expect_error( - compare_dplyr_binding( - .input %>% - select(where(is.numeric)) %>% - collect(), - tbl - ), - "Unsupported selection helper" + + compare_dplyr_binding( + .input %>% + select(where(is.numeric)) %>% + collect(), + tbl ) + compare_dplyr_binding( .input %>% rename_with(toupper) %>% @@ -187,3 +186,32 @@ test_that("relocate with selection helpers", { df ) }) + +test_that("multiple select/rename and group_by", { + compare_dplyr_binding( + .input %>% + group_by(chr) %>% + rename(string = chr, dub = dbl2) %>% + rename(chr_actually = string) %>% + collect(), + tbl + ) + + compare_dplyr_binding( + .input %>% + group_by(chr) %>% + select(string = chr, dub = dbl2) %>% + rename(chr_actually = string) %>% + collect(), + tbl + ) + + compare_dplyr_binding( + .input %>% + group_by(chr) %>% + rename(string = chr, dub = dbl2) %>% + select(chr_actually = string) %>% + collect(), + tbl + ) +})