diff --git a/cpp/src/arrow/compute/kernels/scalar_string.cc b/cpp/src/arrow/compute/kernels/scalar_string.cc index 3f63bf2c405..dbacb6bb96f 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string.cc @@ -3587,10 +3587,12 @@ void AddBinaryJoin(FunctionRegistry* registry) { "binary_join_element_wise", Arity::VarArgs(/*min_args=*/1), &binary_join_element_wise_doc, &kDefaultJoinOptions); for (const auto& ty : BaseBinaryTypes()) { - DCHECK_OK( - func->AddKernel({InputType(ty)}, ty, + ScalarKernel kernel{KernelSignature::Make({InputType(ty)}, ty, /*is_varargs=*/true), GenerateTypeAgnosticVarBinaryBase(ty), - BinaryJoinElementWiseState::Init)); + BinaryJoinElementWiseState::Init}; + kernel.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE; + kernel.mem_allocation = MemAllocation::NO_PREALLOCATE; + DCHECK_OK(func->AddKernel(std::move(kernel))); } DCHECK_OK(registry->AddFunction(std::move(func))); } diff --git a/r/NAMESPACE b/r/NAMESPACE index f298ba905ee..ab45aa9985e 100644 --- a/r/NAMESPACE +++ b/r/NAMESPACE @@ -153,6 +153,7 @@ export(MessageReader) export(MessageType) export(MetadataVersion) export(NullEncodingBehavior) +export(NullHandlingBehavior) export(ParquetArrowReaderProperties) export(ParquetFileFormat) export(ParquetFileReader) diff --git a/r/R/dplyr-functions.R b/r/R/dplyr-functions.R index 91d1b21ad88..1cf6fabebee 100644 --- a/r/R/dplyr-functions.R +++ b/r/R/dplyr-functions.R @@ -215,6 +215,61 @@ nse_funcs$nchar <- function(x, type = "chars", allowNA = FALSE, keepNA = NA) { } } +nse_funcs$paste <- function(..., sep = " ", collapse = NULL, recycle0 = FALSE) { + assert_that( + is.null(collapse), + msg = "paste() with the collapse argument is not yet supported in Arrow" + ) + if (!inherits(sep, "Expression")) { + assert_that(!is.na(sep), msg = "Invalid separator") + } + arrow_string_join_function(NullHandlingBehavior$REPLACE, "NA")(..., sep) +} + +nse_funcs$paste0 <- function(..., collapse = NULL, recycle0 = FALSE) { + assert_that( + is.null(collapse), + msg = "paste0() with the collapse argument is not yet supported in Arrow" + ) + arrow_string_join_function(NullHandlingBehavior$REPLACE, "NA")(..., "") +} + +nse_funcs$str_c <- function(..., sep = "", collapse = NULL) { + assert_that( + is.null(collapse), + msg = "str_c() with the collapse argument is not yet supported in Arrow" + ) + arrow_string_join_function(NullHandlingBehavior$EMIT_NULL)(..., sep) +} + +arrow_string_join_function <- function(null_handling, null_replacement = NULL) { + # the `binary_join_element_wise` Arrow C++ compute kernel takes the separator + # as the last argument, so pass `sep` as the last dots arg to this function + function(...) { + args <- lapply(list(...), function(arg) { + # handle scalar literal args, and cast all args to string for + # consistency with base::paste(), base::paste0(), and stringr::str_c() + if (!inherits(arg, "Expression")) { + assert_that( + length(arg) == 1, + msg = "Literal vectors of length != 1 not supported in string concatenation" + ) + Expression$scalar(as.character(arg)) + } else { + nse_funcs$as.character(arg) + } + }) + Expression$create( + "binary_join_element_wise", + args = args, + options = list( + null_handling = null_handling, + null_replacement = null_replacement + ) + ) + } +} + nse_funcs$str_trim <- function(string, side = c("both", "left", "right")) { side <- match.arg(side) trim_fun <- switch(side, diff --git a/r/R/enums.R b/r/R/enums.R index 4271f2ad138..8a5bf7366a9 100644 --- a/r/R/enums.R +++ b/r/R/enums.R @@ -140,3 +140,9 @@ QuantileInterpolation <- enum("QuantileInterpolation", NullEncodingBehavior <- enum("NullEncodingBehavior", ENCODE = 0L, MASK = 1L ) + +#' @export +#' @rdname enums +NullHandlingBehavior <- enum("NullHandlingBehavior", + EMIT_NULL = 0L, SKIP = 1L, REPLACE = 2L +) diff --git a/r/man/enums.Rd b/r/man/enums.Rd index b871516def8..57ec3ba115e 100644 --- a/r/man/enums.Rd +++ b/r/man/enums.Rd @@ -15,6 +15,7 @@ \alias{MetadataVersion} \alias{QuantileInterpolation} \alias{NullEncodingBehavior} +\alias{NullHandlingBehavior} \title{Arrow enums} \format{ An object of class \code{TimeUnit::type} (inherits from \code{arrow-enum}) of length 4. @@ -40,6 +41,8 @@ An object of class \code{MetadataVersion} (inherits from \code{arrow-enum}) of l An object of class \code{QuantileInterpolation} (inherits from \code{arrow-enum}) of length 5. An object of class \code{NullEncodingBehavior} (inherits from \code{arrow-enum}) of length 2. + +An object of class \code{NullHandlingBehavior} (inherits from \code{arrow-enum}) of length 3. } \usage{ TimeUnit @@ -65,6 +68,8 @@ MetadataVersion QuantileInterpolation NullEncodingBehavior + +NullHandlingBehavior } \description{ Arrow enums diff --git a/r/src/compute.cpp b/r/src/compute.cpp index eab9db54134..3d322ab6c71 100644 --- a/r/src/compute.cpp +++ b/r/src/compute.cpp @@ -218,6 +218,20 @@ std::shared_ptr make_compute_options( return make_cast_options(options); } + if (func_name == "binary_join_element_wise") { + using Options = arrow::compute::JoinOptions; + auto out = std::make_shared(Options::Defaults()); + if (!Rf_isNull(options["null_handling"])) { + out->null_handling = + cpp11::as_cpp( + options["null_handling"]); + } + if (!Rf_isNull(options["null_replacement"])) { + out->null_replacement = cpp11::as_cpp(options["null_replacement"]); + } + return out; + } + if (func_name == "match_substring" || func_name == "match_substring_regex") { using Options = arrow::compute::MatchSubstringOptions; bool ignore_case = false; diff --git a/r/tests/testthat/test-dplyr-string-functions.R b/r/tests/testthat/test-dplyr-string-functions.R index ea27aa14777..4afb88e5732 100644 --- a/r/tests/testthat/test-dplyr-string-functions.R +++ b/r/tests/testthat/test-dplyr-string-functions.R @@ -21,6 +21,162 @@ skip_if_not_available("utf8proc") library(dplyr) library(stringr) +test_that("paste, paste0, and str_c", { + df <- tibble( + v = c("A", "B", "C"), + w = c("a", "b", "c"), + x = c("d", NA_character_, "f"), + y = c(NA_character_, "h", "i"), + z = c(1.1, 2.2, NA) + ) + x <- Expression$field_ref("x") + y <- Expression$field_ref("y") + + # no NAs in data + expect_dplyr_equal( + input %>% + transmute(paste(v, w)) %>% + collect(), + df + ) + expect_dplyr_equal( + input %>% + transmute(paste(v, w, sep = "-")) %>% + collect(), + df + ) + expect_dplyr_equal( + input %>% + transmute(paste0(v, w)) %>% + collect(), + df + ) + expect_dplyr_equal( + input %>% + transmute(str_c(v, w)) %>% + collect(), + df + ) + expect_dplyr_equal( + input %>% + transmute(str_c(v, w, sep = "+")) %>% + collect(), + df + ) + + # NAs in data + expect_dplyr_equal( + input %>% + transmute(paste(x, y)) %>% + collect(), + df + ) + expect_dplyr_equal( + input %>% + transmute(paste(x, y, sep = "-")) %>% + collect(), + df + ) + expect_dplyr_equal( + input %>% + transmute(str_c(x, y)) %>% + collect(), + df + ) + + # non-character column in dots + expect_dplyr_equal( + input %>% + transmute(paste0(x, y, z)) %>% + collect(), + df + ) + + # literal string in dots + expect_dplyr_equal( + input %>% + transmute(paste(x, "foo", y)) %>% + collect(), + df + ) + + # literal NA in dots + expect_dplyr_equal( + input %>% + transmute(paste(x, NA, y)) %>% + collect(), + df + ) + + # expressions in dots + expect_dplyr_equal( + input %>% + transmute(paste0(x, toupper(y), as.character(z))) %>% + collect(), + df + ) + + # sep is literal NA + # errors in paste() (consistent with base::paste()) + expect_error( + nse_funcs$paste(x, y, sep = NA_character_), + "Invalid separator" + ) + # emits null in str_c() (consistent with stringr::str_c()) + expect_dplyr_equal( + input %>% + transmute(str_c(x, y, sep = NA_character_)) %>% + collect(), + df + ) + + # sep passed in dots to paste0 (which doesn't take a sep argument) + expect_dplyr_equal( + input %>% + transmute(paste0(x, y, sep = "-")) %>% + collect(), + df + ) + + # known differences + + # arrow allows the separator to be an array + expect_equal( + df %>% + Table$create() %>% + transmute(result = paste(x, y, sep = w)) %>% + collect(), + df %>% + transmute(result = paste(x, w, y, sep = "")) + ) + + # expected errors + + # collapse argument not supported + expect_error( + nse_funcs$paste(x, y, collapse = ""), + "collapse" + ) + expect_error( + nse_funcs$paste0(x, y, collapse = ""), + "collapse" + ) + expect_error( + nse_funcs$str_c(x, y, collapse = ""), + "collapse" + ) + + # literal vectors of length != 1 not supported + expect_error( + nse_funcs$paste(x, character(0), y), + "Literal vectors of length != 1 not supported in string concatenation" + ) + expect_error( + nse_funcs$paste(x, c(",", ";"), y), + "Literal vectors of length != 1 not supported in string concatenation" + ) +}) + test_that("grepl with ignore.case = FALSE and fixed = TRUE", { df <- tibble(x = c("Foo", "bar")) expect_dplyr_equal(