From cc9cf829a14d19e7a2f9a376e8647e94e056c8e5 Mon Sep 17 00:00:00 2001 From: Will Jones Date: Fri, 6 May 2022 16:11:39 -0700 Subject: [PATCH 1/7] Add union node --- r/R/arrow-package.R | 2 +- r/R/arrowExports.R | 5 +++- r/R/dplyr-union.R | 35 ++++++++++++++++++++++++++ r/R/query-engine.R | 7 ++++++ r/src/arrowExports.cpp | 10 ++++++++ r/src/compute-exec.cpp | 7 ++++++ r/tests/testthat/test-dplyr-union.R | 38 +++++++++++++++++++++++++++++ 7 files changed, 102 insertions(+), 2 deletions(-) create mode 100644 r/R/dplyr-union.R create mode 100644 r/tests/testthat/test-dplyr-union.R diff --git a/r/R/arrow-package.R b/r/R/arrow-package.R index 03a9f8a1161..7b59854f1e1 100644 --- a/r/R/arrow-package.R +++ b/r/R/arrow-package.R @@ -41,7 +41,7 @@ "group_vars", "group_by_drop_default", "ungroup", "mutate", "transmute", "arrange", "rename", "pull", "relocate", "compute", "collapse", "distinct", "left_join", "right_join", "inner_join", "full_join", - "semi_join", "anti_join", "count", "tally", "rename_with" + "semi_join", "anti_join", "count", "tally", "rename_with", "union", "union_all" ) ) for (cl in c("Dataset", "ArrowTabular", "RecordBatchReader", "arrow_dplyr_query")) { diff --git a/r/R/arrowExports.R b/r/R/arrowExports.R index 7c7b1f3cea2..3414c9b21c5 100644 --- a/r/R/arrowExports.R +++ b/r/R/arrowExports.R @@ -440,6 +440,10 @@ ExecNode_Join <- function(input, type, right_data, left_keys, right_keys, left_o .Call(`_arrow_ExecNode_Join`, input, type, right_data, left_keys, right_keys, left_output, right_output, output_suffix_for_left, output_suffix_for_right) } +ExecNode_Union <- function(input, right_data) { + .Call(`_arrow_ExecNode_Union`, input, right_data) +} + ExecNode_SourceNode <- function(plan, reader) { .Call(`_arrow_ExecNode_SourceNode`, plan, reader) } @@ -2003,4 +2007,3 @@ SetIOThreadPoolCapacity <- function(threads) { Array__infer_type <- function(x) { .Call(`_arrow_Array__infer_type`, x) } - diff --git a/r/R/dplyr-union.R b/r/R/dplyr-union.R new file mode 100644 index 00000000000..95335518b9f --- /dev/null +++ b/r/R/dplyr-union.R @@ -0,0 +1,35 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +union.arrow_dplyr_query <- function(x, y, ...) { + x <- as_adq(x) + y <- as_adq(y) + + distinct(union_all(x, y)) +} + +union.Dataset <- union.ArrowTabular <- union.RecordBatchReader <- union.arrow_dplyr_query + +union_all.arrow_dplyr_query <- function(x, y, ...) { + x <- as_adq(x) + y <- as_adq(y) + + x$union_all <- list(right_data = y) + collapse.arrow_dplyr_query(x) +} + +union_all.Dataset <- union_all.ArrowTabular <- union_all.RecordBatchReader <- union_all.arrow_dplyr_query \ No newline at end of file diff --git a/r/R/query-engine.R b/r/R/query-engine.R index e4cc4197b34..4dfb5d3e9bb 100644 --- a/r/R/query-engine.R +++ b/r/R/query-engine.R @@ -138,6 +138,10 @@ ExecPlan <- R6Class("ExecPlan", right_suffix = .data$join$suffix[[2]] ) } + + if (!is.null(.data$union_all)) { + node <- node$UnionAll(self$Build(.data$union_all$right_data)) + } } # Apply sorting: this is currently not an ExecNode itself, it is a @@ -271,6 +275,9 @@ ExecNode <- R6Class("ExecNode", output_suffix_for_right = right_suffix ) ) + }, + UnionAll = function(right_node) { + self$preserve_sort(ExecNode_Union(self, right_node)) } ), active = list( diff --git a/r/src/arrowExports.cpp b/r/src/arrowExports.cpp index 312b778aada..1cfc2ddf2cc 100644 --- a/r/src/arrowExports.cpp +++ b/r/src/arrowExports.cpp @@ -991,6 +991,15 @@ BEGIN_CPP11 END_CPP11 } // compute-exec.cpp +std::shared_ptr ExecNode_Union(const std::shared_ptr& input, const std::shared_ptr& right_data); +extern "C" SEXP _arrow_ExecNode_Union(SEXP input_sexp, SEXP right_data_sexp){ +BEGIN_CPP11 + arrow::r::Input&>::type input(input_sexp); + arrow::r::Input&>::type right_data(right_data_sexp); + return cpp11::as_sexp(ExecNode_Union(input, right_data)); +END_CPP11 +} +// compute-exec.cpp std::shared_ptr ExecNode_SourceNode(const std::shared_ptr& plan, const std::shared_ptr& reader); extern "C" SEXP _arrow_ExecNode_SourceNode(SEXP plan_sexp, SEXP reader_sexp){ BEGIN_CPP11 @@ -5212,6 +5221,7 @@ static const R_CallMethodDef CallEntries[] = { { "_arrow_ExecNode_Project", (DL_FUNC) &_arrow_ExecNode_Project, 3}, { "_arrow_ExecNode_Aggregate", (DL_FUNC) &_arrow_ExecNode_Aggregate, 5}, { "_arrow_ExecNode_Join", (DL_FUNC) &_arrow_ExecNode_Join, 9}, + { "_arrow_ExecNode_Union", (DL_FUNC) &_arrow_ExecNode_Union, 2}, { "_arrow_ExecNode_SourceNode", (DL_FUNC) &_arrow_ExecNode_SourceNode, 2}, { "_arrow_ExecNode_TableSourceNode", (DL_FUNC) &_arrow_ExecNode_TableSourceNode, 2}, { "_arrow_substrait__internal__SubstraitToJSON", (DL_FUNC) &_arrow_substrait__internal__SubstraitToJSON, 1}, diff --git a/r/src/compute-exec.cpp b/r/src/compute-exec.cpp index a8ae3b03b06..b94e346480c 100644 --- a/r/src/compute-exec.cpp +++ b/r/src/compute-exec.cpp @@ -309,6 +309,13 @@ std::shared_ptr ExecNode_Join( std::move(output_suffix_for_left), std::move(output_suffix_for_right)}); } +// [[arrow::export]] +std::shared_ptr ExecNode_Union( + const std::shared_ptr& input, + const std::shared_ptr& right_data) { + return MakeExecNodeOrStop("union", input->plan(), {input.get(), right_data.get()}, {}); +} + // [[arrow::export]] std::shared_ptr ExecNode_SourceNode( const std::shared_ptr& plan, diff --git a/r/tests/testthat/test-dplyr-union.R b/r/tests/testthat/test-dplyr-union.R new file mode 100644 index 00000000000..189ca9ed52c --- /dev/null +++ b/r/tests/testthat/test-dplyr-union.R @@ -0,0 +1,38 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +skip_if(on_old_windows()) + +library(dplyr, warn.conflicts = FALSE) + +test_that("union_all", { + compare_dplyr_binding( + .input %>% + union_all(example_data) %>% + collect(), + example_data + ) +}) + +test_that("union", { + compare_dplyr_binding( + .input %>% + dplyr::union(example_data) %>% + collect(), + example_data + ) +}) From 68982f2e17ab4f03d22fd52b00769a828ae02a09 Mon Sep 17 00:00:00 2001 From: Will Jones Date: Mon, 9 May 2022 08:04:28 -0700 Subject: [PATCH 2/7] Format Cpp --- r/src/compute-exec.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/r/src/compute-exec.cpp b/r/src/compute-exec.cpp index b94e346480c..d18c47260b9 100644 --- a/r/src/compute-exec.cpp +++ b/r/src/compute-exec.cpp @@ -311,8 +311,8 @@ std::shared_ptr ExecNode_Join( // [[arrow::export]] std::shared_ptr ExecNode_Union( - const std::shared_ptr& input, - const std::shared_ptr& right_data) { + const std::shared_ptr& input, + const std::shared_ptr& right_data) { return MakeExecNodeOrStop("union", input->plan(), {input.get(), right_data.get()}, {}); } From aca0f673a8a1c43e189bb5e6c895d99a8d339510 Mon Sep 17 00:00:00 2001 From: Will Jones Date: Mon, 9 May 2022 09:56:48 -0700 Subject: [PATCH 3/7] Add to collate --- r/DESCRIPTION | 1 + 1 file changed, 1 insertion(+) diff --git a/r/DESCRIPTION b/r/DESCRIPTION index 2e3b777c82f..5385877696e 100644 --- a/r/DESCRIPTION +++ b/r/DESCRIPTION @@ -109,6 +109,7 @@ Collate: 'dplyr-mutate.R' 'dplyr-select.R' 'dplyr-summarize.R' + 'dplyr-union.R' 'record-batch.R' 'table.R' 'dplyr.R' From 69d857c1fc0538a8f48a8437301066289d316eb8 Mon Sep 17 00:00:00 2001 From: Will Jones Date: Mon, 9 May 2022 10:58:28 -0700 Subject: [PATCH 4/7] Add more tests --- r/R/dplyr-union.R | 4 +++- r/tests/testthat/test-dplyr-union.R | 34 +++++++++++++++++++++++++++++ 2 files changed, 37 insertions(+), 1 deletion(-) diff --git a/r/R/dplyr-union.R b/r/R/dplyr-union.R index 95335518b9f..3252d4cecf0 100644 --- a/r/R/dplyr-union.R +++ b/r/R/dplyr-union.R @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. +# The following S3 methods are registered on load if dplyr is present + union.arrow_dplyr_query <- function(x, y, ...) { x <- as_adq(x) y <- as_adq(y) @@ -32,4 +34,4 @@ union_all.arrow_dplyr_query <- function(x, y, ...) { collapse.arrow_dplyr_query(x) } -union_all.Dataset <- union_all.ArrowTabular <- union_all.RecordBatchReader <- union_all.arrow_dplyr_query \ No newline at end of file +union_all.Dataset <- union_all.ArrowTabular <- union_all.RecordBatchReader <- union_all.arrow_dplyr_query diff --git a/r/tests/testthat/test-dplyr-union.R b/r/tests/testthat/test-dplyr-union.R index 189ca9ed52c..ee70ea91ff4 100644 --- a/r/tests/testthat/test-dplyr-union.R +++ b/r/tests/testthat/test-dplyr-union.R @@ -26,6 +26,23 @@ test_that("union_all", { collect(), example_data ) + + test_table <- arrow_table(x = 1:10) + + # Union with empty table produces same dataset + expect_equal( + test_table |> + union_all(test_table$Slice(0, 0)) |> + collect(test_table, as_data_frame = FALSE), + test_table + ) + + expect_error( + test_table |> + union_all(arrow_table(y = 1:10)) |> + collect(), + regex = "input schemas must all match" + ) }) test_that("union", { @@ -35,4 +52,21 @@ test_that("union", { collect(), example_data ) + + test_table <- arrow_table(x = 1:10) + + # Union with empty table produces same dataset + expect_equal( + test_table |> + dplyr::union(test_table$Slice(0, 0)) |> + collect(test_table, as_data_frame = FALSE), + test_table + ) + + expect_error( + test_table |> + dplyr::union(arrow_table(y = 1:10)) |> + collect(), + regex = "input schemas must all match" + ) }) From 5128fed20261e493db37fde968e5b4b97c7218f6 Mon Sep 17 00:00:00 2001 From: Will Jones Date: Mon, 9 May 2022 12:38:40 -0700 Subject: [PATCH 5/7] Use dplyr pipe --- r/tests/testthat/test-dplyr-union.R | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/r/tests/testthat/test-dplyr-union.R b/r/tests/testthat/test-dplyr-union.R index ee70ea91ff4..5875877d233 100644 --- a/r/tests/testthat/test-dplyr-union.R +++ b/r/tests/testthat/test-dplyr-union.R @@ -31,15 +31,15 @@ test_that("union_all", { # Union with empty table produces same dataset expect_equal( - test_table |> - union_all(test_table$Slice(0, 0)) |> + test_table %>% + union_all(test_table$Slice(0, 0)) %>% collect(test_table, as_data_frame = FALSE), test_table ) expect_error( - test_table |> - union_all(arrow_table(y = 1:10)) |> + test_table %>% + union_all(arrow_table(y = 1:10)) %>% collect(), regex = "input schemas must all match" ) @@ -57,15 +57,15 @@ test_that("union", { # Union with empty table produces same dataset expect_equal( - test_table |> - dplyr::union(test_table$Slice(0, 0)) |> + test_table %>% + dplyr::union(test_table$Slice(0, 0)) %>% collect(test_table, as_data_frame = FALSE), test_table ) expect_error( - test_table |> - dplyr::union(arrow_table(y = 1:10)) |> + test_table %>% + dplyr::union(arrow_table(y = 1:10)) %>% collect(), regex = "input schemas must all match" ) From 1640771944fb4dc685048246d87caa31806c2712 Mon Sep 17 00:00:00 2001 From: Will Jones Date: Mon, 9 May 2022 13:34:18 -0700 Subject: [PATCH 6/7] Make sure to set sort to FALSE --- r/tests/testthat/test-dplyr-union.R | 2 ++ 1 file changed, 2 insertions(+) diff --git a/r/tests/testthat/test-dplyr-union.R b/r/tests/testthat/test-dplyr-union.R index 5875877d233..f7c8a830697 100644 --- a/r/tests/testthat/test-dplyr-union.R +++ b/r/tests/testthat/test-dplyr-union.R @@ -19,6 +19,8 @@ skip_if(on_old_windows()) library(dplyr, warn.conflicts = FALSE) +withr::local_options(list(arrow.summarise.sort = FALSE)) + test_that("union_all", { compare_dplyr_binding( .input %>% From 009db683a2254bbed4b9623d55b0b218a12261dc Mon Sep 17 00:00:00 2001 From: Will Jones Date: Mon, 23 May 2022 12:48:49 -0700 Subject: [PATCH 7/7] fix: use compute() instead of collect() Co-authored-by: Neal Richardson --- r/tests/testthat/test-dplyr-union.R | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/r/tests/testthat/test-dplyr-union.R b/r/tests/testthat/test-dplyr-union.R index f7c8a830697..5cc6f8eea57 100644 --- a/r/tests/testthat/test-dplyr-union.R +++ b/r/tests/testthat/test-dplyr-union.R @@ -35,7 +35,7 @@ test_that("union_all", { expect_equal( test_table %>% union_all(test_table$Slice(0, 0)) %>% - collect(test_table, as_data_frame = FALSE), + compute(), test_table ) @@ -61,7 +61,7 @@ test_that("union", { expect_equal( test_table %>% dplyr::union(test_table$Slice(0, 0)) %>% - collect(test_table, as_data_frame = FALSE), + compute(), test_table )