diff --git a/r/DESCRIPTION b/r/DESCRIPTION index ee5302833ec..ffc473f365d 100644 --- a/r/DESCRIPTION +++ b/r/DESCRIPTION @@ -85,6 +85,7 @@ Collate: 'record-batch-writer.R' 'reexports-bit64.R' 'reexports-tidyselect.R' + 'scalar.R' 'schema.R' 'struct.R' 'util.R' diff --git a/r/NAMESPACE b/r/NAMESPACE index c8657c853d2..6a8c0f9b3ea 100644 --- a/r/NAMESPACE +++ b/r/NAMESPACE @@ -22,6 +22,7 @@ S3method(as.list,Table) S3method(as.raw,Buffer) S3method(as.vector,Array) S3method(as.vector,ChunkedArray) +S3method(as.vector,Scalar) S3method(as.vector,array_expression) S3method(c,Dataset) S3method(dim,Dataset) @@ -37,9 +38,11 @@ S3method(head,Table) S3method(is.na,Array) S3method(is.na,ChunkedArray) S3method(is.na,Expression) +S3method(is.na,Scalar) S3method(is.na,array_expression) S3method(length,Array) S3method(length,ChunkedArray) +S3method(length,Scalar) S3method(length,Schema) S3method(names,Dataset) S3method(names,RecordBatch) @@ -120,6 +123,7 @@ export(RecordBatchFileWriter) export(RecordBatchStreamReader) export(RecordBatchStreamWriter) export(S3FileSystem) +export(Scalar) export(Scanner) export(ScannerBuilder) export(Schema) diff --git a/r/R/array.R b/r/R/array.R index ceb29941646..4cce6114646 100644 --- a/r/R/array.R +++ b/r/R/array.R @@ -128,17 +128,19 @@ Array <- R6Class("Array", i <- Array$create(i) } if (inherits(i, "ChunkedArray")) { + # Invalid: Kernel does not support chunked array arguments + # so use the old method return(shared_ptr(ChunkedArray, Array__TakeChunked(self, i))) } assert_is(i, "Array") - Array$create(Array__Take(self, i)) + Array$create(call_function("take", self, i)) }, Filter = function(i, keep_na = TRUE) { if (is.logical(i)) { i <- Array$create(i) } assert_is(i, "Array") - Array$create(Array__Filter(self, i, keep_na)) + Array$create(call_function("filter", self, i, options = list(keep_na = keep_na))) }, RangeEquals = function(other, start_idx, end_idx, other_start_idx = 0L) { assert_is(other, "Array") diff --git a/r/R/arrowExports.R b/r/R/arrowExports.R index 6d84bb925d4..9830673eb18 100644 --- a/r/R/arrowExports.R +++ b/r/R/arrowExports.R @@ -272,10 +272,6 @@ Table__cast <- function(table, schema, options){ .Call(`_arrow_Table__cast` , table, schema, options) } -Array__Take <- function(values, indices){ - .Call(`_arrow_Array__Take` , values, indices) -} - Array__TakeChunked <- function(values, indices){ .Call(`_arrow_Array__TakeChunked` , values, indices) } @@ -300,10 +296,6 @@ Table__TakeChunked <- function(table, indices){ .Call(`_arrow_Table__TakeChunked` , table, indices) } -Array__Filter <- function(values, filter, keep_na){ - .Call(`_arrow_Array__Filter` , values, filter, keep_na) -} - RecordBatch__Filter <- function(batch, filter, keep_na){ .Call(`_arrow_RecordBatch__Filter` , batch, filter, keep_na) } @@ -324,6 +316,10 @@ Table__FilterChunked <- function(table, filter, keep_na){ .Call(`_arrow_Table__FilterChunked` , table, filter, keep_na) } +compute__CallFunction <- function(func_name, args, options){ + .Call(`_arrow_compute__CallFunction` , func_name, args, options) +} + csv___ReadOptions__initialize <- function(options){ .Call(`_arrow_csv___ReadOptions__initialize` , options) } @@ -1380,6 +1376,30 @@ ipc___RecordBatchStreamWriter__Open <- function(stream, schema, use_legacy_forma .Call(`_arrow_ipc___RecordBatchStreamWriter__Open` , stream, schema, use_legacy_format) } +Array__GetScalar <- function(x, i){ + .Call(`_arrow_Array__GetScalar` , x, i) +} + +Scalar__ToString <- function(s){ + .Call(`_arrow_Scalar__ToString` , s) +} + +Scalar__CastTo <- function(s, t){ + .Call(`_arrow_Scalar__CastTo` , s, t) +} + +Scalar__as_vector <- function(scalar){ + .Call(`_arrow_Scalar__as_vector` , scalar) +} + +Scalar__is_valid <- function(s){ + .Call(`_arrow_Scalar__is_valid` , s) +} + +Scalar__type <- function(s){ + .Call(`_arrow_Scalar__type` , s) +} + schema_ <- function(fields){ .Call(`_arrow_schema_` , fields) } diff --git a/r/R/chunked-array.R b/r/R/chunked-array.R index f352705e13c..34e383c0835 100644 --- a/r/R/chunked-array.R +++ b/r/R/chunked-array.R @@ -75,6 +75,8 @@ ChunkedArray <- R6Class("ChunkedArray", inherit = ArrowObject, if (is.integer(i)) { i <- Array$create(i) } + # Invalid: Kernel does not support chunked array arguments + # so use the old method for both cases if (inherits(i, "ChunkedArray")) { return(shared_ptr(ChunkedArray, ChunkedArray__TakeChunked(self, i))) } diff --git a/r/R/compute.R b/r/R/compute.R index f9f871dff0f..000c5c86f07 100644 --- a/r/R/compute.R +++ b/r/R/compute.R @@ -17,6 +17,11 @@ #' @include array.R +call_function <- function(function_name, ..., options = list()) { + assert_that(is.string(function_name)) + compute__CallFunction(function_name, list(...), options) +} + CastOptions <- R6Class("CastOptions", inherit = ArrowObject) #' Cast options diff --git a/r/R/expression.R b/r/R/expression.R index 80c71b270c7..338e15260a2 100644 --- a/r/R/expression.R +++ b/r/R/expression.R @@ -83,8 +83,7 @@ Expression$field_ref <- function(name) { shared_ptr(Expression, dataset___expr__field_ref(name)) } Expression$scalar <- function(x) { - stopifnot(vec_size(x) == 1L || is.null(x)) - shared_ptr(Expression, dataset___expr__scalar(x)) + shared_ptr(Expression, dataset___expr__scalar(Scalar$create(x))) } Expression$compare <- function(OP, e1, e2) { comp_func <- comparison_function_map[[OP]] diff --git a/r/R/record-batch.R b/r/R/record-batch.R index a2a25d314d4..a75c5cbd78e 100644 --- a/r/R/record-batch.R +++ b/r/R/record-batch.R @@ -119,6 +119,8 @@ RecordBatch <- R6Class("RecordBatch", inherit = ArrowObject, i <- Array$create(i) } assert_is(i, "Array") + # Invalid: Tried executing function with non-value type: RecordBatch + # so use old methods shared_ptr(RecordBatch, RecordBatch__Take(self, i)) }, Filter = function(i, keep_na = TRUE) { diff --git a/r/R/scalar.R b/r/R/scalar.R new file mode 100644 index 00000000000..df06f7b3a38 --- /dev/null +++ b/r/R/scalar.R @@ -0,0 +1,64 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +#' @include arrow-package.R + +#' @title Arrow scalars +#' @usage NULL +#' @format NULL +#' @docType class +#' +#' @description A `Scalar` holds a single value of an Arrow type. +#' +#' @name Scalar +#' @rdname Scalar +#' @export +Scalar <- R6Class("Scalar", inherit = ArrowObject, + # TODO: document the methods + public = list( + ToString = function() Scalar__ToString(self), + cast = function(target_type) { + Scalar$create(Scalar__CastTo(self, as_type(target_type))) + }, + as_vector = function() Scalar__as_vector(self) + ), + active = list( + is_valid = function() Scalar__is_valid(self), + type = function() DataType$create(Scalar__type(self)) + ) +) +Scalar$create <- function(x, type = NULL) { + if (!inherits(x, "externalptr")) { + if (is.null(x)) { + x <- vctrs::unspecified(1) + } else if (length(x) != 1 && !is.data.frame(x)) { + # Wrap in a list type + x <- list(x) + } + x <- Array__GetScalar(Array$create(x, type = type), 0) + } + shared_ptr(Scalar, x) +} + +#' @export +length.Scalar <- function(x) 1L + +#' @export +is.na.Scalar <- function(x) !x$is_valid + +#' @export +as.vector.Scalar <- function(x, mode) x$as_vector() diff --git a/r/man/Scalar.Rd b/r/man/Scalar.Rd new file mode 100644 index 00000000000..2ef5b02ccbe --- /dev/null +++ b/r/man/Scalar.Rd @@ -0,0 +1,9 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/scalar.R +\docType{class} +\name{Scalar} +\alias{Scalar} +\title{Arrow scalars} +\description{ +A \code{Scalar} holds a single value of an Arrow type. +} diff --git a/r/src/arrowExports.cpp b/r/src/arrowExports.cpp index 1918b860c80..42ca64910b2 100644 --- a/r/src/arrowExports.cpp +++ b/r/src/arrowExports.cpp @@ -1069,22 +1069,6 @@ RcppExport SEXP _arrow_Table__cast(SEXP table_sexp, SEXP schema_sexp, SEXP optio } #endif -// compute.cpp -#if defined(ARROW_R_WITH_ARROW) -std::shared_ptr Array__Take(const std::shared_ptr& values, const std::shared_ptr& indices); -RcppExport SEXP _arrow_Array__Take(SEXP values_sexp, SEXP indices_sexp){ -BEGIN_RCPP - Rcpp::traits::input_parameter&>::type values(values_sexp); - Rcpp::traits::input_parameter&>::type indices(indices_sexp); - return Rcpp::wrap(Array__Take(values, indices)); -END_RCPP -} -#else -RcppExport SEXP _arrow_Array__Take(SEXP values_sexp, SEXP indices_sexp){ - Rf_error("Cannot call Array__Take(). Please use arrow::install_arrow() to install required runtime libraries. "); -} -#endif - // compute.cpp #if defined(ARROW_R_WITH_ARROW) std::shared_ptr Array__TakeChunked(const std::shared_ptr& values, const std::shared_ptr& indices); @@ -1181,23 +1165,6 @@ RcppExport SEXP _arrow_Table__TakeChunked(SEXP table_sexp, SEXP indices_sexp){ } #endif -// compute.cpp -#if defined(ARROW_R_WITH_ARROW) -std::shared_ptr Array__Filter(const std::shared_ptr& values, const std::shared_ptr& filter, bool keep_na); -RcppExport SEXP _arrow_Array__Filter(SEXP values_sexp, SEXP filter_sexp, SEXP keep_na_sexp){ -BEGIN_RCPP - Rcpp::traits::input_parameter&>::type values(values_sexp); - Rcpp::traits::input_parameter&>::type filter(filter_sexp); - Rcpp::traits::input_parameter::type keep_na(keep_na_sexp); - return Rcpp::wrap(Array__Filter(values, filter, keep_na)); -END_RCPP -} -#else -RcppExport SEXP _arrow_Array__Filter(SEXP values_sexp, SEXP filter_sexp, SEXP keep_na_sexp){ - Rf_error("Cannot call Array__Filter(). Please use arrow::install_arrow() to install required runtime libraries. "); -} -#endif - // compute.cpp #if defined(ARROW_R_WITH_ARROW) std::shared_ptr RecordBatch__Filter(const std::shared_ptr& batch, const std::shared_ptr& filter, bool keep_na); @@ -1283,6 +1250,23 @@ RcppExport SEXP _arrow_Table__FilterChunked(SEXP table_sexp, SEXP filter_sexp, S } #endif +// compute.cpp +#if defined(ARROW_R_WITH_ARROW) +SEXP compute__CallFunction(std::string func_name, List_ args, List_ options); +RcppExport SEXP _arrow_compute__CallFunction(SEXP func_name_sexp, SEXP args_sexp, SEXP options_sexp){ +BEGIN_RCPP + Rcpp::traits::input_parameter::type func_name(func_name_sexp); + Rcpp::traits::input_parameter::type args(args_sexp); + Rcpp::traits::input_parameter::type options(options_sexp); + return Rcpp::wrap(compute__CallFunction(func_name, args, options)); +END_RCPP +} +#else +RcppExport SEXP _arrow_compute__CallFunction(SEXP func_name_sexp, SEXP args_sexp, SEXP options_sexp){ + Rf_error("Cannot call compute__CallFunction(). Please use arrow::install_arrow() to install required runtime libraries. "); +} +#endif + // csv.cpp #if defined(ARROW_R_WITH_ARROW) std::shared_ptr csv___ReadOptions__initialize(List_ options); @@ -2778,10 +2762,10 @@ RcppExport SEXP _arrow_dataset___expr__is_valid(SEXP lhs_sexp){ // expression.cpp #if defined(ARROW_R_WITH_ARROW) -std::shared_ptr dataset___expr__scalar(SEXP x); +std::shared_ptr dataset___expr__scalar(const std::shared_ptr& x); RcppExport SEXP _arrow_dataset___expr__scalar(SEXP x_sexp){ BEGIN_RCPP - Rcpp::traits::input_parameter::type x(x_sexp); + Rcpp::traits::input_parameter&>::type x(x_sexp); return Rcpp::wrap(dataset___expr__scalar(x)); END_RCPP } @@ -5409,6 +5393,98 @@ RcppExport SEXP _arrow_ipc___RecordBatchStreamWriter__Open(SEXP stream_sexp, SEX } #endif +// scalar.cpp +#if defined(ARROW_R_WITH_ARROW) +std::shared_ptr Array__GetScalar(const std::shared_ptr& x, int64_t i); +RcppExport SEXP _arrow_Array__GetScalar(SEXP x_sexp, SEXP i_sexp){ +BEGIN_RCPP + Rcpp::traits::input_parameter&>::type x(x_sexp); + Rcpp::traits::input_parameter::type i(i_sexp); + return Rcpp::wrap(Array__GetScalar(x, i)); +END_RCPP +} +#else +RcppExport SEXP _arrow_Array__GetScalar(SEXP x_sexp, SEXP i_sexp){ + Rf_error("Cannot call Array__GetScalar(). Please use arrow::install_arrow() to install required runtime libraries. "); +} +#endif + +// scalar.cpp +#if defined(ARROW_R_WITH_ARROW) +std::string Scalar__ToString(const std::shared_ptr& s); +RcppExport SEXP _arrow_Scalar__ToString(SEXP s_sexp){ +BEGIN_RCPP + Rcpp::traits::input_parameter&>::type s(s_sexp); + return Rcpp::wrap(Scalar__ToString(s)); +END_RCPP +} +#else +RcppExport SEXP _arrow_Scalar__ToString(SEXP s_sexp){ + Rf_error("Cannot call Scalar__ToString(). Please use arrow::install_arrow() to install required runtime libraries. "); +} +#endif + +// scalar.cpp +#if defined(ARROW_R_WITH_ARROW) +std::shared_ptr Scalar__CastTo(const std::shared_ptr& s, const std::shared_ptr& t); +RcppExport SEXP _arrow_Scalar__CastTo(SEXP s_sexp, SEXP t_sexp){ +BEGIN_RCPP + Rcpp::traits::input_parameter&>::type s(s_sexp); + Rcpp::traits::input_parameter&>::type t(t_sexp); + return Rcpp::wrap(Scalar__CastTo(s, t)); +END_RCPP +} +#else +RcppExport SEXP _arrow_Scalar__CastTo(SEXP s_sexp, SEXP t_sexp){ + Rf_error("Cannot call Scalar__CastTo(). Please use arrow::install_arrow() to install required runtime libraries. "); +} +#endif + +// scalar.cpp +#if defined(ARROW_R_WITH_ARROW) +SEXP Scalar__as_vector(const std::shared_ptr& scalar); +RcppExport SEXP _arrow_Scalar__as_vector(SEXP scalar_sexp){ +BEGIN_RCPP + Rcpp::traits::input_parameter&>::type scalar(scalar_sexp); + return Rcpp::wrap(Scalar__as_vector(scalar)); +END_RCPP +} +#else +RcppExport SEXP _arrow_Scalar__as_vector(SEXP scalar_sexp){ + Rf_error("Cannot call Scalar__as_vector(). Please use arrow::install_arrow() to install required runtime libraries. "); +} +#endif + +// scalar.cpp +#if defined(ARROW_R_WITH_ARROW) +bool Scalar__is_valid(const std::shared_ptr& s); +RcppExport SEXP _arrow_Scalar__is_valid(SEXP s_sexp){ +BEGIN_RCPP + Rcpp::traits::input_parameter&>::type s(s_sexp); + return Rcpp::wrap(Scalar__is_valid(s)); +END_RCPP +} +#else +RcppExport SEXP _arrow_Scalar__is_valid(SEXP s_sexp){ + Rf_error("Cannot call Scalar__is_valid(). Please use arrow::install_arrow() to install required runtime libraries. "); +} +#endif + +// scalar.cpp +#if defined(ARROW_R_WITH_ARROW) +std::shared_ptr Scalar__type(const std::shared_ptr& s); +RcppExport SEXP _arrow_Scalar__type(SEXP s_sexp){ +BEGIN_RCPP + Rcpp::traits::input_parameter&>::type s(s_sexp); + return Rcpp::wrap(Scalar__type(s)); +END_RCPP +} +#else +RcppExport SEXP _arrow_Scalar__type(SEXP s_sexp){ + Rf_error("Cannot call Scalar__type(). Please use arrow::install_arrow() to install required runtime libraries. "); +} +#endif + // schema.cpp #if defined(ARROW_R_WITH_ARROW) std::shared_ptr schema_(Rcpp::List fields); @@ -5997,19 +6073,18 @@ static const R_CallMethodDef CallEntries[] = { { "_arrow_ChunkedArray__cast", (DL_FUNC) &_arrow_ChunkedArray__cast, 3}, { "_arrow_RecordBatch__cast", (DL_FUNC) &_arrow_RecordBatch__cast, 3}, { "_arrow_Table__cast", (DL_FUNC) &_arrow_Table__cast, 3}, - { "_arrow_Array__Take", (DL_FUNC) &_arrow_Array__Take, 2}, { "_arrow_Array__TakeChunked", (DL_FUNC) &_arrow_Array__TakeChunked, 2}, { "_arrow_RecordBatch__Take", (DL_FUNC) &_arrow_RecordBatch__Take, 2}, { "_arrow_ChunkedArray__Take", (DL_FUNC) &_arrow_ChunkedArray__Take, 2}, { "_arrow_ChunkedArray__TakeChunked", (DL_FUNC) &_arrow_ChunkedArray__TakeChunked, 2}, { "_arrow_Table__Take", (DL_FUNC) &_arrow_Table__Take, 2}, { "_arrow_Table__TakeChunked", (DL_FUNC) &_arrow_Table__TakeChunked, 2}, - { "_arrow_Array__Filter", (DL_FUNC) &_arrow_Array__Filter, 3}, { "_arrow_RecordBatch__Filter", (DL_FUNC) &_arrow_RecordBatch__Filter, 3}, { "_arrow_ChunkedArray__Filter", (DL_FUNC) &_arrow_ChunkedArray__Filter, 3}, { "_arrow_ChunkedArray__FilterChunked", (DL_FUNC) &_arrow_ChunkedArray__FilterChunked, 3}, { "_arrow_Table__Filter", (DL_FUNC) &_arrow_Table__Filter, 3}, { "_arrow_Table__FilterChunked", (DL_FUNC) &_arrow_Table__FilterChunked, 3}, + { "_arrow_compute__CallFunction", (DL_FUNC) &_arrow_compute__CallFunction, 3}, { "_arrow_csv___ReadOptions__initialize", (DL_FUNC) &_arrow_csv___ReadOptions__initialize, 1}, { "_arrow_csv___ParseOptions__initialize", (DL_FUNC) &_arrow_csv___ParseOptions__initialize, 1}, { "_arrow_csv___ConvertOptions__initialize", (DL_FUNC) &_arrow_csv___ConvertOptions__initialize, 1}, @@ -6274,6 +6349,12 @@ static const R_CallMethodDef CallEntries[] = { { "_arrow_ipc___RecordBatchWriter__Close", (DL_FUNC) &_arrow_ipc___RecordBatchWriter__Close, 1}, { "_arrow_ipc___RecordBatchFileWriter__Open", (DL_FUNC) &_arrow_ipc___RecordBatchFileWriter__Open, 3}, { "_arrow_ipc___RecordBatchStreamWriter__Open", (DL_FUNC) &_arrow_ipc___RecordBatchStreamWriter__Open, 3}, + { "_arrow_Array__GetScalar", (DL_FUNC) &_arrow_Array__GetScalar, 2}, + { "_arrow_Scalar__ToString", (DL_FUNC) &_arrow_Scalar__ToString, 1}, + { "_arrow_Scalar__CastTo", (DL_FUNC) &_arrow_Scalar__CastTo, 2}, + { "_arrow_Scalar__as_vector", (DL_FUNC) &_arrow_Scalar__as_vector, 1}, + { "_arrow_Scalar__is_valid", (DL_FUNC) &_arrow_Scalar__is_valid, 1}, + { "_arrow_Scalar__type", (DL_FUNC) &_arrow_Scalar__type, 1}, { "_arrow_schema_", (DL_FUNC) &_arrow_schema_, 1}, { "_arrow_Schema__ToString", (DL_FUNC) &_arrow_Schema__ToString, 1}, { "_arrow_Schema__num_fields", (DL_FUNC) &_arrow_Schema__num_fields, 1}, diff --git a/r/src/compute.cpp b/r/src/compute.cpp index 0a18cf33522..a30b53b98e4 100644 --- a/r/src/compute.cpp +++ b/r/src/compute.cpp @@ -20,6 +20,8 @@ #if defined(ARROW_R_WITH_ARROW) #include +using Rcpp::List_; + // [[arrow::export]] std::shared_ptr compute___CastOptions__initialize( bool allow_int_overflow, bool allow_time_truncate, bool allow_float_truncate) { @@ -78,13 +80,6 @@ std::shared_ptr Table__cast( return arrow::Table::Make(schema, std::move(columns), table->num_rows()); } -// [[arrow::export]] -std::shared_ptr Array__Take(const std::shared_ptr& values, - const std::shared_ptr& indices) { - arrow::compute::TakeOptions options; - return ValueOrStop(arrow::compute::Take(*values, *indices, options)); -} - // [[arrow::export]] std::shared_ptr Array__TakeChunked( const std::shared_ptr& values, @@ -132,19 +127,6 @@ std::shared_ptr Table__TakeChunked( return ValueOrStop(arrow::compute::Take(*table, *indices, options)); } -// [[arrow::export]] -std::shared_ptr Array__Filter(const std::shared_ptr& values, - const std::shared_ptr& filter, - bool keep_na) { - // Use the EMIT_NULL filter option to match R's behavior in [ - arrow::compute::FilterOptions options; - if (keep_na) { - options.null_selection_behavior = arrow::compute::FilterOptions::EMIT_NULL; - } - arrow::Datum out = ValueOrStop(arrow::compute::Filter(values, filter, options)); - return out.make_array(); -} - // [[arrow::export]] std::shared_ptr RecordBatch__Filter( const std::shared_ptr& batch, @@ -223,4 +205,96 @@ std::shared_ptr Table__FilterChunked( } return tab; } + +template +std::shared_ptr MaybeUnbox(const char* class_name, SEXP x) { + if (Rf_inherits(x, "ArrowObject") && Rf_inherits(x, class_name)) { + Rcpp::ConstReferenceSmartPtrInputParameter> obj(x); + return static_cast>(obj); + } + return nullptr; +} + +arrow::Datum to_datum(SEXP x) { + if (auto array = MaybeUnbox("Array", x)) { + return array; + } + + if (auto chunked_array = MaybeUnbox("ChunkedArray", x)) { + return chunked_array; + } + + if (auto batch = MaybeUnbox("RecordBatch", x)) { + return batch; + } + + if (auto table = MaybeUnbox("Table", x)) { + return table; + } + + if (auto scalar = MaybeUnbox("Scalar", x)) { + return scalar; + } + + // This assumes that R objects have already been converted to Arrow objects; + // that seems right but should we do the wrapping here too/instead? + Rcpp::stop("to_datum: Not implemented for type %s", Rf_type2char(TYPEOF(x))); +} + +SEXP from_datum(arrow::Datum datum) { + switch (datum.kind()) { + case arrow::Datum::SCALAR: + return Rcpp::wrap(datum.scalar()); + + case arrow::Datum::ARRAY: + return Rcpp::wrap(datum.make_array()); + + case arrow::Datum::CHUNKED_ARRAY: + return Rcpp::wrap(datum.chunked_array()); + + case arrow::Datum::RECORD_BATCH: + return Rcpp::wrap(datum.record_batch()); + + case arrow::Datum::TABLE: + return Rcpp::wrap(datum.table()); + + default: + break; + } + + auto str = datum.ToString(); + Rcpp::stop("from_datum: Not implemented for Datum %s", str.c_str()); +} + +std::shared_ptr make_compute_options( + std::string func_name, List_ options) { + if (func_name == "filter") { + auto out = std::make_shared( + arrow::compute::FilterOptions::Defaults()); + if (!Rf_isNull(options["keep_na"]) && options["keep_na"]) { + out->null_selection_behavior = arrow::compute::FilterOptions::EMIT_NULL; + } + return out; + } + + if (func_name == "take") { + auto out = std::make_shared( + arrow::compute::TakeOptions::Defaults()); + return out; + } + + return nullptr; +} + +// [[arrow::export]] +SEXP compute__CallFunction(std::string func_name, List_ args, List_ options) { + auto opts = make_compute_options(func_name, options); + std::vector datum_args; + for (auto arg : args) { + datum_args.push_back(to_datum(arg)); + } + auto out = ValueOrStop(arrow::compute::CallFunction(func_name, datum_args, opts.get())); + return from_datum(out); +} + #endif diff --git a/r/src/expression.cpp b/r/src/expression.cpp index 857c6a0e081..575dddc7da7 100644 --- a/r/src/expression.cpp +++ b/r/src/expression.cpp @@ -103,58 +103,9 @@ std::shared_ptr dataset___expr__is_valid( } // [[arrow::export]] -std::shared_ptr dataset___expr__scalar(SEXP x) { - switch (TYPEOF(x)) { - case NILSXP: - return ds::scalar(std::make_shared()); - case LGLSXP: - return ds::scalar(Rf_asLogical(x)); - case REALSXP: - if (Rf_inherits(x, "Date")) { - return ds::scalar(std::make_shared(REAL(x)[0])); - } else if (Rf_inherits(x, "POSIXct")) { - return ds::scalar(std::make_shared( - REAL(x)[0], arrow::timestamp(arrow::TimeUnit::SECOND))); - } else if (Rf_inherits(x, "integer64")) { - int64_t value = *reinterpret_cast(REAL(x)); - return ds::scalar(value); - } else if (Rf_inherits(x, "difftime")) { - int multiplier = 0; - // TODO: shared with TimeConverter<> in array_from_vector.cpp - std::string unit(CHAR(STRING_ELT(Rf_getAttrib(x, arrow::r::symbols::units), 0))); - if (unit == "secs") { - multiplier = 1; - } else if (unit == "mins") { - multiplier = 60; - } else if (unit == "hours") { - multiplier = 3600; - } else if (unit == "days") { - multiplier = 86400; - } else if (unit == "weeks") { - multiplier = 604800; - } else { - Rcpp::stop("unknown difftime unit"); - } - return ds::scalar(std::make_shared( - static_cast(REAL(x)[0] * multiplier), - arrow::time32(arrow::TimeUnit::SECOND))); - } - return ds::scalar(Rf_asReal(x)); - case INTSXP: - if (Rf_inherits(x, "factor")) { - // TODO: This does not use the actual value, just the levels - auto type = arrow::r::InferArrowTypeFromFactor(x); - return ds::scalar(std::make_shared(type)); - } - return ds::scalar(Rf_asInteger(x)); - case STRSXP: - return ds::scalar(CHAR(STRING_ELT(x, 0))); - default: - Rcpp::stop( - tfm::format("R object of type %s not supported", Rf_type2char(TYPEOF(x)))); - } - - return nullptr; +std::shared_ptr dataset___expr__scalar( + const std::shared_ptr& x) { + return ds::scalar(x); } // [[arrow::export]] diff --git a/r/src/scalar.cpp b/r/src/scalar.cpp new file mode 100644 index 00000000000..9c7d726c0b8 --- /dev/null +++ b/r/src/scalar.cpp @@ -0,0 +1,59 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "./arrow_types.h" + +#if defined(ARROW_R_WITH_ARROW) + +#include +#include + +// [[arrow::export]] +std::shared_ptr Array__GetScalar(const std::shared_ptr& x, + int64_t i) { + return ValueOrStop(x->GetScalar(i)); +} + +// [[arrow::export]] +std::string Scalar__ToString(const std::shared_ptr& s) { + return s->ToString(); +} + +// [[arrow::export]] +std::shared_ptr Scalar__CastTo(const std::shared_ptr& s, + const std::shared_ptr& t) { + return ValueOrStop(s->CastTo(t)); +} + +// [[arrow::export]] +SEXP Scalar__as_vector(const std::shared_ptr& scalar) { + auto array = ValueOrStop(arrow::MakeArrayFromScalar(*scalar, 1)); + + // defined in array_to_vector.cpp + SEXP Array__as_vector(const std::shared_ptr& array); + return Array__as_vector(array); +} + +// [[arrow::export]] +bool Scalar__is_valid(const std::shared_ptr& s) { return s->is_valid; } + +// [[arrow::export]] +std::shared_ptr Scalar__type(const std::shared_ptr& s) { + return s->type; +} + +#endif diff --git a/r/tests/testthat/test-Array.R b/r/tests/testthat/test-Array.R index ad8068cc8b8..6d61d9ac5fc 100644 --- a/r/tests/testthat/test-Array.R +++ b/r/tests/testthat/test-Array.R @@ -517,6 +517,7 @@ test_that("[ method on Array", { expect_vector(a[5:9], vec[5:9]) expect_vector(a[c(9, 3, 5)], vec[c(9, 3, 5)]) expect_vector(a[rep(c(TRUE, FALSE), 5)], vec[c(1, 3, 5, 7, 9)]) + expect_vector(a[rep(c(TRUE, FALSE, NA, FALSE, TRUE), 2)], c(11, NA, 15, 16, NA, 20)) expect_vector(a[-4], vec[-4]) expect_vector(a[-1], vec[-1]) }) diff --git a/r/tests/testthat/test-compute.R b/r/tests/testthat/test-compute.R new file mode 100644 index 00000000000..1d0d23a788d --- /dev/null +++ b/r/tests/testthat/test-compute.R @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +context("compute") + +test_that("Bad input handling of call_function", { + expect_error(call_function("sum", 2, 3), "to_datum: Not implemented for type double") +}) diff --git a/r/tests/testthat/test-expression.R b/r/tests/testthat/test-expression.R index d19274af213..df5b87ca890 100644 --- a/r/tests/testthat/test-expression.R +++ b/r/tests/testthat/test-expression.R @@ -58,7 +58,7 @@ test_that("C++ expressions", { expect_is(f == i64, "Expression") expect_is(f == time, "Expression") expect_is(f == dict, "Expression") - # can't seem to make this work right now + # can't seem to make this work right now because of R Ops.method dispatch # expect_is(f == as.Date("2020-01-15"), "Expression") expect_is(f == ts, "Expression") expect_is(f <= 2L, "Expression") @@ -72,6 +72,6 @@ test_that("C++ expressions", { 'Expression\n(f > 4:double)', fixed = TRUE ) - - expect_error(f == c(1L, 2L)) + # Interprets that as a list type + expect_is(f == c(1L, 2L), "Expression") }) diff --git a/r/tests/testthat/test-scalar.R b/r/tests/testthat/test-scalar.R new file mode 100644 index 00000000000..570fc802334 --- /dev/null +++ b/r/tests/testthat/test-scalar.R @@ -0,0 +1,50 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +context("Scalar") + +expect_scalar_roundtrip <- function(x, type) { + s <- Scalar$create(x) + expect_is(s, "Scalar") + expect_type_equal(s$type, type) + expect_identical(length(s), 1L) + if (inherits(type, "NestedType")) { + # Should this be? Missing if all elements are missing? + # expect_identical(is.na(s), all(is.na(x))) + } else { + expect_identical(is.na(s), is.na(x)) + # MakeArrayFromScalar not implemented for list types + expect_equal(as.vector(s), x) + } +} + +test_that("Scalar object roundtrip", { + expect_scalar_roundtrip(2, float64()) + expect_scalar_roundtrip(2L, int32()) + expect_scalar_roundtrip(c(2, 4), list_of(float64())) + expect_scalar_roundtrip(c(NA, NA), list_of(bool())) + expect_scalar_roundtrip(data.frame(a=2, b=4L), struct(a = double(), b = int32())) +}) + +test_that("Scalar print", { + expect_output(print(Scalar$create(4)), "Scalar\n4") +}) + +test_that("Creating Scalars of a different type and casting them", { + expect_type_equal(Scalar$create(4L, int8())$type, int8()) + expect_type_equal(Scalar$create(4L)$cast(float32())$type, float32()) +})