diff --git a/r/NAMESPACE b/r/NAMESPACE index 10677b43f85..cc5961e5ba1 100644 --- a/r/NAMESPACE +++ b/r/NAMESPACE @@ -12,12 +12,15 @@ S3method(FixedSizeBufferWriter,"arrow::Buffer") S3method(FixedSizeBufferWriter,default) S3method(MessageReader,"arrow::io::InputStream") S3method(MessageReader,default) +S3method(RecordBatchFileReader,"arrow::Buffer") S3method(RecordBatchFileReader,"arrow::io::RandomAccessFile") S3method(RecordBatchFileReader,character) S3method(RecordBatchFileReader,fs_path) +S3method(RecordBatchFileReader,raw) S3method(RecordBatchFileWriter,"arrow::io::OutputStream") S3method(RecordBatchFileWriter,character) S3method(RecordBatchFileWriter,fs_path) +S3method(RecordBatchStreamReader,"arrow::Buffer") S3method(RecordBatchStreamReader,"arrow::io::InputStream") S3method(RecordBatchStreamReader,raw) S3method(RecordBatchStreamWriter,"arrow::io::OutputStream") diff --git a/r/R/RecordBatchReader.R b/r/R/RecordBatchReader.R index 222f05586c1..6dab2d1ff76 100644 --- a/r/R/RecordBatchReader.R +++ b/r/R/RecordBatchReader.R @@ -31,10 +31,12 @@ #' @name arrow__RecordBatchReader `arrow::RecordBatchReader` <- R6Class("arrow::RecordBatchReader", inherit = `arrow::Object`, public = list( - schema = function() shared_ptr(`arrow::Schema`, RecordBatchReader__schema(self)), - ReadNext = function() { + read_next_batch = function() { shared_ptr(`arrow::RecordBatch`, RecordBatchReader__ReadNext(self)) } + ), + active = list( + schema = function() shared_ptr(`arrow::Schema`, RecordBatchReader__schema(self)) ) ) @@ -70,11 +72,13 @@ #' @name arrow__ipc__RecordBatchFileReader `arrow::ipc::RecordBatchFileReader` <- R6Class("arrow::ipc::RecordBatchFileReader", inherit = `arrow::Object`, public = list( - schema = function() shared_ptr(`arrow::Schema`, ipc___RecordBatchFileReader__schema(self)), - num_record_batches = function() ipc___RecordBatchFileReader__num_record_batches(self), - ReadRecordBatch = function(i) shared_ptr(`arrow::RecordBatch`, ipc___RecordBatchFileReader__ReadRecordBatch(self, i)), + get_batch = function(i) shared_ptr(`arrow::RecordBatch`, ipc___RecordBatchFileReader__ReadRecordBatch(self, i)), batches = function() map(ipc___RecordBatchFileReader__batches(self), shared_ptr, class = `arrow::RecordBatch`) + ), + active = list( + num_record_batches = function() ipc___RecordBatchFileReader__num_record_batches(self), + schema = function() shared_ptr(`arrow::Schema`, ipc___RecordBatchFileReader__schema(self)) ) ) @@ -97,6 +101,11 @@ RecordBatchStreamReader <- function(stream){ RecordBatchStreamReader(BufferReader(stream)) } +#' @export +`RecordBatchStreamReader.arrow::Buffer` <- function(stream) { + RecordBatchStreamReader(BufferReader(stream)) +} + #' Create an [arrow::ipc::RecordBatchFileReader][arrow__ipc__RecordBatchFileReader] from a file #' @@ -122,3 +131,13 @@ RecordBatchFileReader <- function(file) { `RecordBatchFileReader.fs_path` <- function(file) { RecordBatchFileReader(ReadableFile(file)) } + +#' @export +`RecordBatchFileReader.arrow::Buffer` <- function(file) { + RecordBatchFileReader(BufferReader(file)) +} + +#' @export +`RecordBatchFileReader.raw` <- function(file) { + RecordBatchFileReader(BufferReader(file)) +} diff --git a/r/tests/testthat/test-recordbatchreader.R b/r/tests/testthat/test-recordbatchreader.R new file mode 100644 index 00000000000..d2b6a09c37b --- /dev/null +++ b/r/tests/testthat/test-recordbatchreader.R @@ -0,0 +1,68 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +context("arrow::RecordBatch.*(Reader|Writer)") + +test_that("RecordBatchStreamReader / Writer", { + batch <- record_batch(tibble::tibble( + x = 1:10, + y = letters[1:10] + )) + + sink <- BufferOutputStream() + writer <- RecordBatchStreamWriter(sink, batch$schema) + expect_is(writer, "arrow::ipc::RecordBatchStreamWriter") + writer$write_batch(batch) + writer$close() + + buf <- sink$getvalue() + expect_is(buf, "arrow::Buffer") + + reader <- RecordBatchStreamReader(buf) + expect_is(reader, "arrow::ipc::RecordBatchStreamReader") + + batch1 <- reader$read_next_batch() + expect_is(batch1, "arrow::RecordBatch") + expect_equal(batch, batch1) + + expect_null(reader$read_next_batch()) +}) + +test_that("RecordBatchFileReader / Writer", { + batch <- record_batch(tibble::tibble( + x = 1:10, + y = letters[1:10] + )) + + sink <- BufferOutputStream() + writer <- RecordBatchFileWriter(sink, batch$schema) + expect_is(writer, "arrow::ipc::RecordBatchFileWriter") + writer$write_batch(batch) + writer$close() + + buf <- sink$getvalue() + expect_is(buf, "arrow::Buffer") + + reader <- RecordBatchFileReader(buf) + expect_is(reader, "arrow::ipc::RecordBatchFileReader") + + batch1 <- reader$get_batch(0L) + expect_is(batch1, "arrow::RecordBatch") + expect_equal(batch, batch1) + + expect_equal(reader$num_record_batches, 1L) +})