From 6b4cf6425debf96f43de79fa7e0b54e1be5a9062 Mon Sep 17 00:00:00 2001 From: Neal Richardson Date: Tue, 14 Sep 2021 13:19:15 -0400 Subject: [PATCH 1/7] Construct a Join node --- r/DESCRIPTION | 1 + r/R/arrow-package.R | 2 +- r/R/arrowExports.R | 8 +++++ r/R/dplyr-collect.R | 7 ++++ r/R/dplyr-join.R | 54 ++++++++++++++++++++++++++++++ r/R/enums.R | 13 +++++++ r/R/query-engine.R | 27 ++++++++++++++- r/src/arrowExports.cpp | 38 +++++++++++++++++++++ r/src/compute-exec.cpp | 33 ++++++++++++++++++ r/tests/testthat/test-dplyr-join.R | 43 ++++++++++++++++++++++++ 10 files changed, 224 insertions(+), 2 deletions(-) create mode 100644 r/R/dplyr-join.R create mode 100644 r/tests/testthat/test-dplyr-join.R diff --git a/r/DESCRIPTION b/r/DESCRIPTION index 0dc44277ccd..89d94305dd1 100644 --- a/r/DESCRIPTION +++ b/r/DESCRIPTION @@ -92,6 +92,7 @@ Collate: 'expression.R' 'dplyr-functions.R' 'dplyr-group-by.R' + 'dplyr-join.R' 'dplyr-mutate.R' 'dplyr-select.R' 'dplyr-summarize.R' diff --git a/r/R/arrow-package.R b/r/R/arrow-package.R index 32f6c6ba05c..ff1229454e9 100644 --- a/r/R/arrow-package.R +++ b/r/R/arrow-package.R @@ -36,7 +36,7 @@ "select", "filter", "collect", "summarise", "group_by", "groups", "group_vars", "group_by_drop_default", "ungroup", "mutate", "transmute", "arrange", "rename", "pull", "relocate", "compute", "collapse", - "distinct" + "distinct", "left_join" ) ) for (cl in c("Dataset", "ArrowTabular", "arrow_dplyr_query")) { diff --git a/r/R/arrowExports.R b/r/R/arrowExports.R index b852a3d8ca9..a67a9306bc2 100644 --- a/r/R/arrowExports.R +++ b/r/R/arrowExports.R @@ -284,6 +284,10 @@ ExecPlan_run <- function(plan, final_node, sort_options) { .Call(`_arrow_ExecPlan_run`, plan, final_node, sort_options) } +ExecNode_output_schema <- function(node) { + .Call(`_arrow_ExecNode_output_schema`, node) +} + ExecNode_Scan <- function(plan, dataset, filter, materialized_field_names) { .Call(`_arrow_ExecNode_Scan`, plan, dataset, filter, materialized_field_names) } @@ -300,6 +304,10 @@ ExecNode_Aggregate <- function(input, options, target_names, out_field_names, ke .Call(`_arrow_ExecNode_Aggregate`, input, options, target_names, out_field_names, key_names) } +ExecNode_Join <- function(input, type, right_data, left_keys, right_keys, left_output, right_output) { + .Call(`_arrow_ExecNode_Join`, input, type, right_data, left_keys, right_keys, left_output, right_output) +} + RecordBatch__cast <- function(batch, schema, options) { .Call(`_arrow_RecordBatch__cast`, batch, schema, options) } diff --git a/r/R/dplyr-collect.R b/r/R/dplyr-collect.R index 06914abe072..17ccac265e6 100644 --- a/r/R/dplyr-collect.R +++ b/r/R/dplyr-collect.R @@ -89,6 +89,13 @@ implicit_schema <- function(.data) { if (is.null(.data$aggregations)) { new_fields <- map(.data$selected_columns, ~ .$type(old_schm)) + if (!is.null(.data$join)) { + right_cols <- .data$join$right_data$selected_columns + new_fields <- c(new_fields, map( + right_cols[setdiff(names(right_cols), .data$join$by)], + ~ .$type(.data$join$right_data$.data$schema) + )) + } } else { new_fields <- map(summarize_projection(.data), ~ .$type(old_schm)) # * Put group_by_vars first (this can't be done by summarize, diff --git a/r/R/dplyr-join.R b/r/R/dplyr-join.R new file mode 100644 index 00000000000..799ac9e9afa --- /dev/null +++ b/r/R/dplyr-join.R @@ -0,0 +1,54 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +# The following S3 methods are registered on load if dplyr is present + +left_join.arrow_dplyr_query <- function(x, + y, + by = NULL, + copy = FALSE, + suffix = c(".x", ".y"), + ..., + keep = FALSE) { + x <- as_adq(x) + y <- as_adq(y) + by <- handle_join_by(by, x, y) + + x$join <- list( + type = JoinType[["LEFT_OUTER"]], + right_data = y, + by = by + ) + collapse.arrow_dplyr_query(x) +} +left_join.Dataset <- left_join.ArrowTabular <- left_join.arrow_dplyr_query + +handle_join_by <- function(by, x, y) { + if (is.null(by)) { + return(set_names(intersect(names(x), names(y)))) + } + stopifnot(is.character(by)) + if (is.null(names(by))) { + by <- set_names(by) + } + stopifnot( + all(names(by) %in% names(x)), + all(by %in% names(y)) + ) + by +} \ No newline at end of file diff --git a/r/R/enums.R b/r/R/enums.R index d9cb3a55c1d..4e69b7a190e 100644 --- a/r/R/enums.R +++ b/r/R/enums.R @@ -163,3 +163,16 @@ RoundMode <- enum("RoundMode", HALF_TO_EVEN = 8L, HALF_TO_ODD = 9L ) + +#' @export +#' @rdname enums +JoinType <- enum("JoinType", + LEFT_SEMI = 0L, + RIGHT_SEMI = 1L, + LEFT_ANTI = 2L, + RIGHT_ANTI = 3L, + INNER = 4L, + LEFT_OUTER = 5L, + RIGHT_OUTER = 6L, + FULL_OUTER = 7L +) diff --git a/r/R/query-engine.R b/r/R/query-engine.R index 8f8514ad1cb..18d183a9688 100644 --- a/r/R/query-engine.R +++ b/r/R/query-engine.R @@ -143,7 +143,15 @@ ExecPlan <- R6Class("ExecPlan", ) } } - } else { + } else if (!is.null(.data$join)) { + node <- node$Join( + type = .data$join$type, + right_node = self$Build(.data$join$right_data), + by = .data$join$by, + left_output = names(.data), + right_output = setdiff(names(.data$join$right_data), .data$join$by) + ) + } else if (length(node$schema)) { # If any columns are derived, reordered, or renamed we need to Project # If there are aggregations, the projection was already handled above # We have to project at least once to eliminate some junk columns @@ -187,6 +195,7 @@ ExecNode <- R6Class("ExecNode", # in the SinkNode (in ExecPlan$run()) sort = NULL, preserve_sort = function(new_node) { + print(new_node$schema) new_node$sort <- self$sort new_node }, @@ -206,6 +215,22 @@ ExecNode <- R6Class("ExecNode", self$preserve_sort( ExecNode_Aggregate(self, options, target_names, out_field_names, key_names) ) + }, + Join = function(type, right_node, by, left_output, right_output) { + self$preserve_sort( + ExecNode_Join( + self, + type, + right_node, + left_keys = names(by), + right_keys = by, + left_output = left_output, + right_output = right_output + ) + ) } + ), + active = list( + schema = function() ExecNode_output_schema(self) ) ) diff --git a/r/src/arrowExports.cpp b/r/src/arrowExports.cpp index 0e6aebeca2c..9ad145b211a 100644 --- a/r/src/arrowExports.cpp +++ b/r/src/arrowExports.cpp @@ -1109,6 +1109,21 @@ extern "C" SEXP _arrow_ExecPlan_run(SEXP plan_sexp, SEXP final_node_sexp, SEXP s } #endif +// compute-exec.cpp +#if defined(ARROW_R_WITH_ARROW) +std::shared_ptr ExecNode_output_schema(const std::shared_ptr& node); +extern "C" SEXP _arrow_ExecNode_output_schema(SEXP node_sexp){ +BEGIN_CPP11 + arrow::r::Input&>::type node(node_sexp); + return cpp11::as_sexp(ExecNode_output_schema(node)); +END_CPP11 +} +#else +extern "C" SEXP _arrow_ExecNode_output_schema(SEXP node_sexp){ + Rf_error("Cannot call ExecNode_output_schema(). See https://arrow.apache.org/docs/r/articles/install.html for help installing Arrow C++ libraries. "); +} +#endif + // compute-exec.cpp #if defined(ARROW_R_WITH_DATASET) std::shared_ptr ExecNode_Scan(const std::shared_ptr& plan, const std::shared_ptr& dataset, const std::shared_ptr& filter, std::vector materialized_field_names); @@ -1179,6 +1194,27 @@ extern "C" SEXP _arrow_ExecNode_Aggregate(SEXP input_sexp, SEXP options_sexp, SE } #endif +// compute-exec.cpp +#if defined(ARROW_R_WITH_ARROW) +std::shared_ptr ExecNode_Join(const std::shared_ptr& input, int type, const std::shared_ptr& right_data, std::vector left_keys, std::vector right_keys, std::vector left_output, std::vector right_output); +extern "C" SEXP _arrow_ExecNode_Join(SEXP input_sexp, SEXP type_sexp, SEXP right_data_sexp, SEXP left_keys_sexp, SEXP right_keys_sexp, SEXP left_output_sexp, SEXP right_output_sexp){ +BEGIN_CPP11 + arrow::r::Input&>::type input(input_sexp); + arrow::r::Input::type type(type_sexp); + arrow::r::Input&>::type right_data(right_data_sexp); + arrow::r::Input>::type left_keys(left_keys_sexp); + arrow::r::Input>::type right_keys(right_keys_sexp); + arrow::r::Input>::type left_output(left_output_sexp); + arrow::r::Input>::type right_output(right_output_sexp); + return cpp11::as_sexp(ExecNode_Join(input, type, right_data, left_keys, right_keys, left_output, right_output)); +END_CPP11 +} +#else +extern "C" SEXP _arrow_ExecNode_Join(SEXP input_sexp, SEXP type_sexp, SEXP right_data_sexp, SEXP left_keys_sexp, SEXP right_keys_sexp, SEXP left_output_sexp, SEXP right_output_sexp){ + Rf_error("Cannot call ExecNode_Join(). See https://arrow.apache.org/docs/r/articles/install.html for help installing Arrow C++ libraries. "); +} +#endif + // compute.cpp #if defined(ARROW_R_WITH_ARROW) std::shared_ptr RecordBatch__cast(const std::shared_ptr& batch, const std::shared_ptr& schema, cpp11::list options); @@ -7123,10 +7159,12 @@ static const R_CallMethodDef CallEntries[] = { { "_arrow_io___CompressedInputStream__Make", (DL_FUNC) &_arrow_io___CompressedInputStream__Make, 2}, { "_arrow_ExecPlan_create", (DL_FUNC) &_arrow_ExecPlan_create, 1}, { "_arrow_ExecPlan_run", (DL_FUNC) &_arrow_ExecPlan_run, 3}, + { "_arrow_ExecNode_output_schema", (DL_FUNC) &_arrow_ExecNode_output_schema, 1}, { "_arrow_ExecNode_Scan", (DL_FUNC) &_arrow_ExecNode_Scan, 4}, { "_arrow_ExecNode_Filter", (DL_FUNC) &_arrow_ExecNode_Filter, 2}, { "_arrow_ExecNode_Project", (DL_FUNC) &_arrow_ExecNode_Project, 3}, { "_arrow_ExecNode_Aggregate", (DL_FUNC) &_arrow_ExecNode_Aggregate, 5}, + { "_arrow_ExecNode_Join", (DL_FUNC) &_arrow_ExecNode_Join, 7}, { "_arrow_RecordBatch__cast", (DL_FUNC) &_arrow_RecordBatch__cast, 3}, { "_arrow_Table__cast", (DL_FUNC) &_arrow_Table__cast, 3}, { "_arrow_compute__CallFunction", (DL_FUNC) &_arrow_compute__CallFunction, 3}, diff --git a/r/src/compute-exec.cpp b/r/src/compute-exec.cpp index 9404e016c98..aa6498b41e8 100644 --- a/r/src/compute-exec.cpp +++ b/r/src/compute-exec.cpp @@ -98,6 +98,12 @@ std::shared_ptr ExecPlan_run( #include +// [[arrow::export]] +std::shared_ptr ExecNode_output_schema( + const std::shared_ptr& node) { + return node->output_schema(); +} + // [[dataset::export]] std::shared_ptr ExecNode_Scan( const std::shared_ptr& plan, @@ -187,4 +193,31 @@ std::shared_ptr ExecNode_Aggregate( std::move(out_field_names), std::move(keys)}); } +// [[arrow::export]] +std::shared_ptr ExecNode_Join( + const std::shared_ptr& input, int type, + const std::shared_ptr& right_data, + std::vector left_keys, std::vector right_keys, + std::vector left_output, std::vector right_output) { + std::vector left_refs, right_refs, left_out_refs, right_out_refs; + for (auto&& name : left_keys) { + left_refs.emplace_back(std::move(name)); + } + for (auto&& name : right_keys) { + right_refs.emplace_back(std::move(name)); + } + for (auto&& name : left_output) { + left_out_refs.emplace_back(std::move(name)); + } + for (auto&& name : right_output) { + right_out_refs.emplace_back(std::move(name)); + } + + return MakeExecNodeOrStop( + "hashjoin", input->plan(), {input.get(), right_data.get()}, + compute::HashJoinNodeOptions{compute::JoinType::LEFT_OUTER, std::move(left_refs), + std::move(right_refs), std::move(left_out_refs), + std::move(right_out_refs)}); +} + #endif diff --git a/r/tests/testthat/test-dplyr-join.R b/r/tests/testthat/test-dplyr-join.R new file mode 100644 index 00000000000..545a9f02dfe --- /dev/null +++ b/r/tests/testthat/test-dplyr-join.R @@ -0,0 +1,43 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +skip_if_not_available("dataset") + +library(dplyr) + +left <- example_data +# Error: Invalid: Dictionary type support for join output field is not yet implemented, output field reference: FieldRef.Name(fct) on left side of the join +# (and select(-fct) doesn't solve this somehow) +left$fct <- NULL +left$some_grouping <- rep(c(1, 2), 5) + +to_join <- tibble::tibble( + # Error: Invalid: Output field name collision in join, name: some_grouping + # (so call it something else) + the_grouping = c(1, 2), + capital_letters = c("A", "B"), + another_column = TRUE +) + +test_that("left_join", { + expect_dplyr_equal( + input %>% + left_join(to_join, by = c(some_grouping = "the_grouping")) %>% + collect(), + left + ) +}) \ No newline at end of file From f4ab4eb1edeb256210f90dd3937d05c5b02b833b Mon Sep 17 00:00:00 2001 From: Neal Richardson Date: Mon, 20 Sep 2021 14:35:04 -0400 Subject: [PATCH 2/7] Add all of the other join functions --- r/R/arrow-package.R | 3 +- r/R/dplyr-collect.R | 3 +- r/R/dplyr-join.R | 89 +++++++++++++++++++++++++++--- r/R/query-engine.R | 1 - r/src/compute-exec.cpp | 36 ++++++++++-- r/tests/testthat/test-dplyr-join.R | 77 ++++++++++++++++++++++++-- 6 files changed, 186 insertions(+), 23 deletions(-) diff --git a/r/R/arrow-package.R b/r/R/arrow-package.R index ff1229454e9..807aea207b7 100644 --- a/r/R/arrow-package.R +++ b/r/R/arrow-package.R @@ -36,7 +36,8 @@ "select", "filter", "collect", "summarise", "group_by", "groups", "group_vars", "group_by_drop_default", "ungroup", "mutate", "transmute", "arrange", "rename", "pull", "relocate", "compute", "collapse", - "distinct", "left_join" + "distinct", "left_join", "right_join", "inner_join", "full_join", + "semi_join", "anti_join" ) ) for (cl in c("Dataset", "ArrowTabular", "arrow_dplyr_query")) { diff --git a/r/R/dplyr-collect.R b/r/R/dplyr-collect.R index 17ccac265e6..519d6421625 100644 --- a/r/R/dplyr-collect.R +++ b/r/R/dplyr-collect.R @@ -89,7 +89,8 @@ implicit_schema <- function(.data) { if (is.null(.data$aggregations)) { new_fields <- map(.data$selected_columns, ~ .$type(old_schm)) - if (!is.null(.data$join)) { + if (!is.null(.data$join) && !(.data$join$type %in% JoinType[1:4])) { + # Add cols from right side, except for semi/anti joins right_cols <- .data$join$right_data$selected_columns new_fields <- c(new_fields, map( right_cols[setdiff(names(right_cols), .data$join$by)], diff --git a/r/R/dplyr-join.R b/r/R/dplyr-join.R index 799ac9e9afa..475fd819834 100644 --- a/r/R/dplyr-join.R +++ b/r/R/dplyr-join.R @@ -18,26 +18,97 @@ # The following S3 methods are registered on load if dplyr is present -left_join.arrow_dplyr_query <- function(x, - y, - by = NULL, - copy = FALSE, - suffix = c(".x", ".y"), - ..., - keep = FALSE) { +do_join <- function(x, + y, + by = NULL, + copy = FALSE, + suffix = c(".x", ".y"), + ..., + keep = FALSE, + na_matches, + join_type) { + # TODO: handle `copy` arg: ignore? + # TODO: handle `suffix` arg: Arrow does prefix + # TODO: handle `keep` arg: "Should the join keys from both ‘x’ and ‘y’ be preserved in the output?" + # TODO: handle `na_matches` arg x <- as_adq(x) y <- as_adq(y) by <- handle_join_by(by, x, y) x$join <- list( - type = JoinType[["LEFT_OUTER"]], + type = JoinType[[join_type]], right_data = y, by = by ) collapse.arrow_dplyr_query(x) } + +left_join.arrow_dplyr_query <- function(x, + y, + by = NULL, + copy = FALSE, + suffix = c(".x", ".y"), + ..., + keep = FALSE) { + do_join(x, y, by, copy, suffix, ..., keep = keep, join_type = "LEFT_OUTER") +} left_join.Dataset <- left_join.ArrowTabular <- left_join.arrow_dplyr_query +right_join.arrow_dplyr_query <- function(x, + y, + by = NULL, + copy = FALSE, + suffix = c(".x", ".y"), + ..., + keep = FALSE) { + do_join(x, y, by, copy, suffix, ..., keep = keep, join_type = "RIGHT_OUTER") +} +right_join.Dataset <- right_join.ArrowTabular <- right_join.arrow_dplyr_query + +inner_join.arrow_dplyr_query <- function(x, + y, + by = NULL, + copy = FALSE, + suffix = c(".x", ".y"), + ..., + keep = FALSE) { + do_join(x, y, by, copy, suffix, ..., keep = keep, join_type = "INNER") +} +inner_join.Dataset <- inner_join.ArrowTabular <- inner_join.arrow_dplyr_query + +full_join.arrow_dplyr_query <- function(x, + y, + by = NULL, + copy = FALSE, + suffix = c(".x", ".y"), + ..., + keep = FALSE) { + do_join(x, y, by, copy, suffix, ..., keep = keep, join_type = "FULL_OUTER") +} +full_join.Dataset <- full_join.ArrowTabular <- full_join.arrow_dplyr_query + +semi_join.arrow_dplyr_query <- function(x, + y, + by = NULL, + copy = FALSE, + suffix = c(".x", ".y"), + ..., + keep = FALSE) { + do_join(x, y, by, copy, suffix, ..., keep = keep, join_type = "LEFT_SEMI") +} +semi_join.Dataset <- semi_join.ArrowTabular <- semi_join.arrow_dplyr_query + +anti_join.arrow_dplyr_query <- function(x, + y, + by = NULL, + copy = FALSE, + suffix = c(".x", ".y"), + ..., + keep = FALSE) { + do_join(x, y, by, copy, suffix, ..., keep = keep, join_type = "LEFT_ANTI") +} +anti_join.Dataset <- anti_join.ArrowTabular <- anti_join.arrow_dplyr_query + handle_join_by <- function(by, x, y) { if (is.null(by)) { return(set_names(intersect(names(x), names(y)))) @@ -51,4 +122,4 @@ handle_join_by <- function(by, x, y) { all(by %in% names(y)) ) by -} \ No newline at end of file +} diff --git a/r/R/query-engine.R b/r/R/query-engine.R index 18d183a9688..ccd3ee9832a 100644 --- a/r/R/query-engine.R +++ b/r/R/query-engine.R @@ -195,7 +195,6 @@ ExecNode <- R6Class("ExecNode", # in the SinkNode (in ExecPlan$run()) sort = NULL, preserve_sort = function(new_node) { - print(new_node$schema) new_node$sort <- self$sort new_node }, diff --git a/r/src/compute-exec.cpp b/r/src/compute-exec.cpp index aa6498b41e8..95033d4c124 100644 --- a/r/src/compute-exec.cpp +++ b/r/src/compute-exec.cpp @@ -209,15 +209,41 @@ std::shared_ptr ExecNode_Join( for (auto&& name : left_output) { left_out_refs.emplace_back(std::move(name)); } - for (auto&& name : right_output) { - right_out_refs.emplace_back(std::move(name)); + if (type != 0 && type != 2) { + // Don't include out_refs in semi/anti join + for (auto&& name : right_output) { + right_out_refs.emplace_back(std::move(name)); + } + } + + // TODO: we should be able to use this enum directly + compute::JoinType join_type; + if (type == 0) { + join_type = compute::JoinType::LEFT_SEMI; + } else if (type == 1) { + // Not readily called from R bc dplyr::semi_join is LEFT_SEMI + join_type = compute::JoinType::RIGHT_SEMI; + } else if (type == 2) { + join_type = compute::JoinType::LEFT_ANTI; + } else if (type == 3) { + // Not readily called from R bc dplyr::semi_join is LEFT_SEMI + join_type = compute::JoinType::RIGHT_ANTI; + } else if (type == 4) { + join_type = compute::JoinType::INNER; + } else if (type == 5) { + join_type = compute::JoinType::LEFT_OUTER; + } else if (type == 6) { + join_type = compute::JoinType::RIGHT_OUTER; + } else if (type == 7) { + join_type = compute::JoinType::FULL_OUTER; + } else { + cpp11::stop("todo"); } return MakeExecNodeOrStop( "hashjoin", input->plan(), {input.get(), right_data.get()}, - compute::HashJoinNodeOptions{compute::JoinType::LEFT_OUTER, std::move(left_refs), - std::move(right_refs), std::move(left_out_refs), - std::move(right_out_refs)}); + compute::HashJoinNodeOptions{join_type, std::move(left_refs), std::move(right_refs), + std::move(left_out_refs), std::move(right_out_refs)}); } #endif diff --git a/r/tests/testthat/test-dplyr-join.R b/r/tests/testthat/test-dplyr-join.R index 545a9f02dfe..f6f98420150 100644 --- a/r/tests/testthat/test-dplyr-join.R +++ b/r/tests/testthat/test-dplyr-join.R @@ -21,14 +21,12 @@ library(dplyr) left <- example_data # Error: Invalid: Dictionary type support for join output field is not yet implemented, output field reference: FieldRef.Name(fct) on left side of the join -# (and select(-fct) doesn't solve this somehow) +# (select(-fct) also solves this but remove once) left$fct <- NULL left$some_grouping <- rep(c(1, 2), 5) to_join <- tibble::tibble( - # Error: Invalid: Output field name collision in join, name: some_grouping - # (so call it something else) - the_grouping = c(1, 2), + some_grouping = c(1, 2), capital_letters = c("A", "B"), another_column = TRUE ) @@ -36,8 +34,75 @@ to_join <- tibble::tibble( test_that("left_join", { expect_dplyr_equal( input %>% - left_join(to_join, by = c(some_grouping = "the_grouping")) %>% + left_join(to_join) %>% collect(), left ) -}) \ No newline at end of file +}) + +test_that("left_join `by` args", { + expect_dplyr_equal( + input %>% + left_join(to_join, by = "some_grouping") %>% + collect(), + left + ) + expect_dplyr_equal( + input %>% + left_join( + to_join %>% + rename(the_grouping = some_grouping), + by = c(some_grouping = "the_grouping") + ) %>% + collect(), + left + ) +}) + +# TODO: test duplicate col names +# TODO: test invalid 'by' + +test_that("right_join", { + expect_dplyr_equal( + input %>% + right_join(to_join) %>% + collect(), + left + ) +}) + +test_that("inner_join", { + expect_dplyr_equal( + input %>% + inner_join(to_join) %>% + collect(), + left + ) +}) + +test_that("full_join", { + expect_dplyr_equal( + input %>% + full_join(to_join) %>% + collect(), + left + ) +}) + +test_that("semi_join", { + expect_dplyr_equal( + input %>% + semi_join(to_join) %>% + collect(), + left + ) +}) + +test_that("anti_join", { + expect_dplyr_equal( + input %>% + anti_join(to_join) %>% + collect(), + left + ) +}) From c922dff44ba8326b729cb57c46c5707c7b696e90 Mon Sep 17 00:00:00 2001 From: Neal Richardson Date: Wed, 29 Sep 2021 06:02:09 -0700 Subject: [PATCH 3/7] Apply suggestions from code review Co-authored-by: Jonathan Keane --- r/tests/testthat/test-dplyr-join.R | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/r/tests/testthat/test-dplyr-join.R b/r/tests/testthat/test-dplyr-join.R index f6f98420150..a96c5556d8a 100644 --- a/r/tests/testthat/test-dplyr-join.R +++ b/r/tests/testthat/test-dplyr-join.R @@ -20,7 +20,9 @@ skip_if_not_available("dataset") library(dplyr) left <- example_data -# Error: Invalid: Dictionary type support for join output field is not yet implemented, output field reference: FieldRef.Name(fct) on left side of the join +# Error: Invalid: Dictionary type support for join output field +# is not yet implemented, output field reference: FieldRef.Name(fct) +# on left side of the join # (select(-fct) also solves this but remove once) left$fct <- NULL left$some_grouping <- rep(c(1, 2), 5) @@ -49,10 +51,10 @@ test_that("left_join `by` args", { ) expect_dplyr_equal( input %>% + rename(the_grouping = some_grouping) %>% left_join( - to_join %>% - rename(the_grouping = some_grouping), - by = c(some_grouping = "the_grouping") + to_join + by = c(the_grouping = "some_grouping") ) %>% collect(), left From 1da46cda9d8d96ee35298075011ad5716f1f6630 Mon Sep 17 00:00:00 2001 From: Jonathan Keane Date: Thu, 30 Sep 2021 11:38:39 -0500 Subject: [PATCH 4/7] typo, docs --- r/NAMESPACE | 1 + r/man/enums.Rd | 5 +++++ r/tests/testthat/test-dplyr-join.R | 2 +- 3 files changed, 7 insertions(+), 1 deletion(-) diff --git a/r/NAMESPACE b/r/NAMESPACE index e90fcdf3451..c909ad4a938 100644 --- a/r/NAMESPACE +++ b/r/NAMESPACE @@ -142,6 +142,7 @@ export(HivePartitioning) export(HivePartitioningFactory) export(InMemoryDataset) export(IpcFileFormat) +export(JoinType) export(JsonParseOptions) export(JsonReadOptions) export(JsonTableReader) diff --git a/r/man/enums.Rd b/r/man/enums.Rd index e21a9f4d9cb..7ec126a0198 100644 --- a/r/man/enums.Rd +++ b/r/man/enums.Rd @@ -17,6 +17,7 @@ \alias{NullEncodingBehavior} \alias{NullHandlingBehavior} \alias{RoundMode} +\alias{JoinType} \title{Arrow enums} \format{ An object of class \code{TimeUnit::type} (inherits from \code{arrow-enum}) of length 4. @@ -46,6 +47,8 @@ An object of class \code{NullEncodingBehavior} (inherits from \code{arrow-enum}) An object of class \code{NullHandlingBehavior} (inherits from \code{arrow-enum}) of length 3. An object of class \code{RoundMode} (inherits from \code{arrow-enum}) of length 10. + +An object of class \code{JoinType} (inherits from \code{arrow-enum}) of length 8. } \usage{ TimeUnit @@ -75,6 +78,8 @@ NullEncodingBehavior NullHandlingBehavior RoundMode + +JoinType } \description{ Arrow enums diff --git a/r/tests/testthat/test-dplyr-join.R b/r/tests/testthat/test-dplyr-join.R index a96c5556d8a..ebe5b33c37a 100644 --- a/r/tests/testthat/test-dplyr-join.R +++ b/r/tests/testthat/test-dplyr-join.R @@ -53,7 +53,7 @@ test_that("left_join `by` args", { input %>% rename(the_grouping = some_grouping) %>% left_join( - to_join + to_join, by = c(the_grouping = "some_grouping") ) %>% collect(), From 783623d8c1fdb0e10501c4511649b5755aa963b8 Mon Sep 17 00:00:00 2001 From: Jonathan Keane Date: Thu, 30 Sep 2021 12:40:47 -0500 Subject: [PATCH 5/7] Some cleanup --- r/R/dplyr-join.R | 1 + r/tests/testthat/test-dplyr-join.R | 51 +++++++++++++++++++++++------- 2 files changed, 41 insertions(+), 11 deletions(-) diff --git a/r/R/dplyr-join.R b/r/R/dplyr-join.R index 475fd819834..c14b1a8f3dd 100644 --- a/r/R/dplyr-join.R +++ b/r/R/dplyr-join.R @@ -117,6 +117,7 @@ handle_join_by <- function(by, x, y) { if (is.null(names(by))) { by <- set_names(by) } + # TODO: nicer messages? stopifnot( all(names(by) %in% names(x)), all(by %in% names(y)) diff --git a/r/tests/testthat/test-dplyr-join.R b/r/tests/testthat/test-dplyr-join.R index ebe5b33c37a..345c37942da 100644 --- a/r/tests/testthat/test-dplyr-join.R +++ b/r/tests/testthat/test-dplyr-join.R @@ -34,12 +34,16 @@ to_join <- tibble::tibble( ) test_that("left_join", { - expect_dplyr_equal( - input %>% - left_join(to_join) %>% - collect(), - left + expect_message( + expect_dplyr_equal( + input %>% + left_join(to_join) %>% + collect(), + left + ), + 'Joining, by = "some_grouping"' ) + }) test_that("left_join `by` args", { @@ -49,6 +53,19 @@ test_that("left_join `by` args", { collect(), left ) + expect_dplyr_equal( + input %>% + left_join( + to_join %>% + rename(the_grouping = some_grouping), + by = c(some_grouping = "the_grouping") + ) %>% + collect(), + left + ) + + # TODO: allow renaming columns on the right side as well + skip("ARROW-XXX") expect_dplyr_equal( input %>% rename(the_grouping = some_grouping) %>% @@ -61,13 +78,25 @@ test_that("left_join `by` args", { ) }) + +test_that("Error handling", { + tab <- Table$create(left) + + expect_error( + tab %>% + left_join(to_join, by = "not_a_col") %>% + collect(), + "all(names(by) %in% names(x)) is not TRUE", + fixed = TRUE + ) +}) + # TODO: test duplicate col names -# TODO: test invalid 'by' test_that("right_join", { expect_dplyr_equal( input %>% - right_join(to_join) %>% + right_join(to_join, by = "some_grouping") %>% collect(), left ) @@ -76,7 +105,7 @@ test_that("right_join", { test_that("inner_join", { expect_dplyr_equal( input %>% - inner_join(to_join) %>% + inner_join(to_join, by = "some_grouping") %>% collect(), left ) @@ -85,7 +114,7 @@ test_that("inner_join", { test_that("full_join", { expect_dplyr_equal( input %>% - full_join(to_join) %>% + full_join(to_join, by = "some_grouping") %>% collect(), left ) @@ -94,7 +123,7 @@ test_that("full_join", { test_that("semi_join", { expect_dplyr_equal( input %>% - semi_join(to_join) %>% + semi_join(to_join, by = "some_grouping") %>% collect(), left ) @@ -103,7 +132,7 @@ test_that("semi_join", { test_that("anti_join", { expect_dplyr_equal( input %>% - anti_join(to_join) %>% + anti_join(to_join, by = "some_grouping") %>% collect(), left ) From c3456983681bfa7654c5530a254b0bbcecf6f0b4 Mon Sep 17 00:00:00 2001 From: Jonathan Keane Date: Thu, 30 Sep 2021 12:56:01 -0500 Subject: [PATCH 6/7] Add in follow on tickets --- r/tests/testthat/test-dplyr-join.R | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/r/tests/testthat/test-dplyr-join.R b/r/tests/testthat/test-dplyr-join.R index 345c37942da..b189bb83af5 100644 --- a/r/tests/testthat/test-dplyr-join.R +++ b/r/tests/testthat/test-dplyr-join.R @@ -27,11 +27,16 @@ left <- example_data left$fct <- NULL left$some_grouping <- rep(c(1, 2), 5) +left_tab <- Table$create(left) + to_join <- tibble::tibble( some_grouping = c(1, 2), capital_letters = c("A", "B"), another_column = TRUE ) +to_join_tab <- Table$create(to_join) + + test_that("left_join", { expect_message( @@ -43,7 +48,6 @@ test_that("left_join", { ), 'Joining, by = "some_grouping"' ) - }) test_that("left_join `by` args", { @@ -65,7 +69,7 @@ test_that("left_join `by` args", { ) # TODO: allow renaming columns on the right side as well - skip("ARROW-XXX") + skip("ARROW-14184") expect_dplyr_equal( input %>% rename(the_grouping = some_grouping) %>% @@ -79,11 +83,20 @@ test_that("left_join `by` args", { }) -test_that("Error handling", { - tab <- Table$create(left) +test_that("join two tables", { + expect_identical( + left_tab %>% + left_join(to_join_tab, by = "some_grouping") %>% + collect(), + left %>% + left_join(to_join, by = "some_grouping") %>% + collect() + ) +}) +test_that("Error handling", { expect_error( - tab %>% + left_tab %>% left_join(to_join, by = "not_a_col") %>% collect(), "all(names(by) %in% names(x)) is not TRUE", @@ -92,6 +105,7 @@ test_that("Error handling", { }) # TODO: test duplicate col names +# TODO: casting: int and float columns? test_that("right_join", { expect_dplyr_equal( From fa9844f30ce455aee6634f57967825b531b8e68b Mon Sep 17 00:00:00 2001 From: Jonathan Keane Date: Thu, 30 Sep 2021 14:07:45 -0500 Subject: [PATCH 7/7] david's patch --- cpp/src/arrow/compute/exec/options.h | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/cpp/src/arrow/compute/exec/options.h b/cpp/src/arrow/compute/exec/options.h index 3dd88210e78..9022781cbdb 100644 --- a/cpp/src/arrow/compute/exec/options.h +++ b/cpp/src/arrow/compute/exec/options.h @@ -155,9 +155,9 @@ class ARROW_EXPORT HashJoinNodeOptions : public ExecNodeOptions { output_all(true), output_prefix_for_left(std::move(output_prefix_for_left)), output_prefix_for_right(std::move(output_prefix_for_right)) { - key_cmp.resize(left_keys.size()); - for (size_t i = 0; i < left_keys.size(); ++i) { - key_cmp[i] = JoinKeyCmp::EQ; + this->key_cmp.resize(this->left_keys.size()); + for (size_t i = 0; i < this->left_keys.size(); ++i) { + this->key_cmp[i] = JoinKeyCmp::EQ; } } HashJoinNodeOptions( @@ -174,9 +174,9 @@ class ARROW_EXPORT HashJoinNodeOptions : public ExecNodeOptions { right_output(std::move(right_output)), output_prefix_for_left(std::move(output_prefix_for_left)), output_prefix_for_right(std::move(output_prefix_for_right)) { - key_cmp.resize(left_keys.size()); - for (size_t i = 0; i < left_keys.size(); ++i) { - key_cmp[i] = JoinKeyCmp::EQ; + this->key_cmp.resize(this->left_keys.size()); + for (size_t i = 0; i < this->left_keys.size(); ++i) { + this->key_cmp[i] = JoinKeyCmp::EQ; } } HashJoinNodeOptions(