From 34952fafa2b7353ba05b5666ce3433cd7684f8ea Mon Sep 17 00:00:00 2001 From: Neal Richardson Date: Tue, 10 Jan 2023 13:40:37 -0500 Subject: [PATCH 1/9] Proof of concept building a nested field ref --- r/R/arrow-object.R | 2 +- r/R/arrowExports.R | 4 ++++ r/R/expression.R | 18 ++++++++++++++++++ r/src/arrowExports.cpp | 10 ++++++++++ r/src/expression.cpp | 19 +++++++++++++++++++ r/tests/testthat/test-expression.R | 12 ++++++++++++ 6 files changed, 64 insertions(+), 1 deletion(-) diff --git a/r/R/arrow-object.R b/r/R/arrow-object.R index 516f407aafd..0fd822cf864 100644 --- a/r/R/arrow-object.R +++ b/r/R/arrow-object.R @@ -32,7 +32,7 @@ ArrowObject <- R6Class("ArrowObject", assign(".:xp:.", xp, envir = self) }, class_title = function() { - if (!is.null(self$.class_title)) { + if (".class_title" %in% ls(self)) { # Allow subclasses to override just printing the class name first class_title <- self$.class_title() } else { diff --git a/r/R/arrowExports.R b/r/R/arrowExports.R index 38f1ecfb971..72b409d8e64 100644 --- a/r/R/arrowExports.R +++ b/r/R/arrowExports.R @@ -1096,6 +1096,10 @@ compute___expr__field_ref <- function(name) { .Call(`_arrow_compute___expr__field_ref`, name) } +compute___expr__nested_field_ref <- function(x, name) { + .Call(`_arrow_compute___expr__nested_field_ref`, x, name) +} + compute___expr__scalar <- function(x) { .Call(`_arrow_compute___expr__scalar`, x) } diff --git a/r/R/expression.R b/r/R/expression.R index a1163c12a85..ab0257e8051 100644 --- a/r/R/expression.R +++ b/r/R/expression.R @@ -89,6 +89,24 @@ Expression$create <- function(function_name, expr } + +#' @export +`[[.Expression` <- function(x, i, ...) { + # TODO: integer (positional) field refs are supported in C++ + assert_that(is.string(name)) + compute___expr__nested_field_ref(x, i) +} + +#' @export +`$.Expression` <- function(x, name, ...) { + assert_that(is.string(name)) + if (name %in% ls(x)) { + get(name, x) + } else { + compute___expr__nested_field_ref(x, name) + } +} + Expression$field_ref <- function(name) { assert_that(is.string(name)) compute___expr__field_ref(name) diff --git a/r/src/arrowExports.cpp b/r/src/arrowExports.cpp index b7bda1870f9..4cd9cd24720 100644 --- a/r/src/arrowExports.cpp +++ b/r/src/arrowExports.cpp @@ -2756,6 +2756,15 @@ BEGIN_CPP11 END_CPP11 } // expression.cpp +std::shared_ptr compute___expr__nested_field_ref(const std::shared_ptr& x, std::string name); +extern "C" SEXP _arrow_compute___expr__nested_field_ref(SEXP x_sexp, SEXP name_sexp){ +BEGIN_CPP11 + arrow::r::Input&>::type x(x_sexp); + arrow::r::Input::type name(name_sexp); + return cpp11::as_sexp(compute___expr__nested_field_ref(x, name)); +END_CPP11 +} +// expression.cpp std::shared_ptr compute___expr__scalar(const std::shared_ptr& x); extern "C" SEXP _arrow_compute___expr__scalar(SEXP x_sexp){ BEGIN_CPP11 @@ -5572,6 +5581,7 @@ static const R_CallMethodDef CallEntries[] = { { "_arrow_field_names_in_expression", (DL_FUNC) &_arrow_field_names_in_expression, 1}, { "_arrow_compute___expr__get_field_ref_name", (DL_FUNC) &_arrow_compute___expr__get_field_ref_name, 1}, { "_arrow_compute___expr__field_ref", (DL_FUNC) &_arrow_compute___expr__field_ref, 1}, + { "_arrow_compute___expr__nested_field_ref", (DL_FUNC) &_arrow_compute___expr__nested_field_ref, 2}, { "_arrow_compute___expr__scalar", (DL_FUNC) &_arrow_compute___expr__scalar, 1}, { "_arrow_compute___expr__ToString", (DL_FUNC) &_arrow_compute___expr__ToString, 1}, { "_arrow_compute___expr__type", (DL_FUNC) &_arrow_compute___expr__type, 2}, diff --git a/r/src/expression.cpp b/r/src/expression.cpp index a845137e09d..1530d541fdd 100644 --- a/r/src/expression.cpp +++ b/r/src/expression.cpp @@ -71,6 +71,25 @@ std::shared_ptr compute___expr__field_ref(std::string name) return std::make_shared(compute::field_ref(std::move(name))); } +// [[arrow::export]] +std::shared_ptr compute___expr__nested_field_ref( + const std::shared_ptr& x, std::string name) { + if (auto field_ref = x->field_ref()) { + std::vector ref_vec; + if (field_ref->IsNested()) { + ref_vec = *field_ref->nested_refs(); + } else { + // There's just one + ref_vec.push_back(*field_ref); + } + // Add the new ref + ref_vec.push_back(arrow::FieldRef(std::move(name))); + return std::make_shared(compute::field_ref(std::move(ref_vec))); + } else { + // error + } +} + // [[arrow::export]] std::shared_ptr compute___expr__scalar( const std::shared_ptr& x) { diff --git a/r/tests/testthat/test-expression.R b/r/tests/testthat/test-expression.R index 2b6039b04ce..aee9f9c3933 100644 --- a/r/tests/testthat/test-expression.R +++ b/r/tests/testthat/test-expression.R @@ -76,6 +76,18 @@ test_that("Field reference expression schemas and types", { expect_equal(x$type(), int32()) }) +test_that("Nested field refs", { + x <- Expression$field_ref("x") + nested <- x$y + # TODO: + # * ToString if nested (not `FieldRef.Nested(FieldRef.Name(x) FieldRef.Name(y))`) + # * field_name method? + # * Error on trying to make nested field ref with non field ref + # * Test with [[ too + # * dplyr tests + print(nested) +}) + test_that("Scalar expression schemas and types", { # type() works on scalars without setting the schema expect_equal( From b2499e1d1c52f4200dbd1d85d38b3c6b7ea60d4a Mon Sep 17 00:00:00 2001 From: Neal Richardson Date: Tue, 10 Jan 2023 14:09:22 -0500 Subject: [PATCH 2/9] Basic test of using in dplyr --- r/src/expression.cpp | 7 ++++++- r/tests/testthat/test-dplyr-mutate.R | 12 ++++++++++++ r/tests/testthat/test-expression.R | 7 ++++--- 3 files changed, 22 insertions(+), 4 deletions(-) diff --git a/r/src/expression.cpp b/r/src/expression.cpp index 1530d541fdd..1791684a64c 100644 --- a/r/src/expression.cpp +++ b/r/src/expression.cpp @@ -61,7 +61,11 @@ std::vector field_names_in_expression( std::string compute___expr__get_field_ref_name( const std::shared_ptr& x) { if (auto field_ref = x->field_ref()) { - return *field_ref->name(); + // Exclude nested field refs because we only use this to determine if we have simple + // field refs + if (!field_ref->IsNested()) { + return *field_ref->name(); + } } return ""; } @@ -98,6 +102,7 @@ std::shared_ptr compute___expr__scalar( // [[arrow::export]] std::string compute___expr__ToString(const std::shared_ptr& x) { + // TODO: something different if is field ref and IsNested? return x->ToString(); } diff --git a/r/tests/testthat/test-dplyr-mutate.R b/r/tests/testthat/test-dplyr-mutate.R index 5d431089ce7..e94e5cf63d9 100644 --- a/r/tests/testthat/test-dplyr-mutate.R +++ b/r/tests/testthat/test-dplyr-mutate.R @@ -650,3 +650,15 @@ test_that("Can use across() within transmute()", { example_data ) }) + +test_that("Can use nested field refs", { + compare_dplyr_binding( + .input %>% + mutate( + nested = df_col$a, + times2 = df_col$a * 2 + ) %>% + collect(), + tibble(int = 1:5, df_col = tibble(a = 6:10, b = 11:15)) + ) +}) diff --git a/r/tests/testthat/test-expression.R b/r/tests/testthat/test-expression.R index aee9f9c3933..cba47259f46 100644 --- a/r/tests/testthat/test-expression.R +++ b/r/tests/testthat/test-expression.R @@ -80,11 +80,12 @@ test_that("Nested field refs", { x <- Expression$field_ref("x") nested <- x$y # TODO: - # * ToString if nested (not `FieldRef.Nested(FieldRef.Name(x) FieldRef.Name(y))`) - # * field_name method? + # * ToString if nested (not `FieldRef.Nested(FieldRef.Name(x) FieldRef.Name(y))`)? + # * field_names_in_expression? any other places where it is assumed field refs have a single name? # * Error on trying to make nested field ref with non field ref # * Test with [[ too - # * dplyr tests + # * R Expression method to determine if is field ref? + # * more dplyr tests, including print print(nested) }) From b582cdec714cee455bf4fb33009060195ed92af2 Mon Sep 17 00:00:00 2001 From: Neal Richardson Date: Wed, 11 Jan 2023 09:42:24 -0500 Subject: [PATCH 3/9] More tests and fixes --- r/R/expression.R | 12 ++++++-- r/src/expression.cpp | 14 +++++++-- r/tests/testthat/test-dplyr-mutate.R | 12 -------- r/tests/testthat/test-dplyr-query.R | 45 ++++++++++++++++++++++++++++ r/tests/testthat/test-expression.R | 23 +++++++++----- 5 files changed, 80 insertions(+), 26 deletions(-) diff --git a/r/R/expression.R b/r/R/expression.R index ab0257e8051..6f585d9d3ee 100644 --- a/r/R/expression.R +++ b/r/R/expression.R @@ -93,8 +93,11 @@ Expression$create <- function(function_name, #' @export `[[.Expression` <- function(x, i, ...) { # TODO: integer (positional) field refs are supported in C++ - assert_that(is.string(name)) - compute___expr__nested_field_ref(x, i) + assert_that(is.string(i)) + out <- compute___expr__nested_field_ref(x, i) + # Schema bookkeeping + out$schema <- x$schema + out } #' @export @@ -103,7 +106,10 @@ Expression$create <- function(function_name, if (name %in% ls(x)) { get(name, x) } else { - compute___expr__nested_field_ref(x, name) + out <- compute___expr__nested_field_ref(x, name) + # Schema bookkeeping + out$schema <- x$schema + out } } diff --git a/r/src/expression.cpp b/r/src/expression.cpp index 1791684a64c..fec16034bcb 100644 --- a/r/src/expression.cpp +++ b/r/src/expression.cpp @@ -19,7 +19,7 @@ #include #include - +#include namespace compute = ::arrow::compute; std::shared_ptr make_compute_options(std::string func_name, @@ -50,9 +50,17 @@ std::shared_ptr compute___expr__call(std::string func_name, std::vector field_names_in_expression( const std::shared_ptr& x) { std::vector out; + std::vector nested; + auto field_refs = FieldsInExpression(*x); for (auto f : field_refs) { - out.push_back(*f.name()); + if (f.IsNested()) { + // We keep the top-level field name. + nested = *f.nested_refs(); + out.push_back(*nested[0].name()); + } else { + out.push_back(*f.name()); + } } return out; } @@ -90,7 +98,7 @@ std::shared_ptr compute___expr__nested_field_ref( ref_vec.push_back(arrow::FieldRef(std::move(name))); return std::make_shared(compute::field_ref(std::move(ref_vec))); } else { - // error + cpp11::stop("'x' must be a FieldRef Expression"); } } diff --git a/r/tests/testthat/test-dplyr-mutate.R b/r/tests/testthat/test-dplyr-mutate.R index e94e5cf63d9..5d431089ce7 100644 --- a/r/tests/testthat/test-dplyr-mutate.R +++ b/r/tests/testthat/test-dplyr-mutate.R @@ -650,15 +650,3 @@ test_that("Can use across() within transmute()", { example_data ) }) - -test_that("Can use nested field refs", { - compare_dplyr_binding( - .input %>% - mutate( - nested = df_col$a, - times2 = df_col$a * 2 - ) %>% - collect(), - tibble(int = 1:5, df_col = tibble(a = 6:10, b = 11:15)) - ) -}) diff --git a/r/tests/testthat/test-dplyr-query.R b/r/tests/testthat/test-dplyr-query.R index ee11cd6678b..7f7eb685dc7 100644 --- a/r/tests/testthat/test-dplyr-query.R +++ b/r/tests/testthat/test-dplyr-query.R @@ -714,3 +714,48 @@ test_that("Scalars in expressions match the type of the field, if possible", { collect() expect_equal(result$tpc_h_1, result$as_dbl) }) + +test_that("Can use nested field refs", { + nested_data <- tibble(int = 1:5, df_col = tibble(a = 6:10, b = 11:15)) + + compare_dplyr_binding( + .input %>% + mutate( + nested = df_col$a, + times2 = df_col$a * 2 + ) %>% + filter(nested > 7) %>% + collect(), + nested_data + ) + + compare_dplyr_binding( + .input %>% + mutate( + nested = df_col$a, + times2 = df_col$a * 2 + ) %>% + filter(nested > 7) %>% + summarize(sum(times2)) %>% + collect(), + nested_data + ) + + # Now with Dataset + expect_equal( + nested_data %>% + InMemoryDataset$create() %>% + mutate( + nested = df_col$a, + times2 = df_col$a * 2 + ) %>% + filter(nested > 7) %>% + collect(), + nested_data %>% + mutate( + nested = df_col$a, + times2 = df_col$a * 2 + ) %>% + filter(nested > 7) + ) +}) diff --git a/r/tests/testthat/test-expression.R b/r/tests/testthat/test-expression.R index cba47259f46..a47a98c6fa3 100644 --- a/r/tests/testthat/test-expression.R +++ b/r/tests/testthat/test-expression.R @@ -79,14 +79,11 @@ test_that("Field reference expression schemas and types", { test_that("Nested field refs", { x <- Expression$field_ref("x") nested <- x$y - # TODO: - # * ToString if nested (not `FieldRef.Nested(FieldRef.Name(x) FieldRef.Name(y))`)? - # * field_names_in_expression? any other places where it is assumed field refs have a single name? - # * Error on trying to make nested field ref with non field ref - # * Test with [[ too - # * R Expression method to determine if is field ref? - # * more dplyr tests, including print - print(nested) + expect_r6_class(nested, "Expression") + expect_r6_class(x[["y"]], "Expression") + expect_r6_class(nested$z, "Expression") + # Should this instead be NULL? + expect_error(Expression$scalar(42L)$y, "'x' must be a FieldRef Expression") }) test_that("Scalar expression schemas and types", { @@ -140,3 +137,13 @@ test_that("Expression schemas and types", { int32() ) }) + +test_that("Nested field ref types", { + nested <- Expression$field_ref("x")$y + schm <- schema(x = struct(y = int32(), z = double())) + expect_equal(nested$type(schm), int32()) + # implicit casting and schema propagation + x <- Expression$field_ref("x") + x$schema <- schm + expect_equal((x$y * 2)$type(), int32()) +}) From 080cad07a9ce7df044aa633a8819bcb1a15742aa Mon Sep 17 00:00:00 2001 From: Neal Richardson Date: Wed, 11 Jan 2023 10:05:59 -0500 Subject: [PATCH 4/9] Clean up --- r/src/expression.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/r/src/expression.cpp b/r/src/expression.cpp index fec16034bcb..00892db3ecc 100644 --- a/r/src/expression.cpp +++ b/r/src/expression.cpp @@ -19,7 +19,7 @@ #include #include -#include + namespace compute = ::arrow::compute; std::shared_ptr make_compute_options(std::string func_name, From 9afafd6961dfb25c22077b0ecf7e18e424b41897 Mon Sep 17 00:00:00 2001 From: Neal Richardson Date: Wed, 11 Jan 2023 10:50:14 -0500 Subject: [PATCH 5/9] Clean up comments and update NAMESPACE --- r/NAMESPACE | 2 ++ r/src/expression.cpp | 1 - r/tests/testthat/test-dplyr-query.R | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/r/NAMESPACE b/r/NAMESPACE index 3df107a2d8f..bd742e2418d 100644 --- a/r/NAMESPACE +++ b/r/NAMESPACE @@ -2,6 +2,7 @@ S3method("!=",ArrowObject) S3method("$",ArrowTabular) +S3method("$",Expression) S3method("$",Schema) S3method("$",StructArray) S3method("$",SubTreeFileSystem) @@ -14,6 +15,7 @@ S3method("[",Dataset) S3method("[",Schema) S3method("[",arrow_dplyr_query) S3method("[[",ArrowTabular) +S3method("[[",Expression) S3method("[[",Schema) S3method("[[",StructArray) S3method("[[<-",ArrowTabular) diff --git a/r/src/expression.cpp b/r/src/expression.cpp index 00892db3ecc..716ca0daed4 100644 --- a/r/src/expression.cpp +++ b/r/src/expression.cpp @@ -110,7 +110,6 @@ std::shared_ptr compute___expr__scalar( // [[arrow::export]] std::string compute___expr__ToString(const std::shared_ptr& x) { - // TODO: something different if is field ref and IsNested? return x->ToString(); } diff --git a/r/tests/testthat/test-dplyr-query.R b/r/tests/testthat/test-dplyr-query.R index 7f7eb685dc7..adf52830597 100644 --- a/r/tests/testthat/test-dplyr-query.R +++ b/r/tests/testthat/test-dplyr-query.R @@ -741,7 +741,7 @@ test_that("Can use nested field refs", { nested_data ) - # Now with Dataset + # Now with Dataset: make sure column pushdown in ScanNode works expect_equal( nested_data %>% InMemoryDataset$create() %>% From bbb317c18facc4fc1c6cb24a4c726e369149dc11 Mon Sep 17 00:00:00 2001 From: Neal Richardson Date: Thu, 12 Jan 2023 09:02:53 -0500 Subject: [PATCH 6/9] Fix test and add error handling test --- r/R/arrow-object.R | 2 +- r/tests/testthat/test-dplyr-query.R | 9 +++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/r/R/arrow-object.R b/r/R/arrow-object.R index 0fd822cf864..5c2cf4691fc 100644 --- a/r/R/arrow-object.R +++ b/r/R/arrow-object.R @@ -32,7 +32,7 @@ ArrowObject <- R6Class("ArrowObject", assign(".:xp:.", xp, envir = self) }, class_title = function() { - if (".class_title" %in% ls(self)) { + if (".class_title" %in% ls(self, all.names = TRUE)) { # Allow subclasses to override just printing the class name first class_title <- self$.class_title() } else { diff --git a/r/tests/testthat/test-dplyr-query.R b/r/tests/testthat/test-dplyr-query.R index adf52830597..7c2a8ae356d 100644 --- a/r/tests/testthat/test-dplyr-query.R +++ b/r/tests/testthat/test-dplyr-query.R @@ -759,3 +759,12 @@ test_that("Can use nested field refs", { filter(nested > 7) ) }) + +test_that("nested field ref error handling", { + expect_error( + example_data %>% + arrow_table() %>% + mutate(x = int$nested) %>% + compute() + ) +}) From 202915133eaa00bcf4cfb6093b946007f4193edc Mon Sep 17 00:00:00 2001 From: Neal Richardson Date: Thu, 12 Jan 2023 09:58:19 -0500 Subject: [PATCH 7/9] Use struct_field kernel on non-field-refs --- r/NAMESPACE | 1 + r/R/arrowExports.R | 5 +++- r/R/expression.R | 45 ++++++++++++++++++++++++----- r/R/type.R | 3 ++ r/src/arrowExports.cpp | 9 ++++++ r/src/compute.cpp | 5 ++++ r/src/expression.cpp | 5 ++++ r/tests/testthat/test-dplyr-query.R | 15 ++++++++++ r/tests/testthat/test-expression.R | 10 +++++-- 9 files changed, 87 insertions(+), 11 deletions(-) diff --git a/r/NAMESPACE b/r/NAMESPACE index bd742e2418d..3ab828a9587 100644 --- a/r/NAMESPACE +++ b/r/NAMESPACE @@ -139,6 +139,7 @@ S3method(names,Scanner) S3method(names,ScannerBuilder) S3method(names,Schema) S3method(names,StructArray) +S3method(names,StructType) S3method(names,Table) S3method(names,arrow_dplyr_query) S3method(print,"arrow-enum") diff --git a/r/R/arrowExports.R b/r/R/arrowExports.R index 72b409d8e64..2eeca24dbdc 100644 --- a/r/R/arrowExports.R +++ b/r/R/arrowExports.R @@ -1084,6 +1084,10 @@ compute___expr__call <- function(func_name, argument_list, options) { .Call(`_arrow_compute___expr__call`, func_name, argument_list, options) } +compute___expr__is_field_ref <- function(x) { + .Call(`_arrow_compute___expr__is_field_ref`, x) +} + field_names_in_expression <- function(x) { .Call(`_arrow_field_names_in_expression`, x) } @@ -2091,4 +2095,3 @@ SetIOThreadPoolCapacity <- function(threads) { Array__infer_type <- function(x) { .Call(`_arrow_Array__infer_type`, x) } - diff --git a/r/R/expression.R b/r/R/expression.R index 6f585d9d3ee..fc4e5b49a1c 100644 --- a/r/R/expression.R +++ b/r/R/expression.R @@ -57,6 +57,9 @@ Expression <- R6Class("Expression", assert_that(!is.null(schema)) compute___expr__type_id(self, schema) }, + is_field_ref = function() { + compute___expr__is_field_ref(self) + }, cast = function(to_type, safe = TRUE, ...) { opts <- cast_options(safe, ...) opts$to_type <- as_type(to_type) @@ -94,10 +97,7 @@ Expression$create <- function(function_name, `[[.Expression` <- function(x, i, ...) { # TODO: integer (positional) field refs are supported in C++ assert_that(is.string(i)) - out <- compute___expr__nested_field_ref(x, i) - # Schema bookkeeping - out$schema <- x$schema - out + get_nested_field(x, i) } #' @export @@ -106,13 +106,42 @@ Expression$create <- function(function_name, if (name %in% ls(x)) { get(name, x) } else { - out <- compute___expr__nested_field_ref(x, name) - # Schema bookkeeping - out$schema <- x$schema - out + get_nested_field(x, name) } } +get_nested_field <- function(expr, name) { + if (expr$is_field_ref()) { + # Make a nested field ref + out <- compute___expr__nested_field_ref(expr, name) + } else { + # Use the struct_field kernel, but that only works if: + # * expr has a knowable type (has a schema set) + # * that type is struct + # * `name` exists in the struct (bc we have to map to an integer position) + expr_type <- expr$type() # errors if no schema set + if (inherits(expr_type, "StructType")) { + ind <- match(name, names(expr_type)) - 1L + if (is.na(ind)) { + stop( + "field '", name, "' not found in ", + expr_type$ToString(), + call. = FALSE + ) + } + out <- Expression$create("struct_field", expr, options = list(indices = ind)) + } else { + stop( + "Cannot extract a field from an Expression of type ", expr_type$ToString(), + call. = FALSE + ) + } + } + # Schema bookkeeping + out$schema <- expr$schema + out +} + Expression$field_ref <- function(name) { assert_that(is.string(name)) compute___expr__field_ref(name) diff --git a/r/R/type.R b/r/R/type.R index d1578dd822e..bd69311b258 100644 --- a/r/R/type.R +++ b/r/R/type.R @@ -641,6 +641,9 @@ StructType$create <- function(...) struct__(.fields(list(...))) #' @export struct <- StructType$create +#' @export +names.StructType <- function(x) StructType__field_names(x) + ListType <- R6Class("ListType", inherit = NestedType, public = list( diff --git a/r/src/arrowExports.cpp b/r/src/arrowExports.cpp index 4cd9cd24720..e918390e269 100644 --- a/r/src/arrowExports.cpp +++ b/r/src/arrowExports.cpp @@ -2732,6 +2732,14 @@ BEGIN_CPP11 END_CPP11 } // expression.cpp +bool compute___expr__is_field_ref(const std::shared_ptr& x); +extern "C" SEXP _arrow_compute___expr__is_field_ref(SEXP x_sexp){ +BEGIN_CPP11 + arrow::r::Input&>::type x(x_sexp); + return cpp11::as_sexp(compute___expr__is_field_ref(x)); +END_CPP11 +} +// expression.cpp std::vector field_names_in_expression(const std::shared_ptr& x); extern "C" SEXP _arrow_field_names_in_expression(SEXP x_sexp){ BEGIN_CPP11 @@ -5578,6 +5586,7 @@ static const R_CallMethodDef CallEntries[] = { { "_arrow_MapType__keys_sorted", (DL_FUNC) &_arrow_MapType__keys_sorted, 1}, { "_arrow_compute___expr__equals", (DL_FUNC) &_arrow_compute___expr__equals, 2}, { "_arrow_compute___expr__call", (DL_FUNC) &_arrow_compute___expr__call, 3}, + { "_arrow_compute___expr__is_field_ref", (DL_FUNC) &_arrow_compute___expr__is_field_ref, 1}, { "_arrow_field_names_in_expression", (DL_FUNC) &_arrow_field_names_in_expression, 1}, { "_arrow_compute___expr__get_field_ref_name", (DL_FUNC) &_arrow_compute___expr__get_field_ref_name, 1}, { "_arrow_compute___expr__field_ref", (DL_FUNC) &_arrow_compute___expr__field_ref, 1}, diff --git a/r/src/compute.cpp b/r/src/compute.cpp index b4b4c5fdc8d..e6c4820cbc3 100644 --- a/r/src/compute.cpp +++ b/r/src/compute.cpp @@ -564,6 +564,11 @@ std::shared_ptr make_compute_options( return out; } + if (func_name == "struct_field") { + using Options = arrow::compute::StructFieldOptions; + return std::make_shared(cpp11::as_cpp>(options["indices"])); + } + return nullptr; } diff --git a/r/src/expression.cpp b/r/src/expression.cpp index 716ca0daed4..d7a511e7600 100644 --- a/r/src/expression.cpp +++ b/r/src/expression.cpp @@ -46,6 +46,11 @@ std::shared_ptr compute___expr__call(std::string func_name, compute::call(std::move(func_name), std::move(arguments), std::move(options_ptr))); } +// [[arrow::export]] +bool compute___expr__is_field_ref(const std::shared_ptr& x) { + return x->field_ref() != nullptr; +} + // [[arrow::export]] std::vector field_names_in_expression( const std::shared_ptr& x) { diff --git a/r/tests/testthat/test-dplyr-query.R b/r/tests/testthat/test-dplyr-query.R index 7c2a8ae356d..242610cea69 100644 --- a/r/tests/testthat/test-dplyr-query.R +++ b/r/tests/testthat/test-dplyr-query.R @@ -760,6 +760,21 @@ test_that("Can use nested field refs", { ) }) +test_that("Use struct_field for $ on non-field-ref", { + compare_dplyr_binding( + .input %>% + mutate( + df_col = tibble(i = int, d = dbl) + ) %>% + transmute( + int2 = df_col$i, + dbl2 = df_col$d + ) %>% + collect(), + example_data + ) +}) + test_that("nested field ref error handling", { expect_error( example_data %>% diff --git a/r/tests/testthat/test-expression.R b/r/tests/testthat/test-expression.R index a47a98c6fa3..ccb09b9eb00 100644 --- a/r/tests/testthat/test-expression.R +++ b/r/tests/testthat/test-expression.R @@ -82,8 +82,7 @@ test_that("Nested field refs", { expect_r6_class(nested, "Expression") expect_r6_class(x[["y"]], "Expression") expect_r6_class(nested$z, "Expression") - # Should this instead be NULL? - expect_error(Expression$scalar(42L)$y, "'x' must be a FieldRef Expression") + expect_error(Expression$scalar(42L)$y, "Cannot extract a field from an Expression of type int32") }) test_that("Scalar expression schemas and types", { @@ -147,3 +146,10 @@ test_that("Nested field ref types", { x$schema <- schm expect_equal((x$y * 2)$type(), int32()) }) + +test_that("Nested field from a non-field-ref (struct_field kernel)", { + x <- Expression$scalar(data.frame(a = 1, b = "two")) + expect_true(inherits(x$a, "Expression")) + expect_equal(x$a$type(), float64()) + expect_error(x$c, "field 'c' not found in struct") +}) From ded54d659a66a799ef584bba21773f492a72d4e5 Mon Sep 17 00:00:00 2001 From: Neal Richardson Date: Wed, 18 Jan 2023 10:42:46 -0500 Subject: [PATCH 8/9] Add error expectation --- r/tests/testthat/test-dplyr-query.R | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/r/tests/testthat/test-dplyr-query.R b/r/tests/testthat/test-dplyr-query.R index 242610cea69..a91c0b6ccb5 100644 --- a/r/tests/testthat/test-dplyr-query.R +++ b/r/tests/testthat/test-dplyr-query.R @@ -780,6 +780,7 @@ test_that("nested field ref error handling", { example_data %>% arrow_table() %>% mutate(x = int$nested) %>% - compute() + compute(), + "No match" ) }) From f5217fc60398f29311fe3c2e9d88a12c85a68565 Mon Sep 17 00:00:00 2001 From: Neal Richardson Date: Wed, 18 Jan 2023 11:27:11 -0500 Subject: [PATCH 9/9] Create struct_field with FieldRef instead of integer; add TODO links --- r/R/expression.R | 26 ++++++++++++++------------ r/src/compute.cpp | 11 ++++++++++- 2 files changed, 24 insertions(+), 13 deletions(-) diff --git a/r/R/expression.R b/r/R/expression.R index fc4e5b49a1c..8f84b4b31ec 100644 --- a/r/R/expression.R +++ b/r/R/expression.R @@ -94,11 +94,7 @@ Expression$create <- function(function_name, #' @export -`[[.Expression` <- function(x, i, ...) { - # TODO: integer (positional) field refs are supported in C++ - assert_that(is.string(i)) - get_nested_field(x, i) -} +`[[.Expression` <- function(x, i, ...) get_nested_field(x, i) #' @export `$.Expression` <- function(x, name, ...) { @@ -113,24 +109,29 @@ Expression$create <- function(function_name, get_nested_field <- function(expr, name) { if (expr$is_field_ref()) { # Make a nested field ref + # TODO(#33756): integer (positional) field refs are supported in C++ + assert_that(is.string(name)) out <- compute___expr__nested_field_ref(expr, name) } else { - # Use the struct_field kernel, but that only works if: - # * expr has a knowable type (has a schema set) - # * that type is struct - # * `name` exists in the struct (bc we have to map to an integer position) + # Use the struct_field kernel if expr is a struct: expr_type <- expr$type() # errors if no schema set if (inherits(expr_type, "StructType")) { - ind <- match(name, names(expr_type)) - 1L - if (is.na(ind)) { + # Because we have the type, we can validate that the field exists + if (!(name %in% names(expr_type))) { stop( "field '", name, "' not found in ", expr_type$ToString(), call. = FALSE ) } - out <- Expression$create("struct_field", expr, options = list(indices = ind)) + out <- Expression$create( + "struct_field", + expr, + options = list(field_ref = Expression$field_ref(name)) + ) } else { + # TODO(#33757): if expr is list type and name is integer or Expression, + # call list_element stop( "Cannot extract a field from an Expression of type ", expr_type$ToString(), call. = FALSE @@ -143,6 +144,7 @@ get_nested_field <- function(expr, name) { } Expression$field_ref <- function(name) { + # TODO(#33756): allow construction of field ref from integer assert_that(is.string(name)) compute___expr__field_ref(name) } diff --git a/r/src/compute.cpp b/r/src/compute.cpp index e6c4820cbc3..578ce74d05d 100644 --- a/r/src/compute.cpp +++ b/r/src/compute.cpp @@ -566,7 +566,16 @@ std::shared_ptr make_compute_options( if (func_name == "struct_field") { using Options = arrow::compute::StructFieldOptions; - return std::make_shared(cpp11::as_cpp>(options["indices"])); + if (!Rf_isNull(options["indices"])) { + return std::make_shared( + cpp11::as_cpp>(options["indices"])); + } else { + // field_ref + return std::make_shared( + *cpp11::as_cpp>( + options["field_ref"]) + ->field_ref()); + } } return nullptr;