From c80315da05cae8a0be0fe48476428e9ce377cd18 Mon Sep 17 00:00:00 2001 From: Matt Topol Date: Tue, 26 Mar 2024 18:28:40 -0400 Subject: [PATCH 01/41] GH-40078: [C++] Import/Export ArrowDeviceArrayStream --- cpp/src/arrow/c/bridge.cc | 265 ++++++++++++++++++++++++-------- cpp/src/arrow/c/bridge.h | 41 +++++ cpp/src/arrow/c/helpers.h | 47 ++++++ cpp/src/arrow/c/util_internal.h | 20 +++ cpp/src/arrow/record_batch.cc | 4 + cpp/src/arrow/record_batch.h | 16 ++ 6 files changed, 325 insertions(+), 68 deletions(-) diff --git a/cpp/src/arrow/c/bridge.cc b/cpp/src/arrow/c/bridge.cc index 8a530b3798d..02dba542721 100644 --- a/cpp/src/arrow/c/bridge.cc +++ b/cpp/src/arrow/c/bridge.cc @@ -2041,6 +2041,19 @@ Status ExportStreamNext(const std::shared_ptr& src, int64_t i } } +Status ExportDeviceStreamNext(const std::shared_ptr& src, int64_t, + struct ArrowDeviceArray* out_array) { + std::shared_ptr batch; + RETURN_NOT_OK(src->ReadNext(&batch)); + if (batch == nullptr) { + // End of stream + ArrowArrayMarkReleased(&out_array->array); + return Status::OK(); + } else { + return ExportDeviceRecordBatch(*batch, batch->GetSyncEvent(), out_array); + } +} + Status ExportStreamNext(const std::shared_ptr& src, int64_t i, struct ArrowArray* out_array) { if (i >= src->num_chunks()) { @@ -2052,7 +2065,19 @@ Status ExportStreamNext(const std::shared_ptr& src, int64_t i, } } -template +Status ExportDeviceStreamNext(const std::shared_ptr& src, int64_t i, + std::shared_ptr sync, + struct ArrowDeviceArray* out_array) { + if (i >= src->num_chunks()) { + // End of stream + ArrowArrayMarkReleased(&out_array->array); + return Status::OK(); + } else { + return ExportDeviceArray(*src->chunk(static_cast(i)), sync, out_array); + } +} + +template class ExportedArrayStream { public: struct PrivateData { @@ -2067,14 +2092,20 @@ class ExportedArrayStream { ARROW_DISALLOW_COPY_AND_ASSIGN(PrivateData); }; - explicit ExportedArrayStream(struct ArrowArrayStream* stream) : stream_(stream) {} + explicit ExportedArrayStream(StreamType* stream) : stream_(stream) {} Status GetSchema(struct ArrowSchema* out_schema) { return ExportStreamSchema(reader(), out_schema); } - Status GetNext(struct ArrowArray* out_array) { - return ExportStreamNext(reader(), next_batch_num(), out_array); + Status GetNext(ArrayType* out_array) { + if constexpr (std::is_same_v) { + return ExportStreamNext(reader(), next_batch_num(), out_array); + } else if constexpr (std::is_same_v) { + return ExportDeviceStreamNext(reader(), next_batch_num(), nullptr, out_array); + } else { + return ExportDeviceStreamNext(reader(), next_batch_num(), out_array); + } } const char* GetLastError() { @@ -2083,38 +2114,44 @@ class ExportedArrayStream { } void Release() { - if (ArrowArrayStreamIsReleased(stream_)) { - return; + if constexpr (std::is_same_v) { + if (ArrowDeviceArrayStreamIsReleased(stream_)) { + return; + } + } else { + if (ArrowArrayStreamIsReleased(stream_)) { + return; + } } DCHECK_NE(private_data(), nullptr); delete private_data(); - ArrowArrayStreamMarkReleased(stream_); + if constexpr (std::is_same_v) { + ArrowDeviceArrayStreamMarkReleased(stream_); + } else { + ArrowArrayStreamMarkReleased(stream_); + } } // C-compatible callbacks - static int StaticGetSchema(struct ArrowArrayStream* stream, - struct ArrowSchema* out_schema) { + static int StaticGetSchema(StreamType* stream, struct ArrowSchema* out_schema) { ExportedArrayStream self{stream}; return self.ToCError(self.GetSchema(out_schema)); } - static int StaticGetNext(struct ArrowArrayStream* stream, - struct ArrowArray* out_array) { + static int StaticGetNext(StreamType* stream, ArrayType* out_array) { ExportedArrayStream self{stream}; return self.ToCError(self.GetNext(out_array)); } - static void StaticRelease(struct ArrowArrayStream* stream) { - ExportedArrayStream{stream}.Release(); - } + static void StaticRelease(StreamType* stream) { ExportedArrayStream{stream}.Release(); } - static const char* StaticGetLastError(struct ArrowArrayStream* stream) { + static const char* StaticGetLastError(StreamType* stream) { return ExportedArrayStream{stream}.GetLastError(); } - static Status Make(std::shared_ptr reader, struct ArrowArrayStream* out) { + static Status Make(std::shared_ptr reader, StreamType* out) { out->get_schema = ExportedArrayStream::StaticGetSchema; out->get_next = ExportedArrayStream::StaticGetNext; out->get_last_error = ExportedArrayStream::StaticGetLastError; @@ -2150,19 +2187,37 @@ class ExportedArrayStream { int64_t next_batch_num() { return private_data()->batch_num_++; } - struct ArrowArrayStream* stream_; + StreamType* stream_; }; } // namespace Status ExportRecordBatchReader(std::shared_ptr reader, struct ArrowArrayStream* out) { - return ExportedArrayStream::Make(std::move(reader), out); + return ExportedArrayStream::Make(std::move(reader), out); } Status ExportChunkedArray(std::shared_ptr chunked_array, struct ArrowArrayStream* out) { - return ExportedArrayStream::Make(std::move(chunked_array), out); + return ExportedArrayStream::Make(std::move(chunked_array), out); +} + +Status ExportDeviceRecordBatchReader(std::shared_ptr reader, + struct ArrowDeviceArrayStream* out) { + out->device_type = static_cast(reader->device_type()); + return ExportedArrayStream::Make(std::move(reader), out); +} + +Status ExportDeviceChunkedArray(std::shared_ptr chunked_array, + DeviceAllocationType device_type, + struct ArrowDeviceArrayStream* out) { + out->device_type = static_cast(device_type); + return ExportedArrayStream::Make(std::move(chunked_array), + out); } ////////////////////////////////////////////////////////////////////////// @@ -2170,33 +2225,58 @@ Status ExportChunkedArray(std::shared_ptr chunked_array, namespace { +template class ArrayStreamReader { + protected: + using StreamType = typename StreamTraits::CType; + using ArrayType = typename ArrayTraits::CType; + public: - explicit ArrayStreamReader(struct ArrowArrayStream* stream) { - ArrowArrayStreamMove(stream, &stream_); - DCHECK(!ArrowArrayStreamIsReleased(&stream_)); + explicit ArrayStreamReader(StreamType* stream, + const DeviceMemoryMapper& mapper = DefaultDeviceMapper) + : mapper_{mapper} { + StreamTraits::MoveFunc(stream, &stream_); + DCHECK(!StreamTraits::IsReleasedFunc(&stream_)); } ~ArrayStreamReader() { ReleaseStream(); } void ReleaseStream() { - if (!ArrowArrayStreamIsReleased(&stream_)) { - ArrowArrayStreamRelease(&stream_); - } - DCHECK(ArrowArrayStreamIsReleased(&stream_)); + // all our trait release funcs check IsReleased so we don't + // need to repeat it here + StreamTraits::ReleaseFunc(&stream_); + DCHECK(StreamTraits::IsReleasedFunc(&stream_)); } protected: - Status ReadNextArrayInternal(struct ArrowArray* array) { - ArrowArrayMarkReleased(array); + Status ReadNextArrayInternal(ArrayType* array) { + ArrayTraits::MarkReleased(array); Status status = StatusFromCError(stream_.get_next(&stream_, array)); - if (!status.ok() && !ArrowArrayIsReleased(array)) { - ArrowArrayRelease(array); + if (!status.ok() && !ArrayTraits::IsReleasedFunc(array)) { + ArrayTraits::ReleaseFunc(array); } return status; } + Result> ImportRecordBatchInternal( + ArrayType* array, std::shared_ptr schema) { + if constexpr (std::is_same_v) { + return ImportDeviceRecordBatch(array, schema, mapper_); + } else { + return ImportRecordBatch(array, schema); + } + } + + Result> ImportArrayInternal( + ArrayType* array, std::shared_ptr type) { + if constexpr (std::is_same_v) { + return ImportDeviceArray(array, type, mapper_); + } else { + return ImportArray(array, type); + } + } + Result> ReadSchema() { struct ArrowSchema c_schema = {}; ARROW_RETURN_NOT_OK( @@ -2214,19 +2294,19 @@ class ArrayStreamReader { } Status CheckNotReleased() { - if (ArrowArrayStreamIsReleased(&stream_)) { + if (StreamTraits::IsReleasedFunc(&stream_)) { return Status::Invalid( "Attempt to read from a stream that has already been closed"); - } else { - return Status::OK(); } + + return Status::OK(); } Status StatusFromCError(int errno_like) const { return StatusFromCError(&stream_, errno_like); } - static Status StatusFromCError(struct ArrowArrayStream* stream, int errno_like) { + static Status StatusFromCError(StreamType* stream, int errno_like) { if (ARROW_PREDICT_TRUE(errno_like == 0)) { return Status::OK(); } @@ -2250,70 +2330,90 @@ class ArrayStreamReader { return {code, last_error ? std::string(last_error) : ""}; } + DeviceAllocationType get_device_type() const { + if constexpr (std::is_same_v) { + return static_cast(stream_.device_type); + } else { + return DeviceAllocationType::kCPU; + } + } + private: - mutable struct ArrowArrayStream stream_; + mutable StreamType stream_; + const DeviceMemoryMapper& mapper_; }; -class ArrayStreamBatchReader : public RecordBatchReader, public ArrayStreamReader { +template +class ArrayStreamBatchReader : public RecordBatchReader, + public ArrayStreamReader { + using StreamType = typename StreamTraits::CType; + using ArrayType = typename ArrayTraits::CType; + public: - explicit ArrayStreamBatchReader(struct ArrowArrayStream* stream) - : ArrayStreamReader(stream) {} + explicit ArrayStreamBatchReader(StreamType* stream, const DeviceMemoryMapper& mapper = DefaultDeviceMapper) + : ArrayStreamReader(stream, mapper) {} Status Init() { - ARROW_ASSIGN_OR_RAISE(schema_, ReadSchema()); + ARROW_ASSIGN_OR_RAISE(schema_, this->ReadSchema()); return Status::OK(); } std::shared_ptr schema() const override { return schema_; } Status ReadNext(std::shared_ptr* batch) override { - ARROW_RETURN_NOT_OK(CheckNotReleased()); + ARROW_RETURN_NOT_OK(this->CheckNotReleased()); - struct ArrowArray c_array; - ARROW_RETURN_NOT_OK(ReadNextArrayInternal(&c_array)); + ArrayType c_array; + ARROW_RETURN_NOT_OK(this->ReadNextArrayInternal(&c_array)); - if (ArrowArrayIsReleased(&c_array)) { + if (ArrayTraits::IsReleasedFunc(&c_array)) { // End of stream batch->reset(); return Status::OK(); } else { - return ImportRecordBatch(&c_array, schema_).Value(batch); + return this->ImportRecordBatchInternal(&c_array, schema_).Value(batch); } } Status Close() override { - ReleaseStream(); + this->ReleaseStream(); return Status::OK(); } + DeviceAllocationType device_type() const override { return this->get_device_type(); } + private: std::shared_ptr schema_; }; -class ArrayStreamArrayReader : public ArrayStreamReader { +template +class ArrayStreamArrayReader : public ArrayStreamReader { + using StreamType = typename StreamTraits::CType; + using ArrayType = typename ArrayTraits::CType; + public: - explicit ArrayStreamArrayReader(struct ArrowArrayStream* stream) - : ArrayStreamReader(stream) {} + explicit ArrayStreamArrayReader(StreamType* stream, const DeviceMemoryMapper& mapper = DefaultDeviceMapper) + : ArrayStreamReader(stream, mapper) {} Status Init() { - ARROW_ASSIGN_OR_RAISE(field_, ReadField()); + ARROW_ASSIGN_OR_RAISE(field_, this->ReadField()); return Status::OK(); } std::shared_ptr data_type() const { return field_->type(); } Status ReadNext(std::shared_ptr* array) { - ARROW_RETURN_NOT_OK(CheckNotReleased()); + ARROW_RETURN_NOT_OK(this->CheckNotReleased()); - struct ArrowArray c_array; - ARROW_RETURN_NOT_OK(ReadNextArrayInternal(&c_array)); + ArrayType c_array; + ARROW_RETURN_NOT_OK(this->ReadNextArrayInternal(&c_array)); - if (ArrowArrayIsReleased(&c_array)) { + if (ArrayTraits::IsReleasedFunc(&c_array)) { // End of stream array->reset(); return Status::OK(); } else { - return ImportArray(&c_array, field_->type()).Value(array); + return this->ImportArrayInternal(&c_array, field_->type()).Value(array); } } @@ -2321,30 +2421,33 @@ class ArrayStreamArrayReader : public ArrayStreamReader { std::shared_ptr field_; }; -} // namespace - -Result> ImportRecordBatchReader( - struct ArrowArrayStream* stream) { - if (ArrowArrayStreamIsReleased(stream)) { - return Status::Invalid("Cannot import released ArrowArrayStream"); +template +Result> ImportReader( + typename StreamTraits::CType* stream, + const DeviceMemoryMapper& mapper = DefaultDeviceMapper) { + if (StreamTraits::IsReleasedFunc(stream)) { + return Status::Invalid("Cannot import released Arrow Stream"); } - auto reader = std::make_shared(stream); + auto reader = + std::make_shared>(stream, mapper); ARROW_RETURN_NOT_OK(reader->Init()); return reader; } -Result> ImportChunkedArray( - struct ArrowArrayStream* stream) { - if (ArrowArrayStreamIsReleased(stream)) { - return Status::Invalid("Cannot import released ArrowArrayStream"); +template +Result> ImportChunked( + typename StreamTraits::CType* stream, + const DeviceMemoryMapper& mapper = DefaultDeviceMapper) { + if (StreamTraits::IsReleasedFunc(stream)) { + return Status::Invalid("Cannot import released Arrow Stream"); } - auto reader = std::make_shared(stream); + auto reader = + std::make_shared>(stream, mapper); ARROW_RETURN_NOT_OK(reader->Init()); - std::shared_ptr data_type = reader->data_type(); - + auto data_type = reader->data_type(); ArrayVector chunks; std::shared_ptr chunk; while (true) { @@ -2360,4 +2463,30 @@ Result> ImportChunkedArray( return ChunkedArray::Make(std::move(chunks), std::move(data_type)); } +} // namespace + +Result> ImportRecordBatchReader( + struct ArrowArrayStream* stream) { + return ImportReader( + stream); +} + +Result> ImportDeviceRecordBatchReader( + struct ArrowDeviceArrayStream* stream, const DeviceMemoryMapper& mapper) { + return ImportReader(stream, mapper); +} + +Result> ImportChunkedArray( + struct ArrowArrayStream* stream) { + return ImportChunked( + stream); +} + +Result> ImportDeviceChunkedArray( + struct ArrowDeviceArrayStream* stream, const DeviceMemoryMapper& mapper) { + return ImportChunked(stream, mapper); +} + } // namespace arrow diff --git a/cpp/src/arrow/c/bridge.h b/cpp/src/arrow/c/bridge.h index 74a302be4c2..4bc415baf12 100644 --- a/cpp/src/arrow/c/bridge.h +++ b/cpp/src/arrow/c/bridge.h @@ -321,6 +321,37 @@ ARROW_EXPORT Status ExportChunkedArray(std::shared_ptr chunked_array, struct ArrowArrayStream* out); +/// \brief Export C++ RecordBatchReader using the C device stream interface +/// +/// The resulting ArrowDeviceArrayStream struct keeps the record batch reader +/// alive until its release callback is called by the consumer. The device +/// type is determined by calling device_type() on the RecordBatchReader. +/// +/// \note it is assumed that the output pointer has already be zeroed out before +/// calling this function. +/// +/// \param[in] reader RecordBatchReader object to export +/// \param[out] out C struct to export the stream to +ARROW_EXPORT +Status ExportDeviceRecordBatchReader(std::shared_ptr reader, + struct ArrowDeviceArrayStream* out); + +/// \brief Export C++ ChunkedArray using the c device data interface format. +/// +/// The resulting ArrowDeviceArrayStream keeps the chunked array data and buffers +/// alive until its release callback is called by the consumer. +/// +/// \note it is assumed that the output pointer has already been zeroed before +/// calling this function. +/// +/// \param[in] chunked_array ChunkedArray object to export +/// \param[in] device_type the device type the data is located on +/// \param[out] out C struct to export the stream to +ARROW_EXPORT +Status ExportDeviceChunkedArray(std::shared_ptr chunked_array, + DeviceAllocationType device_type, + struct ArrowDeviceArrayStream* out); + /// \brief Import C++ RecordBatchReader from the C stream interface. /// /// The ArrowArrayStream struct has its contents moved to a private object @@ -343,6 +374,16 @@ Result> ImportRecordBatchReader( ARROW_EXPORT Result> ImportChunkedArray(struct ArrowArrayStream* stream); +ARROW_EXPORT +Result> ImportDeviceRecordBatchReader( + struct ArrowDeviceArrayStream* stream, + const DeviceMemoryMapper& mapper = DefaultDeviceMapper); + +ARROW_EXPORT +Result> ImportDeviceChunkedArray( + struct ArrowDeviceArrayStream* stream, + const DeviceMemoryMapper& mapper = DefaultDeviceMapper); + /// @} } // namespace arrow diff --git a/cpp/src/arrow/c/helpers.h b/cpp/src/arrow/c/helpers.h index a24f272feac..4368507630f 100644 --- a/cpp/src/arrow/c/helpers.h +++ b/cpp/src/arrow/c/helpers.h @@ -70,9 +70,17 @@ inline int ArrowArrayIsReleased(const struct ArrowArray* array) { return array->release == NULL; } +inline int ArrowDeviceArrayIsReleased(const struct ArrowDeviceArray* array) { + return ArrowArrayIsReleased(&array->array); +} + /// Mark the C array released (for use in release callbacks) inline void ArrowArrayMarkReleased(struct ArrowArray* array) { array->release = NULL; } +inline void ArrowDeviceArrayMarkReleased(struct ArrowDeviceArray* array) { + ArrowArrayMarkReleased(&array->array); +} + /// Move the C array from `src` to `dest` /// /// Note `dest` must *not* point to a valid array already, otherwise there @@ -84,6 +92,13 @@ inline void ArrowArrayMove(struct ArrowArray* src, struct ArrowArray* dest) { ArrowArrayMarkReleased(src); } +inline void ArrowDeviceArrayMove(struct ArrowDeviceArray* src, struct ArrowDeviceArray* dest) { + assert(dest != src); + assert(!ArrowDeviceArrayIsReleased(src)); + memcpy(dest, src, sizeof(struct ArrowDeviceArray)); + ArrowDeviceArrayMarkReleased(src); +} + /// Release the C array, if necessary, by calling its release callback inline void ArrowArrayRelease(struct ArrowArray* array) { if (!ArrowArrayIsReleased(array)) { @@ -93,16 +108,32 @@ inline void ArrowArrayRelease(struct ArrowArray* array) { } } +inline void ArrowDeviceArrayRelease(struct ArrowDeviceArray* array) { + if (!ArrowDeviceArrayIsReleased(array)) { + array->array.release(&array->array); + ARROW_C_ASSERT(ArrowDeviceArrayIsReleased(array), + "ArrowDeviceArrayRelease did not cleanup release callback"); + } +} + /// Query whether the C array stream is released inline int ArrowArrayStreamIsReleased(const struct ArrowArrayStream* stream) { return stream->release == NULL; } +inline int ArrowDeviceArrayStreamIsReleased(const struct ArrowDeviceArrayStream* stream) { + return stream->release == NULL; +} + /// Mark the C array stream released (for use in release callbacks) inline void ArrowArrayStreamMarkReleased(struct ArrowArrayStream* stream) { stream->release = NULL; } +inline void ArrowDeviceArrayStreamMarkReleased(struct ArrowDeviceArrayStream* stream) { + stream->release = NULL; +} + /// Move the C array stream from `src` to `dest` /// /// Note `dest` must *not* point to a valid stream already, otherwise there @@ -115,6 +146,14 @@ inline void ArrowArrayStreamMove(struct ArrowArrayStream* src, ArrowArrayStreamMarkReleased(src); } +inline void ArrowDeviceArrayStreamMove(struct ArrowDeviceArrayStream* src, + struct ArrowDeviceArrayStream* dest) { + assert(dest != src); + assert(!ArrowDeviceArrayStreamIsReleased(src)); + memcpy(dest, src, sizeof(struct ArrowDeviceArrayStream)); + ArrowDeviceArrayStreamMarkReleased(src); +} + /// Release the C array stream, if necessary, by calling its release callback inline void ArrowArrayStreamRelease(struct ArrowArrayStream* stream) { if (!ArrowArrayStreamIsReleased(stream)) { @@ -124,6 +163,14 @@ inline void ArrowArrayStreamRelease(struct ArrowArrayStream* stream) { } } +inline void ArrowDeviceArrayStreamRelease(struct ArrowDeviceArrayStream* stream) { + if (!ArrowDeviceArrayStreamIsReleased(stream)) { + stream->release(stream); + ARROW_C_ASSERT(ArrowDeviceArrayStreamIsReleased(stream), + "ArrowDeviceArrayStreamRelease did not cleanup release callback"); + } +} + #ifdef __cplusplus } #endif diff --git a/cpp/src/arrow/c/util_internal.h b/cpp/src/arrow/c/util_internal.h index 6a33be9b0da..8f292b06249 100644 --- a/cpp/src/arrow/c/util_internal.h +++ b/cpp/src/arrow/c/util_internal.h @@ -32,12 +32,32 @@ struct ArrayExportTraits { typedef struct ArrowArray CType; static constexpr auto IsReleasedFunc = &ArrowArrayIsReleased; static constexpr auto ReleaseFunc = &ArrowArrayRelease; + static constexpr auto MoveFunc = &ArrowArrayMove; + static constexpr auto MarkReleased = &ArrowArrayMarkReleased; +}; + +struct ArrayDeviceExportTraits { + typedef struct ArrowDeviceArray CType; + static constexpr auto IsReleasedFunc = &ArrowDeviceArrayIsReleased; + static constexpr auto ReleaseFunc = &ArrowDeviceArrayRelease; + static constexpr auto MoveFunc = &ArrowDeviceArrayMove; + static constexpr auto MarkReleased = &ArrowDeviceArrayMarkReleased; }; struct ArrayStreamExportTraits { typedef struct ArrowArrayStream CType; static constexpr auto IsReleasedFunc = &ArrowArrayStreamIsReleased; static constexpr auto ReleaseFunc = &ArrowArrayStreamRelease; + static constexpr auto MoveFunc = &ArrowArrayStreamMove; + static constexpr auto MarkReleased = &ArrowArrayStreamMarkReleased; +}; + +struct ArrayDeviceStreamExportTraits { + typedef struct ArrowDeviceArrayStream CType; + static constexpr auto IsReleasedFunc = &ArrowDeviceArrayStreamIsReleased; + static constexpr auto ReleaseFunc = &ArrowDeviceArrayStreamRelease; + static constexpr auto MoveFunc = &ArrowDeviceArrayStreamMove; + static constexpr auto MarkReleased = &ArrowDeviceArrayStreamMarkReleased; }; // A RAII-style object to release a C Array / Schema struct at block scope exit. diff --git a/cpp/src/arrow/record_batch.cc b/cpp/src/arrow/record_batch.cc index 8521d500f5c..d3522c625a0 100644 --- a/cpp/src/arrow/record_batch.cc +++ b/cpp/src/arrow/record_batch.cc @@ -623,6 +623,10 @@ Status RecordBatch::ValidateFull() const { return ValidateBatch(*this, /*full_validation=*/true); } +std::shared_ptr RecordBatch::GetSyncEvent() { + return nullptr; +} + // ---------------------------------------------------------------------- // Base record batch reader diff --git a/cpp/src/arrow/record_batch.h b/cpp/src/arrow/record_batch.h index cd647a88abd..073f3fd3aa7 100644 --- a/cpp/src/arrow/record_batch.h +++ b/cpp/src/arrow/record_batch.h @@ -23,6 +23,7 @@ #include #include "arrow/compare.h" +#include "arrow/device.h" #include "arrow/result.h" #include "arrow/status.h" #include "arrow/type_fwd.h" @@ -260,6 +261,16 @@ class ARROW_EXPORT RecordBatch { /// \return Status virtual Status ValidateFull() const; + /// \brief Return a top-level sync event object for this record batch + /// + /// If all of the data for this record batch is in host memory, then this + /// should return null (the default impl). If the data for this batch is + /// on a device, then if synchronization is needed before accessing the + /// data the returned sync event will allow for it. + /// + /// \return null or a Device::SyncEvent + virtual std::shared_ptr GetSyncEvent(); + protected: RecordBatch(const std::shared_ptr& schema, int64_t num_rows); @@ -306,6 +317,11 @@ class ARROW_EXPORT RecordBatchReader { /// \brief finalize reader virtual Status Close() { return Status::OK(); } + /// \brief Get the device type for record batches this reader produces + /// + /// default implementation is to return ARROW_DEVICE_CPU + virtual DeviceAllocationType device_type() const { return DeviceAllocationType::kCPU; } + class RecordBatchReaderIterator { public: using iterator_category = std::input_iterator_tag; From ce18e5db00de85467ba7e9ed8d9ef823420bd227 Mon Sep 17 00:00:00 2001 From: Matt Topol Date: Tue, 26 Mar 2024 19:07:10 -0400 Subject: [PATCH 02/41] fixing format lint --- cpp/src/arrow/c/bridge.cc | 6 ++++-- cpp/src/arrow/c/helpers.h | 9 +++++---- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/cpp/src/arrow/c/bridge.cc b/cpp/src/arrow/c/bridge.cc index 02dba542721..7e7d39915e8 100644 --- a/cpp/src/arrow/c/bridge.cc +++ b/cpp/src/arrow/c/bridge.cc @@ -2350,7 +2350,8 @@ class ArrayStreamBatchReader : public RecordBatchReader, using ArrayType = typename ArrayTraits::CType; public: - explicit ArrayStreamBatchReader(StreamType* stream, const DeviceMemoryMapper& mapper = DefaultDeviceMapper) + explicit ArrayStreamBatchReader(StreamType* stream, + const DeviceMemoryMapper& mapper = DefaultDeviceMapper) : ArrayStreamReader(stream, mapper) {} Status Init() { @@ -2392,7 +2393,8 @@ class ArrayStreamArrayReader : public ArrayStreamReader(stream, mapper) {} Status Init() { diff --git a/cpp/src/arrow/c/helpers.h b/cpp/src/arrow/c/helpers.h index 4368507630f..043195011e0 100644 --- a/cpp/src/arrow/c/helpers.h +++ b/cpp/src/arrow/c/helpers.h @@ -92,11 +92,12 @@ inline void ArrowArrayMove(struct ArrowArray* src, struct ArrowArray* dest) { ArrowArrayMarkReleased(src); } -inline void ArrowDeviceArrayMove(struct ArrowDeviceArray* src, struct ArrowDeviceArray* dest) { +inline void ArrowDeviceArrayMove(struct ArrowDeviceArray* src, + struct ArrowDeviceArray* dest) { assert(dest != src); assert(!ArrowDeviceArrayIsReleased(src)); memcpy(dest, src, sizeof(struct ArrowDeviceArray)); - ArrowDeviceArrayMarkReleased(src); + ArrowDeviceArrayMarkReleased(src); } /// Release the C array, if necessary, by calling its release callback @@ -146,8 +147,8 @@ inline void ArrowArrayStreamMove(struct ArrowArrayStream* src, ArrowArrayStreamMarkReleased(src); } -inline void ArrowDeviceArrayStreamMove(struct ArrowDeviceArrayStream* src, - struct ArrowDeviceArrayStream* dest) { +inline void ArrowDeviceArrayStreamMove(struct ArrowDeviceArrayStream* src, + struct ArrowDeviceArrayStream* dest) { assert(dest != src); assert(!ArrowDeviceArrayStreamIsReleased(src)); memcpy(dest, src, sizeof(struct ArrowDeviceArrayStream)); From efff4fd2fd56fc98c3b64a6380c1a72902c4cff8 Mon Sep 17 00:00:00 2001 From: Matt Topol Date: Tue, 26 Mar 2024 19:07:33 -0400 Subject: [PATCH 03/41] one more lint fix --- cpp/src/arrow/record_batch.cc | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/cpp/src/arrow/record_batch.cc b/cpp/src/arrow/record_batch.cc index d3522c625a0..a9677934258 100644 --- a/cpp/src/arrow/record_batch.cc +++ b/cpp/src/arrow/record_batch.cc @@ -623,9 +623,7 @@ Status RecordBatch::ValidateFull() const { return ValidateBatch(*this, /*full_validation=*/true); } -std::shared_ptr RecordBatch::GetSyncEvent() { - return nullptr; -} +std::shared_ptr RecordBatch::GetSyncEvent() { return nullptr; } // ---------------------------------------------------------------------- // Base record batch reader From 49190c99ef917b39d719bfcdb5f189beb60d67be Mon Sep 17 00:00:00 2001 From: Matt Topol Date: Wed, 27 Mar 2024 12:30:28 -0400 Subject: [PATCH 04/41] updates from feedback --- cpp/src/arrow/c/bridge.cc | 30 ++++++++++++++++++------------ cpp/src/arrow/c/bridge.h | 32 ++++++++++++++++++++++++++------ 2 files changed, 44 insertions(+), 18 deletions(-) diff --git a/cpp/src/arrow/c/bridge.cc b/cpp/src/arrow/c/bridge.cc index 7e7d39915e8..4ae460699c6 100644 --- a/cpp/src/arrow/c/bridge.cc +++ b/cpp/src/arrow/c/bridge.cc @@ -2194,18 +2194,21 @@ class ExportedArrayStream { Status ExportRecordBatchReader(std::shared_ptr reader, struct ArrowArrayStream* out) { + memset(out, 0, sizeof(struct ArrowArrayStream)); return ExportedArrayStream::Make(std::move(reader), out); } Status ExportChunkedArray(std::shared_ptr chunked_array, struct ArrowArrayStream* out) { + memset(out, 0, sizeof(struct ArrowArrayStream)); return ExportedArrayStream::Make(std::move(chunked_array), out); } Status ExportDeviceRecordBatchReader(std::shared_ptr reader, struct ArrowDeviceArrayStream* out) { + memset(out, 0, sizeof(struct ArrowDeviceArrayStream)); out->device_type = static_cast(reader->device_type()); return ExportedArrayStream::Make(std::move(reader), out); @@ -2214,6 +2217,7 @@ Status ExportDeviceRecordBatchReader(std::shared_ptr reader, Status ExportDeviceChunkedArray(std::shared_ptr chunked_array, DeviceAllocationType device_type, struct ArrowDeviceArrayStream* out) { + memset(out, 0, sizeof(struct ArrowDeviceArrayStream)); out->device_type = static_cast(device_type); return ExportedArrayStream::Make(std::move(chunked_array), @@ -2260,21 +2264,23 @@ class ArrayStreamReader { } Result> ImportRecordBatchInternal( - ArrayType* array, std::shared_ptr schema) { - if constexpr (std::is_same_v) { - return ImportDeviceRecordBatch(array, schema, mapper_); - } else { - return ImportRecordBatch(array, schema); - } + struct ArrowArray* array, std::shared_ptr schema) { + return ImportRecordBatch(array, schema); + } + + Result> ImportRecordBatchInternal( + struct ArrowDeviceArray* array, std::shared_ptr schema) { + return ImportDeviceRecordBatch(array, schema, mapper_); } Result> ImportArrayInternal( - ArrayType* array, std::shared_ptr type) { - if constexpr (std::is_same_v) { - return ImportDeviceArray(array, type, mapper_); - } else { - return ImportArray(array, type); - } + struct ArrowArray* array, std::shared_ptr type) { + return ImportArray(array, type); + } + + Result> ImportArrayInternal( + struct ArrowDeviceArray* array, std::shared_ptr type) { + return ImportDeviceArray(array, type, mapper_); } Result> ReadSchema() { diff --git a/cpp/src/arrow/c/bridge.h b/cpp/src/arrow/c/bridge.h index 4bc415baf12..42a44fc5ec7 100644 --- a/cpp/src/arrow/c/bridge.h +++ b/cpp/src/arrow/c/bridge.h @@ -327,9 +327,6 @@ Status ExportChunkedArray(std::shared_ptr chunked_array, /// alive until its release callback is called by the consumer. The device /// type is determined by calling device_type() on the RecordBatchReader. /// -/// \note it is assumed that the output pointer has already be zeroed out before -/// calling this function. -/// /// \param[in] reader RecordBatchReader object to export /// \param[out] out C struct to export the stream to ARROW_EXPORT @@ -341,9 +338,6 @@ Status ExportDeviceRecordBatchReader(std::shared_ptr reader, /// The resulting ArrowDeviceArrayStream keeps the chunked array data and buffers /// alive until its release callback is called by the consumer. /// -/// \note it is assumed that the output pointer has already been zeroed before -/// calling this function. -/// /// \param[in] chunked_array ChunkedArray object to export /// \param[in] device_type the device type the data is located on /// \param[out] out C struct to export the stream to @@ -374,11 +368,37 @@ Result> ImportRecordBatchReader( ARROW_EXPORT Result> ImportChunkedArray(struct ArrowArrayStream* stream); +/// \brief Import C++ RecordBatchReader from the C device stream interface +/// +/// The ArrowDeviceArrayStream struct has its contents moved to a private object +/// held alive by the resulting record batch reader. +/// +/// \note If there was a required sync event, sync events are accessible by individual +/// buffers of columns. We are not yet bubbling the sync events from the buffers up to +/// the `GetSyncEvent` method of an imported RecordBatch. This will be added in a future +/// update. +/// +/// \param[in,out] stream C device stream interface struct +/// \param[in] mapper mapping from device type and ID to memory manager +/// \return Imported RecordBatchReader object ARROW_EXPORT Result> ImportDeviceRecordBatchReader( struct ArrowDeviceArrayStream* stream, const DeviceMemoryMapper& mapper = DefaultDeviceMapper); +/// \brief Import C++ ChunkedArray from the C device stream interface +/// +/// The ArrowDeviceArrayStream struct has its contents moved to a private object, +/// is consumed in its entirety, and released before returning all chunks as a +/// ChunkedArray. +/// +/// \note Any chunks that require synchronization for their device memory will have +/// the SyncEvent objects available by checking the individual buffers of each chunk. +/// These SyncEvents should be checked before accessing the data in those buffers. +/// +/// \param[in,out] stream C device stream interface struct +/// \param[in] mapper mapping from device type and ID to memory manager +/// \return Imported ChunkedArray object ARROW_EXPORT Result> ImportDeviceChunkedArray( struct ArrowDeviceArrayStream* stream, From 858e8d5072401c2e0241ca1204365559b92ec054 Mon Sep 17 00:00:00 2001 From: Matt Topol Date: Wed, 27 Mar 2024 12:39:59 -0400 Subject: [PATCH 05/41] linting --- cpp/src/arrow/c/bridge.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/src/arrow/c/bridge.h b/cpp/src/arrow/c/bridge.h index 42a44fc5ec7..af188a7b4ec 100644 --- a/cpp/src/arrow/c/bridge.h +++ b/cpp/src/arrow/c/bridge.h @@ -374,8 +374,8 @@ Result> ImportChunkedArray(struct ArrowArrayStream /// held alive by the resulting record batch reader. /// /// \note If there was a required sync event, sync events are accessible by individual -/// buffers of columns. We are not yet bubbling the sync events from the buffers up to -/// the `GetSyncEvent` method of an imported RecordBatch. This will be added in a future +/// buffers of columns. We are not yet bubbling the sync events from the buffers up to +/// the `GetSyncEvent` method of an imported RecordBatch. This will be added in a future /// update. /// /// \param[in,out] stream C device stream interface struct From 5c36beecf7b93fffdae273ebcee558fabbac7949 Mon Sep 17 00:00:00 2001 From: Matt Topol Date: Wed, 27 Mar 2024 12:43:13 -0400 Subject: [PATCH 06/41] rebase updates --- cpp/src/arrow/c/bridge.cc | 10 +++++----- cpp/src/arrow/c/bridge.h | 4 ++-- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/cpp/src/arrow/c/bridge.cc b/cpp/src/arrow/c/bridge.cc index 4ae460699c6..e9dd396bc7f 100644 --- a/cpp/src/arrow/c/bridge.cc +++ b/cpp/src/arrow/c/bridge.cc @@ -2237,7 +2237,7 @@ class ArrayStreamReader { public: explicit ArrayStreamReader(StreamType* stream, - const DeviceMemoryMapper& mapper = DefaultDeviceMapper) + const DeviceMemoryMapper& mapper = DefaultDeviceMemoryMapper) : mapper_{mapper} { StreamTraits::MoveFunc(stream, &stream_); DCHECK(!StreamTraits::IsReleasedFunc(&stream_)); @@ -2356,8 +2356,8 @@ class ArrayStreamBatchReader : public RecordBatchReader, using ArrayType = typename ArrayTraits::CType; public: - explicit ArrayStreamBatchReader(StreamType* stream, - const DeviceMemoryMapper& mapper = DefaultDeviceMapper) + explicit ArrayStreamBatchReader( + StreamType* stream, const DeviceMemoryMapper& mapper = DefaultDeviceMemoryMapper) : ArrayStreamReader(stream, mapper) {} Status Init() { @@ -2432,7 +2432,7 @@ class ArrayStreamArrayReader : public ArrayStreamReader Result> ImportReader( typename StreamTraits::CType* stream, - const DeviceMemoryMapper& mapper = DefaultDeviceMapper) { + const DeviceMemoryMapper& mapper = DefaultDeviceMemoryMapper) { if (StreamTraits::IsReleasedFunc(stream)) { return Status::Invalid("Cannot import released Arrow Stream"); } @@ -2446,7 +2446,7 @@ Result> ImportReader( template Result> ImportChunked( typename StreamTraits::CType* stream, - const DeviceMemoryMapper& mapper = DefaultDeviceMapper) { + const DeviceMemoryMapper& mapper = DefaultDeviceMemoryMapper) { if (StreamTraits::IsReleasedFunc(stream)) { return Status::Invalid("Cannot import released Arrow Stream"); } diff --git a/cpp/src/arrow/c/bridge.h b/cpp/src/arrow/c/bridge.h index af188a7b4ec..f2090ee8a6a 100644 --- a/cpp/src/arrow/c/bridge.h +++ b/cpp/src/arrow/c/bridge.h @@ -384,7 +384,7 @@ Result> ImportChunkedArray(struct ArrowArrayStream ARROW_EXPORT Result> ImportDeviceRecordBatchReader( struct ArrowDeviceArrayStream* stream, - const DeviceMemoryMapper& mapper = DefaultDeviceMapper); + const DeviceMemoryMapper& mapper = DefaultDeviceMemoryMapper); /// \brief Import C++ ChunkedArray from the C device stream interface /// @@ -402,7 +402,7 @@ Result> ImportDeviceRecordBatchReader( ARROW_EXPORT Result> ImportDeviceChunkedArray( struct ArrowDeviceArrayStream* stream, - const DeviceMemoryMapper& mapper = DefaultDeviceMapper); + const DeviceMemoryMapper& mapper = DefaultDeviceMemoryMapper); /// @} From e4d1699b96208c948d288d4365af2e22070827a7 Mon Sep 17 00:00:00 2001 From: Matt Topol Date: Wed, 27 Mar 2024 12:45:29 -0400 Subject: [PATCH 07/41] missed a spot --- cpp/src/arrow/c/bridge.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/arrow/c/bridge.cc b/cpp/src/arrow/c/bridge.cc index e9dd396bc7f..39e0b6e3967 100644 --- a/cpp/src/arrow/c/bridge.cc +++ b/cpp/src/arrow/c/bridge.cc @@ -2400,7 +2400,7 @@ class ArrayStreamArrayReader : public ArrayStreamReader(stream, mapper) {} Status Init() { From 469ac85ee658eefe24b0b1945735369220d2ed57 Mon Sep 17 00:00:00 2001 From: Matt Topol Date: Wed, 27 Mar 2024 12:55:56 -0400 Subject: [PATCH 08/41] linting --- cpp/src/arrow/c/bridge.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/src/arrow/c/bridge.cc b/cpp/src/arrow/c/bridge.cc index 39e0b6e3967..8b70d15207d 100644 --- a/cpp/src/arrow/c/bridge.cc +++ b/cpp/src/arrow/c/bridge.cc @@ -2399,8 +2399,8 @@ class ArrayStreamArrayReader : public ArrayStreamReader(stream, mapper) {} Status Init() { From fda7737304d776f2b60e897a8ecb8457e5d179f3 Mon Sep 17 00:00:00 2001 From: Matt Topol Date: Thu, 28 Mar 2024 15:19:43 -0400 Subject: [PATCH 09/41] update error string --- python/pyarrow/tests/test_cffi.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyarrow/tests/test_cffi.py b/python/pyarrow/tests/test_cffi.py index 5bf41c3c14b..45a3db9b66f 100644 --- a/python/pyarrow/tests/test_cffi.py +++ b/python/pyarrow/tests/test_cffi.py @@ -45,7 +45,7 @@ ValueError, match="Cannot import released ArrowArray") assert_stream_released = pytest.raises( - ValueError, match="Cannot import released ArrowArrayStream") + ValueError, match="Cannot import released Arrow Stream") def PyCapsule_IsValid(capsule, name): From 958688339fb41f51da4aaac55a91d5cff594b6e7 Mon Sep 17 00:00:00 2001 From: Matt Topol Date: Fri, 5 Apr 2024 11:08:15 -0400 Subject: [PATCH 10/41] Update cpp/src/arrow/record_batch.h Co-authored-by: Antoine Pitrou --- cpp/src/arrow/record_batch.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/arrow/record_batch.h b/cpp/src/arrow/record_batch.h index 073f3fd3aa7..fcb08b8d044 100644 --- a/cpp/src/arrow/record_batch.h +++ b/cpp/src/arrow/record_batch.h @@ -263,7 +263,7 @@ class ARROW_EXPORT RecordBatch { /// \brief Return a top-level sync event object for this record batch /// - /// If all of the data for this record batch is in host memory, then this + /// If all of the data for this record batch is in CPU memory, then this /// should return null (the default impl). If the data for this batch is /// on a device, then if synchronization is needed before accessing the /// data the returned sync event will allow for it. From ee531cfa9f5eb49ae0a8445c2d3014fa101ed97c Mon Sep 17 00:00:00 2001 From: Matt Topol Date: Fri, 5 Apr 2024 11:09:54 -0400 Subject: [PATCH 11/41] Update cpp/src/arrow/record_batch.h Co-authored-by: Antoine Pitrou --- cpp/src/arrow/record_batch.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/arrow/record_batch.h b/cpp/src/arrow/record_batch.h index fcb08b8d044..6a262a2c7ee 100644 --- a/cpp/src/arrow/record_batch.h +++ b/cpp/src/arrow/record_batch.h @@ -264,7 +264,7 @@ class ARROW_EXPORT RecordBatch { /// \brief Return a top-level sync event object for this record batch /// /// If all of the data for this record batch is in CPU memory, then this - /// should return null (the default impl). If the data for this batch is + /// will return null. If the data for this batch is /// on a device, then if synchronization is needed before accessing the /// data the returned sync event will allow for it. /// From 889193cd5529737d2e66535a1bce43698092a2a0 Mon Sep 17 00:00:00 2001 From: Matt Topol Date: Fri, 5 Apr 2024 11:43:45 -0400 Subject: [PATCH 12/41] additions from feedback --- cpp/src/arrow/c/bridge.cc | 2 +- cpp/src/arrow/c/helpers.h | 1 + cpp/src/arrow/record_batch.h | 6 +++--- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/cpp/src/arrow/c/bridge.cc b/cpp/src/arrow/c/bridge.cc index 8b70d15207d..969ede1e360 100644 --- a/cpp/src/arrow/c/bridge.cc +++ b/cpp/src/arrow/c/bridge.cc @@ -2073,7 +2073,7 @@ Status ExportDeviceStreamNext(const std::shared_ptr& src, int64_t ArrowArrayMarkReleased(&out_array->array); return Status::OK(); } else { - return ExportDeviceArray(*src->chunk(static_cast(i)), sync, out_array); + return ExportDeviceArray(*src->chunk(static_cast(i)), std::move(sync), out_array); } } diff --git a/cpp/src/arrow/c/helpers.h b/cpp/src/arrow/c/helpers.h index 043195011e0..6e4df17f43e 100644 --- a/cpp/src/arrow/c/helpers.h +++ b/cpp/src/arrow/c/helpers.h @@ -17,6 +17,7 @@ #pragma once +#include #include #include #include diff --git a/cpp/src/arrow/record_batch.h b/cpp/src/arrow/record_batch.h index 6a262a2c7ee..cf3ca6f68f2 100644 --- a/cpp/src/arrow/record_batch.h +++ b/cpp/src/arrow/record_batch.h @@ -261,7 +261,7 @@ class ARROW_EXPORT RecordBatch { /// \return Status virtual Status ValidateFull() const; - /// \brief Return a top-level sync event object for this record batch + /// \brief EXPERIMENTAL: Return a top-level sync event object for this record batch /// /// If all of the data for this record batch is in CPU memory, then this /// will return null. If the data for this batch is @@ -317,9 +317,9 @@ class ARROW_EXPORT RecordBatchReader { /// \brief finalize reader virtual Status Close() { return Status::OK(); } - /// \brief Get the device type for record batches this reader produces + /// \brief EXPERIMENTAL: Get the device type for record batches this reader produces /// - /// default implementation is to return ARROW_DEVICE_CPU + /// default implementation is to return DeviceAllocationType::kCPU virtual DeviceAllocationType device_type() const { return DeviceAllocationType::kCPU; } class RecordBatchReaderIterator { From d80464f19dacacb38177d38f34d7756097b137ea Mon Sep 17 00:00:00 2001 From: Matt Topol Date: Fri, 5 Apr 2024 13:02:01 -0400 Subject: [PATCH 13/41] lint --- cpp/src/arrow/c/bridge.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cpp/src/arrow/c/bridge.cc b/cpp/src/arrow/c/bridge.cc index 969ede1e360..c598252cd69 100644 --- a/cpp/src/arrow/c/bridge.cc +++ b/cpp/src/arrow/c/bridge.cc @@ -2073,7 +2073,8 @@ Status ExportDeviceStreamNext(const std::shared_ptr& src, int64_t ArrowArrayMarkReleased(&out_array->array); return Status::OK(); } else { - return ExportDeviceArray(*src->chunk(static_cast(i)), std::move(sync), out_array); + return ExportDeviceArray(*src->chunk(static_cast(i)), std::move(sync), + out_array); } } From 6dcc07da6b00cd87e2d912a8231b20d1aee6768f Mon Sep 17 00:00:00 2001 From: Matt Topol Date: Wed, 10 Apr 2024 17:46:35 -0400 Subject: [PATCH 14/41] propagating device_type and sync_event --- cpp/src/arrow/array/array_base.h | 2 + cpp/src/arrow/array/data.cc | 22 +++++++ cpp/src/arrow/array/data.h | 2 + cpp/src/arrow/c/bridge.cc | 13 ++-- cpp/src/arrow/record_batch.cc | 102 +++++++++++++++++++++++-------- cpp/src/arrow/record_batch.h | 21 ++++--- 6 files changed, 125 insertions(+), 37 deletions(-) diff --git a/cpp/src/arrow/array/array_base.h b/cpp/src/arrow/array/array_base.h index 6411aebf804..6a7ee492e40 100644 --- a/cpp/src/arrow/array/array_base.h +++ b/cpp/src/arrow/array/array_base.h @@ -224,6 +224,8 @@ class ARROW_EXPORT Array { /// \return Status Status ValidateFull() const; + DeviceAllocationType device_type() const { return data_->device_type(); } + protected: Array() = default; ARROW_DEFAULT_MOVE_AND_ASSIGN(Array); diff --git a/cpp/src/arrow/array/data.cc b/cpp/src/arrow/array/data.cc index ac828a9c35c..a247811ed2d 100644 --- a/cpp/src/arrow/array/data.cc +++ b/cpp/src/arrow/array/data.cc @@ -224,6 +224,28 @@ int64_t ArrayData::ComputeLogicalNullCount() const { return ArraySpan(*this).ComputeLogicalNullCount(); } +DeviceAllocationType ArrayData::device_type() const { + int type = 0; + for (const auto& buf : buffers) { + if (!buf) continue; + if (type == 0) { + type = static_cast(buf->device_type()); + } else { + DCHECK_EQ(type, static_cast(buf->device_type())); + } + } + + for (const auto& child : child_data) { + if (type == 0) { + type = static_cast(child->device_type()); + } else { + DCHECK_EQ(type, static_cast(child->device_type())); + } + } + + return type == 0 ? DeviceAllocationType::kCPU : static_cast(type); +} + // ---------------------------------------------------------------------- // Methods for ArraySpan diff --git a/cpp/src/arrow/array/data.h b/cpp/src/arrow/array/data.h index beec29789ad..55b8e7c9049 100644 --- a/cpp/src/arrow/array/data.h +++ b/cpp/src/arrow/array/data.h @@ -358,6 +358,8 @@ struct ARROW_EXPORT ArrayData { /// \see GetNullCount int64_t ComputeLogicalNullCount() const; + DeviceAllocationType device_type() const; + std::shared_ptr type; int64_t length = 0; mutable std::atomic null_count{0}; diff --git a/cpp/src/arrow/c/bridge.cc b/cpp/src/arrow/c/bridge.cc index c598252cd69..1080c58220a 100644 --- a/cpp/src/arrow/c/bridge.cc +++ b/cpp/src/arrow/c/bridge.cc @@ -1448,6 +1448,7 @@ namespace { // The ArrowArray is released on destruction. struct ImportedArrayData { struct ArrowArray array_; + DeviceAllocationType device_type_; std::shared_ptr device_sync_; ImportedArrayData() { @@ -1495,7 +1496,7 @@ struct ArrayImporter { Status Import(struct ArrowDeviceArray* src, const DeviceMemoryMapper& mapper) { ARROW_ASSIGN_OR_RAISE(memory_mgr_, mapper(src->device_type, src->device_id)); - device_type_ = static_cast(src->device_type); + device_type_ = static_cast(src->device_type); RETURN_NOT_OK(Import(&src->array)); if (src->sync_event != nullptr) { ARROW_ASSIGN_OR_RAISE(import_->device_sync_, memory_mgr_->WrapDeviceSyncEvent( @@ -1514,6 +1515,7 @@ struct ArrayImporter { recursion_level_ = 0; import_ = std::make_shared(); c_struct_ = &import_->array_; + import_->device_type_ = device_type_; ArrowArrayMove(src, c_struct_); return DoImport(); } @@ -1541,7 +1543,8 @@ struct ArrayImporter { "cannot be imported as RecordBatch"); } return RecordBatch::Make(std::move(schema), data_->length, - std::move(data_->child_data)); + std::move(data_->child_data), import_->device_type_, + import_->device_sync_); } Status ImportChild(const ArrayImporter* parent, struct ArrowArray* src) { @@ -2238,8 +2241,8 @@ class ArrayStreamReader { public: explicit ArrayStreamReader(StreamType* stream, - const DeviceMemoryMapper& mapper = DefaultDeviceMemoryMapper) - : mapper_{mapper} { + const DeviceMemoryMapper mapper = DefaultDeviceMemoryMapper) + : mapper_{std::move(mapper)} { StreamTraits::MoveFunc(stream, &stream_); DCHECK(!StreamTraits::IsReleasedFunc(&stream_)); } @@ -2347,7 +2350,7 @@ class ArrayStreamReader { private: mutable StreamType stream_; - const DeviceMemoryMapper& mapper_; + const DeviceMemoryMapper mapper_; }; template diff --git a/cpp/src/arrow/record_batch.cc b/cpp/src/arrow/record_batch.cc index a9677934258..9bd3e2ae871 100644 --- a/cpp/src/arrow/record_batch.cc +++ b/cpp/src/arrow/record_batch.cc @@ -59,17 +59,31 @@ int RecordBatch::num_columns() const { return schema_->num_fields(); } class SimpleRecordBatch : public RecordBatch { public: SimpleRecordBatch(std::shared_ptr schema, int64_t num_rows, - std::vector> columns) - : RecordBatch(std::move(schema), num_rows), boxed_columns_(std::move(columns)) { + std::vector> columns, + std::shared_ptr sync_event = nullptr) + : RecordBatch(std::move(schema), num_rows), + boxed_columns_(std::move(columns)), + device_type_(DeviceAllocationType::kCPU), + sync_event_(std::move(sync_event)) { + if (boxed_columns_.size() > 0) { + device_type_ = boxed_columns_[0]->device_type(); + } + columns_.resize(boxed_columns_.size()); for (size_t i = 0; i < columns_.size(); ++i) { columns_[i] = boxed_columns_[i]->data(); + DCHECK_EQ(device_type_, columns_[i]->device_type()); } } SimpleRecordBatch(const std::shared_ptr& schema, int64_t num_rows, - std::vector> columns) - : RecordBatch(std::move(schema), num_rows), columns_(std::move(columns)) { + std::vector> columns, + DeviceAllocationType device_type = DeviceAllocationType::kCPU, + std::shared_ptr sync_event = nullptr) + : RecordBatch(std::move(schema), num_rows), + columns_(std::move(columns)), + device_type_(device_type), + sync_event_(std::move(sync_event)) { boxed_columns_.resize(schema_->num_fields()); } @@ -99,6 +113,7 @@ class SimpleRecordBatch : public RecordBatch { const std::shared_ptr& column) const override { ARROW_CHECK(field != nullptr); ARROW_CHECK(column != nullptr); + ARROW_CHECK(column->device_type() == device_type_); if (!field->type()->Equals(column->type())) { return Status::TypeError("Column data type ", field->type()->name(), @@ -113,7 +128,8 @@ class SimpleRecordBatch : public RecordBatch { ARROW_ASSIGN_OR_RAISE(auto new_schema, schema_->AddField(i, field)); return RecordBatch::Make(std::move(new_schema), num_rows_, - internal::AddVectorElement(columns_, i, column->data())); + internal::AddVectorElement(columns_, i, column->data()), + device_type_, sync_event_); } Result> SetColumn( @@ -121,6 +137,7 @@ class SimpleRecordBatch : public RecordBatch { const std::shared_ptr& column) const override { ARROW_CHECK(field != nullptr); ARROW_CHECK(column != nullptr); + ARROW_CHECK(column->device_type() == device_type_); if (!field->type()->Equals(column->type())) { return Status::TypeError("Column data type ", field->type()->name(), @@ -135,19 +152,22 @@ class SimpleRecordBatch : public RecordBatch { ARROW_ASSIGN_OR_RAISE(auto new_schema, schema_->SetField(i, field)); return RecordBatch::Make(std::move(new_schema), num_rows_, - internal::ReplaceVectorElement(columns_, i, column->data())); + internal::ReplaceVectorElement(columns_, i, column->data()), + device_type_, sync_event_); } Result> RemoveColumn(int i) const override { ARROW_ASSIGN_OR_RAISE(auto new_schema, schema_->RemoveField(i)); return RecordBatch::Make(std::move(new_schema), num_rows_, - internal::DeleteVectorElement(columns_, i)); + internal::DeleteVectorElement(columns_, i), device_type_, + sync_event_); } std::shared_ptr ReplaceSchemaMetadata( const std::shared_ptr& metadata) const override { auto new_schema = schema_->WithMetadata(metadata); - return RecordBatch::Make(std::move(new_schema), num_rows_, columns_); + return RecordBatch::Make(std::move(new_schema), num_rows_, columns_, device_type_, + sync_event_); } std::shared_ptr Slice(int64_t offset, int64_t length) const override { @@ -157,7 +177,8 @@ class SimpleRecordBatch : public RecordBatch { arrays.emplace_back(field->Slice(offset, length)); } int64_t num_rows = std::min(num_rows_ - offset, length); - return std::make_shared(schema_, num_rows, std::move(arrays)); + return std::make_shared(schema_, num_rows, std::move(arrays), + device_type_, sync_event_); } Status Validate() const override { @@ -167,11 +188,20 @@ class SimpleRecordBatch : public RecordBatch { return RecordBatch::Validate(); } + std::shared_ptr GetSyncEvent() const override { return sync_event_; } + + DeviceAllocationType device_type() const override { return device_type_; } + private: std::vector> columns_; // Caching boxed array data mutable std::vector> boxed_columns_; + + // the type of device that the buffers for columns are allocated on. + // all columns should be on the same type of device. + DeviceAllocationType device_type_; + std::shared_ptr sync_event_; }; RecordBatch::RecordBatch(const std::shared_ptr& schema, int64_t num_rows) @@ -179,18 +209,20 @@ RecordBatch::RecordBatch(const std::shared_ptr& schema, int64_t num_rows std::shared_ptr RecordBatch::Make( std::shared_ptr schema, int64_t num_rows, - std::vector> columns) { + std::vector> columns, + std::shared_ptr sync_event) { DCHECK_EQ(schema->num_fields(), static_cast(columns.size())); return std::make_shared(std::move(schema), num_rows, - std::move(columns)); + std::move(columns), sync_event); } std::shared_ptr RecordBatch::Make( std::shared_ptr schema, int64_t num_rows, - std::vector> columns) { + std::vector> columns, DeviceAllocationType device_type, + std::shared_ptr sync_event) { DCHECK_EQ(schema->num_fields(), static_cast(columns.size())); return std::make_shared(std::move(schema), num_rows, - std::move(columns)); + std::move(columns), device_type, sync_event); } Result> RecordBatch::MakeEmpty( @@ -505,7 +537,7 @@ Result> RecordBatch::ReplaceSchema( ", did not match new schema field type: ", replace_type->ToString()); } } - return RecordBatch::Make(std::move(schema), num_rows(), columns()); + return RecordBatch::Make(std::move(schema), num_rows(), columns(), GetSyncEvent()); } std::vector RecordBatch::ColumnNames() const { @@ -534,7 +566,7 @@ Result> RecordBatch::RenameColumns( } return RecordBatch::Make(::arrow::schema(std::move(fields)), num_rows(), - std::move(columns)); + std::move(columns), GetSyncEvent()); } Result> RecordBatch::SelectColumns( @@ -555,7 +587,8 @@ Result> RecordBatch::SelectColumns( auto new_schema = std::make_shared(std::move(fields), schema()->metadata()); - return RecordBatch::Make(std::move(new_schema), num_rows(), std::move(columns)); + return RecordBatch::Make(std::move(new_schema), num_rows(), std::move(columns), + GetSyncEvent()); } std::shared_ptr RecordBatch::Slice(int64_t offset) const { @@ -623,7 +656,11 @@ Status RecordBatch::ValidateFull() const { return ValidateBatch(*this, /*full_validation=*/true); } -std::shared_ptr RecordBatch::GetSyncEvent() { return nullptr; } +std::shared_ptr RecordBatch::GetSyncEvent() const { return nullptr; } + +DeviceAllocationType RecordBatch::device_type() const { + return DeviceAllocationType::kCPU; +} // ---------------------------------------------------------------------- // Base record batch reader @@ -649,12 +686,16 @@ Result> RecordBatchReader::ToTable() { class SimpleRecordBatchReader : public RecordBatchReader { public: SimpleRecordBatchReader(Iterator> it, - std::shared_ptr schema) - : schema_(std::move(schema)), it_(std::move(it)) {} + std::shared_ptr schema, + DeviceAllocationType device_type = DeviceAllocationType::kCPU) + : schema_(std::move(schema)), it_(std::move(it)), device_type_(device_type) {} SimpleRecordBatchReader(std::vector> batches, - std::shared_ptr schema) - : schema_(std::move(schema)), it_(MakeVectorIterator(std::move(batches))) {} + std::shared_ptr schema, + DeviceAllocationType device_type = DeviceAllocationType::kCPU) + : schema_(std::move(schema)), + it_(MakeVectorIterator(std::move(batches))), + device_type_(device_type) {} Status ReadNext(std::shared_ptr* batch) override { return it_.Next().Value(batch); @@ -662,13 +703,17 @@ class SimpleRecordBatchReader : public RecordBatchReader { std::shared_ptr schema() const override { return schema_; } + DeviceAllocationType device_type() const override { return device_type_; } + protected: std::shared_ptr schema_; Iterator> it_; + DeviceAllocationType device_type_; }; Result> RecordBatchReader::Make( - std::vector> batches, std::shared_ptr schema) { + std::vector> batches, std::shared_ptr schema, + DeviceAllocationType device_type) { if (schema == nullptr) { if (batches.size() == 0 || batches[0] == nullptr) { return Status::Invalid("Cannot infer schema from empty vector or nullptr"); @@ -677,16 +722,19 @@ Result> RecordBatchReader::Make( schema = batches[0]->schema(); } - return std::make_shared(std::move(batches), std::move(schema)); + return std::make_shared(std::move(batches), std::move(schema), + device_type); } Result> RecordBatchReader::MakeFromIterator( - Iterator> batches, std::shared_ptr schema) { + Iterator> batches, std::shared_ptr schema, + DeviceAllocationType device_type) { if (schema == nullptr) { return Status::Invalid("Schema cannot be nullptr"); } - return std::make_shared(std::move(batches), std::move(schema)); + return std::make_shared(std::move(batches), std::move(schema), + device_type); } RecordBatchReader::~RecordBatchReader() { @@ -703,6 +751,10 @@ Result> ConcatenateRecordBatches( int cols = batches[0]->num_columns(); auto schema = batches[0]->schema(); for (size_t i = 0; i < batches.size(); ++i) { + if (auto sync = batches[i]->GetSyncEvent()) { + ARROW_RETURN_NOT_OK(sync->Wait()); + } + length += batches[i]->num_rows(); if (!schema->Equals(batches[i]->schema())) { return Status::Invalid( diff --git a/cpp/src/arrow/record_batch.h b/cpp/src/arrow/record_batch.h index cf3ca6f68f2..8123b9486d3 100644 --- a/cpp/src/arrow/record_batch.h +++ b/cpp/src/arrow/record_batch.h @@ -46,9 +46,10 @@ class ARROW_EXPORT RecordBatch { /// \param[in] num_rows length of fields in the record batch. Each array /// should have the same length as num_rows /// \param[in] columns the record batch fields as vector of arrays - static std::shared_ptr Make(std::shared_ptr schema, - int64_t num_rows, - std::vector> columns); + static std::shared_ptr Make( + std::shared_ptr schema, int64_t num_rows, + std::vector> columns, + std::shared_ptr sync_event = nullptr); /// \brief Construct record batch from vector of internal data structures /// \since 0.5.0 @@ -61,7 +62,9 @@ class ARROW_EXPORT RecordBatch { /// \param columns the data for the batch's columns static std::shared_ptr Make( std::shared_ptr schema, int64_t num_rows, - std::vector> columns); + std::vector> columns, + DeviceAllocationType device_type = DeviceAllocationType::kCPU, + std::shared_ptr sync_event = nullptr); /// \brief Create an empty RecordBatch of a given schema /// @@ -269,7 +272,9 @@ class ARROW_EXPORT RecordBatch { /// data the returned sync event will allow for it. /// /// \return null or a Device::SyncEvent - virtual std::shared_ptr GetSyncEvent(); + virtual std::shared_ptr GetSyncEvent() const; + + virtual DeviceAllocationType device_type() const; protected: RecordBatch(const std::shared_ptr& schema, int64_t num_rows); @@ -396,14 +401,16 @@ class ARROW_EXPORT RecordBatchReader { /// \param[in] schema schema to conform to. Will be inferred from the first /// element if not provided. static Result> Make( - RecordBatchVector batches, std::shared_ptr schema = NULLPTR); + RecordBatchVector batches, std::shared_ptr schema = NULLPTR, + DeviceAllocationType device_type = DeviceAllocationType::kCPU); /// \brief Create a RecordBatchReader from an Iterator of RecordBatch. /// /// \param[in] batches an iterator of RecordBatch to read from. /// \param[in] schema schema that each record batch in iterator will conform to. static Result> MakeFromIterator( - Iterator> batches, std::shared_ptr schema); + Iterator> batches, std::shared_ptr schema, + DeviceAllocationType device_type = DeviceAllocationType::kCPU); }; /// \brief Concatenate record batches From ba5847773b53cbdd16fcb1c629ca6f1e5ea75aa7 Mon Sep 17 00:00:00 2001 From: Matt Topol Date: Wed, 10 Apr 2024 17:51:22 -0400 Subject: [PATCH 15/41] linting --- cpp/src/arrow/c/bridge.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/arrow/c/bridge.cc b/cpp/src/arrow/c/bridge.cc index 1080c58220a..b3d9823572e 100644 --- a/cpp/src/arrow/c/bridge.cc +++ b/cpp/src/arrow/c/bridge.cc @@ -1496,7 +1496,7 @@ struct ArrayImporter { Status Import(struct ArrowDeviceArray* src, const DeviceMemoryMapper& mapper) { ARROW_ASSIGN_OR_RAISE(memory_mgr_, mapper(src->device_type, src->device_id)); - device_type_ = static_cast(src->device_type); + device_type_ = static_cast(src->device_type); RETURN_NOT_OK(Import(&src->array)); if (src->sync_event != nullptr) { ARROW_ASSIGN_OR_RAISE(import_->device_sync_, memory_mgr_->WrapDeviceSyncEvent( From 79b3600929ac4c581d5ec728fd5ec8daf6574655 Mon Sep 17 00:00:00 2001 From: Matt Topol Date: Thu, 11 Apr 2024 16:19:48 -0400 Subject: [PATCH 16/41] fix for lint --- cpp/src/arrow/record_batch.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/src/arrow/record_batch.h b/cpp/src/arrow/record_batch.h index 8123b9486d3..6d4273ce4a3 100644 --- a/cpp/src/arrow/record_batch.h +++ b/cpp/src/arrow/record_batch.h @@ -49,7 +49,7 @@ class ARROW_EXPORT RecordBatch { static std::shared_ptr Make( std::shared_ptr schema, int64_t num_rows, std::vector> columns, - std::shared_ptr sync_event = nullptr); + std::shared_ptr sync_event = NULLPTR); /// \brief Construct record batch from vector of internal data structures /// \since 0.5.0 @@ -64,7 +64,7 @@ class ARROW_EXPORT RecordBatch { std::shared_ptr schema, int64_t num_rows, std::vector> columns, DeviceAllocationType device_type = DeviceAllocationType::kCPU, - std::shared_ptr sync_event = nullptr); + std::shared_ptr sync_event = NULLPTR); /// \brief Create an empty RecordBatch of a given schema /// From f8c2b597d523fd4363706ade0dddf2085f5243f7 Mon Sep 17 00:00:00 2001 From: Matt Topol Date: Tue, 16 Apr 2024 10:32:32 -0400 Subject: [PATCH 17/41] more tests --- cpp/src/arrow/array/array_test.cc | 7 +- cpp/src/arrow/array/data.cc | 8 + cpp/src/arrow/c/bridge_test.cc | 333 ++++++++++++++++++++++++++++++ cpp/src/arrow/c/util_internal.h | 2 + cpp/src/arrow/record_batch.cc | 8 + 5 files changed, 357 insertions(+), 1 deletion(-) diff --git a/cpp/src/arrow/array/array_test.cc b/cpp/src/arrow/array/array_test.cc index 7e25ad61fa2..fbbeb212f6b 100644 --- a/cpp/src/arrow/array/array_test.cc +++ b/cpp/src/arrow/array/array_test.cc @@ -478,6 +478,7 @@ TEST_F(TestArray, TestMakeArrayOfNull) { ASSERT_EQ(array->type(), type); ASSERT_OK(array->ValidateFull()); ASSERT_EQ(array->length(), length); + ASSERT_EQ(array->device_type(), DeviceAllocationType::kCPU); if (is_union(type->id())) { ASSERT_EQ(array->null_count(), 0); ASSERT_EQ(array->ComputeLogicalNullCount(), length); @@ -719,6 +720,7 @@ TEST_F(TestArray, TestMakeArrayFromScalar) { ASSERT_OK(array->ValidateFull()); ASSERT_EQ(array->length(), length); ASSERT_EQ(array->null_count(), 0); + ASSERT_EQ(array->device_type(), DeviceAllocationType::kCPU); // test case for ARROW-13321 for (int64_t i : {int64_t{0}, length / 2, length - 1}) { @@ -744,6 +746,7 @@ TEST_F(TestArray, TestMakeArrayFromScalarSliced) { auto sliced = array->Slice(1, 4); ASSERT_EQ(sliced->length(), 4); ASSERT_EQ(sliced->null_count(), 0); + ASSERT_EQ(array->device_type(), DeviceAllocationType::kCPU); ARROW_EXPECT_OK(sliced->ValidateFull()); } } @@ -758,7 +761,8 @@ TEST_F(TestArray, TestMakeArrayFromDictionaryScalar) { ASSERT_OK(array->ValidateFull()); ASSERT_EQ(array->length(), 4); ASSERT_EQ(array->null_count(), 0); - + ASSERT_EQ(array->device_type(), DeviceAllocationType::kCPU); + for (int i = 0; i < 4; i++) { ASSERT_OK_AND_ASSIGN(auto item, array->GetScalar(i)); ASSERT_TRUE(item->Equals(scalar)); @@ -797,6 +801,7 @@ TEST_F(TestArray, TestMakeEmptyArray) { ASSERT_OK_AND_ASSIGN(auto array, MakeEmptyArray(type)); ASSERT_OK(array->ValidateFull()); ASSERT_EQ(array->length(), 0); + CheckSpanRoundTrip(*array); } } diff --git a/cpp/src/arrow/array/data.cc b/cpp/src/arrow/array/data.cc index a247811ed2d..72cea0f142f 100644 --- a/cpp/src/arrow/array/data.cc +++ b/cpp/src/arrow/array/data.cc @@ -243,6 +243,14 @@ DeviceAllocationType ArrayData::device_type() const { } } + if (dictionary) { + if (type == 0) { + type = static_cast(dictionary->device_type()); + } else { + DCHECK_EQ(type, static_cast(dictionary->device_type())); + } + } + return type == 0 ? DeviceAllocationType::kCPU : static_cast(type); } diff --git a/cpp/src/arrow/c/bridge_test.cc b/cpp/src/arrow/c/bridge_test.cc index d64fe67accd..f36c6994ddb 100644 --- a/cpp/src/arrow/c/bridge_test.cc +++ b/cpp/src/arrow/c/bridge_test.cc @@ -53,11 +53,15 @@ namespace arrow { +using internal::ArrayDeviceExportTraits; +using internal::ArrayDeviceStreamExportTraits; using internal::ArrayExportGuard; using internal::ArrayExportTraits; using internal::ArrayStreamExportGuard; using internal::ArrayStreamExportTraits; using internal::checked_cast; +using internal::DeviceArrayExportGuard; +using internal::DeviceArrayStreamExportGuard; using internal::SchemaExportGuard; using internal::SchemaExportTraits; using internal::Zip; @@ -4746,4 +4750,333 @@ TEST_F(TestArrayStreamRoundtrip, ChunkedArrayRoundtripEmpty) { }); } +//////////////////////////////////////////////////////////////////////////// +// Array device stream export tests + +class TestArrayDeviceStreamExport : public BaseArrayStreamTest { + public: + void AssertStreamSchema(struct ArrowDeviceArrayStream* c_stream, + const Schema& expected) { + struct ArrowSchema c_schema; + ASSERT_EQ(0, c_stream->get_schema(c_stream, &c_schema)); + + SchemaExportGuard schema_guard(&c_schema); + ASSERT_FALSE(ArrowSchemaIsReleased(&c_schema)); + ASSERT_OK_AND_ASSIGN(auto schema, ImportSchema(&c_schema)); + AssertSchemaEqual(expected, *schema, /*check_metadata=*/true); + } + + void AssertStreamEnd(struct ArrowDeviceArrayStream* c_stream) { + struct ArrowDeviceArray c_array; + ASSERT_EQ(0, c_stream->get_next(c_stream, &c_array)); + + DeviceArrayExportGuard guard(&c_array); + ASSERT_TRUE(ArrowDeviceArrayIsReleased(&c_array)); + } + + void AssertStreamNext(struct ArrowDeviceArrayStream* c_stream, + const RecordBatch& expected) { + struct ArrowDeviceArray c_array; + ASSERT_EQ(0, c_stream->get_next(c_stream, &c_array)); + + DeviceArrayExportGuard guard(&c_array); + ASSERT_FALSE(ArrowDeviceArrayIsReleased(&c_array)); + + ASSERT_OK_AND_ASSIGN(auto batch, + ImportDeviceRecordBatch(&c_array, expected.schema(), + TestDeviceArrayRoundtrip::DeviceMapper)); + AssertBatchesEqual(expected, *batch); + } + + void AssertStreamNext(struct ArrowDeviceArrayStream* c_stream, const Array& expected) { + struct ArrowDeviceArray c_array; + ASSERT_EQ(0, c_stream->get_next(c_stream, &c_array)); + + DeviceArrayExportGuard guard(&c_array); + ASSERT_FALSE(ArrowDeviceArrayIsReleased(&c_array)); + + ASSERT_OK_AND_ASSIGN(auto array, + ImportDeviceArray(&c_array, expected.type(), + TestDeviceArrayRoundtrip::DeviceMapper)); + AssertArraysEqual(expected, *array); + } + + static Result> ToDeviceData( + const std::shared_ptr& mm, const ArrayData& data) { + arrow::BufferVector buffers; + for (const auto& buf : data.buffers) { + if (buf) { + ARROW_ASSIGN_OR_RAISE(auto dest, mm->CopyBuffer(buf, mm)); + buffers.push_back(dest); + } else { + buffers.push_back(nullptr); + } + } + + arrow::ArrayDataVector children; + for (const auto& child : data.child_data) { + ARROW_ASSIGN_OR_RAISE(auto dest, ToDeviceData(mm, *child)); + children.push_back(dest); + } + + return ArrayData::Make(data.type, data.length, buffers, children, data.null_count, + data.offset); + } + + static Result> ToDevice(const std::shared_ptr& mm, + const ArrayData& data) { + ARROW_ASSIGN_OR_RAISE(auto result, ToDeviceData(mm, data)); + return MakeArray(result); + } +}; + +TEST_F(TestArrayDeviceStreamExport, Empty) { + auto schema = arrow::schema({field("ints", int32())}); + auto batches = MakeBatches(schema, {}); + ASSERT_OK_AND_ASSIGN( + auto reader, + RecordBatchReader::Make(batches, schema, + static_cast(kMyDeviceType))); + + struct ArrowDeviceArrayStream c_stream; + ASSERT_OK(ExportDeviceRecordBatchReader(reader, &c_stream)); + DeviceArrayStreamExportGuard guard(&c_stream); + + ASSERT_FALSE(ArrowDeviceArrayStreamIsReleased(&c_stream)); + ASSERT_EQ(kMyDeviceType, c_stream.device_type); + AssertStreamSchema(&c_stream, *schema); + AssertStreamEnd(&c_stream); + AssertStreamEnd(&c_stream); +} + +TEST_F(TestArrayDeviceStreamExport, Simple) { + std::shared_ptr device = std::make_shared(1); + auto mm = device->default_memory_manager(); + + ASSERT_OK_AND_ASSIGN(auto arr1, + ToDevice(mm, *ArrayFromJSON(int32(), "[1, 2]")->data())); + ASSERT_EQ(device->device_type(), arr1->device_type()); + ASSERT_OK_AND_ASSIGN(auto arr2, + ToDevice(mm, *ArrayFromJSON(int32(), "[4, 5, null]")->data())); + ASSERT_EQ(device->device_type(), arr2->device_type()); + auto schema = arrow::schema({field("ints", int32())}); + auto batches = MakeBatches(schema, {arr1, arr2}); + ASSERT_OK_AND_ASSIGN(auto reader, + RecordBatchReader::Make(batches, schema, device->device_type())); + + struct ArrowDeviceArrayStream c_stream; + + ASSERT_OK(ExportDeviceRecordBatchReader(reader, &c_stream)); + DeviceArrayStreamExportGuard guard(&c_stream); + + ASSERT_FALSE(ArrowDeviceArrayStreamIsReleased(&c_stream)); + AssertStreamSchema(&c_stream, *schema); + ASSERT_EQ(kMyDeviceType, c_stream.device_type); + AssertStreamNext(&c_stream, *batches[0]); + AssertStreamNext(&c_stream, *batches[1]); + AssertStreamEnd(&c_stream); + AssertStreamEnd(&c_stream); +} + +TEST_F(TestArrayDeviceStreamExport, ArrayLifetime) { + std::shared_ptr device = std::make_shared(1); + auto mm = device->default_memory_manager(); + + ASSERT_OK_AND_ASSIGN(auto arr1, + ToDevice(mm, *ArrayFromJSON(int32(), "[1, 2]")->data())); + ASSERT_EQ(device->device_type(), arr1->device_type()); + ASSERT_OK_AND_ASSIGN(auto arr2, + ToDevice(mm, *ArrayFromJSON(int32(), "[4, 5, null]")->data())); + ASSERT_EQ(device->device_type(), arr2->device_type()); + auto schema = arrow::schema({field("ints", int32())}); + auto batches = MakeBatches(schema, {arr1, arr2}); + ASSERT_OK_AND_ASSIGN(auto reader, + RecordBatchReader::Make(batches, schema, device->device_type())); + + struct ArrowDeviceArrayStream c_stream; + struct ArrowSchema c_schema; + struct ArrowDeviceArray c_array0, c_array1; + + ASSERT_OK(ExportDeviceRecordBatchReader(reader, &c_stream)); + { + DeviceArrayStreamExportGuard guard(&c_stream); + ASSERT_FALSE(ArrowDeviceArrayStreamIsReleased(&c_stream)); + ASSERT_EQ(kMyDeviceType, c_stream.device_type); + + ASSERT_EQ(0, c_stream.get_schema(&c_stream, &c_schema)); + ASSERT_EQ(0, c_stream.get_next(&c_stream, &c_array0)); + ASSERT_EQ(0, c_stream.get_next(&c_stream, &c_array1)); + AssertStreamEnd(&c_stream); + } + + DeviceArrayExportGuard guard0(&c_array0), guard1(&c_array1); + + { + SchemaExportGuard schema_guard(&c_schema); + ASSERT_OK_AND_ASSIGN(auto got_schema, ImportSchema(&c_schema)); + AssertSchemaEqual(*schema, *got_schema, /*check_metadata=*/true); + } + + ASSERT_EQ(kMyDeviceType, c_array0.device_type); + ASSERT_EQ(kMyDeviceType, c_array1.device_type); + + ASSERT_GT(pool_->bytes_allocated(), orig_allocated_); + ASSERT_OK_AND_ASSIGN( + auto batch, + ImportDeviceRecordBatch(&c_array1, schema, TestDeviceArrayRoundtrip::DeviceMapper)); + AssertBatchesEqual(*batches[1], *batch); + ASSERT_EQ(device->device_type(), batch->device_type()); + ASSERT_OK_AND_ASSIGN( + batch, + ImportDeviceRecordBatch(&c_array0, schema, TestDeviceArrayRoundtrip::DeviceMapper)); + AssertBatchesEqual(*batches[0], *batch); + ASSERT_EQ(device->device_type(), batch->device_type()); +} + +TEST_F(TestArrayDeviceStreamExport, Errors) { + auto reader = + std::make_shared(Status::Invalid("some example error")); + + struct ArrowDeviceArrayStream c_stream; + + ASSERT_OK(ExportDeviceRecordBatchReader(reader, &c_stream)); + DeviceArrayStreamExportGuard guard(&c_stream); + + struct ArrowSchema c_schema; + ASSERT_EQ(0, c_stream.get_schema(&c_stream, &c_schema)); + ASSERT_FALSE(ArrowSchemaIsReleased(&c_schema)); + { + SchemaExportGuard schema_guard(&c_schema); + ASSERT_OK_AND_ASSIGN(auto schema, ImportSchema(&c_schema)); + AssertSchemaEqual(schema, arrow::schema({}), /*check_metadata=*/true); + } + + struct ArrowDeviceArray c_array; + ASSERT_EQ(EINVAL, c_stream.get_next(&c_stream, &c_array)); +} + +TEST_F(TestArrayDeviceStreamExport, ChunkedArrayExportEmpty) { + ASSERT_OK_AND_ASSIGN(auto chunked_array, ChunkedArray::Make({}, int32())); + + struct ArrowDeviceArrayStream c_stream; + struct ArrowSchema c_schema; + + ASSERT_OK(ExportDeviceChunkedArray( + chunked_array, static_cast(kMyDeviceType), &c_stream)); + DeviceArrayStreamExportGuard guard(&c_stream); + + { + DeviceArrayStreamExportGuard guard(&c_stream); + ASSERT_FALSE(ArrowDeviceArrayStreamIsReleased(&c_stream)); + + ASSERT_EQ(kMyDeviceType, c_stream.device_type); + ASSERT_EQ(0, c_stream.get_schema(&c_stream, &c_schema)); + AssertStreamEnd(&c_stream); + } + + { + SchemaExportGuard schema_guard(&c_schema); + ASSERT_OK_AND_ASSIGN(auto got_type, ImportType(&c_schema)); + AssertTypeEqual(*chunked_array->type(), *got_type); + } +} + +TEST_F(TestArrayDeviceStreamExport, ChunkedArrayExport) { + std::shared_ptr device = std::make_shared(1); + auto mm = device->default_memory_manager(); + + ASSERT_OK_AND_ASSIGN(auto arr1, + ToDevice(mm, *ArrayFromJSON(int32(), "[1, 2]")->data())); + ASSERT_EQ(device->device_type(), arr1->device_type()); + ASSERT_OK_AND_ASSIGN(auto arr2, + ToDevice(mm, *ArrayFromJSON(int32(), "[4, 5, null]")->data())); + ASSERT_EQ(device->device_type(), arr2->device_type()); + + ASSERT_OK_AND_ASSIGN(auto chunked_array, ChunkedArray::Make({arr1, arr2})); + + struct ArrowDeviceArrayStream c_stream; + struct ArrowSchema c_schema; + struct ArrowDeviceArray c_array0, c_array1; + + ASSERT_OK(ExportDeviceChunkedArray(chunked_array, device->device_type(), &c_stream)); + DeviceArrayStreamExportGuard guard(&c_stream); + + { + DeviceArrayStreamExportGuard guard(&c_stream); + ASSERT_FALSE(ArrowDeviceArrayStreamIsReleased(&c_stream)); + ASSERT_EQ(kMyDeviceType, c_stream.device_type); + + ASSERT_EQ(0, c_stream.get_schema(&c_stream, &c_schema)); + ASSERT_EQ(0, c_stream.get_next(&c_stream, &c_array0)); + ASSERT_EQ(0, c_stream.get_next(&c_stream, &c_array1)); + AssertStreamEnd(&c_stream); + } + + DeviceArrayExportGuard guard0(&c_array0), guard1(&c_array1); + + { + SchemaExportGuard schema_guard(&c_schema); + ASSERT_OK_AND_ASSIGN(auto got_type, ImportType(&c_schema)); + AssertTypeEqual(*chunked_array->type(), *got_type); + } + + ASSERT_EQ(kMyDeviceType, c_array0.device_type); + ASSERT_EQ(kMyDeviceType, c_array1.device_type); + + ASSERT_GT(pool_->bytes_allocated(), orig_allocated_); + ASSERT_OK_AND_ASSIGN(auto array, + ImportDeviceArray(&c_array0, chunked_array->type(), + TestDeviceArrayRoundtrip::DeviceMapper)); + ASSERT_EQ(device->device_type(), array->device_type()); + AssertArraysEqual(*chunked_array->chunk(0), *array); + ASSERT_OK_AND_ASSIGN(array, ImportDeviceArray(&c_array1, chunked_array->type(), + TestDeviceArrayRoundtrip::DeviceMapper)); + ASSERT_EQ(device->device_type(), array->device_type()); + AssertArraysEqual(*chunked_array->chunk(1), *array); +} + +//////////////////////////////////////////////////////////////////////////// +// Array device stream roundtrip tests + +class TestArrayDeviceStreamRoundtrip : public BaseArrayStreamTest { + public: + void Roundtrip(std::shared_ptr* reader, + struct ArrowDeviceArrayStream* c_stream) { + ASSERT_OK(ExportDeviceRecordBatchReader(*reader, c_stream)); + ASSERT_FALSE(ArrowDeviceArrayStreamIsReleased(c_stream)); + + ASSERT_OK_AND_ASSIGN( + auto got_reader, + ImportDeviceRecordBatchReader(c_stream, TestDeviceArrayRoundtrip::DeviceMapper)); + *reader = std::move(got_reader); + } + + void Roundtrip( + std::shared_ptr reader, + std::function&)> check_func) { + ArrowDeviceArrayStream c_stream; + + // NOTE: ReleaseCallback<> is not immediately usable with ArrowDeviceArayStream + // because get_next and get_schema need the original private_data. + std::weak_ptr weak_reader(reader); + ASSERT_EQ(weak_reader.use_count(), 1); // Expiration check will fail otherwise + + ASSERT_OK(ExportDeviceRecordBatchReader(std::move(reader), &c_stream)); + ASSERT_FALSE(ArrowDeviceArrayStreamIsReleased(&c_stream)); + + { + ASSERT_OK_AND_ASSIGN(auto new_reader, + ImportDeviceRecordBatchReader( + &c_stream, TestDeviceArrayRoundtrip::DeviceMapper)); + // stream was moved + ASSERT_TRUE(ArrowDeviceArrayStreamIsReleased(&c_stream)); + ASSERT_FALSE(weak_reader.expired()); + + check_func(new_reader); + } + // Stream was released when `new_reader` was destroyed + ASSERT_TRUE(weak_reader.expired()); + } +}; + } // namespace arrow diff --git a/cpp/src/arrow/c/util_internal.h b/cpp/src/arrow/c/util_internal.h index 8f292b06249..dc0e25710e9 100644 --- a/cpp/src/arrow/c/util_internal.h +++ b/cpp/src/arrow/c/util_internal.h @@ -99,7 +99,9 @@ class ExportGuard { using SchemaExportGuard = ExportGuard; using ArrayExportGuard = ExportGuard; +using DeviceArrayExportGuard = ExportGuard; using ArrayStreamExportGuard = ExportGuard; +using DeviceArrayStreamExportGuard = ExportGuard; } // namespace internal } // namespace arrow diff --git a/cpp/src/arrow/record_batch.cc b/cpp/src/arrow/record_batch.cc index 9bd3e2ae871..84ebbd2f219 100644 --- a/cpp/src/arrow/record_batch.cc +++ b/cpp/src/arrow/record_batch.cc @@ -498,6 +498,10 @@ bool RecordBatch::Equals(const RecordBatch& other, bool check_metadata, return false; } + if (device_type() != other.device_type()) { + return false; + } + for (int i = 0; i < num_columns(); ++i) { if (!column(i)->Equals(other.column(i), opts)) { return false; @@ -512,6 +516,10 @@ bool RecordBatch::ApproxEquals(const RecordBatch& other, const EqualOptions& opt return false; } + if (device_type() != other.device_type()) { + return false; + } + for (int i = 0; i < num_columns(); ++i) { if (!column(i)->ApproxEquals(other.column(i), opts)) { return false; From fcd8dc9137b632b27cc60151b97857eac90ba054 Mon Sep 17 00:00:00 2001 From: Matt Topol Date: Tue, 16 Apr 2024 16:52:34 -0400 Subject: [PATCH 18/41] added tests --- cpp/src/arrow/c/bridge_test.cc | 197 +++++++++++++++++++++++++++++++++ 1 file changed, 197 insertions(+) diff --git a/cpp/src/arrow/c/bridge_test.cc b/cpp/src/arrow/c/bridge_test.cc index f36c6994ddb..78ba50cd27f 100644 --- a/cpp/src/arrow/c/bridge_test.cc +++ b/cpp/src/arrow/c/bridge_test.cc @@ -5040,6 +5040,34 @@ TEST_F(TestArrayDeviceStreamExport, ChunkedArrayExport) { class TestArrayDeviceStreamRoundtrip : public BaseArrayStreamTest { public: + static Result> ToDeviceData( + const std::shared_ptr& mm, const ArrayData& data) { + arrow::BufferVector buffers; + for (const auto& buf : data.buffers) { + if (buf) { + ARROW_ASSIGN_OR_RAISE(auto dest, mm->CopyBuffer(buf, mm)); + buffers.push_back(dest); + } else { + buffers.push_back(nullptr); + } + } + + arrow::ArrayDataVector children; + for (const auto& child : data.child_data) { + ARROW_ASSIGN_OR_RAISE(auto dest, ToDeviceData(mm, *child)); + children.push_back(dest); + } + + return ArrayData::Make(data.type, data.length, buffers, children, data.null_count, + data.offset); + } + + static Result> ToDevice(const std::shared_ptr& mm, + const ArrayData& data) { + ARROW_ASSIGN_OR_RAISE(auto result, ToDeviceData(mm, data)); + return MakeArray(result); + } + void Roundtrip(std::shared_ptr* reader, struct ArrowDeviceArrayStream* c_stream) { ASSERT_OK(ExportDeviceRecordBatchReader(*reader, c_stream)); @@ -5077,6 +5105,175 @@ class TestArrayDeviceStreamRoundtrip : public BaseArrayStreamTest { // Stream was released when `new_reader` was destroyed ASSERT_TRUE(weak_reader.expired()); } + + void Roundtrip(std::shared_ptr src, + std::function&)> check_func) { + ArrowDeviceArrayStream c_stream; + + // One original copy to compare the result, one copy held by the stream + std::weak_ptr weak_src(src); + int64_t initial_use_count = weak_src.use_count(); + + ASSERT_OK(ExportDeviceChunkedArray( + std::move(src), static_cast(kMyDeviceType), &c_stream)); + ASSERT_FALSE(ArrowDeviceArrayStreamIsReleased(&c_stream)); + ASSERT_EQ(kMyDeviceType, c_stream.device_type); + + { + ASSERT_OK_AND_ASSIGN( + auto dst, + ImportDeviceChunkedArray(&c_stream, TestDeviceArrayRoundtrip::DeviceMapper)); + // Stream was moved, consumed, and released + ASSERT_TRUE(ArrowDeviceArrayStreamIsReleased(&c_stream)); + + // Stream was released by ImportDeviceChunkedArray but original copy remains + ASSERT_EQ(weak_src.use_count(), initial_use_count - 1); + + check_func(dst); + } + } + + void AssertReaderNext(const std::shared_ptr& reader, + const RecordBatch& expected) { + ASSERT_OK_AND_ASSIGN(auto batch, reader->Next()); + ASSERT_NE(batch, nullptr); + ASSERT_EQ(static_cast(kMyDeviceType), batch->device_type()); + AssertBatchesEqual(expected, *batch); + } + + void AssertReaderEnd(const std::shared_ptr& reader) { + ASSERT_OK_AND_ASSIGN(auto batch, reader->Next()); + ASSERT_EQ(batch, nullptr); + } + + void AssertReaderClosed(const std::shared_ptr& reader) { + ASSERT_THAT(reader->Next(), + Raises(StatusCode::Invalid, ::testing::HasSubstr("already been closed"))); + } + + void AssertReaderClose(const std::shared_ptr& reader) { + ASSERT_OK(reader->Close()); + AssertReaderClosed(reader); + } }; +TEST_F(TestArrayDeviceStreamRoundtrip, Simple) { + std::shared_ptr device = std::make_shared(1); + auto mm = device->default_memory_manager(); + + ASSERT_OK_AND_ASSIGN(auto arr1, + ToDevice(mm, *ArrayFromJSON(int32(), "[1, 2]")->data())); + ASSERT_EQ(device->device_type(), arr1->device_type()); + ASSERT_OK_AND_ASSIGN(auto arr2, + ToDevice(mm, *ArrayFromJSON(int32(), "[4, 5, null]")->data())); + ASSERT_EQ(device->device_type(), arr2->device_type()); + auto orig_schema = arrow::schema({field("ints", int32())}); + auto batches = MakeBatches(orig_schema, {arr1, arr2}); + ASSERT_OK_AND_ASSIGN( + auto reader, RecordBatchReader::Make(batches, orig_schema, device->device_type())); + + Roundtrip(std::move(reader), [&](const std::shared_ptr& reader) { + AssertSchemaEqual(*orig_schema, *reader->schema(), /*check_metadata=*/true); + AssertReaderNext(reader, *batches[0]); + AssertReaderNext(reader, *batches[1]); + AssertReaderEnd(reader); + AssertReaderEnd(reader); + AssertReaderClose(reader); + }); +} + +TEST_F(TestArrayDeviceStreamRoundtrip, CloseEarly) { + std::shared_ptr device = std::make_shared(1); + auto mm = device->default_memory_manager(); + + ASSERT_OK_AND_ASSIGN(auto arr1, + ToDevice(mm, *ArrayFromJSON(int32(), "[1, 2]")->data())); + ASSERT_EQ(device->device_type(), arr1->device_type()); + ASSERT_OK_AND_ASSIGN(auto arr2, + ToDevice(mm, *ArrayFromJSON(int32(), "[4, 5, null]")->data())); + ASSERT_EQ(device->device_type(), arr2->device_type()); + auto orig_schema = arrow::schema({field("ints", int32())}); + auto batches = MakeBatches(orig_schema, {arr1, arr2}); + ASSERT_OK_AND_ASSIGN( + auto reader, RecordBatchReader::Make(batches, orig_schema, device->device_type())); + + Roundtrip(std::move(reader), [&](const std::shared_ptr& reader) { + AssertReaderNext(reader, *batches[0]); + AssertReaderClose(reader); + }); +} + +TEST_F(TestArrayDeviceStreamRoundtrip, Errors) { + auto reader = std::make_shared( + Status::Invalid("roundtrip error example")); + + Roundtrip(std::move(reader), [&](const std::shared_ptr& reader) { + auto status = reader->Next().status(); + ASSERT_RAISES(Invalid, status); + ASSERT_THAT(status.message(), ::testing::HasSubstr("roundtrip error example")); + }); +} + +TEST_F(TestArrayDeviceStreamRoundtrip, SchemaError) { + struct StreamState { + bool released = false; + + static const char* GetLastError(struct ArrowDeviceArrayStream* stream) { + return "Expected error"; + } + + static int GetSchema(struct ArrowDeviceArrayStream* stream, + struct ArrowSchema* schema) { + return EIO; + } + + static int GetNext(struct ArrowDeviceArrayStream* stream, struct ArrowDeviceArray* array) { + return EINVAL; + } + + static void Release(struct ArrowDeviceArrayStream* stream) { + reinterpret_cast(stream->private_data)->released = true; + std::memset(stream, 0, sizeof(*stream)); + } + } state; + struct ArrowDeviceArrayStream stream = {}; + stream.get_last_error = &StreamState::GetLastError; + stream.get_schema = &StreamState::GetSchema; + stream.get_next = &StreamState::GetNext; + stream.release = &StreamState::Release; + stream.private_data = &state; + + EXPECT_RAISES_WITH_MESSAGE_THAT(IOError, ::testing::HasSubstr("Expected error"), + ImportDeviceRecordBatchReader(&stream)); + ASSERT_TRUE(state.released); +} + +TEST_F(TestArrayDeviceStreamRoundtrip, ChunkedArrayRoundtrip) { + std::shared_ptr device = std::make_shared(1); + auto mm = device->default_memory_manager(); + + ASSERT_OK_AND_ASSIGN(auto arr1, + ToDevice(mm, *ArrayFromJSON(int32(), "[1, 2]")->data())); + ASSERT_EQ(device->device_type(), arr1->device_type()); + ASSERT_OK_AND_ASSIGN(auto arr2, + ToDevice(mm, *ArrayFromJSON(int32(), "[4, 5, null]")->data())); + ASSERT_EQ(device->device_type(), arr2->device_type()); + + ASSERT_OK_AND_ASSIGN(auto src, ChunkedArray::Make({arr1, arr2})); + + Roundtrip(src, [&](const std::shared_ptr& dst) { + AssertTypeEqual(*dst->type(), *src->type()); + AssertChunkedEqual(*dst, *src); + }); +} + +TEST_F(TestArrayDeviceStreamRoundtrip, ChunkedArrayRoundtripEmpty) { + ASSERT_OK_AND_ASSIGN(auto src, ChunkedArray::Make({}, int32())); + + Roundtrip(src, [&](const std::shared_ptr& dst) { + AssertTypeEqual(*dst->type(), *src->type()); + AssertChunkedEqual(*dst, *src); + }); +} + } // namespace arrow From dca4503405d2f2eb77b71c0b49824e28bab18e8e Mon Sep 17 00:00:00 2001 From: Matt Topol Date: Thu, 25 Apr 2024 16:39:37 -0400 Subject: [PATCH 19/41] fix lint --- cpp/src/arrow/array/array_test.cc | 4 ++-- cpp/src/arrow/c/bridge_test.cc | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/cpp/src/arrow/array/array_test.cc b/cpp/src/arrow/array/array_test.cc index fbbeb212f6b..32806d9d2ed 100644 --- a/cpp/src/arrow/array/array_test.cc +++ b/cpp/src/arrow/array/array_test.cc @@ -762,7 +762,7 @@ TEST_F(TestArray, TestMakeArrayFromDictionaryScalar) { ASSERT_EQ(array->length(), 4); ASSERT_EQ(array->null_count(), 0); ASSERT_EQ(array->device_type(), DeviceAllocationType::kCPU); - + for (int i = 0; i < 4; i++) { ASSERT_OK_AND_ASSIGN(auto item, array->GetScalar(i)); ASSERT_TRUE(item->Equals(scalar)); @@ -801,7 +801,7 @@ TEST_F(TestArray, TestMakeEmptyArray) { ASSERT_OK_AND_ASSIGN(auto array, MakeEmptyArray(type)); ASSERT_OK(array->ValidateFull()); ASSERT_EQ(array->length(), 0); - + CheckSpanRoundTrip(*array); } } diff --git a/cpp/src/arrow/c/bridge_test.cc b/cpp/src/arrow/c/bridge_test.cc index 78ba50cd27f..6f435b5bb9c 100644 --- a/cpp/src/arrow/c/bridge_test.cc +++ b/cpp/src/arrow/c/bridge_test.cc @@ -5227,7 +5227,8 @@ TEST_F(TestArrayDeviceStreamRoundtrip, SchemaError) { return EIO; } - static int GetNext(struct ArrowDeviceArrayStream* stream, struct ArrowDeviceArray* array) { + static int GetNext(struct ArrowDeviceArrayStream* stream, + struct ArrowDeviceArray* array) { return EINVAL; } From 3211ef325ddaf07abc27f7fab8b75bcb1436be8a Mon Sep 17 00:00:00 2001 From: Matt Topol Date: Thu, 25 Apr 2024 17:19:35 -0400 Subject: [PATCH 20/41] fix docs --- cpp/src/arrow/record_batch.h | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/cpp/src/arrow/record_batch.h b/cpp/src/arrow/record_batch.h index 6d4273ce4a3..d49c6201418 100644 --- a/cpp/src/arrow/record_batch.h +++ b/cpp/src/arrow/record_batch.h @@ -46,6 +46,8 @@ class ARROW_EXPORT RecordBatch { /// \param[in] num_rows length of fields in the record batch. Each array /// should have the same length as num_rows /// \param[in] columns the record batch fields as vector of arrays + /// \param[in] sync_event optional synchronization event for non-CPU device + /// memory used by buffers static std::shared_ptr Make( std::shared_ptr schema, int64_t num_rows, std::vector> columns, @@ -60,6 +62,8 @@ class ARROW_EXPORT RecordBatch { /// \param num_rows the number of semantic rows in the record batch. This /// should be equal to the length of each field /// \param columns the data for the batch's columns + /// \param[in] sync_event optional synchronization event for non-CPU device + /// memory used by buffers static std::shared_ptr Make( std::shared_ptr schema, int64_t num_rows, std::vector> columns, From ff521ab25f135868942e425d705cd227b936c8b1 Mon Sep 17 00:00:00 2001 From: Matt Topol Date: Thu, 25 Apr 2024 17:39:31 -0400 Subject: [PATCH 21/41] more doxygen fixes --- cpp/src/arrow/record_batch.h | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/cpp/src/arrow/record_batch.h b/cpp/src/arrow/record_batch.h index d49c6201418..33d2665a9ae 100644 --- a/cpp/src/arrow/record_batch.h +++ b/cpp/src/arrow/record_batch.h @@ -62,7 +62,9 @@ class ARROW_EXPORT RecordBatch { /// \param num_rows the number of semantic rows in the record batch. This /// should be equal to the length of each field /// \param columns the data for the batch's columns - /// \param[in] sync_event optional synchronization event for non-CPU device + /// \param device_type the type of the device that the Arrow columns are + /// allocated on + /// \param sync_event optional synchronization event for non-CPU device /// memory used by buffers static std::shared_ptr Make( std::shared_ptr schema, int64_t num_rows, @@ -404,6 +406,7 @@ class ARROW_EXPORT RecordBatchReader { /// \param[in] batches the vector of RecordBatch to read from /// \param[in] schema schema to conform to. Will be inferred from the first /// element if not provided. + /// \param[in] device_type the type of device that the batches are allocated on static Result> Make( RecordBatchVector batches, std::shared_ptr schema = NULLPTR, DeviceAllocationType device_type = DeviceAllocationType::kCPU); @@ -412,6 +415,7 @@ class ARROW_EXPORT RecordBatchReader { /// /// \param[in] batches an iterator of RecordBatch to read from. /// \param[in] schema schema that each record batch in iterator will conform to. + /// \param[in] device_type the type of device that the batches are allocated on static Result> MakeFromIterator( Iterator> batches, std::shared_ptr schema, DeviceAllocationType device_type = DeviceAllocationType::kCPU); From aeff09f83862edcc47decae3a2e62dfd2a64a202 Mon Sep 17 00:00:00 2001 From: Matt Topol Date: Fri, 26 Apr 2024 11:19:02 -0400 Subject: [PATCH 22/41] updates from feeback --- cpp/src/arrow/c/bridge.cc | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/cpp/src/arrow/c/bridge.cc b/cpp/src/arrow/c/bridge.cc index b3d9823572e..d4f36317f7f 100644 --- a/cpp/src/arrow/c/bridge.cc +++ b/cpp/src/arrow/c/bridge.cc @@ -2044,7 +2044,11 @@ Status ExportStreamNext(const std::shared_ptr& src, int64_t i } } -Status ExportDeviceStreamNext(const std::shared_ptr& src, int64_t, +// the int64_t i input here is unused, but exists simply to allow utilizing the +// overload of this with the version for ChunkedArrays. If we removed the int64_t +// from the signature despite it being unused, we wouldn't be able to leverage the +// overloading in the templated exporters. +Status ExportDeviceStreamNext(const std::shared_ptr& src, int64_t i, struct ArrowDeviceArray* out_array) { std::shared_ptr batch; RETURN_NOT_OK(src->ReadNext(&batch)); @@ -2069,15 +2073,13 @@ Status ExportStreamNext(const std::shared_ptr& src, int64_t i, } Status ExportDeviceStreamNext(const std::shared_ptr& src, int64_t i, - std::shared_ptr sync, struct ArrowDeviceArray* out_array) { if (i >= src->num_chunks()) { // End of stream ArrowArrayMarkReleased(&out_array->array); return Status::OK(); } else { - return ExportDeviceArray(*src->chunk(static_cast(i)), std::move(sync), - out_array); + return ExportDeviceArray(*src->chunk(static_cast(i)), nullptr, out_array); } } @@ -2105,8 +2107,6 @@ class ExportedArrayStream { Status GetNext(ArrayType* out_array) { if constexpr (std::is_same_v) { return ExportStreamNext(reader(), next_batch_num(), out_array); - } else if constexpr (std::is_same_v) { - return ExportDeviceStreamNext(reader(), next_batch_num(), nullptr, out_array); } else { return ExportDeviceStreamNext(reader(), next_batch_num(), out_array); } From 35ef670cd79645a8d1c8ec55d7e0f650de123b36 Mon Sep 17 00:00:00 2001 From: Matt Topol Date: Fri, 26 Apr 2024 11:45:50 -0400 Subject: [PATCH 23/41] add comment about device types --- cpp/src/arrow/array/data.cc | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/cpp/src/arrow/array/data.cc b/cpp/src/arrow/array/data.cc index 72cea0f142f..ada2668630e 100644 --- a/cpp/src/arrow/array/data.cc +++ b/cpp/src/arrow/array/data.cc @@ -225,6 +225,11 @@ int64_t ArrayData::ComputeLogicalNullCount() const { } DeviceAllocationType ArrayData::device_type() const { + // we're using 0 as a sentinel value for NOT YET ASSIGNED + // there is explicitly no constant DeviceAllocationType to represent + // the "UNASSIGNED" case as it is invalid for data to not have an + // assigned device type. If it's still 0 at the end, then we return + // CPU as the allocation device type int type = 0; for (const auto& buf : buffers) { if (!buf) continue; From 8443150700998bd8116a168357d9859c872f3360 Mon Sep 17 00:00:00 2001 From: Matt Topol Date: Fri, 26 Apr 2024 11:49:22 -0400 Subject: [PATCH 24/41] lint --- cpp/src/arrow/record_batch.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/arrow/record_batch.h b/cpp/src/arrow/record_batch.h index 33d2665a9ae..f2f92e6bdcf 100644 --- a/cpp/src/arrow/record_batch.h +++ b/cpp/src/arrow/record_batch.h @@ -62,7 +62,7 @@ class ARROW_EXPORT RecordBatch { /// \param num_rows the number of semantic rows in the record batch. This /// should be equal to the length of each field /// \param columns the data for the batch's columns - /// \param device_type the type of the device that the Arrow columns are + /// \param device_type the type of the device that the Arrow columns are /// allocated on /// \param sync_event optional synchronization event for non-CPU device /// memory used by buffers From 406ecf1ca887981b88ffa01de0d39e9c93b495b9 Mon Sep 17 00:00:00 2001 From: Matt Topol Date: Mon, 29 Apr 2024 17:49:21 -0400 Subject: [PATCH 25/41] Update cpp/src/arrow/c/bridge.h Co-authored-by: Felipe Oliveira Carvalho --- cpp/src/arrow/c/bridge.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/arrow/c/bridge.h b/cpp/src/arrow/c/bridge.h index f2090ee8a6a..45367e4f930 100644 --- a/cpp/src/arrow/c/bridge.h +++ b/cpp/src/arrow/c/bridge.h @@ -333,7 +333,7 @@ ARROW_EXPORT Status ExportDeviceRecordBatchReader(std::shared_ptr reader, struct ArrowDeviceArrayStream* out); -/// \brief Export C++ ChunkedArray using the c device data interface format. +/// \brief Export C++ ChunkedArray using the C device data interface format. /// /// The resulting ArrowDeviceArrayStream keeps the chunked array data and buffers /// alive until its release callback is called by the consumer. From c22768aae1a7614e7e1e470927ead89cb1c15db9 Mon Sep 17 00:00:00 2001 From: Matt Topol Date: Mon, 29 Apr 2024 17:54:21 -0400 Subject: [PATCH 26/41] update from feedback --- cpp/src/arrow/record_batch.cc | 8 ++++---- cpp/src/arrow/record_batch.h | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/cpp/src/arrow/record_batch.cc b/cpp/src/arrow/record_batch.cc index 84ebbd2f219..f6ebea1e11c 100644 --- a/cpp/src/arrow/record_batch.cc +++ b/cpp/src/arrow/record_batch.cc @@ -188,7 +188,7 @@ class SimpleRecordBatch : public RecordBatch { return RecordBatch::Validate(); } - std::shared_ptr GetSyncEvent() const override { return sync_event_; } + const std::shared_ptr& GetSyncEvent() const override { return sync_event_; } DeviceAllocationType device_type() const override { return device_type_; } @@ -213,7 +213,7 @@ std::shared_ptr RecordBatch::Make( std::shared_ptr sync_event) { DCHECK_EQ(schema->num_fields(), static_cast(columns.size())); return std::make_shared(std::move(schema), num_rows, - std::move(columns), sync_event); + std::move(columns), std::move(sync_event)); } std::shared_ptr RecordBatch::Make( @@ -222,7 +222,7 @@ std::shared_ptr RecordBatch::Make( std::shared_ptr sync_event) { DCHECK_EQ(schema->num_fields(), static_cast(columns.size())); return std::make_shared(std::move(schema), num_rows, - std::move(columns), device_type, sync_event); + std::move(columns), device_type, std::move(sync_event)); } Result> RecordBatch::MakeEmpty( @@ -664,7 +664,7 @@ Status RecordBatch::ValidateFull() const { return ValidateBatch(*this, /*full_validation=*/true); } -std::shared_ptr RecordBatch::GetSyncEvent() const { return nullptr; } +const std::shared_ptr& RecordBatch::GetSyncEvent() const { return nullptr; } DeviceAllocationType RecordBatch::device_type() const { return DeviceAllocationType::kCPU; diff --git a/cpp/src/arrow/record_batch.h b/cpp/src/arrow/record_batch.h index f2f92e6bdcf..49b40356dfe 100644 --- a/cpp/src/arrow/record_batch.h +++ b/cpp/src/arrow/record_batch.h @@ -278,7 +278,7 @@ class ARROW_EXPORT RecordBatch { /// data the returned sync event will allow for it. /// /// \return null or a Device::SyncEvent - virtual std::shared_ptr GetSyncEvent() const; + virtual const std::shared_ptr& GetSyncEvent() const; virtual DeviceAllocationType device_type() const; From bca54e248eb95467c8e4749feda57277ff66ec75 Mon Sep 17 00:00:00 2001 From: Matt Topol Date: Mon, 29 Apr 2024 17:57:21 -0400 Subject: [PATCH 27/41] linting --- cpp/src/arrow/record_batch.cc | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/cpp/src/arrow/record_batch.cc b/cpp/src/arrow/record_batch.cc index f6ebea1e11c..6dbe42a35d4 100644 --- a/cpp/src/arrow/record_batch.cc +++ b/cpp/src/arrow/record_batch.cc @@ -188,7 +188,9 @@ class SimpleRecordBatch : public RecordBatch { return RecordBatch::Validate(); } - const std::shared_ptr& GetSyncEvent() const override { return sync_event_; } + const std::shared_ptr& GetSyncEvent() const override { + return sync_event_; + } DeviceAllocationType device_type() const override { return device_type_; } @@ -222,7 +224,8 @@ std::shared_ptr RecordBatch::Make( std::shared_ptr sync_event) { DCHECK_EQ(schema->num_fields(), static_cast(columns.size())); return std::make_shared(std::move(schema), num_rows, - std::move(columns), device_type, std::move(sync_event)); + std::move(columns), device_type, + std::move(sync_event)); } Result> RecordBatch::MakeEmpty( @@ -664,7 +667,9 @@ Status RecordBatch::ValidateFull() const { return ValidateBatch(*this, /*full_validation=*/true); } -const std::shared_ptr& RecordBatch::GetSyncEvent() const { return nullptr; } +const std::shared_ptr& RecordBatch::GetSyncEvent() const { + return nullptr; +} DeviceAllocationType RecordBatch::device_type() const { return DeviceAllocationType::kCPU; From d5a6332cebbc4d20381e6e1ab4067c669833bc81 Mon Sep 17 00:00:00 2001 From: Matt Topol Date: Mon, 29 Apr 2024 18:19:56 -0400 Subject: [PATCH 28/41] use static const null_sync_event --- cpp/src/arrow/record_batch.cc | 4 +++- cpp/src/arrow/record_batch.h | 2 ++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/cpp/src/arrow/record_batch.cc b/cpp/src/arrow/record_batch.cc index 6dbe42a35d4..e8513732599 100644 --- a/cpp/src/arrow/record_batch.cc +++ b/cpp/src/arrow/record_batch.cc @@ -668,13 +668,15 @@ Status RecordBatch::ValidateFull() const { } const std::shared_ptr& RecordBatch::GetSyncEvent() const { - return nullptr; + return null_sync_event_; } DeviceAllocationType RecordBatch::device_type() const { return DeviceAllocationType::kCPU; } +const std::shared_ptr RecordBatch::null_sync_event_{nullptr}; + // ---------------------------------------------------------------------- // Base record batch reader diff --git a/cpp/src/arrow/record_batch.h b/cpp/src/arrow/record_batch.h index 49b40356dfe..09f82c0b794 100644 --- a/cpp/src/arrow/record_batch.h +++ b/cpp/src/arrow/record_batch.h @@ -290,6 +290,8 @@ class ARROW_EXPORT RecordBatch { private: ARROW_DISALLOW_COPY_AND_ASSIGN(RecordBatch); + + static const std::shared_ptr null_sync_event_; }; struct ARROW_EXPORT RecordBatchWithMetadata { From 71e152cd81852bac0124769d0ac6a002e4366957 Mon Sep 17 00:00:00 2001 From: Matt Topol Date: Tue, 30 Apr 2024 15:11:17 -0400 Subject: [PATCH 29/41] clean up templates --- cpp/src/arrow/c/bridge.cc | 95 +++++++++++++++++++++------------------ 1 file changed, 51 insertions(+), 44 deletions(-) diff --git a/cpp/src/arrow/c/bridge.cc b/cpp/src/arrow/c/bridge.cc index d4f36317f7f..aac5689a3ca 100644 --- a/cpp/src/arrow/c/bridge.cc +++ b/cpp/src/arrow/c/bridge.cc @@ -2083,8 +2083,16 @@ Status ExportDeviceStreamNext(const std::shared_ptr& src, int64_t } } -template +template class ExportedArrayStream { + using StreamTraits = + std::conditional_t; + using StreamType = typename StreamTraits::CType; + using ArrayTraits = std::conditional_t; + using ArrayType = typename ArrayTraits::CType; + public: struct PrivateData { explicit PrivateData(std::shared_ptr reader) @@ -2118,23 +2126,14 @@ class ExportedArrayStream { } void Release() { - if constexpr (std::is_same_v) { - if (ArrowDeviceArrayStreamIsReleased(stream_)) { - return; - } - } else { - if (ArrowArrayStreamIsReleased(stream_)) { - return; - } + if (StreamTraits::IsReleasedFunc(stream_)) { + return; } + DCHECK_NE(private_data(), nullptr); delete private_data(); - if constexpr (std::is_same_v) { - ArrowDeviceArrayStreamMarkReleased(stream_); - } else { - ArrowArrayStreamMarkReleased(stream_); - } + StreamTraits::MarkReleased(stream_); } // C-compatible callbacks @@ -2199,23 +2198,20 @@ class ExportedArrayStream { Status ExportRecordBatchReader(std::shared_ptr reader, struct ArrowArrayStream* out) { memset(out, 0, sizeof(struct ArrowArrayStream)); - return ExportedArrayStream::Make(std::move(reader), out); + return ExportedArrayStream::Make(std::move(reader), out); } Status ExportChunkedArray(std::shared_ptr chunked_array, struct ArrowArrayStream* out) { memset(out, 0, sizeof(struct ArrowArrayStream)); - return ExportedArrayStream::Make(std::move(chunked_array), out); + return ExportedArrayStream::Make(std::move(chunked_array), out); } Status ExportDeviceRecordBatchReader(std::shared_ptr reader, struct ArrowDeviceArrayStream* out) { memset(out, 0, sizeof(struct ArrowDeviceArrayStream)); out->device_type = static_cast(reader->device_type()); - return ExportedArrayStream::Make(std::move(reader), out); + return ExportedArrayStream::Make(std::move(reader), out); } Status ExportDeviceChunkedArray(std::shared_ptr chunked_array, @@ -2223,9 +2219,7 @@ Status ExportDeviceChunkedArray(std::shared_ptr chunked_array, struct ArrowDeviceArrayStream* out) { memset(out, 0, sizeof(struct ArrowDeviceArrayStream)); out->device_type = static_cast(device_type); - return ExportedArrayStream::Make(std::move(chunked_array), - out); + return ExportedArrayStream::Make(std::move(chunked_array), out); } ////////////////////////////////////////////////////////////////////////// @@ -2233,10 +2227,15 @@ Status ExportDeviceChunkedArray(std::shared_ptr chunked_array, namespace { -template +template class ArrayStreamReader { protected: + using StreamTraits = + std::conditional_t; using StreamType = typename StreamTraits::CType; + using ArrayTraits = std::conditional_t; using ArrayType = typename ArrayTraits::CType; public: @@ -2353,16 +2352,21 @@ class ArrayStreamReader { const DeviceMemoryMapper mapper_; }; -template +template class ArrayStreamBatchReader : public RecordBatchReader, - public ArrayStreamReader { + public ArrayStreamReader { + using StreamTraits = + std::conditional_t; using StreamType = typename StreamTraits::CType; + using ArrayTraits = std::conditional_t; using ArrayType = typename ArrayTraits::CType; public: explicit ArrayStreamBatchReader( StreamType* stream, const DeviceMemoryMapper& mapper = DefaultDeviceMemoryMapper) - : ArrayStreamReader(stream, mapper) {} + : ArrayStreamReader(stream, mapper) {} Status Init() { ARROW_ASSIGN_OR_RAISE(schema_, this->ReadSchema()); @@ -2397,15 +2401,20 @@ class ArrayStreamBatchReader : public RecordBatchReader, std::shared_ptr schema_; }; -template -class ArrayStreamArrayReader : public ArrayStreamReader { +template +class ArrayStreamArrayReader : public ArrayStreamReader { + using StreamTraits = + std::conditional_t; using StreamType = typename StreamTraits::CType; + using ArrayTraits = std::conditional_t; using ArrayType = typename ArrayTraits::CType; public: explicit ArrayStreamArrayReader( StreamType* stream, const DeviceMemoryMapper& mapper = DefaultDeviceMemoryMapper) - : ArrayStreamReader(stream, mapper) {} + : ArrayStreamReader(stream, mapper) {} Status Init() { ARROW_ASSIGN_OR_RAISE(field_, this->ReadField()); @@ -2433,7 +2442,9 @@ class ArrayStreamArrayReader : public ArrayStreamReader field_; }; -template +template > Result> ImportReader( typename StreamTraits::CType* stream, const DeviceMemoryMapper& mapper = DefaultDeviceMemoryMapper) { @@ -2441,13 +2452,14 @@ Result> ImportReader( return Status::Invalid("Cannot import released Arrow Stream"); } - auto reader = - std::make_shared>(stream, mapper); + auto reader = std::make_shared>(stream, mapper); ARROW_RETURN_NOT_OK(reader->Init()); return reader; } -template +template > Result> ImportChunked( typename StreamTraits::CType* stream, const DeviceMemoryMapper& mapper = DefaultDeviceMemoryMapper) { @@ -2455,8 +2467,7 @@ Result> ImportChunked( return Status::Invalid("Cannot import released Arrow Stream"); } - auto reader = - std::make_shared>(stream, mapper); + auto reader = std::make_shared>(stream, mapper); ARROW_RETURN_NOT_OK(reader->Init()); auto data_type = reader->data_type(); @@ -2479,26 +2490,22 @@ Result> ImportChunked( Result> ImportRecordBatchReader( struct ArrowArrayStream* stream) { - return ImportReader( - stream); + return ImportReader(stream); } Result> ImportDeviceRecordBatchReader( struct ArrowDeviceArrayStream* stream, const DeviceMemoryMapper& mapper) { - return ImportReader(stream, mapper); + return ImportReader(stream, mapper); } Result> ImportChunkedArray( struct ArrowArrayStream* stream) { - return ImportChunked( - stream); + return ImportChunked(stream); } Result> ImportDeviceChunkedArray( struct ArrowDeviceArrayStream* stream, const DeviceMemoryMapper& mapper) { - return ImportChunked(stream, mapper); + return ImportChunked(stream, mapper); } } // namespace arrow From c9e63fd3e1bf544416cf822674d37ee0c797cb05 Mon Sep 17 00:00:00 2001 From: Matt Topol Date: Tue, 30 Apr 2024 16:56:30 -0400 Subject: [PATCH 30/41] use overload instead of if constexpr --- cpp/src/arrow/c/bridge.cc | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/cpp/src/arrow/c/bridge.cc b/cpp/src/arrow/c/bridge.cc index aac5689a3ca..57d9a943470 100644 --- a/cpp/src/arrow/c/bridge.cc +++ b/cpp/src/arrow/c/bridge.cc @@ -2048,7 +2048,7 @@ Status ExportStreamNext(const std::shared_ptr& src, int64_t i // overload of this with the version for ChunkedArrays. If we removed the int64_t // from the signature despite it being unused, we wouldn't be able to leverage the // overloading in the templated exporters. -Status ExportDeviceStreamNext(const std::shared_ptr& src, int64_t i, +Status ExportStreamNext(const std::shared_ptr& src, int64_t i, struct ArrowDeviceArray* out_array) { std::shared_ptr batch; RETURN_NOT_OK(src->ReadNext(&batch)); @@ -2072,7 +2072,7 @@ Status ExportStreamNext(const std::shared_ptr& src, int64_t i, } } -Status ExportDeviceStreamNext(const std::shared_ptr& src, int64_t i, +Status ExportStreamNext(const std::shared_ptr& src, int64_t i, struct ArrowDeviceArray* out_array) { if (i >= src->num_chunks()) { // End of stream @@ -2113,11 +2113,7 @@ class ExportedArrayStream { } Status GetNext(ArrayType* out_array) { - if constexpr (std::is_same_v) { - return ExportStreamNext(reader(), next_batch_num(), out_array); - } else { - return ExportDeviceStreamNext(reader(), next_batch_num(), out_array); - } + return ExportStreamNext(reader(), next_batch_num(), out_array); } const char* GetLastError() { From df4062510f25e8c3b72c4da9b819e08c54d4ff0f Mon Sep 17 00:00:00 2001 From: Matt Topol Date: Tue, 30 Apr 2024 17:39:58 -0400 Subject: [PATCH 31/41] linting --- cpp/src/arrow/c/bridge.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/src/arrow/c/bridge.cc b/cpp/src/arrow/c/bridge.cc index 57d9a943470..25fe4b376b3 100644 --- a/cpp/src/arrow/c/bridge.cc +++ b/cpp/src/arrow/c/bridge.cc @@ -2049,7 +2049,7 @@ Status ExportStreamNext(const std::shared_ptr& src, int64_t i // from the signature despite it being unused, we wouldn't be able to leverage the // overloading in the templated exporters. Status ExportStreamNext(const std::shared_ptr& src, int64_t i, - struct ArrowDeviceArray* out_array) { + struct ArrowDeviceArray* out_array) { std::shared_ptr batch; RETURN_NOT_OK(src->ReadNext(&batch)); if (batch == nullptr) { @@ -2073,7 +2073,7 @@ Status ExportStreamNext(const std::shared_ptr& src, int64_t i, } Status ExportStreamNext(const std::shared_ptr& src, int64_t i, - struct ArrowDeviceArray* out_array) { + struct ArrowDeviceArray* out_array) { if (i >= src->num_chunks()) { // End of stream ArrowArrayMarkReleased(&out_array->array); From 9284893aa7195fd8cfa9df545c080f6aa7028f92 Mon Sep 17 00:00:00 2001 From: Matt Topol Date: Wed, 1 May 2024 10:33:34 -0400 Subject: [PATCH 32/41] Update cpp/src/arrow/c/bridge.cc Co-authored-by: Benjamin Kietzman --- cpp/src/arrow/c/bridge.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/arrow/c/bridge.cc b/cpp/src/arrow/c/bridge.cc index 25fe4b376b3..5e58b536e6c 100644 --- a/cpp/src/arrow/c/bridge.cc +++ b/cpp/src/arrow/c/bridge.cc @@ -2255,7 +2255,7 @@ class ArrayStreamReader { Status ReadNextArrayInternal(ArrayType* array) { ArrayTraits::MarkReleased(array); Status status = StatusFromCError(stream_.get_next(&stream_, array)); - if (!status.ok() && !ArrayTraits::IsReleasedFunc(array)) { + if (!status.ok()) { ArrayTraits::ReleaseFunc(array); } From 915b7233b6d4c8bc9c6f2b8c33906364ea0a86ce Mon Sep 17 00:00:00 2001 From: Matt Topol Date: Wed, 1 May 2024 10:33:56 -0400 Subject: [PATCH 33/41] Update cpp/src/arrow/c/bridge.cc Co-authored-by: Benjamin Kietzman --- cpp/src/arrow/c/bridge.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/arrow/c/bridge.cc b/cpp/src/arrow/c/bridge.cc index 5e58b536e6c..8c5e3637b6e 100644 --- a/cpp/src/arrow/c/bridge.cc +++ b/cpp/src/arrow/c/bridge.cc @@ -2336,7 +2336,7 @@ class ArrayStreamReader { } DeviceAllocationType get_device_type() const { - if constexpr (std::is_same_v) { + if constexpr (IsDevice) { return static_cast(stream_.device_type); } else { return DeviceAllocationType::kCPU; From 46f436ea45c46d10cafabe8ce2781a6838f0b37f Mon Sep 17 00:00:00 2001 From: Matt Topol Date: Wed, 1 May 2024 11:32:01 -0400 Subject: [PATCH 34/41] Update cpp/src/arrow/c/bridge_test.cc Co-authored-by: Benjamin Kietzman --- cpp/src/arrow/c/bridge_test.cc | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/cpp/src/arrow/c/bridge_test.cc b/cpp/src/arrow/c/bridge_test.cc index 6f435b5bb9c..56d6c65427c 100644 --- a/cpp/src/arrow/c/bridge_test.cc +++ b/cpp/src/arrow/c/bridge_test.cc @@ -5208,9 +5208,7 @@ TEST_F(TestArrayDeviceStreamRoundtrip, Errors) { Status::Invalid("roundtrip error example")); Roundtrip(std::move(reader), [&](const std::shared_ptr& reader) { - auto status = reader->Next().status(); - ASSERT_RAISES(Invalid, status); - ASSERT_THAT(status.message(), ::testing::HasSubstr("roundtrip error example")); + EXPECT_THAT(reader->Next(), Raises(StatusCode::Invalid, ::testing::HasSubstr("roundtrip error example"))); }); } From 0a07beefe5d16e440902187f16aacf79f3b1b3d1 Mon Sep 17 00:00:00 2001 From: Matt Topol Date: Wed, 1 May 2024 11:50:16 -0400 Subject: [PATCH 35/41] Update cpp/src/arrow/c/bridge_test.cc Co-authored-by: Benjamin Kietzman --- cpp/src/arrow/c/bridge_test.cc | 42 +++++++++++++--------------------- 1 file changed, 16 insertions(+), 26 deletions(-) diff --git a/cpp/src/arrow/c/bridge_test.cc b/cpp/src/arrow/c/bridge_test.cc index 56d6c65427c..edf038bf5bb 100644 --- a/cpp/src/arrow/c/bridge_test.cc +++ b/cpp/src/arrow/c/bridge_test.cc @@ -5213,34 +5213,24 @@ TEST_F(TestArrayDeviceStreamRoundtrip, Errors) { } TEST_F(TestArrayDeviceStreamRoundtrip, SchemaError) { - struct StreamState { - bool released = false; - - static const char* GetLastError(struct ArrowDeviceArrayStream* stream) { - return "Expected error"; - } - - static int GetSchema(struct ArrowDeviceArrayStream* stream, + struct ArrowDeviceArrayStream stream = {}; + stream.get_last_error = [](struct ArrowDeviceArrayStream* stream) { + return "Expected error"; + }; + stream.get_schema = [](struct ArrowDeviceArrayStream* stream, struct ArrowSchema* schema) { - return EIO; - } - - static int GetNext(struct ArrowDeviceArrayStream* stream, + return EIO; + }; + stream.get_next = [](struct ArrowDeviceArrayStream* stream, struct ArrowDeviceArray* array) { - return EINVAL; - } - - static void Release(struct ArrowDeviceArrayStream* stream) { - reinterpret_cast(stream->private_data)->released = true; - std::memset(stream, 0, sizeof(*stream)); - } - } state; - struct ArrowDeviceArrayStream stream = {}; - stream.get_last_error = &StreamState::GetLastError; - stream.get_schema = &StreamState::GetSchema; - stream.get_next = &StreamState::GetNext; - stream.release = &StreamState::Release; - stream.private_data = &state; + return EINVAL; + }; + stream.release = [](struct ArrowDeviceArrayStream* stream) { + *static_cast(stream->private_data) = true; + std::memset(stream, 0, sizeof(*stream)); + }; + bool released = false; + stream.private_data = &released; EXPECT_RAISES_WITH_MESSAGE_THAT(IOError, ::testing::HasSubstr("Expected error"), ImportDeviceRecordBatchReader(&stream)); From 7b89fbfec9b1e850edb61aae1956908915b85c4a Mon Sep 17 00:00:00 2001 From: Matt Topol Date: Wed, 1 May 2024 11:52:54 -0400 Subject: [PATCH 36/41] updates from feedback and lint --- cpp/src/arrow/array/array_base.h | 6 ++++++ cpp/src/arrow/array/data.h | 8 ++++++++ cpp/src/arrow/c/bridge_test.cc | 13 +++++-------- cpp/src/arrow/record_batch.cc | 10 ---------- cpp/src/arrow/record_batch.h | 6 ++---- 5 files changed, 21 insertions(+), 22 deletions(-) diff --git a/cpp/src/arrow/array/array_base.h b/cpp/src/arrow/array/array_base.h index 6a7ee492e40..716ae072206 100644 --- a/cpp/src/arrow/array/array_base.h +++ b/cpp/src/arrow/array/array_base.h @@ -224,6 +224,12 @@ class ARROW_EXPORT Array { /// \return Status Status ValidateFull() const; + /// \brief Return the device_type that this array's data is allocated on + /// + /// This just delegates to calling device_type on the underlying ArrayData + /// object which backs this Array. + /// + /// \return DeviceAllocationType DeviceAllocationType device_type() const { return data_->device_type(); } protected: diff --git a/cpp/src/arrow/array/data.h b/cpp/src/arrow/array/data.h index 55b8e7c9049..7689dbd15b6 100644 --- a/cpp/src/arrow/array/data.h +++ b/cpp/src/arrow/array/data.h @@ -358,6 +358,14 @@ struct ARROW_EXPORT ArrayData { /// \see GetNullCount int64_t ComputeLogicalNullCount() const; + /// \brief Returns the device_type of the underlying buffers and children + /// + /// If there are no buffers in this ArrayData object, it just returns + /// DeviceAllocationType::kCPU as a default. We also assume that all buffers + /// should be allocated on the same device type and perform DCHECKs to confirm + /// this in debug mode. + /// + /// \return DeviceAllocationType DeviceAllocationType device_type() const; std::shared_ptr type; diff --git a/cpp/src/arrow/c/bridge_test.cc b/cpp/src/arrow/c/bridge_test.cc index edf038bf5bb..0ecfb5a9577 100644 --- a/cpp/src/arrow/c/bridge_test.cc +++ b/cpp/src/arrow/c/bridge_test.cc @@ -5208,7 +5208,8 @@ TEST_F(TestArrayDeviceStreamRoundtrip, Errors) { Status::Invalid("roundtrip error example")); Roundtrip(std::move(reader), [&](const std::shared_ptr& reader) { - EXPECT_THAT(reader->Next(), Raises(StatusCode::Invalid, ::testing::HasSubstr("roundtrip error example"))); + EXPECT_THAT(reader->Next(), Raises(StatusCode::Invalid, + ::testing::HasSubstr("roundtrip error example"))); }); } @@ -5218,13 +5219,9 @@ TEST_F(TestArrayDeviceStreamRoundtrip, SchemaError) { return "Expected error"; }; stream.get_schema = [](struct ArrowDeviceArrayStream* stream, - struct ArrowSchema* schema) { - return EIO; - }; + struct ArrowSchema* schema) { return EIO; }; stream.get_next = [](struct ArrowDeviceArrayStream* stream, - struct ArrowDeviceArray* array) { - return EINVAL; - }; + struct ArrowDeviceArray* array) { return EINVAL; }; stream.release = [](struct ArrowDeviceArrayStream* stream) { *static_cast(stream->private_data) = true; std::memset(stream, 0, sizeof(*stream)); @@ -5234,7 +5231,7 @@ TEST_F(TestArrayDeviceStreamRoundtrip, SchemaError) { EXPECT_RAISES_WITH_MESSAGE_THAT(IOError, ::testing::HasSubstr("Expected error"), ImportDeviceRecordBatchReader(&stream)); - ASSERT_TRUE(state.released); + ASSERT_TRUE(released); } TEST_F(TestArrayDeviceStreamRoundtrip, ChunkedArrayRoundtrip) { diff --git a/cpp/src/arrow/record_batch.cc b/cpp/src/arrow/record_batch.cc index e8513732599..351f72f5236 100644 --- a/cpp/src/arrow/record_batch.cc +++ b/cpp/src/arrow/record_batch.cc @@ -667,16 +667,6 @@ Status RecordBatch::ValidateFull() const { return ValidateBatch(*this, /*full_validation=*/true); } -const std::shared_ptr& RecordBatch::GetSyncEvent() const { - return null_sync_event_; -} - -DeviceAllocationType RecordBatch::device_type() const { - return DeviceAllocationType::kCPU; -} - -const std::shared_ptr RecordBatch::null_sync_event_{nullptr}; - // ---------------------------------------------------------------------- // Base record batch reader diff --git a/cpp/src/arrow/record_batch.h b/cpp/src/arrow/record_batch.h index 09f82c0b794..b03cbf2251f 100644 --- a/cpp/src/arrow/record_batch.h +++ b/cpp/src/arrow/record_batch.h @@ -278,9 +278,9 @@ class ARROW_EXPORT RecordBatch { /// data the returned sync event will allow for it. /// /// \return null or a Device::SyncEvent - virtual const std::shared_ptr& GetSyncEvent() const; + virtual const std::shared_ptr& GetSyncEvent() const = 0; - virtual DeviceAllocationType device_type() const; + virtual DeviceAllocationType device_type() const = 0; protected: RecordBatch(const std::shared_ptr& schema, int64_t num_rows); @@ -290,8 +290,6 @@ class ARROW_EXPORT RecordBatch { private: ARROW_DISALLOW_COPY_AND_ASSIGN(RecordBatch); - - static const std::shared_ptr null_sync_event_; }; struct ARROW_EXPORT RecordBatchWithMetadata { From 4915884e804a924e133e47391d2ffa8d85c5cb9f Mon Sep 17 00:00:00 2001 From: Matt Topol Date: Wed, 1 May 2024 15:14:22 -0400 Subject: [PATCH 37/41] fix the python failure --- cpp/src/arrow/array/util.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/src/arrow/array/util.cc b/cpp/src/arrow/array/util.cc index bdba92c9a11..fc4910651a4 100644 --- a/cpp/src/arrow/array/util.cc +++ b/cpp/src/arrow/array/util.cc @@ -547,8 +547,8 @@ class NullArrayFactory { return Status::OK(); } - Status Visit(const StructType& type) { - for (int i = 0; i < type_->num_fields(); ++i) { + Status Visit(const StructType& type) { + for (int i = 0; i < type.num_fields(); ++i) { ARROW_ASSIGN_OR_RAISE(out_->child_data[i], CreateChild(type, i, length_)); } return Status::OK(); From 3dc2d78d6edb0b37a6680d143227f708a19a20e4 Mon Sep 17 00:00:00 2001 From: Matt Topol Date: Wed, 1 May 2024 15:14:51 -0400 Subject: [PATCH 38/41] lint --- cpp/src/arrow/array/util.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/arrow/array/util.cc b/cpp/src/arrow/array/util.cc index fc4910651a4..41cd6a1c0b2 100644 --- a/cpp/src/arrow/array/util.cc +++ b/cpp/src/arrow/array/util.cc @@ -547,7 +547,7 @@ class NullArrayFactory { return Status::OK(); } - Status Visit(const StructType& type) { + Status Visit(const StructType& type) { for (int i = 0; i < type.num_fields(); ++i) { ARROW_ASSIGN_OR_RAISE(out_->child_data[i], CreateChild(type, i, length_)); } From 331ce10032f59c77ffecef2a2a9c12a1e16cdb25 Mon Sep 17 00:00:00 2001 From: Matt Topol Date: Mon, 6 May 2024 11:28:16 -0400 Subject: [PATCH 39/41] add debug check to confirm device_type in ArrayData constructor --- cpp/src/arrow/array/data.h | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/cpp/src/arrow/array/data.h b/cpp/src/arrow/array/data.h index 7689dbd15b6..ff5702f3a7b 100644 --- a/cpp/src/arrow/array/data.h +++ b/cpp/src/arrow/array/data.h @@ -101,6 +101,11 @@ struct ARROW_EXPORT ArrayData { int64_t null_count = kUnknownNullCount, int64_t offset = 0) : ArrayData(std::move(type), length, null_count, offset) { this->buffers = std::move(buffers); +#ifndef NDEBUG + // in debug mode, call the `device_type` function to trigger + // the DCHECKs that validate all the buffers are on the same device + ARROW_UNUSED(this->device_type()); +#endif } ArrayData(std::shared_ptr type, int64_t length, @@ -110,6 +115,12 @@ struct ARROW_EXPORT ArrayData { : ArrayData(std::move(type), length, null_count, offset) { this->buffers = std::move(buffers); this->child_data = std::move(child_data); +#ifndef NDEBUG + // in debug mode, call the `device_type` function to trigger + // the DCHECKs that validate all the buffers (including children) + // are on the same device + ARROW_UNUSED(this->device_type()); +#endif } static std::shared_ptr Make(std::shared_ptr type, int64_t length, From 9b14645f9c869004fe35345b669af925d2a3e01e Mon Sep 17 00:00:00 2001 From: Matt Topol Date: Thu, 9 May 2024 14:41:21 -0400 Subject: [PATCH 40/41] fix lint --- cpp/src/arrow/array/data.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/arrow/array/data.h b/cpp/src/arrow/array/data.h index ff5702f3a7b..0c49f36229a 100644 --- a/cpp/src/arrow/array/data.h +++ b/cpp/src/arrow/array/data.h @@ -117,7 +117,7 @@ struct ARROW_EXPORT ArrayData { this->child_data = std::move(child_data); #ifndef NDEBUG // in debug mode, call the `device_type` function to trigger - // the DCHECKs that validate all the buffers (including children) + // the DCHECKs that validate all the buffers (including children) // are on the same device ARROW_UNUSED(this->device_type()); #endif From 381bc7d1d83940ecb9a2d4c68b04b450ede1c8d1 Mon Sep 17 00:00:00 2001 From: Matt Topol Date: Thu, 9 May 2024 15:17:01 -0400 Subject: [PATCH 41/41] add a null check for children in device_type --- cpp/src/arrow/array/data.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/cpp/src/arrow/array/data.cc b/cpp/src/arrow/array/data.cc index ada2668630e..76a43521394 100644 --- a/cpp/src/arrow/array/data.cc +++ b/cpp/src/arrow/array/data.cc @@ -241,6 +241,7 @@ DeviceAllocationType ArrayData::device_type() const { } for (const auto& child : child_data) { + if (!child) continue; if (type == 0) { type = static_cast(child->device_type()); } else {