diff --git a/LICENSE.txt b/LICENSE.txt index 52f471ed2eb..519a73f04f2 100644 --- a/LICENSE.txt +++ b/LICENSE.txt @@ -2242,3 +2242,32 @@ be copied, modified, or distributed except according to those terms. Distributed on an "AS IS" BASIS, WITHOUT WARRANTY OF ANY KIND, either express or implied. See your chosen license for details. + +-------------------------------------------------------------------------------- +r/R/dplyr-count-tally.R (some portions) + +Some portions of this file are derived from code from + +https://github.com/tidyverse/dplyr/ + +which is made available under the MIT license + +Copyright (c) 2013-2019 RStudio and others. + +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the “Software”), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/r/DESCRIPTION b/r/DESCRIPTION index 89d94305dd1..8497784b485 100644 --- a/r/DESCRIPTION +++ b/r/DESCRIPTION @@ -86,6 +86,7 @@ Collate: 'dictionary.R' 'dplyr-arrange.R' 'dplyr-collect.R' + 'dplyr-count-tally.R' 'dplyr-distinct.R' 'dplyr-eval.R' 'dplyr-filter.R' diff --git a/r/R/arrow-package.R b/r/R/arrow-package.R index 807aea207b7..7e39f5e22ac 100644 --- a/r/R/arrow-package.R +++ b/r/R/arrow-package.R @@ -37,7 +37,7 @@ "group_vars", "group_by_drop_default", "ungroup", "mutate", "transmute", "arrange", "rename", "pull", "relocate", "compute", "collapse", "distinct", "left_join", "right_join", "inner_join", "full_join", - "semi_join", "anti_join" + "semi_join", "anti_join", "count", "tally" ) ) for (cl in c("Dataset", "ArrowTabular", "arrow_dplyr_query")) { diff --git a/r/R/dplyr-count-tally.R b/r/R/dplyr-count-tally.R new file mode 100644 index 00000000000..147a1ac84f1 --- /dev/null +++ b/r/R/dplyr-count-tally.R @@ -0,0 +1,56 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# The following S3 methods are registered on load if dplyr is present + +count.arrow_dplyr_query <- function(x, ..., wt = NULL, sort = FALSE, name = NULL) { + if (!missing(...)) { + out <- group_by(x, ..., .add = TRUE) + } else { + out <- x + } + out <- tally(out, wt = {{ wt }}, sort = sort, name = name) + + # Restore original group vars + gv <- dplyr::group_vars(x) + if (length(gv)) { + out$group_by_vars <- gv + } + + out +} + +count.Dataset <- count.ArrowTabular <- count.arrow_dplyr_query + +tally.arrow_dplyr_query <- function(x, wt = NULL, sort = FALSE, name = NULL) { + check_name <- getFromNamespace("check_name", "dplyr") + name <- check_name(name, dplyr::group_vars(x)) + + if (quo_is_null(enquo(wt))) { + out <- dplyr::summarize(x, !!name := n()) + } else { + out <- dplyr::summarize(x, !!name := sum({{ wt }}, na.rm = TRUE)) + } + + if (sort) { + arrange(out, desc(!!sym(name))) + } else { + out + } +} + +tally.Dataset <- tally.ArrowTabular <- tally.arrow_dplyr_query diff --git a/r/R/util.R b/r/R/util.R index ab6753b09c1..9e3ade6a967 100644 --- a/r/R/util.R +++ b/r/R/util.R @@ -66,6 +66,10 @@ r_symbolic_constants <- c( ) is_function <- function(expr, name) { + # We could have a quosure here if we have an expression like `sum({{ var }})` + if (is_quosure(expr)) { + expr <- quo_get_expr(expr) + } if (!is.call(expr)) { return(FALSE) } else { diff --git a/r/tests/testthat/test-dataset-dplyr.R b/r/tests/testthat/test-dataset-dplyr.R index db5541507fb..e1a9d1cb6a9 100644 --- a/r/tests/testthat/test-dataset-dplyr.R +++ b/r/tests/testthat/test-dataset-dplyr.R @@ -186,13 +186,14 @@ test_that("collect() on Dataset works (if fits in memory)", { }) test_that("count()", { - skip("count() is not a generic so we have to get here through summarize()") ds <- open_dataset(dataset_dir) df <- rbind(df1, df2) expect_equal( ds %>% filter(int > 6, int < 108) %>% - count(chr), + count(chr) %>% + arrange(chr) %>% + collect(), df %>% filter(int > 6, int < 108) %>% count(chr) diff --git a/r/tests/testthat/test-dplyr-count-tally.R b/r/tests/testthat/test-dplyr-count-tally.R new file mode 100644 index 00000000000..1a852e19999 --- /dev/null +++ b/r/tests/testthat/test-dplyr-count-tally.R @@ -0,0 +1,92 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +skip_if_not_available("dataset") + +library(dplyr, warn.conflicts = FALSE) + +tbl <- example_data +tbl$some_grouping <- rep(c(1, 2), 5) + +test_that("count/tally", { + expect_dplyr_equal( + input %>% + count() %>% + collect(), + tbl + ) + + expect_dplyr_equal( + input %>% + tally() %>% + collect(), + tbl + ) +}) + +test_that("count/tally with wt and grouped data", { + expect_dplyr_equal( + input %>% + group_by(some_grouping) %>% + count(wt = int) %>% + collect(), + tbl + ) + + expect_dplyr_equal( + input %>% + group_by(some_grouping) %>% + tally(wt = int) %>% + collect(), + tbl + ) +}) + +test_that("count/tally with sort", { + expect_dplyr_equal( + input %>% + group_by(some_grouping) %>% + count(wt = int, sort = TRUE) %>% + collect(), + tbl + ) + + expect_dplyr_equal( + input %>% + group_by(some_grouping) %>% + tally(wt = int, sort = TRUE) %>% + collect(), + tbl + ) +}) + +test_that("count/tally with name arg", { + expect_dplyr_equal( + input %>% + count(name = "new_col") %>% + collect(), + tbl + ) + + expect_dplyr_equal( + input %>% + tally(name = "new_col") %>% + collect(), + tbl + ) +})