Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions r/NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

S3method("!=",ArrowObject)
S3method("$",ArrowTabular)
S3method("$",Expression)
S3method("$",Schema)
S3method("$",StructArray)
S3method("$",SubTreeFileSystem)
Expand All @@ -14,6 +15,7 @@ S3method("[",Dataset)
S3method("[",Schema)
S3method("[",arrow_dplyr_query)
S3method("[[",ArrowTabular)
S3method("[[",Expression)
S3method("[[",Schema)
S3method("[[",StructArray)
S3method("[[<-",ArrowTabular)
Expand Down Expand Up @@ -137,6 +139,7 @@ S3method(names,Scanner)
S3method(names,ScannerBuilder)
S3method(names,Schema)
S3method(names,StructArray)
S3method(names,StructType)
S3method(names,Table)
S3method(names,arrow_dplyr_query)
S3method(print,"arrow-enum")
Expand Down
2 changes: 1 addition & 1 deletion r/R/arrow-object.R
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ ArrowObject <- R6Class("ArrowObject",
assign(".:xp:.", xp, envir = self)
},
class_title = function() {
if (!is.null(self$.class_title)) {
if (".class_title" %in% ls(self, all.names = TRUE)) {
# Allow subclasses to override just printing the class name first
class_title <- self$.class_title()
} else {
Expand Down
9 changes: 8 additions & 1 deletion r/R/arrowExports.R

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

55 changes: 55 additions & 0 deletions r/R/expression.R
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ Expression <- R6Class("Expression",
assert_that(!is.null(schema))
compute___expr__type_id(self, schema)
},
is_field_ref = function() {
compute___expr__is_field_ref(self)
},
cast = function(to_type, safe = TRUE, ...) {
opts <- cast_options(safe, ...)
opts$to_type <- as_type(to_type)
Expand Down Expand Up @@ -89,7 +92,59 @@ Expression$create <- function(function_name,
expr
}


#' @export
`[[.Expression` <- function(x, i, ...) get_nested_field(x, i)

#' @export
`$.Expression` <- function(x, name, ...) {
assert_that(is.string(name))
if (name %in% ls(x)) {
get(name, x)
} else {
get_nested_field(x, name)
}
}

get_nested_field <- function(expr, name) {
if (expr$is_field_ref()) {
# Make a nested field ref
# TODO(#33756): integer (positional) field refs are supported in C++
assert_that(is.string(name))
out <- compute___expr__nested_field_ref(expr, name)
} else {
# Use the struct_field kernel if expr is a struct:
expr_type <- expr$type() # errors if no schema set
if (inherits(expr_type, "StructType")) {
# Because we have the type, we can validate that the field exists
if (!(name %in% names(expr_type))) {
stop(
"field '", name, "' not found in ",
expr_type$ToString(),
call. = FALSE
)
}
out <- Expression$create(
"struct_field",
expr,
options = list(field_ref = Expression$field_ref(name))
)
} else {
# TODO(#33757): if expr is list type and name is integer or Expression,
# call list_element
stop(
"Cannot extract a field from an Expression of type ", expr_type$ToString(),
call. = FALSE
)
}
}
# Schema bookkeeping
out$schema <- expr$schema
out
}

Expression$field_ref <- function(name) {
# TODO(#33756): allow construction of field ref from integer
assert_that(is.string(name))
compute___expr__field_ref(name)
}
Expand Down
3 changes: 3 additions & 0 deletions r/R/type.R
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,9 @@ StructType$create <- function(...) struct__(.fields(list(...)))
#' @export
struct <- StructType$create

#' @export
names.StructType <- function(x) StructType__field_names(x)

ListType <- R6Class("ListType",
inherit = NestedType,
public = list(
Expand Down
19 changes: 19 additions & 0 deletions r/src/arrowExports.cpp

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

14 changes: 14 additions & 0 deletions r/src/compute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,20 @@ std::shared_ptr<arrow::compute::FunctionOptions> make_compute_options(
return out;
}

if (func_name == "struct_field") {
using Options = arrow::compute::StructFieldOptions;
if (!Rf_isNull(options["indices"])) {
return std::make_shared<Options>(
cpp11::as_cpp<std::vector<int>>(options["indices"]));
} else {
// field_ref
return std::make_shared<Options>(
*cpp11::as_cpp<std::shared_ptr<arrow::compute::Expression>>(
options["field_ref"])
->field_ref());
}
}

return nullptr;
}

Expand Down
40 changes: 38 additions & 2 deletions r/src/expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,26 @@ std::shared_ptr<compute::Expression> compute___expr__call(std::string func_name,
compute::call(std::move(func_name), std::move(arguments), std::move(options_ptr)));
}

// [[arrow::export]]
bool compute___expr__is_field_ref(const std::shared_ptr<compute::Expression>& x) {
return x->field_ref() != nullptr;
}

// [[arrow::export]]
std::vector<std::string> field_names_in_expression(
const std::shared_ptr<compute::Expression>& x) {
std::vector<std::string> out;
std::vector<arrow::FieldRef> nested;

auto field_refs = FieldsInExpression(*x);
for (auto f : field_refs) {
out.push_back(*f.name());
if (f.IsNested()) {
// We keep the top-level field name.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might not be used in practice (in a mutate call where you select the field, you also directly specify the resulting column name), but otherwise it might also make sense to keep the innermost field name?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is only used to prune columns in the dataset scanner, and IIRC that interface accepts column names, not FieldRefs, so I need the names of the top-level columns. But if I'm mistaken and we can use FieldRefs there now, we can refactor this.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can also specify field refs (well, generic expressions), but then you also need to pass the resulting name for the schema. See the second Project signature at

/// \brief Set the subset of columns to materialize.
///
/// Columns which are not referenced may not be read from fragments.
///
/// \param[in] columns list of columns to project. Order and duplicates will
/// be preserved.
///
/// \return Failure if any column name does not exists in the dataset's
/// Schema.
Status Project(std::vector<std::string> columns);
/// \brief Set expressions which will be evaluated to produce the materialized
/// columns.
///
/// Columns which are not referenced may not be read from fragments.
///
/// \param[in] exprs expressions to evaluate to produce columns.
/// \param[in] names list of names for the resulting columns.
///
/// \return Failure if any referenced column does not exists in the dataset's
/// Schema.
Status Project(std::vector<compute::Expression> exprs, std::vector<std::string> names);

which gets translated to ScanOptions.projection. It seems that is also what the R bindings actually do inside ExecNode_Scan (it will convert the materialized_field_names back to FieldRefs). Now, the scanner itself will also just use the top-level name of a nested field ref to do pruning of what it needs to read, so right now preserving the nested field ref is not useful. But ideally in the future we would optimize that for formats that can do that (like parquet, cfr #33167)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the pointers. I've deferred cleaning this up to #33760 since I see a few places where it could be more involved than just deleting code.

nested = *f.nested_refs();
out.push_back(*nested[0].name());
} else {
out.push_back(*f.name());
}
}
return out;
}
Expand All @@ -61,7 +74,11 @@ std::vector<std::string> field_names_in_expression(
std::string compute___expr__get_field_ref_name(
const std::shared_ptr<compute::Expression>& x) {
if (auto field_ref = x->field_ref()) {
return *field_ref->name();
// Exclude nested field refs because we only use this to determine if we have simple
// field refs
if (!field_ref->IsNested()) {
return *field_ref->name();
}
}
return "";
}
Expand All @@ -71,6 +88,25 @@ std::shared_ptr<compute::Expression> compute___expr__field_ref(std::string name)
return std::make_shared<compute::Expression>(compute::field_ref(std::move(name)));
}

// [[arrow::export]]
std::shared_ptr<compute::Expression> compute___expr__nested_field_ref(
const std::shared_ptr<compute::Expression>& x, std::string name) {
if (auto field_ref = x->field_ref()) {
std::vector<arrow::FieldRef> ref_vec;
if (field_ref->IsNested()) {
ref_vec = *field_ref->nested_refs();
} else {
// There's just one
ref_vec.push_back(*field_ref);
}
// Add the new ref
ref_vec.push_back(arrow::FieldRef(std::move(name)));
return std::make_shared<compute::Expression>(compute::field_ref(std::move(ref_vec)));
} else {
cpp11::stop("'x' must be a FieldRef Expression");
}
}

// [[arrow::export]]
std::shared_ptr<compute::Expression> compute___expr__scalar(
const std::shared_ptr<arrow::Scalar>& x) {
Expand Down
70 changes: 70 additions & 0 deletions r/tests/testthat/test-dplyr-query.R
Original file line number Diff line number Diff line change
Expand Up @@ -714,3 +714,73 @@ test_that("Scalars in expressions match the type of the field, if possible", {
collect()
expect_equal(result$tpc_h_1, result$as_dbl)
})

test_that("Can use nested field refs", {
nested_data <- tibble(int = 1:5, df_col = tibble(a = 6:10, b = 11:15))

compare_dplyr_binding(
.input %>%
mutate(
nested = df_col$a,
times2 = df_col$a * 2
) %>%
filter(nested > 7) %>%
collect(),
nested_data
)

compare_dplyr_binding(
.input %>%
mutate(
nested = df_col$a,
times2 = df_col$a * 2
) %>%
filter(nested > 7) %>%
summarize(sum(times2)) %>%
collect(),
nested_data
)

# Now with Dataset: make sure column pushdown in ScanNode works
expect_equal(
nested_data %>%
InMemoryDataset$create() %>%
mutate(
nested = df_col$a,
times2 = df_col$a * 2
) %>%
filter(nested > 7) %>%
collect(),
nested_data %>%
mutate(
nested = df_col$a,
times2 = df_col$a * 2
) %>%
filter(nested > 7)
)
})

test_that("Use struct_field for $ on non-field-ref", {
compare_dplyr_binding(
.input %>%
mutate(
df_col = tibble(i = int, d = dbl)
) %>%
transmute(
int2 = df_col$i,
dbl2 = df_col$d
) %>%
collect(),
example_data
)
})

test_that("nested field ref error handling", {
expect_error(
example_data %>%
arrow_table() %>%
mutate(x = int$nested) %>%
compute(),
"No match"
)
})
26 changes: 26 additions & 0 deletions r/tests/testthat/test-expression.R
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,15 @@ test_that("Field reference expression schemas and types", {
expect_equal(x$type(), int32())
})

test_that("Nested field refs", {
x <- Expression$field_ref("x")
nested <- x$y
expect_r6_class(nested, "Expression")
expect_r6_class(x[["y"]], "Expression")
expect_r6_class(nested$z, "Expression")
expect_error(Expression$scalar(42L)$y, "Cannot extract a field from an Expression of type int32")
})

test_that("Scalar expression schemas and types", {
# type() works on scalars without setting the schema
expect_equal(
Expand Down Expand Up @@ -127,3 +136,20 @@ test_that("Expression schemas and types", {
int32()
)
})

test_that("Nested field ref types", {
nested <- Expression$field_ref("x")$y
schm <- schema(x = struct(y = int32(), z = double()))
expect_equal(nested$type(schm), int32())
# implicit casting and schema propagation
x <- Expression$field_ref("x")
x$schema <- schm
expect_equal((x$y * 2)$type(), int32())
})

test_that("Nested field from a non-field-ref (struct_field kernel)", {
x <- Expression$scalar(data.frame(a = 1, b = "two"))
expect_true(inherits(x$a, "Expression"))
expect_equal(x$a$type(), float64())
expect_error(x$c, "field 'c' not found in struct<a: double, b: string>")
})