diff --git a/r/DESCRIPTION b/r/DESCRIPTION index 2e3b777c82f..5385877696e 100644 --- a/r/DESCRIPTION +++ b/r/DESCRIPTION @@ -109,6 +109,7 @@ Collate: 'dplyr-mutate.R' 'dplyr-select.R' 'dplyr-summarize.R' + 'dplyr-union.R' 'record-batch.R' 'table.R' 'dplyr.R' diff --git a/r/R/arrow-package.R b/r/R/arrow-package.R index 03a9f8a1161..7b59854f1e1 100644 --- a/r/R/arrow-package.R +++ b/r/R/arrow-package.R @@ -41,7 +41,7 @@ "group_vars", "group_by_drop_default", "ungroup", "mutate", "transmute", "arrange", "rename", "pull", "relocate", "compute", "collapse", "distinct", "left_join", "right_join", "inner_join", "full_join", - "semi_join", "anti_join", "count", "tally", "rename_with" + "semi_join", "anti_join", "count", "tally", "rename_with", "union", "union_all" ) ) for (cl in c("Dataset", "ArrowTabular", "RecordBatchReader", "arrow_dplyr_query")) { diff --git a/r/R/arrowExports.R b/r/R/arrowExports.R index 7c7b1f3cea2..3414c9b21c5 100644 --- a/r/R/arrowExports.R +++ b/r/R/arrowExports.R @@ -440,6 +440,10 @@ ExecNode_Join <- function(input, type, right_data, left_keys, right_keys, left_o .Call(`_arrow_ExecNode_Join`, input, type, right_data, left_keys, right_keys, left_output, right_output, output_suffix_for_left, output_suffix_for_right) } +ExecNode_Union <- function(input, right_data) { + .Call(`_arrow_ExecNode_Union`, input, right_data) +} + ExecNode_SourceNode <- function(plan, reader) { .Call(`_arrow_ExecNode_SourceNode`, plan, reader) } @@ -2003,4 +2007,3 @@ SetIOThreadPoolCapacity <- function(threads) { Array__infer_type <- function(x) { .Call(`_arrow_Array__infer_type`, x) } - diff --git a/r/R/dplyr-union.R b/r/R/dplyr-union.R new file mode 100644 index 00000000000..3252d4cecf0 --- /dev/null +++ b/r/R/dplyr-union.R @@ -0,0 +1,37 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# The following S3 methods are registered on load if dplyr is present + +union.arrow_dplyr_query <- function(x, y, ...) { + x <- as_adq(x) + y <- as_adq(y) + + distinct(union_all(x, y)) +} + +union.Dataset <- union.ArrowTabular <- union.RecordBatchReader <- union.arrow_dplyr_query + +union_all.arrow_dplyr_query <- function(x, y, ...) { + x <- as_adq(x) + y <- as_adq(y) + + x$union_all <- list(right_data = y) + collapse.arrow_dplyr_query(x) +} + +union_all.Dataset <- union_all.ArrowTabular <- union_all.RecordBatchReader <- union_all.arrow_dplyr_query diff --git a/r/R/query-engine.R b/r/R/query-engine.R index e4cc4197b34..4dfb5d3e9bb 100644 --- a/r/R/query-engine.R +++ b/r/R/query-engine.R @@ -138,6 +138,10 @@ ExecPlan <- R6Class("ExecPlan", right_suffix = .data$join$suffix[[2]] ) } + + if (!is.null(.data$union_all)) { + node <- node$UnionAll(self$Build(.data$union_all$right_data)) + } } # Apply sorting: this is currently not an ExecNode itself, it is a @@ -271,6 +275,9 @@ ExecNode <- R6Class("ExecNode", output_suffix_for_right = right_suffix ) ) + }, + UnionAll = function(right_node) { + self$preserve_sort(ExecNode_Union(self, right_node)) } ), active = list( diff --git a/r/src/arrowExports.cpp b/r/src/arrowExports.cpp index 312b778aada..1cfc2ddf2cc 100644 --- a/r/src/arrowExports.cpp +++ b/r/src/arrowExports.cpp @@ -991,6 +991,15 @@ BEGIN_CPP11 END_CPP11 } // compute-exec.cpp +std::shared_ptr ExecNode_Union(const std::shared_ptr& input, const std::shared_ptr& right_data); +extern "C" SEXP _arrow_ExecNode_Union(SEXP input_sexp, SEXP right_data_sexp){ +BEGIN_CPP11 + arrow::r::Input&>::type input(input_sexp); + arrow::r::Input&>::type right_data(right_data_sexp); + return cpp11::as_sexp(ExecNode_Union(input, right_data)); +END_CPP11 +} +// compute-exec.cpp std::shared_ptr ExecNode_SourceNode(const std::shared_ptr& plan, const std::shared_ptr& reader); extern "C" SEXP _arrow_ExecNode_SourceNode(SEXP plan_sexp, SEXP reader_sexp){ BEGIN_CPP11 @@ -5212,6 +5221,7 @@ static const R_CallMethodDef CallEntries[] = { { "_arrow_ExecNode_Project", (DL_FUNC) &_arrow_ExecNode_Project, 3}, { "_arrow_ExecNode_Aggregate", (DL_FUNC) &_arrow_ExecNode_Aggregate, 5}, { "_arrow_ExecNode_Join", (DL_FUNC) &_arrow_ExecNode_Join, 9}, + { "_arrow_ExecNode_Union", (DL_FUNC) &_arrow_ExecNode_Union, 2}, { "_arrow_ExecNode_SourceNode", (DL_FUNC) &_arrow_ExecNode_SourceNode, 2}, { "_arrow_ExecNode_TableSourceNode", (DL_FUNC) &_arrow_ExecNode_TableSourceNode, 2}, { "_arrow_substrait__internal__SubstraitToJSON", (DL_FUNC) &_arrow_substrait__internal__SubstraitToJSON, 1}, diff --git a/r/src/compute-exec.cpp b/r/src/compute-exec.cpp index a8ae3b03b06..d18c47260b9 100644 --- a/r/src/compute-exec.cpp +++ b/r/src/compute-exec.cpp @@ -309,6 +309,13 @@ std::shared_ptr ExecNode_Join( std::move(output_suffix_for_left), std::move(output_suffix_for_right)}); } +// [[arrow::export]] +std::shared_ptr ExecNode_Union( + const std::shared_ptr& input, + const std::shared_ptr& right_data) { + return MakeExecNodeOrStop("union", input->plan(), {input.get(), right_data.get()}, {}); +} + // [[arrow::export]] std::shared_ptr ExecNode_SourceNode( const std::shared_ptr& plan, diff --git a/r/tests/testthat/test-dplyr-union.R b/r/tests/testthat/test-dplyr-union.R new file mode 100644 index 00000000000..5cc6f8eea57 --- /dev/null +++ b/r/tests/testthat/test-dplyr-union.R @@ -0,0 +1,74 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +skip_if(on_old_windows()) + +library(dplyr, warn.conflicts = FALSE) + +withr::local_options(list(arrow.summarise.sort = FALSE)) + +test_that("union_all", { + compare_dplyr_binding( + .input %>% + union_all(example_data) %>% + collect(), + example_data + ) + + test_table <- arrow_table(x = 1:10) + + # Union with empty table produces same dataset + expect_equal( + test_table %>% + union_all(test_table$Slice(0, 0)) %>% + compute(), + test_table + ) + + expect_error( + test_table %>% + union_all(arrow_table(y = 1:10)) %>% + collect(), + regex = "input schemas must all match" + ) +}) + +test_that("union", { + compare_dplyr_binding( + .input %>% + dplyr::union(example_data) %>% + collect(), + example_data + ) + + test_table <- arrow_table(x = 1:10) + + # Union with empty table produces same dataset + expect_equal( + test_table %>% + dplyr::union(test_table$Slice(0, 0)) %>% + compute(), + test_table + ) + + expect_error( + test_table %>% + dplyr::union(arrow_table(y = 1:10)) %>% + collect(), + regex = "input schemas must all match" + ) +})