diff --git a/cpp/src/jni/dataset/jni_util.cc b/cpp/src/jni/dataset/jni_util.cc index 113669a4cf6..50613088ecb 100644 --- a/cpp/src/jni/dataset/jni_util.cc +++ b/cpp/src/jni/dataset/jni_util.cc @@ -16,18 +16,26 @@ // under the License. #include "jni/dataset/jni_util.h" - +#include "arrow/ipc/metadata_internal.h" +#include "arrow/util/base64.h" +#include "arrow/util/key_value_metadata.h" #include "arrow/util/logging.h" +#include #include +#include + namespace arrow { + +namespace flatbuf = org::apache::arrow::flatbuf; + namespace dataset { namespace jni { class ReservationListenableMemoryPool::Impl { public: - explicit Impl(arrow::MemoryPool* pool, std::shared_ptr listener, + explicit Impl(MemoryPool* pool, std::shared_ptr listener, int64_t block_size) : pool_(pool), listener_(listener), @@ -35,17 +43,17 @@ class ReservationListenableMemoryPool::Impl { blocks_reserved_(0), bytes_reserved_(0) {} - arrow::Status Allocate(int64_t size, uint8_t** out) { + Status Allocate(int64_t size, uint8_t** out) { RETURN_NOT_OK(UpdateReservation(size)); - arrow::Status error = pool_->Allocate(size, out); + Status error = pool_->Allocate(size, out); if (!error.ok()) { RETURN_NOT_OK(UpdateReservation(-size)); return error; } - return arrow::Status::OK(); + return Status::OK(); } - arrow::Status Reallocate(int64_t old_size, int64_t new_size, uint8_t** ptr) { + Status Reallocate(int64_t old_size, int64_t new_size, uint8_t** ptr) { bool reserved = false; int64_t diff = new_size - old_size; if (new_size >= old_size) { @@ -54,7 +62,7 @@ class ReservationListenableMemoryPool::Impl { RETURN_NOT_OK(UpdateReservation(diff)); reserved = true; } - arrow::Status error = pool_->Reallocate(old_size, new_size, ptr); + Status error = pool_->Reallocate(old_size, new_size, ptr); if (!error.ok()) { if (reserved) { // roll back reservations on error @@ -66,13 +74,13 @@ class ReservationListenableMemoryPool::Impl { // otherwise (e.g. new_size < old_size), make updates after calling underlying pool RETURN_NOT_OK(UpdateReservation(diff)); } - return arrow::Status::OK(); + return Status::OK(); } void Free(uint8_t* buffer, int64_t size) { pool_->Free(buffer, size); // FIXME: See ARROW-11143, currently method ::Free doesn't allow Status return - arrow::Status s = UpdateReservation(-size); + Status s = UpdateReservation(-size); if (!s.ok()) { ARROW_LOG(FATAL) << "Failed to update reservation while freeing bytes: " << s.message(); @@ -80,17 +88,17 @@ class ReservationListenableMemoryPool::Impl { } } - arrow::Status UpdateReservation(int64_t diff) { + Status UpdateReservation(int64_t diff) { int64_t granted = Reserve(diff); if (granted == 0) { - return arrow::Status::OK(); + return Status::OK(); } if (granted < 0) { RETURN_NOT_OK(listener_->OnRelease(-granted)); - return arrow::Status::OK(); + return Status::OK(); } RETURN_NOT_OK(listener_->OnReservation(granted)); - return arrow::Status::OK(); + return Status::OK(); } int64_t Reserve(int64_t diff) { @@ -117,7 +125,7 @@ class ReservationListenableMemoryPool::Impl { std::shared_ptr get_listener() { return listener_; } private: - arrow::MemoryPool* pool_; + MemoryPool* pool_; std::shared_ptr listener_; int64_t block_size_; int64_t blocks_reserved_; @@ -125,18 +133,40 @@ class ReservationListenableMemoryPool::Impl { std::mutex mutex_; }; +/// \brief Buffer implementation that binds to a +/// Java buffer reference. Java buffer's release +/// method will be called once when being destructed. +class JavaAllocatedBuffer : public Buffer { + public: + JavaAllocatedBuffer(JNIEnv* env, jobject cleaner_ref, jmethodID cleaner_method_ref, + uint8_t* buffer, int32_t len) + : Buffer(buffer, len), + env_(env), + cleaner_ref_(cleaner_ref), + cleaner_method_ref_(cleaner_method_ref) {} + + ~JavaAllocatedBuffer() override { + env_->CallVoidMethod(cleaner_ref_, cleaner_method_ref_); + env_->DeleteGlobalRef(cleaner_ref_); + } + + private: + JNIEnv* env_; + jobject cleaner_ref_; + jmethodID cleaner_method_ref_; +}; + ReservationListenableMemoryPool::ReservationListenableMemoryPool( MemoryPool* pool, std::shared_ptr listener, int64_t block_size) { impl_.reset(new Impl(pool, listener, block_size)); } -arrow::Status ReservationListenableMemoryPool::Allocate(int64_t size, uint8_t** out) { +Status ReservationListenableMemoryPool::Allocate(int64_t size, uint8_t** out) { return impl_->Allocate(size, out); } -arrow::Status ReservationListenableMemoryPool::Reallocate(int64_t old_size, - int64_t new_size, - uint8_t** ptr) { +Status ReservationListenableMemoryPool::Reallocate(int64_t old_size, int64_t new_size, + uint8_t** ptr) { return impl_->Reallocate(old_size, new_size, ptr); } @@ -162,6 +192,15 @@ std::shared_ptr ReservationListenableMemoryPool::get_listen ReservationListenableMemoryPool::~ReservationListenableMemoryPool() {} +Status CheckException(JNIEnv* env) { + if (env->ExceptionCheck()) { + env->ExceptionDescribe(); + env->ExceptionClear(); + return Status::Invalid("Error during calling Java code from native code"); + } + return Status::OK(); +} + jclass CreateGlobalClassReference(JNIEnv* env, const char* class_name) { jclass local_class = env->FindClass(class_name); jclass global_class = (jclass)env->NewGlobalRef(local_class); @@ -169,24 +208,24 @@ jclass CreateGlobalClassReference(JNIEnv* env, const char* class_name) { return global_class; } -arrow::Result GetMethodID(JNIEnv* env, jclass this_class, const char* name, - const char* sig) { +Result GetMethodID(JNIEnv* env, jclass this_class, const char* name, + const char* sig) { jmethodID ret = env->GetMethodID(this_class, name, sig); if (ret == nullptr) { std::string error_message = "Unable to find method " + std::string(name) + " within signature" + std::string(sig); - return arrow::Status::Invalid(error_message); + return Status::Invalid(error_message); } return ret; } -arrow::Result GetStaticMethodID(JNIEnv* env, jclass this_class, - const char* name, const char* sig) { +Result GetStaticMethodID(JNIEnv* env, jclass this_class, const char* name, + const char* sig) { jmethodID ret = env->GetStaticMethodID(this_class, name, sig); if (ret == nullptr) { std::string error_message = "Unable to find static method " + std::string(name) + " within signature" + std::string(sig); - return arrow::Status::Invalid(error_message); + return Status::Invalid(error_message); } return ret; } @@ -211,11 +250,9 @@ std::vector ToStringVector(JNIEnv* env, jobjectArray& str_array) { return vector; } -arrow::Result ToSchemaByteArray(JNIEnv* env, - std::shared_ptr schema) { - ARROW_ASSIGN_OR_RAISE( - std::shared_ptr buffer, - arrow::ipc::SerializeSchema(*schema, arrow::default_memory_pool())) +Result ToSchemaByteArray(JNIEnv* env, std::shared_ptr schema) { + ARROW_ASSIGN_OR_RAISE(std::shared_ptr buffer, + ipc::SerializeSchema(*schema, default_memory_pool())) jbyteArray out = env->NewByteArray(buffer->size()); auto src = reinterpret_cast(buffer->data()); @@ -223,20 +260,173 @@ arrow::Result ToSchemaByteArray(JNIEnv* env, return out; } -arrow::Result> FromSchemaByteArray( - JNIEnv* env, jbyteArray schemaBytes) { - arrow::ipc::DictionaryMemo in_memo; +Result> FromSchemaByteArray(JNIEnv* env, jbyteArray schemaBytes) { + ipc::DictionaryMemo in_memo; int schemaBytes_len = env->GetArrayLength(schemaBytes); jbyte* schemaBytes_data = env->GetByteArrayElements(schemaBytes, nullptr); - auto serialized_schema = std::make_shared( + auto serialized_schema = std::make_shared( reinterpret_cast(schemaBytes_data), schemaBytes_len); - arrow::io::BufferReader buf_reader(serialized_schema); - ARROW_ASSIGN_OR_RAISE(std::shared_ptr schema, - arrow::ipc::ReadSchema(&buf_reader, &in_memo)) + io::BufferReader buf_reader(serialized_schema); + ARROW_ASSIGN_OR_RAISE(std::shared_ptr schema, + ipc::ReadSchema(&buf_reader, &in_memo)) env->ReleaseByteArrayElements(schemaBytes, schemaBytes_data, JNI_ABORT); return schema; } +Status SetMetadataForSingleField(std::shared_ptr array_data, + std::vector& nodes_meta, + std::vector& buffers_meta, + std::shared_ptr& custom_metadata) { + nodes_meta.push_back({array_data->length, array_data->null_count, 0L}); + + for (size_t i = 0; i < array_data->buffers.size(); i++) { + auto buffer = array_data->buffers.at(i); + uint8_t* data = nullptr; + int64_t size = 0; + if (buffer != nullptr) { + data = (uint8_t*)buffer->data(); + size = buffer->size(); + } + ipc::internal::BufferMetadata buffer_metadata{}; + buffer_metadata.offset = reinterpret_cast(data); + buffer_metadata.length = size; + // store buffer refs into custom metadata + jlong ref = CreateNativeRef(buffer); + custom_metadata->Append( + "NATIVE_BUFFER_REF_" + std::to_string(i), + util::base64_encode(reinterpret_cast(&ref), sizeof(ref))); + buffers_meta.push_back(buffer_metadata); + } + + auto children_data = array_data->child_data; + for (const auto& child_data : children_data) { + RETURN_NOT_OK( + SetMetadataForSingleField(child_data, nodes_meta, buffers_meta, custom_metadata)); + } + return Status::OK(); +} + +Result> SerializeMetadata(const RecordBatch& batch, + const ipc::IpcWriteOptions& options) { + std::vector nodes; + std::vector buffers; + std::shared_ptr custom_metadata = + std::make_shared(); + for (const auto& column : batch.columns()) { + auto array_data = column->data(); + RETURN_NOT_OK(SetMetadataForSingleField(array_data, nodes, buffers, custom_metadata)); + } + std::shared_ptr meta_buffer; + RETURN_NOT_OK(ipc::internal::WriteRecordBatchMessage( + batch.num_rows(), 0L, custom_metadata, nodes, buffers, options, &meta_buffer)); + // no message body is needed for JNI serialization/deserialization + int32_t meta_length = -1; + ARROW_ASSIGN_OR_RAISE(auto stream, io::BufferOutputStream::Create(1024L)); + RETURN_NOT_OK(ipc::WriteMessage(*meta_buffer, options, stream.get(), &meta_length)); + return stream->Finish(); +} + +Result SerializeUnsafeFromNative(JNIEnv* env, + const std::shared_ptr& batch) { + ARROW_ASSIGN_OR_RAISE(auto meta_buffer, + SerializeMetadata(*batch, ipc::IpcWriteOptions::Defaults())); + + jbyteArray ret = env->NewByteArray(meta_buffer->size()); + auto src = reinterpret_cast(meta_buffer->data()); + env->SetByteArrayRegion(ret, 0, meta_buffer->size(), src); + return ret; +} + +Result> MakeArrayData( + JNIEnv* env, const flatbuf::RecordBatch& batch_meta, + const std::shared_ptr& custom_metadata, + const std::shared_ptr& type, int32_t* field_offset, + int32_t* buffer_offset) { + const org::apache::arrow::flatbuf::FieldNode* field = + batch_meta.nodes()->Get((*field_offset)++); + int32_t own_buffer_size = static_cast(type->layout().buffers.size()); + std::vector> buffers; + for (int32_t i = *buffer_offset; i < *buffer_offset + own_buffer_size; i++) { + const org::apache::arrow::flatbuf::Buffer* java_managed_buffer = + batch_meta.buffers()->Get(i); + const std::string& cleaner_object_ref_base64 = + util::base64_decode(custom_metadata->value(i * 2)); + const std::string& cleaner_method_ref_base64 = + util::base64_decode(custom_metadata->value(i * 2 + 1)); + const auto* cleaner_object_ref = + reinterpret_cast(cleaner_object_ref_base64.data()); + const auto* cleaner_method_ref = + reinterpret_cast(cleaner_method_ref_base64.data()); + auto buffer = std::make_shared( + env, reinterpret_cast(*cleaner_object_ref), + reinterpret_cast(*cleaner_method_ref), + reinterpret_cast(java_managed_buffer->offset()), + java_managed_buffer->length()); + buffers.push_back(buffer); + } + (*buffer_offset) += own_buffer_size; + if (type->num_fields() == 0) { + return ArrayData::Make(type, field->length(), buffers, field->null_count()); + } + std::vector> children_array_data; + for (const auto& child_field : type->fields()) { + ARROW_ASSIGN_OR_RAISE(auto child_array_data, + MakeArrayData(env, batch_meta, custom_metadata, + child_field->type(), field_offset, buffer_offset)) + children_array_data.push_back(child_array_data); + } + return ArrayData::Make(type, field->length(), buffers, children_array_data, + field->null_count()); +} + +Result> DeserializeUnsafeFromJava( + JNIEnv* env, std::shared_ptr schema, jbyteArray byte_array) { + int bytes_len = env->GetArrayLength(byte_array); + jbyte* byte_data = env->GetByteArrayElements(byte_array, nullptr); + io::BufferReader meta_reader(reinterpret_cast(byte_data), + static_cast(bytes_len)); + ARROW_ASSIGN_OR_RAISE(auto meta_message, ipc::ReadMessage(&meta_reader)) + auto meta_buffer = meta_message->metadata(); + auto custom_metadata = meta_message->custom_metadata(); + const flatbuf::Message* flat_meta = nullptr; + RETURN_NOT_OK( + ipc::internal::VerifyMessage(meta_buffer->data(), meta_buffer->size(), &flat_meta)); + auto batch_meta = flat_meta->header_as_RecordBatch(); + + // Record batch serialized from java should have two ref IDs per buffer: cleaner object + // ref and cleaner method ref. The refs are originally of 64bit integer type and encoded + // within base64. + if (custom_metadata->size() != + static_cast(batch_meta->buffers()->size() * 2)) { + return Status::SerializationError( + "Buffer count mismatch between metadata and Java managed refs"); + } + + std::vector> columns_array_data; + int32_t field_offset = 0; + int32_t buffer_offset = 0; + for (int32_t i = 0; i < schema->num_fields(); i++) { + auto field = schema->field(i); + ARROW_ASSIGN_OR_RAISE(auto column_array_data, + MakeArrayData(env, *batch_meta, custom_metadata, field->type(), + &field_offset, &buffer_offset)) + columns_array_data.push_back(column_array_data); + } + if (field_offset != static_cast(batch_meta->nodes()->size())) { + return Status::SerializationError( + "Deserialization failed: Field count is not " + "as expected based on type layout"); + } + if (buffer_offset != static_cast(batch_meta->buffers()->size())) { + return Status::SerializationError( + "Deserialization failed: Buffer count is not " + "as expected based on type layout"); + } + int64_t length = batch_meta->length(); + env->ReleaseByteArrayElements(byte_array, byte_data, JNI_ABORT); + return RecordBatch::Make(schema, length, columns_array_data); +} + } // namespace jni } // namespace dataset } // namespace arrow diff --git a/cpp/src/jni/dataset/jni_util.h b/cpp/src/jni/dataset/jni_util.h index c76033ae633..8c83d5e17a1 100644 --- a/cpp/src/jni/dataset/jni_util.h +++ b/cpp/src/jni/dataset/jni_util.h @@ -30,23 +30,33 @@ namespace arrow { namespace dataset { namespace jni { +Status CheckException(JNIEnv* env); + jclass CreateGlobalClassReference(JNIEnv* env, const char* class_name); -arrow::Result GetMethodID(JNIEnv* env, jclass this_class, const char* name, - const char* sig); +Result GetMethodID(JNIEnv* env, jclass this_class, const char* name, + const char* sig); -arrow::Result GetStaticMethodID(JNIEnv* env, jclass this_class, - const char* name, const char* sig); +Result GetStaticMethodID(JNIEnv* env, jclass this_class, const char* name, + const char* sig); std::string JStringToCString(JNIEnv* env, jstring string); std::vector ToStringVector(JNIEnv* env, jobjectArray& str_array); -arrow::Result ToSchemaByteArray(JNIEnv* env, - std::shared_ptr schema); +Result ToSchemaByteArray(JNIEnv* env, std::shared_ptr schema); + +Result> FromSchemaByteArray(JNIEnv* env, jbyteArray schemaBytes); + +/// \brief Serialize arrow::RecordBatch to jbyteArray (Java byte array byte[]). For +/// letting Java code manage lifecycles of buffers in the input batch, shared pointer IDs +/// pointing to the buffers are serialized into buffer metadata. +Result SerializeUnsafeFromNative(JNIEnv* env, + const std::shared_ptr& batch); -arrow::Result> FromSchemaByteArray(JNIEnv* env, - jbyteArray schemaBytes); +/// \brief Deserialize jbyteArray (Java byte array byte[]) to arrow::RecordBatch. +Result> DeserializeUnsafeFromJava( + JNIEnv* env, std::shared_ptr schema, jbyteArray byte_array); /// \brief Create a new shared_ptr on heap from shared_ptr t to prevent /// the managed object from being garbage-collected. @@ -86,8 +96,8 @@ class ReservationListener { public: virtual ~ReservationListener() = default; - virtual arrow::Status OnReservation(int64_t size) = 0; - virtual arrow::Status OnRelease(int64_t size) = 0; + virtual Status OnReservation(int64_t size) = 0; + virtual Status OnRelease(int64_t size) = 0; protected: ReservationListener() = default; @@ -98,7 +108,7 @@ class ReservationListener { /// have to be subject to another "virtual" resource manager, which just tracks or /// limits number of bytes of application's overall memory usage. The underlying /// memory pool will still be responsible for actual malloc/free operations. -class ReservationListenableMemoryPool : public arrow::MemoryPool { +class ReservationListenableMemoryPool : public MemoryPool { public: /// \brief Constructor. /// @@ -111,9 +121,9 @@ class ReservationListenableMemoryPool : public arrow::MemoryPool { ~ReservationListenableMemoryPool(); - arrow::Status Allocate(int64_t size, uint8_t** out) override; + Status Allocate(int64_t size, uint8_t** out) override; - arrow::Status Reallocate(int64_t old_size, int64_t new_size, uint8_t** ptr) override; + Status Reallocate(int64_t old_size, int64_t new_size, uint8_t** ptr) override; void Free(uint8_t* buffer, int64_t size) override; diff --git a/cpp/src/jni/dataset/jni_wrapper.cc b/cpp/src/jni/dataset/jni_wrapper.cc index d61fb3f964e..8f939e0d74e 100644 --- a/cpp/src/jni/dataset/jni_wrapper.cc +++ b/cpp/src/jni/dataset/jni_wrapper.cc @@ -19,6 +19,7 @@ #include "arrow/array.h" #include "arrow/dataset/api.h" +#include "arrow/dataset/dataset_internal.h" #include "arrow/dataset/file_base.h" #include "arrow/filesystem/localfs.h" #include "arrow/ipc/api.h" @@ -36,14 +37,11 @@ jclass illegal_access_exception_class; jclass illegal_argument_exception_class; jclass runtime_exception_class; -jclass record_batch_handle_class; -jclass record_batch_handle_field_class; -jclass record_batch_handle_buffer_class; +jclass serialized_record_batch_iterator_class; jclass java_reservation_listener_class; -jmethodID record_batch_handle_constructor; -jmethodID record_batch_handle_field_constructor; -jmethodID record_batch_handle_buffer_constructor; +jmethodID serialized_record_batch_iterator_hasNext; +jmethodID serialized_record_batch_iterator_next; jmethodID reserve_memory_method; jmethodID unreserve_memory_method; @@ -99,11 +97,7 @@ class ReserveFromJava : public arrow::dataset::jni::ReservationListener { return arrow::Status::Invalid("JNIEnv was not attached to current thread"); } env->CallObjectMethod(java_reservation_listener_, reserve_memory_method, size); - if (env->ExceptionCheck()) { - env->ExceptionDescribe(); - env->ExceptionClear(); - return arrow::Status::Invalid("Error calling Java side reservation listener"); - } + RETURN_NOT_OK(arrow::dataset::jni::CheckException(env)); return arrow::Status::OK(); } @@ -113,11 +107,7 @@ class ReserveFromJava : public arrow::dataset::jni::ReservationListener { return arrow::Status::Invalid("JNIEnv was not attached to current thread"); } env->CallObjectMethod(java_reservation_listener_, unreserve_memory_method, size); - if (env->ExceptionCheck()) { - env->ExceptionDescribe(); - env->ExceptionClear(); - return arrow::Status::Invalid("Error calling Java side reservation listener"); - } + RETURN_NOT_OK(arrow::dataset::jni::CheckException(env)); return arrow::Status::OK(); } @@ -166,6 +156,126 @@ class DisposableScannerAdaptor { } }; +/// \brief Simple scan task implementation that is constructed directly +/// from a record batch iterator (and its corresponding fragment). +class SimpleIteratorTask : public arrow::dataset::ScanTask { + public: + SimpleIteratorTask(std::shared_ptr options, + std::shared_ptr fragment, + arrow::RecordBatchIterator itr) + : ScanTask(options, fragment) { + this->itr_ = std::move(itr); + } + + static arrow::Result> Make( + arrow::RecordBatchIterator itr, + std::shared_ptr options, + std::shared_ptr fragment) { + return std::make_shared(options, fragment, std::move(itr)); + } + + arrow::Result Execute() override { + if (used_) { + return arrow::Status::Invalid( + "SimpleIteratorFragment is disposable and" + "already scanned"); + } + used_ = true; + return std::move(itr_); + } + + private: + arrow::RecordBatchIterator itr_; + bool used_ = false; +}; + +/// \brief Simple fragment implementation that is constructed directly +/// from a record batch iterator. +class SimpleIteratorFragment : public arrow::dataset::Fragment { + public: + explicit SimpleIteratorFragment(arrow::RecordBatchIterator itr) + : arrow::dataset::Fragment() { + itr_ = std::move(itr); + } + + static arrow::Result> Make( + arrow::RecordBatchIterator itr) { + return std::make_shared(std::move(itr)); + } + + arrow::Result ScanBatchesAsync( + const std::shared_ptr& options) override { + return arrow::Status::NotImplemented("Aysnc scan not supported"); + } + + arrow::Result Scan( + std::shared_ptr options) override { + if (used_) { + return arrow::Status::Invalid( + "SimpleIteratorFragment is disposable and" + "already scanned"); + } + used_ = true; + ARROW_ASSIGN_OR_RAISE( + auto task, SimpleIteratorTask::Make(std::move(itr_), options, shared_from_this())) + return arrow::MakeVectorIterator>({task}); + } + + std::string type_name() const override { return "simple_iterator"; } + + arrow::Result> ReadPhysicalSchemaImpl() override { + return arrow::Status::NotImplemented("No physical schema is readable"); + } + + private: + arrow::RecordBatchIterator itr_; + bool used_ = false; +}; + +arrow::Result> FromBytes( + JNIEnv* env, std::shared_ptr schema, jbyteArray bytes) { + ARROW_ASSIGN_OR_RAISE( + std::shared_ptr batch, + arrow::dataset::jni::DeserializeUnsafeFromJava(env, schema, bytes)) + return batch; +} + +/// \brief Create scanner that scans over Java dataset API's components. +/// +/// Currently, we use a NativeSerializedRecordBatchIterator as the underlying +/// Java object to do scanning. Which means, only one single task will +/// be produced from C++ code. +arrow::Result> MakeJavaDatasetScanner( + JavaVM* vm, jobject java_serialized_record_batch_iterator, + std::shared_ptr schema) { + arrow::RecordBatchIterator itr = arrow::MakeFunctionIterator( + [vm, java_serialized_record_batch_iterator, + schema]() -> arrow::Result> { + JNIEnv* env; + if (vm->GetEnv(reinterpret_cast(&env), JNI_VERSION) != JNI_OK) { + return arrow::Status::Invalid("JNIEnv was not attached to current thread"); + } + if (!env->CallBooleanMethod(java_serialized_record_batch_iterator, + serialized_record_batch_iterator_hasNext)) { + return nullptr; // stream ended + } + auto bytes = (jbyteArray)env->CallObjectMethod( + java_serialized_record_batch_iterator, serialized_record_batch_iterator_next); + RETURN_NOT_OK(arrow::dataset::jni::CheckException(env)); + ARROW_ASSIGN_OR_RAISE(auto batch, FromBytes(env, schema, bytes)); + return batch; + }); + + ARROW_ASSIGN_OR_RAISE(auto fragment, SimpleIteratorFragment::Make(std::move(itr))) + + arrow::dataset::ScannerBuilder scanner_builder( + std::move(schema), fragment, std::make_shared()); + // Use default memory pool is enough as native allocation is ideally + // not being called during scanning Java-based fragments. + RETURN_NOT_OK(scanner_builder.Pool(arrow::default_memory_pool())); + return scanner_builder.Finish(); +} + } // namespace using arrow::dataset::jni::CreateGlobalClassReference; @@ -205,33 +315,18 @@ jint JNI_OnLoad(JavaVM* vm, void* reserved) { runtime_exception_class = CreateGlobalClassReference(env, "Ljava/lang/RuntimeException;"); - record_batch_handle_class = - CreateGlobalClassReference(env, - "Lorg/apache/arrow/" - "dataset/jni/NativeRecordBatchHandle;"); - record_batch_handle_field_class = - CreateGlobalClassReference(env, - "Lorg/apache/arrow/" - "dataset/jni/NativeRecordBatchHandle$Field;"); - record_batch_handle_buffer_class = + serialized_record_batch_iterator_class = CreateGlobalClassReference(env, "Lorg/apache/arrow/" - "dataset/jni/NativeRecordBatchHandle$Buffer;"); + "dataset/jni/NativeSerializedRecordBatchIterator;"); java_reservation_listener_class = CreateGlobalClassReference(env, "Lorg/apache/arrow/" "dataset/jni/ReservationListener;"); - - record_batch_handle_constructor = - JniGetOrThrow(GetMethodID(env, record_batch_handle_class, "", - "(J[Lorg/apache/arrow/dataset/" - "jni/NativeRecordBatchHandle$Field;" - "[Lorg/apache/arrow/dataset/" - "jni/NativeRecordBatchHandle$Buffer;)V")); - record_batch_handle_field_constructor = - JniGetOrThrow(GetMethodID(env, record_batch_handle_field_class, "", "(JJ)V")); - record_batch_handle_buffer_constructor = JniGetOrThrow( - GetMethodID(env, record_batch_handle_buffer_class, "", "(JJJJ)V")); + serialized_record_batch_iterator_hasNext = JniGetOrThrow( + GetMethodID(env, serialized_record_batch_iterator_class, "hasNext", "()Z")); + serialized_record_batch_iterator_next = JniGetOrThrow( + GetMethodID(env, serialized_record_batch_iterator_class, "next", "()[B")); reserve_memory_method = JniGetOrThrow(GetMethodID(env, java_reservation_listener_class, "reserve", "(J)V")); unreserve_memory_method = JniGetOrThrow( @@ -249,9 +344,7 @@ void JNI_OnUnload(JavaVM* vm, void* reserved) { env->DeleteGlobalRef(illegal_access_exception_class); env->DeleteGlobalRef(illegal_argument_exception_class); env->DeleteGlobalRef(runtime_exception_class); - env->DeleteGlobalRef(record_batch_handle_class); - env->DeleteGlobalRef(record_batch_handle_field_class); - env->DeleteGlobalRef(record_batch_handle_buffer_class); + env->DeleteGlobalRef(serialized_record_batch_iterator_class); env->DeleteGlobalRef(java_reservation_listener_class); default_memory_pool_id = -1L; @@ -458,9 +551,9 @@ Java_org_apache_arrow_dataset_jni_JniWrapper_getSchemaFromScanner(JNIEnv* env, j /* * Class: org_apache_arrow_dataset_jni_JniWrapper * Method: nextRecordBatch - * Signature: (J)Lorg/apache/arrow/dataset/jni/NativeRecordBatchHandle; + * Signature: (J)[B */ -JNIEXPORT jobject JNICALL Java_org_apache_arrow_dataset_jni_JniWrapper_nextRecordBatch( +JNIEXPORT jbyteArray JNICALL Java_org_apache_arrow_dataset_jni_JniWrapper_nextRecordBatch( JNIEnv* env, jobject, jlong scanner_id) { JNI_METHOD_START std::shared_ptr scanner_adaptor = @@ -471,46 +564,7 @@ JNIEXPORT jobject JNICALL Java_org_apache_arrow_dataset_jni_JniWrapper_nextRecor if (record_batch == nullptr) { return nullptr; // stream ended } - std::shared_ptr schema = record_batch->schema(); - jobjectArray field_array = - env->NewObjectArray(schema->num_fields(), record_batch_handle_field_class, nullptr); - - std::vector> buffers; - for (int i = 0; i < schema->num_fields(); ++i) { - auto column = record_batch->column(i); - auto dataArray = column->data(); - jobject field = env->NewObject(record_batch_handle_field_class, - record_batch_handle_field_constructor, - column->length(), column->null_count()); - env->SetObjectArrayElement(field_array, i, field); - - for (auto& buffer : dataArray->buffers) { - buffers.push_back(buffer); - } - } - - jobjectArray buffer_array = - env->NewObjectArray(buffers.size(), record_batch_handle_buffer_class, nullptr); - - for (size_t j = 0; j < buffers.size(); ++j) { - auto buffer = buffers[j]; - uint8_t* data = nullptr; - int64_t size = 0; - int64_t capacity = 0; - if (buffer != nullptr) { - data = (uint8_t*)buffer->data(); - size = buffer->size(); - capacity = buffer->capacity(); - } - jobject buffer_handle = env->NewObject(record_batch_handle_buffer_class, - record_batch_handle_buffer_constructor, - CreateNativeRef(buffer), data, size, capacity); - env->SetObjectArrayElement(buffer_array, j, buffer_handle); - } - - jobject ret = env->NewObject(record_batch_handle_class, record_batch_handle_constructor, - record_batch->num_rows(), field_array, buffer_array); - return ret; + return JniGetOrThrow(arrow::dataset::jni::SerializeUnsafeFromNative(env, record_batch)); JNI_METHOD_END(nullptr) } @@ -526,6 +580,38 @@ JNIEXPORT void JNICALL Java_org_apache_arrow_dataset_jni_JniWrapper_releaseBuffe JNI_METHOD_END() } +/* + * Class: org_apache_arrow_dataset_file_JniWrapper + * Method: newJniGlobalReference + * Signature: (Ljava/lang/Object;)J + */ +JNIEXPORT jlong JNICALL +Java_org_apache_arrow_dataset_file_JniWrapper_newJniGlobalReference(JNIEnv* env, jobject, + jobject referent) { + JNI_METHOD_START + return reinterpret_cast(env->NewGlobalRef(referent)); + JNI_METHOD_END(-1L) +} + +/* + * Class: org_apache_arrow_dataset_file_JniWrapper + * Method: newJniMethodReference + * Signature: (Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;)J + */ +JNIEXPORT jlong JNICALL +Java_org_apache_arrow_dataset_file_JniWrapper_newJniMethodReference(JNIEnv* env, jobject, + jstring class_sig, + jstring method_name, + jstring method_sig) { + JNI_METHOD_START + jclass clazz = env->FindClass(JStringToCString(env, class_sig).data()); + jmethodID jmethod_id = + env->GetMethodID(clazz, JStringToCString(env, method_name).data(), + JStringToCString(env, method_sig).data()); + return reinterpret_cast(jmethod_id); + JNI_METHOD_END(-1L) +} + /* * Class: org_apache_arrow_dataset_file_JniWrapper * Method: makeFileSystemDatasetFactory @@ -544,3 +630,40 @@ Java_org_apache_arrow_dataset_file_JniWrapper_makeFileSystemDatasetFactory( return CreateNativeRef(d); JNI_METHOD_END(-1L) } + +/* + * Class: org_apache_arrow_dataset_file_JniWrapper + * Method: writeFromScannerToFile + * Signature: + * (Lorg/apache/arrow/dataset/jni/NativeSerializedRecordBatchIterator;[BILjava/lang/String;[Ljava/lang/String;ILjava/lang/String;)V + */ +JNIEXPORT void JNICALL +Java_org_apache_arrow_dataset_file_JniWrapper_writeFromScannerToFile( + JNIEnv* env, jobject, jobject itr, jbyteArray schema_bytes, jint file_format_id, + jstring uri, jobjectArray partition_columns, jint max_partitions, + jstring base_name_template) { + JNI_METHOD_START + JavaVM* vm; + if (env->GetJavaVM(&vm) != JNI_OK) { + JniThrow("Unable to get JavaVM instance"); + } + auto schema = JniGetOrThrow(FromSchemaByteArray(env, schema_bytes)); + auto scanner = JniGetOrThrow(MakeJavaDatasetScanner(vm, itr, schema)); + std::shared_ptr file_format = + JniGetOrThrow(GetFileFormat(file_format_id)); + arrow::dataset::FileSystemDatasetWriteOptions options; + std::string output_path; + auto filesystem = JniGetOrThrow( + arrow::fs::FileSystemFromUri(JStringToCString(env, uri), &output_path)); + std::vector partition_column_vector = + ToStringVector(env, partition_columns); + options.file_write_options = file_format->DefaultWriteOptions(); + options.filesystem = filesystem; + options.base_dir = output_path; + options.basename_template = JStringToCString(env, base_name_template); + options.partitioning = std::make_shared( + arrow::dataset::SchemaFromColumnNames(schema, partition_column_vector)); + options.max_partitions = max_partitions; + JniAssertOkOrThrow(arrow::dataset::FileSystemDataset::Write(options, scanner)); + JNI_METHOD_END() +} diff --git a/java/dataset/CMakeLists.txt b/java/dataset/CMakeLists.txt index 07e2d0ae8fc..ba4dc90f8a4 100644 --- a/java/dataset/CMakeLists.txt +++ b/java/dataset/CMakeLists.txt @@ -30,14 +30,16 @@ include(FindJNI) message("generating headers to ${JNI_HEADERS_DIR}") -add_jar(arrow_dataset_java - src/main/java/org/apache/arrow/dataset/jni/JniLoader.java - src/main/java/org/apache/arrow/dataset/jni/JniWrapper.java - src/main/java/org/apache/arrow/dataset/jni/NativeRecordBatchHandle.java - src/main/java/org/apache/arrow/dataset/file/JniWrapper.java - src/main/java/org/apache/arrow/dataset/jni/NativeMemoryPool.java - src/main/java/org/apache/arrow/dataset/jni/ReservationListener.java - GENERATE_NATIVE_HEADERS - arrow_dataset_java-native - DESTINATION - ${JNI_HEADERS_DIR}) +add_jar( + arrow_dataset_java + src/main/java/org/apache/arrow/dataset/jni/JniLoader.java + src/main/java/org/apache/arrow/dataset/jni/JniWrapper.java + src/main/java/org/apache/arrow/dataset/file/JniWrapper.java + src/main/java/org/apache/arrow/dataset/jni/NativeMemoryPool.java + src/main/java/org/apache/arrow/dataset/jni/ReservationListener.java + src/main/java/org/apache/arrow/dataset/file/NativeScannerAdaptor.java + src/main/java/org/apache/arrow/dataset/jni/NativeSerializedRecordBatchIterator.java + GENERATE_NATIVE_HEADERS + arrow_dataset_java-native + DESTINATION + ${JNI_HEADERS_DIR}) diff --git a/java/dataset/pom.xml b/java/dataset/pom.xml index c4246a89090..f3501ff2752 100644 --- a/java/dataset/pom.xml +++ b/java/dataset/pom.xml @@ -27,7 +27,7 @@ ../../../cpp/release-build/ 2.5.0 1.11.0 - 1.8.2 + 1.9.1 @@ -38,12 +38,30 @@ compile ${arrow.vector.classifier} + + org.apache.arrow + arrow-format + ${project.version} + compile + org.apache.arrow arrow-memory-core ${project.version} compile + + commons-io + commons-io + 2.4 + compile + + + com.google.flatbuffers + flatbuffers-java + ${dep.fbs.version} + compile + org.apache.arrow arrow-memory-netty @@ -56,6 +74,12 @@ ${parquet.version} test + + org.apache.avro + avro + ${avro.version} + test + org.apache.parquet parquet-hadoop @@ -86,18 +110,6 @@ - - org.apache.avro - avro - ${avro.version} - test - - - com.google.guava - guava - ${dep.guava.version} - test - @@ -108,27 +120,6 @@ - - - - org.xolstice.maven.plugins - protobuf-maven-plugin - 0.5.1 - - com.google.protobuf:protoc:${protobuf.version}:exe:${os.detected.classifier} - - ../../cpp/src/jni/dataset/proto - - - - - compile - test-compile - - - - - diff --git a/java/dataset/src/main/java/org/apache/arrow/dataset/file/DatasetFileWriter.java b/java/dataset/src/main/java/org/apache/arrow/dataset/file/DatasetFileWriter.java new file mode 100644 index 00000000000..ea85b5d0efb --- /dev/null +++ b/java/dataset/src/main/java/org/apache/arrow/dataset/file/DatasetFileWriter.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.dataset.file; + +import org.apache.arrow.dataset.jni.NativeSerializedRecordBatchIterator; +import org.apache.arrow.dataset.scanner.Scanner; +import org.apache.arrow.util.AutoCloseables; +import org.apache.arrow.vector.util.SchemaUtility; + +/** + * JNI-based utility to write datasets into files. It internally depends on C++ static method + * FileSystemDataset::Write. + */ +public class DatasetFileWriter { + + /** + * Scan over an input {@link Scanner} then write all record batches to file. + * + * @param scanner the source scanner for writing + * @param format target file format + * @param uri target file uri + * @param maxPartitions maximum partitions to be included in written files + * @param partitionColumns columns used to partition output files. Empty to disable partitioning + * @param baseNameTemplate file name template used to make partitions. E.g. "dat_{i}", i is current partition + * ID around all written files. + */ + public static void write(Scanner scanner, FileFormat format, String uri, + String[] partitionColumns, int maxPartitions, String baseNameTemplate) { + final NativeScannerAdaptorImpl adaptor = new NativeScannerAdaptorImpl(scanner); + final NativeSerializedRecordBatchIterator itr = adaptor.scan(); + RuntimeException throwableWrapper = null; + try { + JniWrapper.get().writeFromScannerToFile(itr, SchemaUtility.serialize(scanner.schema()), + format.id(), uri, partitionColumns, maxPartitions, baseNameTemplate); + } catch (Throwable t) { + throwableWrapper = new RuntimeException(t); + throw throwableWrapper; + } finally { + try { + AutoCloseables.close(itr); + } catch (Exception e) { + if (throwableWrapper != null) { + throwableWrapper.addSuppressed(e); + } + } + } + } + + /** + * Scan over an input {@link Scanner} then write all record batches to file, with default partitioning settings. + * + * @param scanner the source scanner for writing + * @param format target file format + * @param uri target file uri + */ + public static void write(Scanner scanner, FileFormat format, String uri) { + write(scanner, format, uri, new String[0], 1024, "dat_{i}"); + } +} diff --git a/java/dataset/src/main/java/org/apache/arrow/dataset/file/FileFormat.java b/java/dataset/src/main/java/org/apache/arrow/dataset/file/FileFormat.java index e341d46beac..107fc2f71d2 100644 --- a/java/dataset/src/main/java/org/apache/arrow/dataset/file/FileFormat.java +++ b/java/dataset/src/main/java/org/apache/arrow/dataset/file/FileFormat.java @@ -24,7 +24,7 @@ public enum FileFormat { PARQUET(0), NONE(-1); - private int id; + private final int id; FileFormat(int id) { this.id = id; diff --git a/java/dataset/src/main/java/org/apache/arrow/dataset/file/JniWrapper.java b/java/dataset/src/main/java/org/apache/arrow/dataset/file/JniWrapper.java index 1af307aac38..a503ca31f7e 100644 --- a/java/dataset/src/main/java/org/apache/arrow/dataset/file/JniWrapper.java +++ b/java/dataset/src/main/java/org/apache/arrow/dataset/file/JniWrapper.java @@ -18,6 +18,7 @@ package org.apache.arrow.dataset.file; import org.apache.arrow.dataset.jni.JniLoader; +import org.apache.arrow.dataset.jni.NativeSerializedRecordBatchIterator; /** * JniWrapper for filesystem based {@link org.apache.arrow.dataset.source.Dataset} implementations. @@ -34,9 +35,27 @@ private JniWrapper() { JniLoader.get().ensureLoaded(); } + /** + * Create a Jni global reference for the object. + * @param object the input object + * @return the native pointer of global reference object. + */ + public native long newJniGlobalReference(Object object); + + /** + * Create a Jni method reference. + * @param classSignature signature of the class defining the target method + * @param methodName method name + * @param methodSignature signature of the target method + * @return the native pointer of method reference object. + */ + public native long newJniMethodReference(String classSignature, String methodName, + String methodSignature); + /** * Create FileSystemDatasetFactory and return its native pointer. The pointer is pointing to a * intermediate shared_ptr of the factory instance. + * * @param uri file uri to read * @param fileFormat file format ID * @return the native pointer of the arrow::dataset::FileSystemDatasetFactory instance. @@ -44,4 +63,20 @@ private JniWrapper() { */ public native long makeFileSystemDatasetFactory(String uri, int fileFormat); + /** + * Write all record batches in a {@link NativeSerializedRecordBatchIterator} into files. This internally + * depends on C++ write API: FileSystemDataset::Write. + * + * @param itr iterator to be used for writing + * @param schema serialized schema of output files + * @param fileFormat target file format (ID) + * @param uri target file uri + * @param partitionColumns columns used to partition output files + * @param maxPartitions maximum partitions to be included in written files + * @param baseNameTemplate file name template used to make partitions. E.g. "dat_{i}", i is current partition + * ID around all written files. + */ + public native void writeFromScannerToFile(NativeSerializedRecordBatchIterator itr, byte[] schema, + int fileFormat, String uri, String[] partitionColumns, int maxPartitions, String baseNameTemplate); + } diff --git a/java/dataset/src/main/java/org/apache/arrow/dataset/file/NativeScannerAdaptor.java b/java/dataset/src/main/java/org/apache/arrow/dataset/file/NativeScannerAdaptor.java new file mode 100644 index 00000000000..b30ab2cafbf --- /dev/null +++ b/java/dataset/src/main/java/org/apache/arrow/dataset/file/NativeScannerAdaptor.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.dataset.file; + +import org.apache.arrow.dataset.jni.NativeSerializedRecordBatchIterator; + +/** + * A short path comparing to {@link org.apache.arrow.dataset.scanner.Scanner} for being called from C++ side + * via JNI, to minimize JNI call overhead. + */ +public interface NativeScannerAdaptor { + + /** + * Scan with the delegated scanner. + * + * @return a iterator outputting JNI-friendly flatbuffers-serialized + * {@link org.apache.arrow.vector.ipc.message.ArrowRecordBatch} instances. + */ + NativeSerializedRecordBatchIterator scan(); +} diff --git a/java/dataset/src/main/java/org/apache/arrow/dataset/file/NativeScannerAdaptorImpl.java b/java/dataset/src/main/java/org/apache/arrow/dataset/file/NativeScannerAdaptorImpl.java new file mode 100644 index 00000000000..8fe9421f7e7 --- /dev/null +++ b/java/dataset/src/main/java/org/apache/arrow/dataset/file/NativeScannerAdaptorImpl.java @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.dataset.file; + +import java.util.Iterator; +import java.util.NoSuchElementException; + +import org.apache.arrow.dataset.jni.NativeSerializedRecordBatchIterator; +import org.apache.arrow.dataset.jni.UnsafeRecordBatchSerializer; +import org.apache.arrow.dataset.scanner.ScanTask; +import org.apache.arrow.dataset.scanner.Scanner; +import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; + +/** + * Default implementation of {@link NativeScannerAdaptor}. + */ +public class NativeScannerAdaptorImpl implements NativeScannerAdaptor, AutoCloseable { + + private final Scanner scanner; + + /** + * Constructor. + * + * @param scanner the delegated scanner. + */ + public NativeScannerAdaptorImpl(Scanner scanner) { + this.scanner = scanner; + } + + @Override + public NativeSerializedRecordBatchIterator scan() { + final Iterable tasks = scanner.scan(); + return new IteratorImpl(tasks); + } + + @Override + public void close() throws Exception { + scanner.close(); + } + + private static class IteratorImpl implements NativeSerializedRecordBatchIterator { + + private final Iterator taskIterator; + + private ScanTask currentTask = null; + private ScanTask.BatchIterator currentBatchIterator = null; + + public IteratorImpl(Iterable tasks) { + this.taskIterator = tasks.iterator(); + } + + @Override + public void close() throws Exception { + closeCurrent(); + } + + private void closeCurrent() throws Exception { + if (currentTask == null) { + return; + } + currentTask.close(); + currentBatchIterator.close(); + } + + private boolean advance() { + if (!taskIterator.hasNext()) { + return false; + } + try { + closeCurrent(); + } catch (Exception e) { + throw new RuntimeException(e); + } + currentTask = taskIterator.next(); + currentBatchIterator = currentTask.execute(); + return true; + } + + @Override + public boolean hasNext() { + if (currentTask == null) { + if (!advance()) { + return false; + } + } + if (!currentBatchIterator.hasNext()) { + if (!advance()) { + return false; + } + } + return currentBatchIterator.hasNext(); + } + + @Override + public byte[] next() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + return serialize(currentBatchIterator.next()); + } + + private byte[] serialize(ArrowRecordBatch batch) { + return UnsafeRecordBatchSerializer.serializeUnsafe(batch); + } + } +} diff --git a/java/dataset/src/main/java/org/apache/arrow/dataset/jni/JniWrapper.java b/java/dataset/src/main/java/org/apache/arrow/dataset/jni/JniWrapper.java index 8460841eeee..37c98850819 100644 --- a/java/dataset/src/main/java/org/apache/arrow/dataset/jni/JniWrapper.java +++ b/java/dataset/src/main/java/org/apache/arrow/dataset/jni/JniWrapper.java @@ -52,7 +52,7 @@ private JniWrapper() { * Create Dataset from a DatasetFactory and get the native pointer of the Dataset. * * @param datasetFactoryId the native pointer of the arrow::dataset::DatasetFactory instance. - * @param schema the predefined schema of the resulting Dataset. + * @param schema the predefined schema of the resulting Dataset. * @return the native pointer of the arrow::dataset::Dataset instance. */ public native long createDataset(long datasetFactoryId, byte[] schema); @@ -66,9 +66,11 @@ private JniWrapper() { /** * Create Scanner from a Dataset and get the native pointer of the Dataset. - * @param datasetId the native pointer of the arrow::dataset::Dataset instance. - * @param columns desired column names. Columns not in this list will not be emitted when performing scan operation. - * @param batchSize batch size of scanned record batches. + * + * @param datasetId the native pointer of the arrow::dataset::Dataset instance. + * @param columns desired column names. Columns not in this list will not be emitted when performing scan + * operation. + * @param batchSize batch size of scanned record batches. * @param memoryPool identifier of memory pool used in the native scanner. * @return the native pointer of the arrow::dataset::Scanner instance. */ @@ -85,19 +87,24 @@ private JniWrapper() { /** * Release the Scanner by destroying its reference held by JNI wrapper. + * * @param scannerId the native pointer of the arrow::dataset::Scanner instance. */ public native void closeScanner(long scannerId); /** * Read next record batch from the specified scanner. + * * @param scannerId the native pointer of the arrow::dataset::Scanner instance. - * @return an instance of {@link NativeRecordBatchHandle} describing the overall layout of the native record batch. + * @return a flatbuffers-serialized + * {@link org.apache.arrow.flatbuf.Message} describing + * the overall layout of the native record batch. */ - public native NativeRecordBatchHandle nextRecordBatch(long scannerId); + public native byte[] nextRecordBatch(long scannerId); /** * Release the Buffer by destroying its reference held by JNI wrapper. + * * @param bufferId the native pointer of the arrow::Buffer instance. */ public native void releaseBuffer(long bufferId); diff --git a/java/dataset/src/main/java/org/apache/arrow/dataset/jni/NativeRecordBatchHandle.java b/java/dataset/src/main/java/org/apache/arrow/dataset/jni/NativeRecordBatchHandle.java deleted file mode 100644 index dd90fd1c1dd..00000000000 --- a/java/dataset/src/main/java/org/apache/arrow/dataset/jni/NativeRecordBatchHandle.java +++ /dev/null @@ -1,106 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.arrow.dataset.jni; - -import java.util.Arrays; -import java.util.List; - -/** - * Hold pointers to a Arrow C++ RecordBatch. - */ -public class NativeRecordBatchHandle { - - private final long numRows; - private final List fields; - private final List buffers; - - /** - * Constructor. - * - * @param numRows Total row number of the associated RecordBatch - * @param fields Metadata of fields - * @param buffers Retained Arrow buffers - */ - public NativeRecordBatchHandle(long numRows, Field[] fields, Buffer[] buffers) { - this.numRows = numRows; - this.fields = Arrays.asList(fields); - this.buffers = Arrays.asList(buffers); - } - - /** - * Returns the total row number of the associated RecordBatch. - * @return Total row number of the associated RecordBatch. - */ - public long getNumRows() { - return numRows; - } - - /** - * Returns Metadata of fields. - * @return Metadata of fields. - */ - public List getFields() { - return fields; - } - - /** - * Returns the buffers. - * @return Retained Arrow buffers. - */ - public List getBuffers() { - return buffers; - } - - /** - * Field metadata. - */ - public static class Field { - public final long length; - public final long nullCount; - - public Field(long length, long nullCount) { - this.length = length; - this.nullCount = nullCount; - } - } - - /** - * Pointers and metadata of the targeted Arrow buffer. - */ - public static class Buffer { - public final long nativeInstanceId; - public final long memoryAddress; - public final long size; - public final long capacity; - - /** - * Constructor. - * - * @param nativeInstanceId Native instance's id - * @param memoryAddress Memory address of the first byte - * @param size Size (in bytes) - * @param capacity Capacity (in bytes) - */ - public Buffer(long nativeInstanceId, long memoryAddress, long size, long capacity) { - this.nativeInstanceId = nativeInstanceId; - this.memoryAddress = memoryAddress; - this.size = size; - this.capacity = capacity; - } - } -} diff --git a/java/dataset/src/main/java/org/apache/arrow/dataset/jni/NativeScanner.java b/java/dataset/src/main/java/org/apache/arrow/dataset/jni/NativeScanner.java index 24c298067af..ea2c9edf4ec 100644 --- a/java/dataset/src/main/java/org/apache/arrow/dataset/jni/NativeScanner.java +++ b/java/dataset/src/main/java/org/apache/arrow/dataset/jni/NativeScanner.java @@ -18,23 +18,15 @@ package org.apache.arrow.dataset.jni; import java.io.IOException; -import java.util.ArrayList; import java.util.Collections; import java.util.NoSuchElementException; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReadWriteLock; import java.util.concurrent.locks.ReentrantReadWriteLock; -import java.util.stream.Collectors; import org.apache.arrow.dataset.scanner.ScanTask; import org.apache.arrow.dataset.scanner.Scanner; -import org.apache.arrow.memory.ArrowBuf; -import org.apache.arrow.memory.BufferAllocator; -import org.apache.arrow.memory.BufferLedger; -import org.apache.arrow.memory.NativeUnderlyingMemory; -import org.apache.arrow.memory.util.LargeMemoryUtil; -import org.apache.arrow.vector.ipc.message.ArrowFieldNode; import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; import org.apache.arrow.vector.types.pojo.Schema; import org.apache.arrow.vector.util.SchemaUtility; @@ -81,39 +73,21 @@ public boolean hasNext() { if (peek != null) { return true; } - final NativeRecordBatchHandle handle; + final byte[] bytes; readLock.lock(); try { if (closed) { throw new NativeInstanceReleasedException(); } - handle = JniWrapper.get().nextRecordBatch(scannerId); + bytes = JniWrapper.get().nextRecordBatch(scannerId); } finally { readLock.unlock(); } - if (handle == null) { + if (bytes == null) { return false; } - final ArrayList buffers = new ArrayList<>(); - for (NativeRecordBatchHandle.Buffer buffer : handle.getBuffers()) { - final BufferAllocator allocator = context.getAllocator(); - final int size = LargeMemoryUtil.checkedCastToInt(buffer.size); - final NativeUnderlyingMemory am = NativeUnderlyingMemory.create(allocator, - size, buffer.nativeInstanceId, buffer.memoryAddress); - BufferLedger ledger = am.associate(allocator); - ArrowBuf buf = new ArrowBuf(ledger, null, size, buffer.memoryAddress); - buffers.add(buf); - } - - try { - final int numRows = LargeMemoryUtil.checkedCastToInt(handle.getNumRows()); - peek = new ArrowRecordBatch(numRows, handle.getFields().stream() - .map(field -> new ArrowFieldNode(field.length, field.nullCount)) - .collect(Collectors.toList()), buffers); - return true; - } finally { - buffers.forEach(buffer -> buffer.getReferenceManager().release()); - } + peek = UnsafeRecordBatchSerializer.deserializeUnsafe(context.getAllocator(), bytes); + return true; } @Override diff --git a/java/dataset/src/main/java/org/apache/arrow/dataset/jni/NativeSerializedRecordBatchIterator.java b/java/dataset/src/main/java/org/apache/arrow/dataset/jni/NativeSerializedRecordBatchIterator.java new file mode 100644 index 00000000000..5495f85cf1c --- /dev/null +++ b/java/dataset/src/main/java/org/apache/arrow/dataset/jni/NativeSerializedRecordBatchIterator.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.dataset.jni; + +import java.util.Iterator; + +/** + * Iterate on flatbuffers-serialized {@link org.apache.arrow.vector.ipc.message.ArrowRecordBatch}. + *

+ * {@link #next()} should be called from C++ scanner to read Java-generated Arrow data. + */ +public interface NativeSerializedRecordBatchIterator extends Iterator, AutoCloseable { + + /** + * Return next serialized {@link org.apache.arrow.vector.ipc.message.ArrowRecordBatch} Java + * byte array. + */ + @Override + byte[] next(); +} diff --git a/java/dataset/src/main/java/org/apache/arrow/dataset/jni/UnsafeRecordBatchSerializer.java b/java/dataset/src/main/java/org/apache/arrow/dataset/jni/UnsafeRecordBatchSerializer.java new file mode 100644 index 00000000000..f7ad378f3dd --- /dev/null +++ b/java/dataset/src/main/java/org/apache/arrow/dataset/jni/UnsafeRecordBatchSerializer.java @@ -0,0 +1,266 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.dataset.jni; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.channels.Channels; +import java.util.ArrayList; +import java.util.Base64; +import java.util.List; +import java.util.stream.Collectors; + +import org.apache.arrow.dataset.file.JniWrapper; +import org.apache.arrow.flatbuf.Buffer; +import org.apache.arrow.flatbuf.FieldNode; +import org.apache.arrow.flatbuf.KeyValue; +import org.apache.arrow.flatbuf.Message; +import org.apache.arrow.flatbuf.MessageHeader; +import org.apache.arrow.flatbuf.RecordBatch; +import org.apache.arrow.memory.ArrowBuf; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.BufferLedger; +import org.apache.arrow.memory.NativeUnderlyingMemory; +import org.apache.arrow.memory.util.LargeMemoryUtil; +import org.apache.arrow.util.Preconditions; +import org.apache.arrow.vector.compression.NoCompressionCodec; +import org.apache.arrow.vector.ipc.ReadChannel; +import org.apache.arrow.vector.ipc.WriteChannel; +import org.apache.arrow.vector.ipc.message.ArrowBodyCompression; +import org.apache.arrow.vector.ipc.message.ArrowFieldNode; +import org.apache.arrow.vector.ipc.message.ArrowMessage; +import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; +import org.apache.arrow.vector.ipc.message.FBSerializable; +import org.apache.arrow.vector.ipc.message.FBSerializables; +import org.apache.arrow.vector.ipc.message.IpcOption; +import org.apache.arrow.vector.ipc.message.MessageMetadataResult; +import org.apache.arrow.vector.ipc.message.MessageSerializer; + +import com.google.flatbuffers.FlatBufferBuilder; + +/** + * A set of serialization utility methods against {@link org.apache.arrow.vector.ipc.message.ArrowRecordBatch}. + * + *

The utility should be used only in JNI case since the record batch + * to serialize should keep alive during the life cycle of its deserialized + * native record batch. We made this design for achieving zero-copy of + * the buffer bodies. + */ +public class UnsafeRecordBatchSerializer { + + /** + * This is in response to native arrow::Buffer instance's destructor after + * Java side {@link ArrowBuf} is transferred to native via JNI. At here we + * think of a C++ transferred buffer holding one Java side buffer's reference + * count so memory management system from Java side can be able to + * know when native code finishes using the buffer. This way the corresponding + * allocated memory space can be correctly collected. + * + * @see UnsafeRecordBatchSerializer#serializeUnsafe(ArrowRecordBatch) + */ + private static class TransferredReferenceCleaner implements Runnable { + private static final long NATIVE_METHOD_REF = + JniWrapper.get().newJniMethodReference("Ljava/lang/Runnable;", "run", "()V"); + private final ArrowBuf buf; + + private TransferredReferenceCleaner(ArrowBuf buf) { + this.buf = buf; + } + + @Override + public void run() { + buf.getReferenceManager().release(); + } + } + + /** + * Deserialize from native serialized bytes to {@link ArrowRecordBatch} using flatbuffers. + * The input byte array should be written from native code and of type + * {@link Message} + * in which a native buffer ID is required in custom metadata. + * + * @param allocator Allocator that the deserialized buffer should be associated with + * @param bytes flatbuffers byte array + * @return the deserialized record batch + * @see NativeUnderlyingMemory + */ + public static ArrowRecordBatch deserializeUnsafe( + BufferAllocator allocator, + byte[] bytes) { + final ReadChannel metaIn = new ReadChannel( + Channels.newChannel(new ByteArrayInputStream(bytes))); + + final Message metaMessage; + try { + final MessageMetadataResult result = MessageSerializer.readMessage(metaIn); + Preconditions.checkNotNull(result); + metaMessage = result.getMessage(); + } catch (IOException e) { + throw new RuntimeException("Failed to deserialize record batch metadata", e); + } + final RecordBatch batchMeta = (RecordBatch) metaMessage.header(new RecordBatch()); + Preconditions.checkNotNull(batchMeta); + if (batchMeta.buffersLength() != metaMessage.customMetadataLength()) { + throw new IllegalArgumentException("Buffer count mismatch between metadata and native managed refs"); + } + + final ArrayList buffers = new ArrayList<>(); + for (int i = 0; i < batchMeta.buffersLength(); i++) { + final Buffer bufferMeta = batchMeta.buffers(i); + final KeyValue keyValue = metaMessage.customMetadata(i); // custom metadata containing native buffer refs + final byte[] refDecoded = Base64.getDecoder().decode(keyValue.value()); + final long nativeBufferRef = ByteBuffer.wrap(refDecoded).order(ByteOrder.LITTLE_ENDIAN).getLong(); + final int size = LargeMemoryUtil.checkedCastToInt(bufferMeta.length()); + final NativeUnderlyingMemory am = NativeUnderlyingMemory.create(allocator, + size, nativeBufferRef, bufferMeta.offset()); + BufferLedger ledger = am.associate(allocator); + ArrowBuf buf = new ArrowBuf(ledger, null, size, bufferMeta.offset()); + buffers.add(buf); + } + + try { + final int numRows = LargeMemoryUtil.checkedCastToInt(batchMeta.length()); + final List nodes = new ArrayList<>(batchMeta.nodesLength()); + for (int i = 0; i < batchMeta.nodesLength(); i++) { + final FieldNode node = batchMeta.nodes(i); + nodes.add(new ArrowFieldNode(node.length(), node.nullCount())); + } + return new ArrowRecordBatch(numRows, nodes, buffers); + } finally { + buffers.forEach(buffer -> buffer.getReferenceManager().release()); + } + } + + /** + * Serialize from {@link ArrowRecordBatch} to flatbuffers bytes for native use. A cleaner callback + * {@link TransferredReferenceCleaner} will be created for each individual serialized + * buffer. The callback should be invoked once the buffer is collected from native code. + * We use the callback to decrease reference count of Java side {@link ArrowBuf} here. + * + * @param batch input record batch + * @return serialized bytes + * @see TransferredReferenceCleaner + */ + public static byte[] serializeUnsafe(ArrowRecordBatch batch) { + final ArrowBodyCompression bodyCompression = batch.getBodyCompression(); + if (bodyCompression.getCodec() != NoCompressionCodec.COMPRESSION_TYPE) { + throw new UnsupportedOperationException("Could not serialize compressed buffers"); + } + + final FlatBufferBuilder builder = new FlatBufferBuilder(); + List buffers = batch.getBuffers(); + int[] metadataOffsets = new int[buffers.size() * 2]; + for (int i = 0, buffersSize = buffers.size(); i < buffersSize; i++) { + ArrowBuf buffer = buffers.get(i); + final TransferredReferenceCleaner cleaner = new TransferredReferenceCleaner(buffer); + // cleaner object ref + long objectRefValue = JniWrapper.get().newJniGlobalReference(cleaner); + byte[] objectRefBytes = ByteBuffer.allocate(Long.BYTES).order(ByteOrder.LITTLE_ENDIAN) + .putLong(objectRefValue).array(); + metadataOffsets[i * 2] = KeyValue.createKeyValue(builder, builder.createString("JAVA_BUFFER_CO_REF_" + i), + builder.createString(Base64.getEncoder().encodeToString(objectRefBytes))); + // cleaner method ref + long methodRefValue = TransferredReferenceCleaner.NATIVE_METHOD_REF; + byte[] methodRefBytes = ByteBuffer.allocate(Long.BYTES).order(ByteOrder.LITTLE_ENDIAN) + .putLong(methodRefValue).array(); + metadataOffsets[i * 2 + 1] = + KeyValue.createKeyValue(builder, builder.createString("JAVA_BUFFER_CM_REF_" + i), + builder.createString(Base64.getEncoder().encodeToString(methodRefBytes))); + } + final ArrowMessage unsafeRecordMessage = new UnsafeRecordBatchMetadataMessage(batch); + final int batchOffset = unsafeRecordMessage.writeTo(builder); + final int customMetadataOffset = Message.createCustomMetadataVector(builder, metadataOffsets); + Message.startMessage(builder); + Message.addHeaderType(builder, unsafeRecordMessage.getMessageType()); + Message.addHeader(builder, batchOffset); + Message.addVersion(builder, IpcOption.DEFAULT.metadataVersion.toFlatbufID()); + Message.addBodyLength(builder, unsafeRecordMessage.computeBodyLength()); + Message.addCustomMetadata(builder, customMetadataOffset); + builder.finish(Message.endMessage(builder)); + final ByteBuffer metaBuffer = builder.dataBuffer(); + final ByteArrayOutputStream out = new ByteArrayOutputStream(); + try { + MessageSerializer.writeMessageBuffer(new WriteChannel(Channels.newChannel(out)), metaBuffer.remaining(), + metaBuffer, IpcOption.DEFAULT); + } catch (IOException e) { + throw new RuntimeException("Failed to serialize Java record batch", e); + } + return out.toByteArray(); + } + + /** + * IPC message for record batches that are based on unsafe shared virtual memory. + */ + public static class UnsafeRecordBatchMetadataMessage implements ArrowMessage { + private ArrowRecordBatch delegated; + + public UnsafeRecordBatchMetadataMessage(ArrowRecordBatch delegated) { + this.delegated = delegated; + } + + @Override + public long computeBodyLength() { + return 0L; + } + + @Override + public T accepts(ArrowMessageVisitor visitor) { + throw new UnsupportedOperationException(); + } + + @Override + public byte getMessageType() { + return MessageHeader.RecordBatch; + } + + @Override + public int writeTo(FlatBufferBuilder builder) { + final List nodes = delegated.getNodes(); + final List buffers = delegated.getBuffers(); + final ArrowBodyCompression bodyCompression = delegated.getBodyCompression(); + final int length = delegated.getLength(); + RecordBatch.startNodesVector(builder, nodes.size()); + int nodesOffset = FBSerializables.writeAllStructsToVector(builder, nodes); + RecordBatch.startBuffersVector(builder, buffers.size()); + int buffersOffset = FBSerializables.writeAllStructsToVector(builder, buffers.stream() + .map(buf -> (FBSerializable) b -> Buffer.createBuffer(b, buf.memoryAddress(), + buf.getReferenceManager().getSize())) + .collect(Collectors.toList())); + int compressOffset = 0; + if (bodyCompression.getCodec() != NoCompressionCodec.COMPRESSION_TYPE) { + compressOffset = bodyCompression.writeTo(builder); + } + RecordBatch.startRecordBatch(builder); + RecordBatch.addLength(builder, length); + RecordBatch.addNodes(builder, nodesOffset); + RecordBatch.addBuffers(builder, buffersOffset); + if (bodyCompression.getCodec() != NoCompressionCodec.COMPRESSION_TYPE) { + RecordBatch.addCompression(builder, compressOffset); + } + return RecordBatch.endRecordBatch(builder); + } + + @Override + public void close() throws Exception { + delegated.close(); + } + } +} diff --git a/java/dataset/src/main/java/org/apache/arrow/memory/NativeUnderlyingMemory.java b/java/dataset/src/main/java/org/apache/arrow/memory/NativeUnderlyingMemory.java index 963fb617040..5f3095910cc 100644 --- a/java/dataset/src/main/java/org/apache/arrow/memory/NativeUnderlyingMemory.java +++ b/java/dataset/src/main/java/org/apache/arrow/memory/NativeUnderlyingMemory.java @@ -25,7 +25,7 @@ public class NativeUnderlyingMemory extends AllocationManager { private final int size; - private final long nativeInstanceId; + private final long nativeBufferId; private final long address; /** @@ -33,12 +33,12 @@ public class NativeUnderlyingMemory extends AllocationManager { * * @param accountingAllocator The accounting allocator instance * @param size Size of underlying memory (in bytes) - * @param nativeInstanceId ID of the native instance + * @param nativeBufferId ID of the native instance */ - NativeUnderlyingMemory(BufferAllocator accountingAllocator, int size, long nativeInstanceId, long address) { + NativeUnderlyingMemory(BufferAllocator accountingAllocator, int size, long nativeBufferId, long address) { super(accountingAllocator); this.size = size; - this.nativeInstanceId = nativeInstanceId; + this.nativeBufferId = nativeBufferId; this.address = address; // pre-allocate bytes on accounting allocator final AllocationListener listener = accountingAllocator.getListener(); @@ -55,9 +55,9 @@ public class NativeUnderlyingMemory extends AllocationManager { /** * Alias to constructor. */ - public static NativeUnderlyingMemory create(BufferAllocator bufferAllocator, int size, long nativeInstanceId, + public static NativeUnderlyingMemory create(BufferAllocator bufferAllocator, int size, long nativeBufferId, long address) { - return new NativeUnderlyingMemory(bufferAllocator, size, nativeInstanceId, address); + return new NativeUnderlyingMemory(bufferAllocator, size, nativeBufferId, address); } public BufferLedger associate(BufferAllocator allocator) { @@ -66,7 +66,7 @@ public BufferLedger associate(BufferAllocator allocator) { @Override protected void release0() { - JniWrapper.get().releaseBuffer(nativeInstanceId); + JniWrapper.get().releaseBuffer(nativeBufferId); } @Override diff --git a/java/dataset/src/test/java/org/apache/arrow/dataset/ParquetWriteSupport.java b/java/dataset/src/test/java/org/apache/arrow/dataset/ParquetWriteSupport.java index c6299d135a0..da48da6c0db 100644 --- a/java/dataset/src/test/java/org/apache/arrow/dataset/ParquetWriteSupport.java +++ b/java/dataset/src/test/java/org/apache/arrow/dataset/ParquetWriteSupport.java @@ -47,7 +47,7 @@ public class ParquetWriteSupport implements AutoCloseable { public ParquetWriteSupport(String schemaName, File outputFolder) throws Exception { avroSchema = readSchemaFromFile(schemaName); path = outputFolder.getPath() + File.separator + "generated.parquet"; - uri = "file://" + path; + uri = new File(path).toURI().toString(); writer = AvroParquetWriter.builder(new org.apache.hadoop.fs.Path(path)) .withSchema(avroSchema) .build(); diff --git a/java/dataset/src/test/java/org/apache/arrow/dataset/file/TestDatasetFileWriter.java b/java/dataset/src/test/java/org/apache/arrow/dataset/file/TestDatasetFileWriter.java new file mode 100644 index 00000000000..7b84218dfa0 --- /dev/null +++ b/java/dataset/src/test/java/org/apache/arrow/dataset/file/TestDatasetFileWriter.java @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.dataset.file; + +import java.io.ByteArrayOutputStream; +import java.io.File; +import java.io.IOException; +import java.nio.channels.Channels; +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; +import java.util.Objects; +import java.util.Set; +import java.util.stream.Collectors; + +import org.apache.arrow.dataset.ParquetWriteSupport; +import org.apache.arrow.dataset.TestDataset; +import org.apache.arrow.dataset.jni.NativeMemoryPool; +import org.apache.arrow.dataset.scanner.ScanOptions; +import org.apache.arrow.dataset.scanner.Scanner; +import org.apache.arrow.dataset.source.Dataset; +import org.apache.arrow.util.AutoCloseables; +import org.apache.arrow.vector.ipc.WriteChannel; +import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; +import org.apache.arrow.vector.ipc.message.MessageSerializer; +import org.apache.commons.io.FileUtils; +import org.junit.Assert; +import org.junit.ClassRule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +public class TestDatasetFileWriter extends TestDataset { + + @ClassRule + public static final TemporaryFolder TMP = new TemporaryFolder(); + + public static final String AVRO_SCHEMA_USER = "user.avsc"; + + @Test + public void testParquetWriteSimple() throws Exception { + ParquetWriteSupport writeSupport = ParquetWriteSupport.writeTempFile(AVRO_SCHEMA_USER, TMP.newFolder(), + 1, "a", 2, "b", 3, "c", 2, "d"); + String sampleParquet = writeSupport.getOutputURI(); + FileSystemDatasetFactory factory = new FileSystemDatasetFactory(rootAllocator(), NativeMemoryPool.getDefault(), + FileFormat.PARQUET, sampleParquet); + ScanOptions options = new ScanOptions(new String[0], 100); + final Dataset dataset = factory.finish(); + final Scanner scanner = dataset.newScan(options); + final File writtenFolder = TMP.newFolder(); + final String writtenParquet = writtenFolder.toURI().toString(); + try { + DatasetFileWriter.write(scanner, FileFormat.PARQUET, writtenParquet); + assertParquetFileEquals(sampleParquet, Objects.requireNonNull(writtenFolder.listFiles())[0].toURI().toString()); + } finally { + AutoCloseables.close(factory, scanner, dataset); + } + } + + @Test + public void testParquetWriteWithPartitions() throws Exception { + ParquetWriteSupport writeSupport = ParquetWriteSupport.writeTempFile(AVRO_SCHEMA_USER, TMP.newFolder(), + 1, "a", 2, "b", 3, "c", 2, "d"); + String sampleParquet = writeSupport.getOutputURI(); + FileSystemDatasetFactory factory = new FileSystemDatasetFactory(rootAllocator(), NativeMemoryPool.getDefault(), + FileFormat.PARQUET, sampleParquet); + ScanOptions options = new ScanOptions(new String[0], 100); + final Dataset dataset = factory.finish(); + final Scanner scanner = dataset.newScan(options); + final File writtenFolder = TMP.newFolder(); + final String writtenParquet = writtenFolder.toURI().toString(); + try { + DatasetFileWriter.write(scanner, FileFormat.PARQUET, writtenParquet, new String[]{"id", "name"}, 100, "dat_{i}"); + final Set expectedOutputFiles = new HashSet<>( + Arrays.asList("id=1/name=a/dat_0", "id=2/name=b/dat_1", "id=3/name=c/dat_2", "id=2/name=d/dat_3")); + final Set outputFiles = FileUtils.listFiles(writtenFolder, null, true) + .stream() + .map(file -> { + return writtenFolder.toURI().relativize(file.toURI()).toString(); + }) + .collect(Collectors.toSet()); + Assert.assertEquals(expectedOutputFiles, outputFiles); + } finally { + AutoCloseables.close(factory, scanner, dataset); + } + } + + private void assertParquetFileEquals(String expectedURI, String actualURI) throws Exception { + final FileSystemDatasetFactory expectedFactory = new FileSystemDatasetFactory( + rootAllocator(), NativeMemoryPool.getDefault(), FileFormat.PARQUET, expectedURI); + List expectedBatches = collectResultFromFactory(expectedFactory, + new ScanOptions(new String[0], 100)); + final FileSystemDatasetFactory actualFactory = new FileSystemDatasetFactory( + rootAllocator(), NativeMemoryPool.getDefault(), FileFormat.PARQUET, actualURI); + List actualBatches = collectResultFromFactory(actualFactory, + new ScanOptions(new String[0], 100)); + // fast-fail by comparing metadata + Assert.assertEquals(expectedBatches.toString(), actualBatches.toString()); + // compare buffers + Assert.assertEquals(serialize(expectedBatches), serialize(actualBatches)); + AutoCloseables.close(expectedBatches, actualBatches); + } + + private String serialize(List batches) throws IOException { + ByteArrayOutputStream out = new ByteArrayOutputStream(); + for (ArrowRecordBatch batch : batches) { + MessageSerializer.serialize(new WriteChannel(Channels.newChannel(out)), batch); + } + return Arrays.toString(out.toByteArray()); + } +} diff --git a/java/dataset/src/test/java/org/apache/arrow/dataset/file/TestFileSystemDataset.java b/java/dataset/src/test/java/org/apache/arrow/dataset/file/TestFileSystemDataset.java index 063f0955925..07d5a754f9b 100644 --- a/java/dataset/src/test/java/org/apache/arrow/dataset/file/TestFileSystemDataset.java +++ b/java/dataset/src/test/java/org/apache/arrow/dataset/file/TestFileSystemDataset.java @@ -81,6 +81,7 @@ public void testParquetRead() throws Exception { checkParquetReadResult(schema, writeSupport.getWrittenRecords(), datum); AutoCloseables.close(datum); + AutoCloseables.close(factory); } @Test @@ -107,6 +108,7 @@ public void testParquetProjector() throws Exception { .build()), datum); AutoCloseables.close(datum); + AutoCloseables.close(factory); } @Test @@ -126,6 +128,7 @@ public void testParquetBatchSize() throws Exception { checkParquetReadResult(schema, writeSupport.getWrittenRecords(), datum); AutoCloseables.close(datum); + AutoCloseables.close(factory); } @Test @@ -140,6 +143,8 @@ public void testCloseAgain() throws Exception { dataset.close(); dataset.close(); }); + + AutoCloseables.close(factory); } @Test @@ -197,6 +202,7 @@ public void testScanAfterClose1() throws Exception { NativeScanner scanner = dataset.newScan(options); scanner.close(); assertThrows(NativeInstanceReleasedException.class, scanner::scan); + AutoCloseables.close(factory); } @Test @@ -212,6 +218,7 @@ public void testScanAfterClose2() throws Exception { NativeScanTask task = tasks.get(0); task.close(); assertThrows(NativeInstanceReleasedException.class, task::execute); + AutoCloseables.close(factory); } @Test @@ -228,6 +235,7 @@ public void testScanAfterClose3() throws Exception { ScanTask.BatchIterator iterator = task.execute(); task.close(); assertThrows(NativeInstanceReleasedException.class, iterator::hasNext); + AutoCloseables.close(factory); } @Test @@ -247,6 +255,7 @@ public void testMemoryAllocation() throws Exception { long finalReservation = rootAllocator().getAllocatedMemory(); Assert.assertEquals(expected_diff, reservation - initReservation); Assert.assertEquals(-expected_diff, finalReservation - reservation); + AutoCloseables.close(factory); } private void checkParquetReadResult(Schema schema, List expected, List actual) { diff --git a/java/pom.xml b/java/pom.xml index 89be67f775f..8bd02c2acf4 100644 --- a/java/pom.xml +++ b/java/pom.xml @@ -716,19 +716,19 @@ org.apache.maven.plugins maven-compiler-plugin - - - -XDcompilePolicy=simple - -Xplugin:ErrorProne - - - - com.google.errorprone - error_prone_core - 2.4.0 - - - + + + + + + + + + + + + +