diff --git a/cpp/src/arrow/compute/exec/options.h b/cpp/src/arrow/compute/exec/options.h index 3dd88210e78..9022781cbdb 100644 --- a/cpp/src/arrow/compute/exec/options.h +++ b/cpp/src/arrow/compute/exec/options.h @@ -155,9 +155,9 @@ class ARROW_EXPORT HashJoinNodeOptions : public ExecNodeOptions { output_all(true), output_prefix_for_left(std::move(output_prefix_for_left)), output_prefix_for_right(std::move(output_prefix_for_right)) { - key_cmp.resize(left_keys.size()); - for (size_t i = 0; i < left_keys.size(); ++i) { - key_cmp[i] = JoinKeyCmp::EQ; + this->key_cmp.resize(this->left_keys.size()); + for (size_t i = 0; i < this->left_keys.size(); ++i) { + this->key_cmp[i] = JoinKeyCmp::EQ; } } HashJoinNodeOptions( @@ -174,9 +174,9 @@ class ARROW_EXPORT HashJoinNodeOptions : public ExecNodeOptions { right_output(std::move(right_output)), output_prefix_for_left(std::move(output_prefix_for_left)), output_prefix_for_right(std::move(output_prefix_for_right)) { - key_cmp.resize(left_keys.size()); - for (size_t i = 0; i < left_keys.size(); ++i) { - key_cmp[i] = JoinKeyCmp::EQ; + this->key_cmp.resize(this->left_keys.size()); + for (size_t i = 0; i < this->left_keys.size(); ++i) { + this->key_cmp[i] = JoinKeyCmp::EQ; } } HashJoinNodeOptions( diff --git a/r/DESCRIPTION b/r/DESCRIPTION index 0dc44277ccd..89d94305dd1 100644 --- a/r/DESCRIPTION +++ b/r/DESCRIPTION @@ -92,6 +92,7 @@ Collate: 'expression.R' 'dplyr-functions.R' 'dplyr-group-by.R' + 'dplyr-join.R' 'dplyr-mutate.R' 'dplyr-select.R' 'dplyr-summarize.R' diff --git a/r/NAMESPACE b/r/NAMESPACE index e90fcdf3451..c909ad4a938 100644 --- a/r/NAMESPACE +++ b/r/NAMESPACE @@ -142,6 +142,7 @@ export(HivePartitioning) export(HivePartitioningFactory) export(InMemoryDataset) export(IpcFileFormat) +export(JoinType) export(JsonParseOptions) export(JsonReadOptions) export(JsonTableReader) diff --git a/r/R/arrow-package.R b/r/R/arrow-package.R index 32f6c6ba05c..807aea207b7 100644 --- a/r/R/arrow-package.R +++ b/r/R/arrow-package.R @@ -36,7 +36,8 @@ "select", "filter", "collect", "summarise", "group_by", "groups", "group_vars", "group_by_drop_default", "ungroup", "mutate", "transmute", "arrange", "rename", "pull", "relocate", "compute", "collapse", - "distinct" + "distinct", "left_join", "right_join", "inner_join", "full_join", + "semi_join", "anti_join" ) ) for (cl in c("Dataset", "ArrowTabular", "arrow_dplyr_query")) { diff --git a/r/R/arrowExports.R b/r/R/arrowExports.R index b852a3d8ca9..a67a9306bc2 100644 --- a/r/R/arrowExports.R +++ b/r/R/arrowExports.R @@ -284,6 +284,10 @@ ExecPlan_run <- function(plan, final_node, sort_options) { .Call(`_arrow_ExecPlan_run`, plan, final_node, sort_options) } +ExecNode_output_schema <- function(node) { + .Call(`_arrow_ExecNode_output_schema`, node) +} + ExecNode_Scan <- function(plan, dataset, filter, materialized_field_names) { .Call(`_arrow_ExecNode_Scan`, plan, dataset, filter, materialized_field_names) } @@ -300,6 +304,10 @@ ExecNode_Aggregate <- function(input, options, target_names, out_field_names, ke .Call(`_arrow_ExecNode_Aggregate`, input, options, target_names, out_field_names, key_names) } +ExecNode_Join <- function(input, type, right_data, left_keys, right_keys, left_output, right_output) { + .Call(`_arrow_ExecNode_Join`, input, type, right_data, left_keys, right_keys, left_output, right_output) +} + RecordBatch__cast <- function(batch, schema, options) { .Call(`_arrow_RecordBatch__cast`, batch, schema, options) } diff --git a/r/R/dplyr-collect.R b/r/R/dplyr-collect.R index 06914abe072..519d6421625 100644 --- a/r/R/dplyr-collect.R +++ b/r/R/dplyr-collect.R @@ -89,6 +89,14 @@ implicit_schema <- function(.data) { if (is.null(.data$aggregations)) { new_fields <- map(.data$selected_columns, ~ .$type(old_schm)) + if (!is.null(.data$join) && !(.data$join$type %in% JoinType[1:4])) { + # Add cols from right side, except for semi/anti joins + right_cols <- .data$join$right_data$selected_columns + new_fields <- c(new_fields, map( + right_cols[setdiff(names(right_cols), .data$join$by)], + ~ .$type(.data$join$right_data$.data$schema) + )) + } } else { new_fields <- map(summarize_projection(.data), ~ .$type(old_schm)) # * Put group_by_vars first (this can't be done by summarize, diff --git a/r/R/dplyr-join.R b/r/R/dplyr-join.R new file mode 100644 index 00000000000..c14b1a8f3dd --- /dev/null +++ b/r/R/dplyr-join.R @@ -0,0 +1,126 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +# The following S3 methods are registered on load if dplyr is present + +do_join <- function(x, + y, + by = NULL, + copy = FALSE, + suffix = c(".x", ".y"), + ..., + keep = FALSE, + na_matches, + join_type) { + # TODO: handle `copy` arg: ignore? + # TODO: handle `suffix` arg: Arrow does prefix + # TODO: handle `keep` arg: "Should the join keys from both ‘x’ and ‘y’ be preserved in the output?" + # TODO: handle `na_matches` arg + x <- as_adq(x) + y <- as_adq(y) + by <- handle_join_by(by, x, y) + + x$join <- list( + type = JoinType[[join_type]], + right_data = y, + by = by + ) + collapse.arrow_dplyr_query(x) +} + +left_join.arrow_dplyr_query <- function(x, + y, + by = NULL, + copy = FALSE, + suffix = c(".x", ".y"), + ..., + keep = FALSE) { + do_join(x, y, by, copy, suffix, ..., keep = keep, join_type = "LEFT_OUTER") +} +left_join.Dataset <- left_join.ArrowTabular <- left_join.arrow_dplyr_query + +right_join.arrow_dplyr_query <- function(x, + y, + by = NULL, + copy = FALSE, + suffix = c(".x", ".y"), + ..., + keep = FALSE) { + do_join(x, y, by, copy, suffix, ..., keep = keep, join_type = "RIGHT_OUTER") +} +right_join.Dataset <- right_join.ArrowTabular <- right_join.arrow_dplyr_query + +inner_join.arrow_dplyr_query <- function(x, + y, + by = NULL, + copy = FALSE, + suffix = c(".x", ".y"), + ..., + keep = FALSE) { + do_join(x, y, by, copy, suffix, ..., keep = keep, join_type = "INNER") +} +inner_join.Dataset <- inner_join.ArrowTabular <- inner_join.arrow_dplyr_query + +full_join.arrow_dplyr_query <- function(x, + y, + by = NULL, + copy = FALSE, + suffix = c(".x", ".y"), + ..., + keep = FALSE) { + do_join(x, y, by, copy, suffix, ..., keep = keep, join_type = "FULL_OUTER") +} +full_join.Dataset <- full_join.ArrowTabular <- full_join.arrow_dplyr_query + +semi_join.arrow_dplyr_query <- function(x, + y, + by = NULL, + copy = FALSE, + suffix = c(".x", ".y"), + ..., + keep = FALSE) { + do_join(x, y, by, copy, suffix, ..., keep = keep, join_type = "LEFT_SEMI") +} +semi_join.Dataset <- semi_join.ArrowTabular <- semi_join.arrow_dplyr_query + +anti_join.arrow_dplyr_query <- function(x, + y, + by = NULL, + copy = FALSE, + suffix = c(".x", ".y"), + ..., + keep = FALSE) { + do_join(x, y, by, copy, suffix, ..., keep = keep, join_type = "LEFT_ANTI") +} +anti_join.Dataset <- anti_join.ArrowTabular <- anti_join.arrow_dplyr_query + +handle_join_by <- function(by, x, y) { + if (is.null(by)) { + return(set_names(intersect(names(x), names(y)))) + } + stopifnot(is.character(by)) + if (is.null(names(by))) { + by <- set_names(by) + } + # TODO: nicer messages? + stopifnot( + all(names(by) %in% names(x)), + all(by %in% names(y)) + ) + by +} diff --git a/r/R/enums.R b/r/R/enums.R index d9cb3a55c1d..4e69b7a190e 100644 --- a/r/R/enums.R +++ b/r/R/enums.R @@ -163,3 +163,16 @@ RoundMode <- enum("RoundMode", HALF_TO_EVEN = 8L, HALF_TO_ODD = 9L ) + +#' @export +#' @rdname enums +JoinType <- enum("JoinType", + LEFT_SEMI = 0L, + RIGHT_SEMI = 1L, + LEFT_ANTI = 2L, + RIGHT_ANTI = 3L, + INNER = 4L, + LEFT_OUTER = 5L, + RIGHT_OUTER = 6L, + FULL_OUTER = 7L +) diff --git a/r/R/query-engine.R b/r/R/query-engine.R index 8f8514ad1cb..ccd3ee9832a 100644 --- a/r/R/query-engine.R +++ b/r/R/query-engine.R @@ -143,7 +143,15 @@ ExecPlan <- R6Class("ExecPlan", ) } } - } else { + } else if (!is.null(.data$join)) { + node <- node$Join( + type = .data$join$type, + right_node = self$Build(.data$join$right_data), + by = .data$join$by, + left_output = names(.data), + right_output = setdiff(names(.data$join$right_data), .data$join$by) + ) + } else if (length(node$schema)) { # If any columns are derived, reordered, or renamed we need to Project # If there are aggregations, the projection was already handled above # We have to project at least once to eliminate some junk columns @@ -206,6 +214,22 @@ ExecNode <- R6Class("ExecNode", self$preserve_sort( ExecNode_Aggregate(self, options, target_names, out_field_names, key_names) ) + }, + Join = function(type, right_node, by, left_output, right_output) { + self$preserve_sort( + ExecNode_Join( + self, + type, + right_node, + left_keys = names(by), + right_keys = by, + left_output = left_output, + right_output = right_output + ) + ) } + ), + active = list( + schema = function() ExecNode_output_schema(self) ) ) diff --git a/r/man/enums.Rd b/r/man/enums.Rd index e21a9f4d9cb..7ec126a0198 100644 --- a/r/man/enums.Rd +++ b/r/man/enums.Rd @@ -17,6 +17,7 @@ \alias{NullEncodingBehavior} \alias{NullHandlingBehavior} \alias{RoundMode} +\alias{JoinType} \title{Arrow enums} \format{ An object of class \code{TimeUnit::type} (inherits from \code{arrow-enum}) of length 4. @@ -46,6 +47,8 @@ An object of class \code{NullEncodingBehavior} (inherits from \code{arrow-enum}) An object of class \code{NullHandlingBehavior} (inherits from \code{arrow-enum}) of length 3. An object of class \code{RoundMode} (inherits from \code{arrow-enum}) of length 10. + +An object of class \code{JoinType} (inherits from \code{arrow-enum}) of length 8. } \usage{ TimeUnit @@ -75,6 +78,8 @@ NullEncodingBehavior NullHandlingBehavior RoundMode + +JoinType } \description{ Arrow enums diff --git a/r/src/arrowExports.cpp b/r/src/arrowExports.cpp index 0e6aebeca2c..9ad145b211a 100644 --- a/r/src/arrowExports.cpp +++ b/r/src/arrowExports.cpp @@ -1109,6 +1109,21 @@ extern "C" SEXP _arrow_ExecPlan_run(SEXP plan_sexp, SEXP final_node_sexp, SEXP s } #endif +// compute-exec.cpp +#if defined(ARROW_R_WITH_ARROW) +std::shared_ptr ExecNode_output_schema(const std::shared_ptr& node); +extern "C" SEXP _arrow_ExecNode_output_schema(SEXP node_sexp){ +BEGIN_CPP11 + arrow::r::Input&>::type node(node_sexp); + return cpp11::as_sexp(ExecNode_output_schema(node)); +END_CPP11 +} +#else +extern "C" SEXP _arrow_ExecNode_output_schema(SEXP node_sexp){ + Rf_error("Cannot call ExecNode_output_schema(). See https://arrow.apache.org/docs/r/articles/install.html for help installing Arrow C++ libraries. "); +} +#endif + // compute-exec.cpp #if defined(ARROW_R_WITH_DATASET) std::shared_ptr ExecNode_Scan(const std::shared_ptr& plan, const std::shared_ptr& dataset, const std::shared_ptr& filter, std::vector materialized_field_names); @@ -1179,6 +1194,27 @@ extern "C" SEXP _arrow_ExecNode_Aggregate(SEXP input_sexp, SEXP options_sexp, SE } #endif +// compute-exec.cpp +#if defined(ARROW_R_WITH_ARROW) +std::shared_ptr ExecNode_Join(const std::shared_ptr& input, int type, const std::shared_ptr& right_data, std::vector left_keys, std::vector right_keys, std::vector left_output, std::vector right_output); +extern "C" SEXP _arrow_ExecNode_Join(SEXP input_sexp, SEXP type_sexp, SEXP right_data_sexp, SEXP left_keys_sexp, SEXP right_keys_sexp, SEXP left_output_sexp, SEXP right_output_sexp){ +BEGIN_CPP11 + arrow::r::Input&>::type input(input_sexp); + arrow::r::Input::type type(type_sexp); + arrow::r::Input&>::type right_data(right_data_sexp); + arrow::r::Input>::type left_keys(left_keys_sexp); + arrow::r::Input>::type right_keys(right_keys_sexp); + arrow::r::Input>::type left_output(left_output_sexp); + arrow::r::Input>::type right_output(right_output_sexp); + return cpp11::as_sexp(ExecNode_Join(input, type, right_data, left_keys, right_keys, left_output, right_output)); +END_CPP11 +} +#else +extern "C" SEXP _arrow_ExecNode_Join(SEXP input_sexp, SEXP type_sexp, SEXP right_data_sexp, SEXP left_keys_sexp, SEXP right_keys_sexp, SEXP left_output_sexp, SEXP right_output_sexp){ + Rf_error("Cannot call ExecNode_Join(). See https://arrow.apache.org/docs/r/articles/install.html for help installing Arrow C++ libraries. "); +} +#endif + // compute.cpp #if defined(ARROW_R_WITH_ARROW) std::shared_ptr RecordBatch__cast(const std::shared_ptr& batch, const std::shared_ptr& schema, cpp11::list options); @@ -7123,10 +7159,12 @@ static const R_CallMethodDef CallEntries[] = { { "_arrow_io___CompressedInputStream__Make", (DL_FUNC) &_arrow_io___CompressedInputStream__Make, 2}, { "_arrow_ExecPlan_create", (DL_FUNC) &_arrow_ExecPlan_create, 1}, { "_arrow_ExecPlan_run", (DL_FUNC) &_arrow_ExecPlan_run, 3}, + { "_arrow_ExecNode_output_schema", (DL_FUNC) &_arrow_ExecNode_output_schema, 1}, { "_arrow_ExecNode_Scan", (DL_FUNC) &_arrow_ExecNode_Scan, 4}, { "_arrow_ExecNode_Filter", (DL_FUNC) &_arrow_ExecNode_Filter, 2}, { "_arrow_ExecNode_Project", (DL_FUNC) &_arrow_ExecNode_Project, 3}, { "_arrow_ExecNode_Aggregate", (DL_FUNC) &_arrow_ExecNode_Aggregate, 5}, + { "_arrow_ExecNode_Join", (DL_FUNC) &_arrow_ExecNode_Join, 7}, { "_arrow_RecordBatch__cast", (DL_FUNC) &_arrow_RecordBatch__cast, 3}, { "_arrow_Table__cast", (DL_FUNC) &_arrow_Table__cast, 3}, { "_arrow_compute__CallFunction", (DL_FUNC) &_arrow_compute__CallFunction, 3}, diff --git a/r/src/compute-exec.cpp b/r/src/compute-exec.cpp index 9404e016c98..95033d4c124 100644 --- a/r/src/compute-exec.cpp +++ b/r/src/compute-exec.cpp @@ -98,6 +98,12 @@ std::shared_ptr ExecPlan_run( #include +// [[arrow::export]] +std::shared_ptr ExecNode_output_schema( + const std::shared_ptr& node) { + return node->output_schema(); +} + // [[dataset::export]] std::shared_ptr ExecNode_Scan( const std::shared_ptr& plan, @@ -187,4 +193,57 @@ std::shared_ptr ExecNode_Aggregate( std::move(out_field_names), std::move(keys)}); } +// [[arrow::export]] +std::shared_ptr ExecNode_Join( + const std::shared_ptr& input, int type, + const std::shared_ptr& right_data, + std::vector left_keys, std::vector right_keys, + std::vector left_output, std::vector right_output) { + std::vector left_refs, right_refs, left_out_refs, right_out_refs; + for (auto&& name : left_keys) { + left_refs.emplace_back(std::move(name)); + } + for (auto&& name : right_keys) { + right_refs.emplace_back(std::move(name)); + } + for (auto&& name : left_output) { + left_out_refs.emplace_back(std::move(name)); + } + if (type != 0 && type != 2) { + // Don't include out_refs in semi/anti join + for (auto&& name : right_output) { + right_out_refs.emplace_back(std::move(name)); + } + } + + // TODO: we should be able to use this enum directly + compute::JoinType join_type; + if (type == 0) { + join_type = compute::JoinType::LEFT_SEMI; + } else if (type == 1) { + // Not readily called from R bc dplyr::semi_join is LEFT_SEMI + join_type = compute::JoinType::RIGHT_SEMI; + } else if (type == 2) { + join_type = compute::JoinType::LEFT_ANTI; + } else if (type == 3) { + // Not readily called from R bc dplyr::semi_join is LEFT_SEMI + join_type = compute::JoinType::RIGHT_ANTI; + } else if (type == 4) { + join_type = compute::JoinType::INNER; + } else if (type == 5) { + join_type = compute::JoinType::LEFT_OUTER; + } else if (type == 6) { + join_type = compute::JoinType::RIGHT_OUTER; + } else if (type == 7) { + join_type = compute::JoinType::FULL_OUTER; + } else { + cpp11::stop("todo"); + } + + return MakeExecNodeOrStop( + "hashjoin", input->plan(), {input.get(), right_data.get()}, + compute::HashJoinNodeOptions{join_type, std::move(left_refs), std::move(right_refs), + std::move(left_out_refs), std::move(right_out_refs)}); +} + #endif diff --git a/r/tests/testthat/test-dplyr-join.R b/r/tests/testthat/test-dplyr-join.R new file mode 100644 index 00000000000..b189bb83af5 --- /dev/null +++ b/r/tests/testthat/test-dplyr-join.R @@ -0,0 +1,153 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +skip_if_not_available("dataset") + +library(dplyr) + +left <- example_data +# Error: Invalid: Dictionary type support for join output field +# is not yet implemented, output field reference: FieldRef.Name(fct) +# on left side of the join +# (select(-fct) also solves this but remove once) +left$fct <- NULL +left$some_grouping <- rep(c(1, 2), 5) + +left_tab <- Table$create(left) + +to_join <- tibble::tibble( + some_grouping = c(1, 2), + capital_letters = c("A", "B"), + another_column = TRUE +) +to_join_tab <- Table$create(to_join) + + + +test_that("left_join", { + expect_message( + expect_dplyr_equal( + input %>% + left_join(to_join) %>% + collect(), + left + ), + 'Joining, by = "some_grouping"' + ) +}) + +test_that("left_join `by` args", { + expect_dplyr_equal( + input %>% + left_join(to_join, by = "some_grouping") %>% + collect(), + left + ) + expect_dplyr_equal( + input %>% + left_join( + to_join %>% + rename(the_grouping = some_grouping), + by = c(some_grouping = "the_grouping") + ) %>% + collect(), + left + ) + + # TODO: allow renaming columns on the right side as well + skip("ARROW-14184") + expect_dplyr_equal( + input %>% + rename(the_grouping = some_grouping) %>% + left_join( + to_join, + by = c(the_grouping = "some_grouping") + ) %>% + collect(), + left + ) +}) + + +test_that("join two tables", { + expect_identical( + left_tab %>% + left_join(to_join_tab, by = "some_grouping") %>% + collect(), + left %>% + left_join(to_join, by = "some_grouping") %>% + collect() + ) +}) + +test_that("Error handling", { + expect_error( + left_tab %>% + left_join(to_join, by = "not_a_col") %>% + collect(), + "all(names(by) %in% names(x)) is not TRUE", + fixed = TRUE + ) +}) + +# TODO: test duplicate col names +# TODO: casting: int and float columns? + +test_that("right_join", { + expect_dplyr_equal( + input %>% + right_join(to_join, by = "some_grouping") %>% + collect(), + left + ) +}) + +test_that("inner_join", { + expect_dplyr_equal( + input %>% + inner_join(to_join, by = "some_grouping") %>% + collect(), + left + ) +}) + +test_that("full_join", { + expect_dplyr_equal( + input %>% + full_join(to_join, by = "some_grouping") %>% + collect(), + left + ) +}) + +test_that("semi_join", { + expect_dplyr_equal( + input %>% + semi_join(to_join, by = "some_grouping") %>% + collect(), + left + ) +}) + +test_that("anti_join", { + expect_dplyr_equal( + input %>% + anti_join(to_join, by = "some_grouping") %>% + collect(), + left + ) +})