diff --git a/r/R/ChunkedArray.R b/r/R/ChunkedArray.R index 69a022494ac..fa9aaee1ca3 100644 --- a/r/R/ChunkedArray.R +++ b/r/R/ChunkedArray.R @@ -32,7 +32,7 @@ `arrow::ChunkedArray` <- R6Class("arrow::ChunkedArray", inherit = `arrow::Object`, public = list( length = function() ChunkedArray__length(self), - chunk = function(i) shared_ptr(`arrow::Array`, ChunkedArray__chunk(self, i)), + chunk = function(i) `arrow::Array`$dispatch(ChunkedArray__chunk(self, i)), as_vector = function() ChunkedArray__as_vector(self), Slice = function(offset, length = NULL){ if (is.null(length)) { @@ -50,7 +50,7 @@ active = list( null_count = function() ChunkedArray__null_count(self), num_chunks = function() ChunkedArray__num_chunks(self), - chunks = function() map(ChunkedArray__chunks(self), shared_ptr, class = `arrow::Array`), + chunks = function() map(ChunkedArray__chunks(self), ~ `arrow::Array`$dispatch(.x)), type = function() `arrow::DataType`$dispatch(ChunkedArray__type(self)) ) ) diff --git a/r/R/Struct.R b/r/R/Struct.R index ec786996c9d..820e1a895ef 100644 --- a/r/R/Struct.R +++ b/r/R/Struct.R @@ -18,11 +18,16 @@ #' @include R6.R `arrow::StructType` <- R6Class("arrow::StructType", - inherit = `arrow::NestedType` + inherit = `arrow::NestedType`, + public = list( + GetFieldByName = function(name) shared_ptr(`arrow::Field`, StructType__GetFieldByName(self, name)), + GetFieldIndex = function(name) StructType__GetFieldIndex(self, name) + ) ) #' @rdname DataType #' @export struct <- function(...){ - shared_ptr(`arrow::StructType`, struct_(.fields(list(...)))) + xp <- struct_(.fields(list(...))) + shared_ptr(`arrow::StructType`, xp) } diff --git a/r/R/array.R b/r/R/array.R index 244cee05aeb..b6e21ef8e69 100644 --- a/r/R/array.R +++ b/r/R/array.R @@ -103,10 +103,27 @@ ) ) +`arrow::DictionaryArray` <- R6Class("arrow::DictionaryArray", inherit = `arrow::Array`, + public = list( + indices = function() `arrow::Array`$dispatch(DictionaryArray__indices(self)), + dictionary = function() `arrow::Array`$dispatch(DictionaryArray__dictionary(self)) + ) +) + +`arrow::StructArray` <- R6Class("arrow::StructArray", inherit = `arrow::Array`, + public = list( + field = function(i) `arrow::Array`$dispatch(StructArray__field(self, i)), + GetFieldByName = function(name) `arrow::Array`$dispatch(StructArray__GetFieldByName(self, name)), + Flatten = function() map(StructArray__Flatten(self), ~ `arrow::Array`$dispatch(.x)) + ) +) + `arrow::Array`$dispatch <- function(xp){ a <- shared_ptr(`arrow::Array`, xp) if(a$type_id() == Type$DICTIONARY){ a <- shared_ptr(`arrow::DictionaryArray`, xp) + } else if (a$type_id() == Type$STRUCT) { + a <- shared_ptr(`arrow::StructArray`, xp) } a } @@ -126,11 +143,3 @@ array <- function(x, type = NULL){ `arrow::Array`$dispatch(Array__from_vector(x, type)) } - -`arrow::DictionaryArray` <- R6Class("arrow::DictionaryArray", inherit = `arrow::Array`, - public = list( - indices = function() `arrow::Array`$dispatch(DictionaryArray__indices(self)), - dictionary = function() `arrow::Array`$dispatch(DictionaryArray__dictionary(self)) - ) -) - diff --git a/r/R/arrowExports.R b/r/R/arrowExports.R index 52ff4921063..8609f9b85f1 100644 --- a/r/R/arrowExports.R +++ b/r/R/arrowExports.R @@ -68,6 +68,18 @@ DictionaryArray__dictionary <- function(array){ .Call(`_arrow_DictionaryArray__dictionary` , array) } +StructArray__field <- function(array, i){ + .Call(`_arrow_StructArray__field` , array, i) +} + +StructArray__GetFieldByName <- function(array, name){ + .Call(`_arrow_StructArray__GetFieldByName` , array, name) +} + +StructArray__Flatten <- function(array){ + .Call(`_arrow_StructArray__Flatten` , array) +} + Array__as_vector <- function(array){ .Call(`_arrow_Array__as_vector` , array) } @@ -436,6 +448,14 @@ DictionaryType__ordered <- function(type){ .Call(`_arrow_DictionaryType__ordered` , type) } +StructType__GetFieldByName <- function(type, name){ + .Call(`_arrow_StructType__GetFieldByName` , type, name) +} + +StructType__GetFieldIndex <- function(type, name){ + .Call(`_arrow_StructType__GetFieldIndex` , type, name) +} + ipc___feather___TableWriter__SetDescription <- function(writer, description){ invisible(.Call(`_arrow_ipc___feather___TableWriter__SetDescription` , writer, description)) } diff --git a/r/src/array.cpp b/r/src/array.cpp index 60fd7da8b9a..35da4b1e4b3 100644 --- a/r/src/array.cpp +++ b/r/src/array.cpp @@ -119,4 +119,25 @@ std::shared_ptr DictionaryArray__dictionary( return array->dictionary(); } +// [[arrow::export]] +std::shared_ptr StructArray__field( + const std::shared_ptr& array, int i) { + return array->field(i); +} + +// [[arrow::export]] +std::shared_ptr StructArray__GetFieldByName( + const std::shared_ptr& array, const std::string& name) { + return array->GetFieldByName(name); +} + +// [[arrow::export]] +arrow::ArrayVector StructArray__Flatten( + const std::shared_ptr& array) { + int nf = array->num_fields(); + arrow::ArrayVector out(nf); + STOP_IF_NOT_OK(array->Flatten(arrow::default_memory_pool(), &out)); + return out; +} + #endif diff --git a/r/src/array__to_vector.cpp b/r/src/array__to_vector.cpp index 17d0600d78a..4e26f8d53f5 100644 --- a/r/src/array__to_vector.cpp +++ b/r/src/array__to_vector.cpp @@ -345,6 +345,65 @@ class Converter_Dictionary : public Converter { } }; +class Converter_Struct : public Converter { + public: + explicit Converter_Struct(const ArrayVector& arrays) : Converter(arrays), converters() { + auto first_array = + internal::checked_cast(Converter::arrays_[0].get()); + int nf = first_array->num_fields(); + for (int i = 0; i < nf; i++) { + converters.push_back(Converter::Make({first_array->field(i)})); + } + } + + SEXP Allocate(R_xlen_t n) const { + // allocate a data frame column to host each array + auto first_array = + internal::checked_cast(Converter::arrays_[0].get()); + auto type = first_array->struct_type(); + int nf = first_array->num_fields(); + Rcpp::List out(nf); + Rcpp::CharacterVector colnames(nf); + for (int i = 0; i < nf; i++) { + out[i] = converters[i]->Allocate(n); + colnames[i] = type->child(i)->name(); + } + IntegerVector rn(2); + rn[0] = NA_INTEGER; + rn[1] = -n; + Rf_setAttrib(out, symbols::row_names, rn); + Rf_setAttrib(out, R_NamesSymbol, colnames); + Rf_setAttrib(out, R_ClassSymbol, Rf_mkString("data.frame")); + return out; + } + + Status Ingest_all_nulls(SEXP data, R_xlen_t start, R_xlen_t n) const { + int nf = converters.size(); + for (int i = 0; i < nf; i++) { + STOP_IF_NOT_OK(converters[i]->Ingest_all_nulls(VECTOR_ELT(data, i), start, n)); + } + return Status::OK(); + } + + Status Ingest_some_nulls(SEXP data, const std::shared_ptr& array, + R_xlen_t start, R_xlen_t n) const { + auto struct_array = internal::checked_cast(array.get()); + int nf = converters.size(); + // Flatten() deals with merging of nulls + ArrayVector arrays(nf); + STOP_IF_NOT_OK(struct_array->Flatten(default_memory_pool(), &arrays)); + for (int i = 0; i < nf; i++) { + STOP_IF_NOT_OK( + converters[i]->Ingest_some_nulls(VECTOR_ELT(data, i), arrays[i], start, n)); + } + + return Status::OK(); + } + + private: + std::vector> converters; +}; + double ms_to_seconds(int64_t ms) { return static_cast(ms / 1000); } class Converter_Date64 : public Converter { @@ -599,6 +658,9 @@ std::shared_ptr Converter::Make(const ArrayVector& arrays) { case Type::DECIMAL: return std::make_shared(arrays); + case Type::STRUCT: + return std::make_shared(arrays); + default: break; } diff --git a/r/src/arrowExports.cpp b/r/src/arrowExports.cpp index f16179b144e..2352184ec83 100644 --- a/r/src/arrowExports.cpp +++ b/r/src/arrowExports.cpp @@ -270,6 +270,53 @@ RcppExport SEXP _arrow_DictionaryArray__dictionary(SEXP array_sexp){ } #endif +// array.cpp +#if defined(ARROW_R_WITH_ARROW) +std::shared_ptr StructArray__field(const std::shared_ptr& array, int i); +RcppExport SEXP _arrow_StructArray__field(SEXP array_sexp, SEXP i_sexp){ +BEGIN_RCPP + Rcpp::traits::input_parameter&>::type array(array_sexp); + Rcpp::traits::input_parameter::type i(i_sexp); + return Rcpp::wrap(StructArray__field(array, i)); +END_RCPP +} +#else +RcppExport SEXP _arrow_StructArray__field(SEXP array_sexp, SEXP i_sexp){ + Rf_error("Cannot call StructArray__field(). Please use arrow::install_arrow() to install required runtime libraries. "); +} +#endif + +// array.cpp +#if defined(ARROW_R_WITH_ARROW) +std::shared_ptr StructArray__GetFieldByName(const std::shared_ptr& array, const std::string& name); +RcppExport SEXP _arrow_StructArray__GetFieldByName(SEXP array_sexp, SEXP name_sexp){ +BEGIN_RCPP + Rcpp::traits::input_parameter&>::type array(array_sexp); + Rcpp::traits::input_parameter::type name(name_sexp); + return Rcpp::wrap(StructArray__GetFieldByName(array, name)); +END_RCPP +} +#else +RcppExport SEXP _arrow_StructArray__GetFieldByName(SEXP array_sexp, SEXP name_sexp){ + Rf_error("Cannot call StructArray__GetFieldByName(). Please use arrow::install_arrow() to install required runtime libraries. "); +} +#endif + +// array.cpp +#if defined(ARROW_R_WITH_ARROW) +arrow::ArrayVector StructArray__Flatten(const std::shared_ptr& array); +RcppExport SEXP _arrow_StructArray__Flatten(SEXP array_sexp){ +BEGIN_RCPP + Rcpp::traits::input_parameter&>::type array(array_sexp); + return Rcpp::wrap(StructArray__Flatten(array)); +END_RCPP +} +#else +RcppExport SEXP _arrow_StructArray__Flatten(SEXP array_sexp){ + Rf_error("Cannot call StructArray__Flatten(). Please use arrow::install_arrow() to install required runtime libraries. "); +} +#endif + // array__to_vector.cpp #if defined(ARROW_R_WITH_ARROW) SEXP Array__as_vector(const std::shared_ptr& array); @@ -1664,6 +1711,38 @@ RcppExport SEXP _arrow_DictionaryType__ordered(SEXP type_sexp){ } #endif +// datatype.cpp +#if defined(ARROW_R_WITH_ARROW) +std::shared_ptr StructType__GetFieldByName(const std::shared_ptr& type, const std::string& name); +RcppExport SEXP _arrow_StructType__GetFieldByName(SEXP type_sexp, SEXP name_sexp){ +BEGIN_RCPP + Rcpp::traits::input_parameter&>::type type(type_sexp); + Rcpp::traits::input_parameter::type name(name_sexp); + return Rcpp::wrap(StructType__GetFieldByName(type, name)); +END_RCPP +} +#else +RcppExport SEXP _arrow_StructType__GetFieldByName(SEXP type_sexp, SEXP name_sexp){ + Rf_error("Cannot call StructType__GetFieldByName(). Please use arrow::install_arrow() to install required runtime libraries. "); +} +#endif + +// datatype.cpp +#if defined(ARROW_R_WITH_ARROW) +int StructType__GetFieldIndex(const std::shared_ptr& type, const std::string& name); +RcppExport SEXP _arrow_StructType__GetFieldIndex(SEXP type_sexp, SEXP name_sexp){ +BEGIN_RCPP + Rcpp::traits::input_parameter&>::type type(type_sexp); + Rcpp::traits::input_parameter::type name(name_sexp); + return Rcpp::wrap(StructType__GetFieldIndex(type, name)); +END_RCPP +} +#else +RcppExport SEXP _arrow_StructType__GetFieldIndex(SEXP type_sexp, SEXP name_sexp){ + Rf_error("Cannot call StructType__GetFieldIndex(). Please use arrow::install_arrow() to install required runtime libraries. "); +} +#endif + // feather.cpp #if defined(ARROW_R_WITH_ARROW) void ipc___feather___TableWriter__SetDescription(const std::unique_ptr& writer, const std::string& description); @@ -3289,6 +3368,9 @@ static const R_CallMethodDef CallEntries[] = { { "_arrow_Array__Mask", (DL_FUNC) &_arrow_Array__Mask, 1}, { "_arrow_DictionaryArray__indices", (DL_FUNC) &_arrow_DictionaryArray__indices, 1}, { "_arrow_DictionaryArray__dictionary", (DL_FUNC) &_arrow_DictionaryArray__dictionary, 1}, + { "_arrow_StructArray__field", (DL_FUNC) &_arrow_StructArray__field, 2}, + { "_arrow_StructArray__GetFieldByName", (DL_FUNC) &_arrow_StructArray__GetFieldByName, 2}, + { "_arrow_StructArray__Flatten", (DL_FUNC) &_arrow_StructArray__Flatten, 1}, { "_arrow_Array__as_vector", (DL_FUNC) &_arrow_Array__as_vector, 1}, { "_arrow_ChunkedArray__as_vector", (DL_FUNC) &_arrow_ChunkedArray__as_vector, 1}, { "_arrow_RecordBatch__to_dataframe", (DL_FUNC) &_arrow_RecordBatch__to_dataframe, 2}, @@ -3381,6 +3463,8 @@ static const R_CallMethodDef CallEntries[] = { { "_arrow_DictionaryType__value_type", (DL_FUNC) &_arrow_DictionaryType__value_type, 1}, { "_arrow_DictionaryType__name", (DL_FUNC) &_arrow_DictionaryType__name, 1}, { "_arrow_DictionaryType__ordered", (DL_FUNC) &_arrow_DictionaryType__ordered, 1}, + { "_arrow_StructType__GetFieldByName", (DL_FUNC) &_arrow_StructType__GetFieldByName, 2}, + { "_arrow_StructType__GetFieldIndex", (DL_FUNC) &_arrow_StructType__GetFieldIndex, 2}, { "_arrow_ipc___feather___TableWriter__SetDescription", (DL_FUNC) &_arrow_ipc___feather___TableWriter__SetDescription, 2}, { "_arrow_ipc___feather___TableWriter__SetNumRows", (DL_FUNC) &_arrow_ipc___feather___TableWriter__SetNumRows, 2}, { "_arrow_ipc___feather___TableWriter__Append", (DL_FUNC) &_arrow_ipc___feather___TableWriter__Append, 3}, diff --git a/r/src/arrow_types.h b/r/src/arrow_types.h index ca1e2d6bf7f..c93d4487f1d 100644 --- a/r/src/arrow_types.h +++ b/r/src/arrow_types.h @@ -31,6 +31,7 @@ struct symbols { static SEXP xp; static SEXP dot_Internal; static SEXP inspect; + static SEXP row_names; }; } // namespace r } // namespace arrow @@ -172,9 +173,9 @@ inline std::shared_ptr extract(SEXP x) { #include #include #include +#include #include #include -#include RCPP_EXPOSED_ENUM_NODECL(arrow::Type::type) RCPP_EXPOSED_ENUM_NODECL(arrow::DateUnit) diff --git a/r/src/datatype.cpp b/r/src/datatype.cpp index 0ab881dd6c6..18920f22713 100644 --- a/r/src/datatype.cpp +++ b/r/src/datatype.cpp @@ -269,4 +269,16 @@ bool DictionaryType__ordered(const std::shared_ptr& type) return type->ordered(); } +// [[arrow::export]] +std::shared_ptr StructType__GetFieldByName( + const std::shared_ptr& type, const std::string& name) { + return type->GetFieldByName(name); +} + +// [[arrow::export]] +int StructType__GetFieldIndex(const std::shared_ptr& type, + const std::string& name) { + return type->GetFieldIndex(name); +} + #endif diff --git a/r/src/symbols.cpp b/r/src/symbols.cpp index de3fcf90131..828033bf82d 100644 --- a/r/src/symbols.cpp +++ b/r/src/symbols.cpp @@ -23,6 +23,7 @@ SEXP symbols::units = Rf_install("units"); SEXP symbols::xp = Rf_install(".:xp:."); SEXP symbols::dot_Internal = Rf_install(".Internal"); SEXP symbols::inspect = Rf_install("inspect"); +SEXP symbols::row_names = Rf_install("row.names"); void inspect(SEXP obj) { Rcpp::Shield call_inspect(Rf_lang2(symbols::inspect, obj)); diff --git a/r/tests/testthat/test-DataType.R b/r/tests/testthat/test-DataType.R index 5faf7214649..6f77b3b87e3 100644 --- a/r/tests/testthat/test-DataType.R +++ b/r/tests/testthat/test-DataType.R @@ -311,6 +311,13 @@ test_that("struct type works as expected", { x$children(), list(field("x", int32()), field("y", boolean())) ) + expect_equal(x$GetFieldIndex("x"), 0L) + expect_equal(x$GetFieldIndex("y"), 1L) + expect_equal(x$GetFieldIndex("z"), -1L) + + expect_equal(x$GetFieldByName("x"), field("x", int32())) + expect_equal(x$GetFieldByName("y"), field("y", boolean())) + expect_null(x$GetFieldByName("z")) }) test_that("DictionaryType works as expected (ARROW-3355)", { diff --git a/r/tests/testthat/test-json.R b/r/tests/testthat/test-json.R index 0321fb4d35d..38b20a84f4c 100644 --- a/r/tests/testthat/test-json.R +++ b/r/tests/testthat/test-json.R @@ -75,10 +75,12 @@ test_that("read_json_arrow() converts to tibble", { test_that("Can read json file with nested columns (ARROW-5503)", { tf <- tempfile() writeLines(' - { "hello": 3.5, "world": false, "yo": "thing", "arr": [1, 2, 3], "nuf": {} } - { "hello": 3.25, "world": null, "arr": [2], "nuf": null } - { "hello": 3.125, "world": null, "yo": "\u5fcd", "arr": [], "nuf": { "ps": 78 } } - { "hello": 0.0, "world": true, "yo": null, "arr": null, "nuf": { "ps": 90 } } + { "nuf": {} } + { "nuf": null } + { "nuf": { "ps": 78.0, "hello": "hi" } } + { "nuf": { "ps": 90.0, "hello": "bonjour" } } + { "nuf": { "hello": "ciao" } } + { "nuf": { "ps": 19 } } ', tf) tab1 <- read_json_arrow(tf, as_tibble = FALSE) @@ -91,13 +93,21 @@ test_that("Can read json file with nested columns (ARROW-5503)", { expect_equal( tab1$schema, schema( - hello = float64(), - world = boolean(), - yo = utf8(), - arr = list_of(int64()), - nuf = struct(ps = int64()) + nuf = struct(ps = float64(), hello = utf8()) ) ) + + struct_array <- tab1$column(0)$data()$chunk(0) + ps <- array(c(NA, NA, 78, 90, NA, 19)) + hello <- array(c(NA, NA, "hi", "bonjour", "ciao", NA)) + expect_equal(struct_array$field(0L), ps) + expect_equal(struct_array$GetFieldByName("ps"), ps) + expect_equal(struct_array$Flatten(), list(ps, hello)) + expect_equal( + struct_array$as_vector(), + data.frame(ps = ps$as_vector(), hello = hello$as_vector(), stringsAsFactors = FALSE) + ) + # cannot yet test list and struct types in R api # tib <- as.data.frame(tab1)