From 7729a0a2aa621527c37c05f84c1804eafbccf402 Mon Sep 17 00:00:00 2001 From: Hongze Zhang Date: Sat, 22 May 2021 13:14:18 +0800 Subject: [PATCH 1/8] ARROW-7272: [C++][Java] JNI bridge between RecordBatch and VectorSchemaRoot --- cpp/CMakeLists.txt | 4 + cpp/src/arrow/CMakeLists.txt | 2 + cpp/src/arrow/json/array_parser.cc | 72 +++++ cpp/src/arrow/json/array_parser.h | 50 +++ cpp/src/arrow/json/array_writer.cc | 64 ++++ cpp/src/arrow/json/array_writer.h | 46 +++ cpp/src/jni/dataset/jni_util.cc | 283 ++++++++++++++--- cpp/src/jni/dataset/jni_util.h | 36 ++- cpp/src/jni/dataset/jni_wrapper.cc | 137 ++++----- java/dataset/CMakeLists.txt | 22 +- java/dataset/pom.xml | 63 ++-- .../apache/arrow/dataset/file/FileFormat.java | 2 +- .../apache/arrow/dataset/file/JniWrapper.java | 18 ++ .../apache/arrow/dataset/jni/JniWrapper.java | 34 +- .../dataset/jni/NativeRecordBatchHandle.java | 106 ------- .../arrow/dataset/jni/NativeScanner.java | 36 +-- .../jni/UnsafeRecordBatchSerializer.java | 291 ++++++++++++++++++ .../arrow/memory/NativeUnderlyingMemory.java | 14 +- .../arrow/dataset/ParquetWriteSupport.java | 2 +- .../dataset/file/TestFileSystemDataset.java | 9 + .../jni/TestUnsafeRecordBatchSerializer.java | 130 ++++++++ 21 files changed, 1089 insertions(+), 332 deletions(-) create mode 100644 cpp/src/arrow/json/array_parser.cc create mode 100644 cpp/src/arrow/json/array_parser.h create mode 100644 cpp/src/arrow/json/array_writer.cc create mode 100644 cpp/src/arrow/json/array_writer.h delete mode 100644 java/dataset/src/main/java/org/apache/arrow/dataset/jni/NativeRecordBatchHandle.java create mode 100644 java/dataset/src/main/java/org/apache/arrow/dataset/jni/UnsafeRecordBatchSerializer.java create mode 100644 java/dataset/src/test/java/org/apache/arrow/dataset/jni/TestUnsafeRecordBatchSerializer.java diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 8a358db8b95..c50a98ba2d2 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -334,6 +334,10 @@ if(ARROW_BUILD_BENCHMARKS set(ARROW_TESTING ON) endif() +if(ARROW_DATASET_JNI) + set(ARROW_JSON ON) +endif() + if(ARROW_GANDIVA) set(ARROW_WITH_RE2 ON) endif() diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index d2f80ce7213..e5c7e6bc2ce 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -472,6 +472,8 @@ if(ARROW_JSON) json/chunked_builder.cc json/chunker.cc json/converter.cc + json/array_parser.cc + json/array_writer.cc json/object_parser.cc json/object_writer.cc json/parser.cc diff --git a/cpp/src/arrow/json/array_parser.cc b/cpp/src/arrow/json/array_parser.cc new file mode 100644 index 00000000000..f9b627162be --- /dev/null +++ b/cpp/src/arrow/json/array_parser.cc @@ -0,0 +1,72 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/json/array_parser.h" +#include "arrow/json/rapidjson_defs.h" // IWYU pragma: keep + +#include + +namespace arrow { +namespace json { +namespace internal { + +namespace rj = arrow::rapidjson; + +class ArrayParser::Impl { + public: + Status Parse(arrow::util::string_view json) { + document_.Parse(reinterpret_cast(json.data()), + static_cast(json.size())); + + if (document_.HasParseError()) { + return Status::Invalid("Json parse error (offset ", document_.GetErrorOffset(), + "): ", document_.GetParseError()); + } + if (!document_.IsArray()) { + return Status::TypeError("Not a json array"); + } + return Status::OK(); + } + + Result GetInt64(int32_t ordinal) const { + if (!document_[ordinal].IsInt64()) { + return Status::TypeError("Value at ordinal '", ordinal, "' is not a int64"); + } + return document_[ordinal].GetInt64(); + } + + Result Length() const { return document_.GetArray().Size(); } + + private: + rj::Document document_; +}; + +ArrayParser::ArrayParser() : impl_(new ArrayParser::Impl()) {} + +ArrayParser::~ArrayParser() = default; + +Status ArrayParser::Parse(arrow::util::string_view json) { return impl_->Parse(json); } + +Result ArrayParser::Length() const { return impl_->Length(); } + +Result ArrayParser::GetInt64(int32_t ordinal) const { + return impl_->GetInt64(ordinal); +} + +} // namespace internal +} // namespace json +} // namespace arrow diff --git a/cpp/src/arrow/json/array_parser.h b/cpp/src/arrow/json/array_parser.h new file mode 100644 index 00000000000..d718ef84df5 --- /dev/null +++ b/cpp/src/arrow/json/array_parser.h @@ -0,0 +1,50 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include + +#include "arrow/result.h" +#include "arrow/util/string_view.h" +#include "arrow/util/visibility.h" + +namespace arrow { +namespace json { +namespace internal { + +/// This class is a helper to parse a json array from a string. +/// It uses rapidjson::Document in implementation. +class ARROW_EXPORT ArrayParser { + public: + ArrayParser(); + ~ArrayParser(); + + Status Parse(arrow::util::string_view json); + + Result Length() const; + + Result GetInt64(int32_t ordinal) const; + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace internal +} // namespace json +} // namespace arrow diff --git a/cpp/src/arrow/json/array_writer.cc b/cpp/src/arrow/json/array_writer.cc new file mode 100644 index 00000000000..480a3ab8809 --- /dev/null +++ b/cpp/src/arrow/json/array_writer.cc @@ -0,0 +1,64 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/json/array_writer.h" +#include "arrow/json/rapidjson_defs.h" // IWYU pragma: keep + +#include +#include +#include + +namespace rj = arrow::rapidjson; + +namespace arrow { +namespace json { +namespace internal { + +class ArrayWriter::Impl { + public: + Impl() : root_(rj::kArrayType) {} + + void AppendInt64(int64_t value) { + rj::Document::AllocatorType& allocator = document_.GetAllocator(); + + root_.PushBack(value, allocator); + } + + std::string Serialize() { + rj::StringBuffer buffer; + rj::Writer writer(buffer); + root_.Accept(writer); + + return buffer.GetString(); + } + + private: + rj::Document document_; + rj::Value root_; +}; + +ArrayWriter::ArrayWriter() : impl_(new ArrayWriter::Impl()) {} + +ArrayWriter::~ArrayWriter() = default; + +void ArrayWriter::AppendInt64(int64_t value) { impl_->AppendInt64(value); } + +std::string ArrayWriter::Serialize() { return impl_->Serialize(); } + +} // namespace internal +} // namespace json +} // namespace arrow diff --git a/cpp/src/arrow/json/array_writer.h b/cpp/src/arrow/json/array_writer.h new file mode 100644 index 00000000000..b8d190e4d59 --- /dev/null +++ b/cpp/src/arrow/json/array_writer.h @@ -0,0 +1,46 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include + +#include "arrow/util/visibility.h" + +namespace arrow { +namespace json { +namespace internal { + +/// This class is a helper to serialize a json array to a string. +/// It uses rapidjson in implementation. +class ARROW_EXPORT ArrayWriter { + public: + ArrayWriter(); + ~ArrayWriter(); + + void AppendInt64(int64_t value); + + std::string Serialize(); + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace internal +} // namespace json +} // namespace arrow diff --git a/cpp/src/jni/dataset/jni_util.cc b/cpp/src/jni/dataset/jni_util.cc index 113669a4cf6..2bd533810d1 100644 --- a/cpp/src/jni/dataset/jni_util.cc +++ b/cpp/src/jni/dataset/jni_util.cc @@ -17,17 +17,27 @@ #include "jni/dataset/jni_util.h" +#include "arrow/ipc/metadata_internal.h" +#include "arrow/json/array_parser.h" +#include "arrow/json/array_writer.h" +#include "arrow/util/key_value_metadata.h" #include "arrow/util/logging.h" +#include #include +#include + namespace arrow { + +namespace flatbuf = org::apache::arrow::flatbuf; + namespace dataset { namespace jni { class ReservationListenableMemoryPool::Impl { public: - explicit Impl(arrow::MemoryPool* pool, std::shared_ptr listener, + explicit Impl(MemoryPool* pool, std::shared_ptr listener, int64_t block_size) : pool_(pool), listener_(listener), @@ -35,17 +45,17 @@ class ReservationListenableMemoryPool::Impl { blocks_reserved_(0), bytes_reserved_(0) {} - arrow::Status Allocate(int64_t size, uint8_t** out) { + Status Allocate(int64_t size, uint8_t** out) { RETURN_NOT_OK(UpdateReservation(size)); - arrow::Status error = pool_->Allocate(size, out); + Status error = pool_->Allocate(size, out); if (!error.ok()) { RETURN_NOT_OK(UpdateReservation(-size)); return error; } - return arrow::Status::OK(); + return Status::OK(); } - arrow::Status Reallocate(int64_t old_size, int64_t new_size, uint8_t** ptr) { + Status Reallocate(int64_t old_size, int64_t new_size, uint8_t** ptr) { bool reserved = false; int64_t diff = new_size - old_size; if (new_size >= old_size) { @@ -54,7 +64,7 @@ class ReservationListenableMemoryPool::Impl { RETURN_NOT_OK(UpdateReservation(diff)); reserved = true; } - arrow::Status error = pool_->Reallocate(old_size, new_size, ptr); + Status error = pool_->Reallocate(old_size, new_size, ptr); if (!error.ok()) { if (reserved) { // roll back reservations on error @@ -66,13 +76,13 @@ class ReservationListenableMemoryPool::Impl { // otherwise (e.g. new_size < old_size), make updates after calling underlying pool RETURN_NOT_OK(UpdateReservation(diff)); } - return arrow::Status::OK(); + return Status::OK(); } void Free(uint8_t* buffer, int64_t size) { pool_->Free(buffer, size); // FIXME: See ARROW-11143, currently method ::Free doesn't allow Status return - arrow::Status s = UpdateReservation(-size); + Status s = UpdateReservation(-size); if (!s.ok()) { ARROW_LOG(FATAL) << "Failed to update reservation while freeing bytes: " << s.message(); @@ -80,17 +90,17 @@ class ReservationListenableMemoryPool::Impl { } } - arrow::Status UpdateReservation(int64_t diff) { + Status UpdateReservation(int64_t diff) { int64_t granted = Reserve(diff); if (granted == 0) { - return arrow::Status::OK(); + return Status::OK(); } if (granted < 0) { RETURN_NOT_OK(listener_->OnRelease(-granted)); - return arrow::Status::OK(); + return Status::OK(); } RETURN_NOT_OK(listener_->OnReservation(granted)); - return arrow::Status::OK(); + return Status::OK(); } int64_t Reserve(int64_t diff) { @@ -117,7 +127,7 @@ class ReservationListenableMemoryPool::Impl { std::shared_ptr get_listener() { return listener_; } private: - arrow::MemoryPool* pool_; + MemoryPool* pool_; std::shared_ptr listener_; int64_t block_size_; int64_t blocks_reserved_; @@ -125,18 +135,40 @@ class ReservationListenableMemoryPool::Impl { std::mutex mutex_; }; +/// \brief Buffer implementation that binds to a +/// Java buffer reference. Java buffer's release +/// method will be called once when being destructed. +class JavaAllocatedBuffer : public Buffer { + public: + JavaAllocatedBuffer(JNIEnv* env, jobject cleaner_ref, jmethodID cleaner_method_ref, + uint8_t* buffer, int32_t len) + : Buffer(buffer, len), + env_(env), + cleaner_ref_(cleaner_ref), + cleaner_method_ref_(cleaner_method_ref) {} + + ~JavaAllocatedBuffer() override { + env_->CallVoidMethod(cleaner_ref_, cleaner_method_ref_); + env_->DeleteGlobalRef(cleaner_ref_); + } + + private: + JNIEnv* env_; + jobject cleaner_ref_; + jmethodID cleaner_method_ref_; +}; + ReservationListenableMemoryPool::ReservationListenableMemoryPool( MemoryPool* pool, std::shared_ptr listener, int64_t block_size) { impl_.reset(new Impl(pool, listener, block_size)); } -arrow::Status ReservationListenableMemoryPool::Allocate(int64_t size, uint8_t** out) { +Status ReservationListenableMemoryPool::Allocate(int64_t size, uint8_t** out) { return impl_->Allocate(size, out); } -arrow::Status ReservationListenableMemoryPool::Reallocate(int64_t old_size, - int64_t new_size, - uint8_t** ptr) { +Status ReservationListenableMemoryPool::Reallocate(int64_t old_size, int64_t new_size, + uint8_t** ptr) { return impl_->Reallocate(old_size, new_size, ptr); } @@ -162,6 +194,15 @@ std::shared_ptr ReservationListenableMemoryPool::get_listen ReservationListenableMemoryPool::~ReservationListenableMemoryPool() {} +Status CheckException(JNIEnv* env) { + if (env->ExceptionCheck()) { + env->ExceptionDescribe(); + env->ExceptionClear(); + return Status::Invalid("Error during calling Java code from native code"); + } + return Status::OK(); +} + jclass CreateGlobalClassReference(JNIEnv* env, const char* class_name) { jclass local_class = env->FindClass(class_name); jclass global_class = (jclass)env->NewGlobalRef(local_class); @@ -169,24 +210,24 @@ jclass CreateGlobalClassReference(JNIEnv* env, const char* class_name) { return global_class; } -arrow::Result GetMethodID(JNIEnv* env, jclass this_class, const char* name, - const char* sig) { +Result GetMethodID(JNIEnv* env, jclass this_class, const char* name, + const char* sig) { jmethodID ret = env->GetMethodID(this_class, name, sig); if (ret == nullptr) { std::string error_message = "Unable to find method " + std::string(name) + " within signature" + std::string(sig); - return arrow::Status::Invalid(error_message); + return Status::Invalid(error_message); } return ret; } -arrow::Result GetStaticMethodID(JNIEnv* env, jclass this_class, - const char* name, const char* sig) { +Result GetStaticMethodID(JNIEnv* env, jclass this_class, const char* name, + const char* sig) { jmethodID ret = env->GetStaticMethodID(this_class, name, sig); if (ret == nullptr) { std::string error_message = "Unable to find static method " + std::string(name) + " within signature" + std::string(sig); - return arrow::Status::Invalid(error_message); + return Status::Invalid(error_message); } return ret; } @@ -211,11 +252,9 @@ std::vector ToStringVector(JNIEnv* env, jobjectArray& str_array) { return vector; } -arrow::Result ToSchemaByteArray(JNIEnv* env, - std::shared_ptr schema) { - ARROW_ASSIGN_OR_RAISE( - std::shared_ptr buffer, - arrow::ipc::SerializeSchema(*schema, arrow::default_memory_pool())) +Result ToSchemaByteArray(JNIEnv* env, std::shared_ptr schema) { + ARROW_ASSIGN_OR_RAISE(std::shared_ptr buffer, + ipc::SerializeSchema(*schema, default_memory_pool())) jbyteArray out = env->NewByteArray(buffer->size()); auto src = reinterpret_cast(buffer->data()); @@ -223,20 +262,188 @@ arrow::Result ToSchemaByteArray(JNIEnv* env, return out; } -arrow::Result> FromSchemaByteArray( - JNIEnv* env, jbyteArray schemaBytes) { - arrow::ipc::DictionaryMemo in_memo; - int schemaBytes_len = env->GetArrayLength(schemaBytes); - jbyte* schemaBytes_data = env->GetByteArrayElements(schemaBytes, nullptr); - auto serialized_schema = std::make_shared( +Result> FromSchemaByteArray(JNIEnv* env, jbyteArray schema_bytes) { + ipc::DictionaryMemo in_memo; + int schemaBytes_len = env->GetArrayLength(schema_bytes); + jbyte* schemaBytes_data = env->GetByteArrayElements(schema_bytes, nullptr); + auto serialized_schema = std::make_shared( reinterpret_cast(schemaBytes_data), schemaBytes_len); - arrow::io::BufferReader buf_reader(serialized_schema); - ARROW_ASSIGN_OR_RAISE(std::shared_ptr schema, - arrow::ipc::ReadSchema(&buf_reader, &in_memo)) - env->ReleaseByteArrayElements(schemaBytes, schemaBytes_data, JNI_ABORT); + io::BufferReader buf_reader(serialized_schema); + ARROW_ASSIGN_OR_RAISE(std::shared_ptr schema, + ipc::ReadSchema(&buf_reader, &in_memo)) + env->ReleaseByteArrayElements(schema_bytes, schemaBytes_data, JNI_ABORT); return schema; } +Status SetMetadataForSingleField(std::shared_ptr array_data, + std::vector& node_metas, + std::vector& buffer_metas, + arrow::json::internal::ArrayWriter& buffer_refs) { + node_metas.push_back({array_data->length, array_data->null_count, 0L}); + + for (size_t i = 0; i < array_data->buffers.size(); i++) { + auto buffer = array_data->buffers.at(i); + uint8_t* data = nullptr; + int64_t size = 0; + if (buffer != nullptr) { + data = (uint8_t*)buffer->data(); + size = buffer->size(); + } + ipc::internal::BufferMetadata buffer_metadata{}; + buffer_metadata.offset = reinterpret_cast(data); + buffer_metadata.length = size; + buffer_metas.push_back(buffer_metadata); + + // store buffer refs into custom metadata + jlong ref = CreateNativeRef(buffer); + buffer_refs.AppendInt64(ref); + } + + auto children_data = array_data->child_data; + for (const auto& child_data : children_data) { + RETURN_NOT_OK( + SetMetadataForSingleField(child_data, node_metas, buffer_metas, buffer_refs)); + } + return Status::OK(); +} + +Result> SerializeMetadata(const RecordBatch& batch, + const ipc::IpcWriteOptions& options) { + std::vector nodes; + std::vector buffers; + std::shared_ptr custom_metadata = + std::make_shared(); + arrow::json::internal::ArrayWriter buffer_refs; + for (const auto& column : batch.columns()) { + auto array_data = column->data(); + RETURN_NOT_OK(SetMetadataForSingleField(array_data, nodes, buffers, buffer_refs)); + } + custom_metadata->Append("NATIVE_BUFFER_REFS", buffer_refs.Serialize()); + std::shared_ptr meta_buffer; + RETURN_NOT_OK(ipc::internal::WriteRecordBatchMessage( + batch.num_rows(), 0L, custom_metadata, nodes, buffers, options, &meta_buffer)); + // no message body is needed for JNI serialization/deserialization + int32_t message_length = -1; + ARROW_ASSIGN_OR_RAISE(auto stream, io::BufferOutputStream::Create(1024L)); + RETURN_NOT_OK(ipc::WriteMessage(*meta_buffer, options, stream.get(), &message_length)); + return stream->Finish(); +} + +Result SerializeUnsafeFromNative(JNIEnv* env, + const std::shared_ptr& batch) { + ARROW_ASSIGN_OR_RAISE(auto meta_buffer, + SerializeMetadata(*batch, ipc::IpcWriteOptions::Defaults())); + + jbyteArray ret = env->NewByteArray(meta_buffer->size()); + auto src = reinterpret_cast(meta_buffer->data()); + env->SetByteArrayRegion(ret, 0, meta_buffer->size(), src); + return ret; +} + +Result> MakeArrayData( + JNIEnv* env, const flatbuf::RecordBatch& batch_meta, + const arrow::json::internal::ArrayParser& cleaner_object_refs, + const arrow::json::internal::ArrayParser& cleaner_method_refs, + const std::shared_ptr& type, int32_t* field_offset, + int32_t* buffer_offset) { + const org::apache::arrow::flatbuf::FieldNode* field = + batch_meta.nodes()->Get((*field_offset)++); + int32_t own_buffer_size = static_cast(type->layout().buffers.size()); + std::vector> buffers; + for (int32_t i = *buffer_offset; i < *buffer_offset + own_buffer_size; i++) { + const org::apache::arrow::flatbuf::Buffer* java_managed_buffer = + batch_meta.buffers()->Get(i); + ARROW_ASSIGN_OR_RAISE(const int64_t& cleaner_object_ref_int64, + cleaner_object_refs.GetInt64(i)) + ARROW_ASSIGN_OR_RAISE(const int64_t& cleaner_method_ref_int64, + cleaner_method_refs.GetInt64(i)) + auto buffer = std::make_shared( + env, reinterpret_cast(cleaner_object_ref_int64), + reinterpret_cast(cleaner_method_ref_int64), + reinterpret_cast(java_managed_buffer->offset()), + java_managed_buffer->length()); + buffers.push_back(buffer); + } + (*buffer_offset) += own_buffer_size; + if (type->num_fields() == 0) { + return ArrayData::Make(type, field->length(), buffers, field->null_count()); + } + std::vector> children_array_data; + for (const auto& child_field : type->fields()) { + ARROW_ASSIGN_OR_RAISE( + auto child_array_data, + MakeArrayData(env, batch_meta, cleaner_object_refs, cleaner_method_refs, + child_field->type(), field_offset, buffer_offset)) + children_array_data.push_back(child_array_data); + } + return ArrayData::Make(type, field->length(), buffers, children_array_data, + field->null_count()); +} + +Result> DeserializeUnsafeFromJava( + JNIEnv* env, std::shared_ptr schema, jbyteArray byte_array) { + int bytes_len = env->GetArrayLength(byte_array); + jbyte* byte_data = env->GetByteArrayElements(byte_array, nullptr); + io::BufferReader meta_reader(reinterpret_cast(byte_data), + static_cast(bytes_len)); + ARROW_ASSIGN_OR_RAISE(auto meta_message, ipc::ReadMessage(&meta_reader)) + auto meta_buffer = meta_message->metadata(); + auto custom_metadata = meta_message->custom_metadata(); + const flatbuf::Message* flat_meta = nullptr; + RETURN_NOT_OK( + ipc::internal::VerifyMessage(meta_buffer->data(), meta_buffer->size(), &flat_meta)); + auto batch_meta = flat_meta->header_as_RecordBatch(); + + // Record batch serialized from java should have two ref IDs per buffer: cleaner object + // ref and cleaner method ref. The refs are originally of 64bit integer type and stored + // in json arrays + if (custom_metadata->size() != 2) { + return Status::SerializationError("RecordBatch metadata not found"); + } + + ARROW_ASSIGN_OR_RAISE(std::string cleaner_object_refs_string, + custom_metadata->Get("JAVA_BUFFER_CO_REFS")) + ARROW_ASSIGN_OR_RAISE(std::string cleaner_method_refs_string, + custom_metadata->Get("JAVA_BUFFER_CM_REFS")) + arrow::json::internal::ArrayParser cleaner_object_refs; + RETURN_NOT_OK(cleaner_object_refs.Parse(cleaner_object_refs_string)); + arrow::json::internal::ArrayParser cleaner_method_refs; + RETURN_NOT_OK(cleaner_method_refs.Parse(cleaner_method_refs_string)); + + if (cleaner_object_refs.Length() != + static_cast(batch_meta->buffers()->size()) || + cleaner_method_refs.Length() != + static_cast(batch_meta->buffers()->size())) { + return Status::SerializationError( + "Buffer count mismatch between metadata and Java managed refs"); + } + + std::vector> columns_array_data; + int32_t field_offset = 0; + int32_t buffer_offset = 0; + for (int32_t i = 0; i < schema->num_fields(); i++) { + auto field = schema->field(i); + ARROW_ASSIGN_OR_RAISE( + auto column_array_data, + MakeArrayData(env, *batch_meta, cleaner_object_refs, cleaner_method_refs, + field->type(), &field_offset, &buffer_offset)) + columns_array_data.push_back(column_array_data); + } + if (field_offset != static_cast(batch_meta->nodes()->size())) { + return Status::SerializationError( + "Deserialization failed: Field count is not " + "as expected based on type layout"); + } + if (buffer_offset != static_cast(batch_meta->buffers()->size())) { + return Status::SerializationError( + "Deserialization failed: Buffer count is not " + "as expected based on type layout"); + } + int64_t length = batch_meta->length(); + env->ReleaseByteArrayElements(byte_array, byte_data, JNI_ABORT); + return RecordBatch::Make(schema, length, columns_array_data); +} + } // namespace jni } // namespace dataset } // namespace arrow diff --git a/cpp/src/jni/dataset/jni_util.h b/cpp/src/jni/dataset/jni_util.h index c76033ae633..a56d4bc152b 100644 --- a/cpp/src/jni/dataset/jni_util.h +++ b/cpp/src/jni/dataset/jni_util.h @@ -30,23 +30,33 @@ namespace arrow { namespace dataset { namespace jni { +Status CheckException(JNIEnv* env); + jclass CreateGlobalClassReference(JNIEnv* env, const char* class_name); -arrow::Result GetMethodID(JNIEnv* env, jclass this_class, const char* name, - const char* sig); +Result GetMethodID(JNIEnv* env, jclass this_class, const char* name, + const char* sig); -arrow::Result GetStaticMethodID(JNIEnv* env, jclass this_class, - const char* name, const char* sig); +Result GetStaticMethodID(JNIEnv* env, jclass this_class, const char* name, + const char* sig); std::string JStringToCString(JNIEnv* env, jstring string); std::vector ToStringVector(JNIEnv* env, jobjectArray& str_array); -arrow::Result ToSchemaByteArray(JNIEnv* env, - std::shared_ptr schema); +Result ToSchemaByteArray(JNIEnv* env, std::shared_ptr schema); + +Result> FromSchemaByteArray(JNIEnv* env, jbyteArray schema_bytes); + +/// \brief Serialize arrow::RecordBatch to jbyteArray (Java byte array byte[]). For +/// letting Java code manage lifecycles of buffers in the input batch, shared pointer IDs +/// pointing to the buffers are serialized into buffer metadata. +Result SerializeUnsafeFromNative(JNIEnv* env, + const std::shared_ptr& batch); -arrow::Result> FromSchemaByteArray(JNIEnv* env, - jbyteArray schemaBytes); +/// \brief Deserialize jbyteArray (Java byte array byte[]) to arrow::RecordBatch. +Result> DeserializeUnsafeFromJava( + JNIEnv* env, std::shared_ptr schema, jbyteArray byte_array); /// \brief Create a new shared_ptr on heap from shared_ptr t to prevent /// the managed object from being garbage-collected. @@ -86,8 +96,8 @@ class ReservationListener { public: virtual ~ReservationListener() = default; - virtual arrow::Status OnReservation(int64_t size) = 0; - virtual arrow::Status OnRelease(int64_t size) = 0; + virtual Status OnReservation(int64_t size) = 0; + virtual Status OnRelease(int64_t size) = 0; protected: ReservationListener() = default; @@ -98,7 +108,7 @@ class ReservationListener { /// have to be subject to another "virtual" resource manager, which just tracks or /// limits number of bytes of application's overall memory usage. The underlying /// memory pool will still be responsible for actual malloc/free operations. -class ReservationListenableMemoryPool : public arrow::MemoryPool { +class ReservationListenableMemoryPool : public MemoryPool { public: /// \brief Constructor. /// @@ -111,9 +121,9 @@ class ReservationListenableMemoryPool : public arrow::MemoryPool { ~ReservationListenableMemoryPool(); - arrow::Status Allocate(int64_t size, uint8_t** out) override; + Status Allocate(int64_t size, uint8_t** out) override; - arrow::Status Reallocate(int64_t old_size, int64_t new_size, uint8_t** ptr) override; + Status Reallocate(int64_t old_size, int64_t new_size, uint8_t** ptr) override; void Free(uint8_t* buffer, int64_t size) override; diff --git a/cpp/src/jni/dataset/jni_wrapper.cc b/cpp/src/jni/dataset/jni_wrapper.cc index d61fb3f964e..20af823b8e2 100644 --- a/cpp/src/jni/dataset/jni_wrapper.cc +++ b/cpp/src/jni/dataset/jni_wrapper.cc @@ -36,14 +36,8 @@ jclass illegal_access_exception_class; jclass illegal_argument_exception_class; jclass runtime_exception_class; -jclass record_batch_handle_class; -jclass record_batch_handle_field_class; -jclass record_batch_handle_buffer_class; jclass java_reservation_listener_class; -jmethodID record_batch_handle_constructor; -jmethodID record_batch_handle_field_constructor; -jmethodID record_batch_handle_buffer_constructor; jmethodID reserve_memory_method; jmethodID unreserve_memory_method; @@ -99,11 +93,7 @@ class ReserveFromJava : public arrow::dataset::jni::ReservationListener { return arrow::Status::Invalid("JNIEnv was not attached to current thread"); } env->CallObjectMethod(java_reservation_listener_, reserve_memory_method, size); - if (env->ExceptionCheck()) { - env->ExceptionDescribe(); - env->ExceptionClear(); - return arrow::Status::Invalid("Error calling Java side reservation listener"); - } + RETURN_NOT_OK(arrow::dataset::jni::CheckException(env)); return arrow::Status::OK(); } @@ -113,11 +103,7 @@ class ReserveFromJava : public arrow::dataset::jni::ReservationListener { return arrow::Status::Invalid("JNIEnv was not attached to current thread"); } env->CallObjectMethod(java_reservation_listener_, unreserve_memory_method, size); - if (env->ExceptionCheck()) { - env->ExceptionDescribe(); - env->ExceptionClear(); - return arrow::Status::Invalid("Error calling Java side reservation listener"); - } + RETURN_NOT_OK(arrow::dataset::jni::CheckException(env)); return arrow::Status::OK(); } @@ -205,33 +191,10 @@ jint JNI_OnLoad(JavaVM* vm, void* reserved) { runtime_exception_class = CreateGlobalClassReference(env, "Ljava/lang/RuntimeException;"); - record_batch_handle_class = - CreateGlobalClassReference(env, - "Lorg/apache/arrow/" - "dataset/jni/NativeRecordBatchHandle;"); - record_batch_handle_field_class = - CreateGlobalClassReference(env, - "Lorg/apache/arrow/" - "dataset/jni/NativeRecordBatchHandle$Field;"); - record_batch_handle_buffer_class = - CreateGlobalClassReference(env, - "Lorg/apache/arrow/" - "dataset/jni/NativeRecordBatchHandle$Buffer;"); java_reservation_listener_class = CreateGlobalClassReference(env, "Lorg/apache/arrow/" "dataset/jni/ReservationListener;"); - - record_batch_handle_constructor = - JniGetOrThrow(GetMethodID(env, record_batch_handle_class, "", - "(J[Lorg/apache/arrow/dataset/" - "jni/NativeRecordBatchHandle$Field;" - "[Lorg/apache/arrow/dataset/" - "jni/NativeRecordBatchHandle$Buffer;)V")); - record_batch_handle_field_constructor = - JniGetOrThrow(GetMethodID(env, record_batch_handle_field_class, "", "(JJ)V")); - record_batch_handle_buffer_constructor = JniGetOrThrow( - GetMethodID(env, record_batch_handle_buffer_class, "", "(JJJJ)V")); reserve_memory_method = JniGetOrThrow(GetMethodID(env, java_reservation_listener_class, "reserve", "(J)V")); unreserve_memory_method = JniGetOrThrow( @@ -249,9 +212,6 @@ void JNI_OnUnload(JavaVM* vm, void* reserved) { env->DeleteGlobalRef(illegal_access_exception_class); env->DeleteGlobalRef(illegal_argument_exception_class); env->DeleteGlobalRef(runtime_exception_class); - env->DeleteGlobalRef(record_batch_handle_class); - env->DeleteGlobalRef(record_batch_handle_field_class); - env->DeleteGlobalRef(record_batch_handle_buffer_class); env->DeleteGlobalRef(java_reservation_listener_class); default_memory_pool_id = -1L; @@ -458,9 +418,9 @@ Java_org_apache_arrow_dataset_jni_JniWrapper_getSchemaFromScanner(JNIEnv* env, j /* * Class: org_apache_arrow_dataset_jni_JniWrapper * Method: nextRecordBatch - * Signature: (J)Lorg/apache/arrow/dataset/jni/NativeRecordBatchHandle; + * Signature: (J)[B */ -JNIEXPORT jobject JNICALL Java_org_apache_arrow_dataset_jni_JniWrapper_nextRecordBatch( +JNIEXPORT jbyteArray JNICALL Java_org_apache_arrow_dataset_jni_JniWrapper_nextRecordBatch( JNIEnv* env, jobject, jlong scanner_id) { JNI_METHOD_START std::shared_ptr scanner_adaptor = @@ -471,46 +431,7 @@ JNIEXPORT jobject JNICALL Java_org_apache_arrow_dataset_jni_JniWrapper_nextRecor if (record_batch == nullptr) { return nullptr; // stream ended } - std::shared_ptr schema = record_batch->schema(); - jobjectArray field_array = - env->NewObjectArray(schema->num_fields(), record_batch_handle_field_class, nullptr); - - std::vector> buffers; - for (int i = 0; i < schema->num_fields(); ++i) { - auto column = record_batch->column(i); - auto dataArray = column->data(); - jobject field = env->NewObject(record_batch_handle_field_class, - record_batch_handle_field_constructor, - column->length(), column->null_count()); - env->SetObjectArrayElement(field_array, i, field); - - for (auto& buffer : dataArray->buffers) { - buffers.push_back(buffer); - } - } - - jobjectArray buffer_array = - env->NewObjectArray(buffers.size(), record_batch_handle_buffer_class, nullptr); - - for (size_t j = 0; j < buffers.size(); ++j) { - auto buffer = buffers[j]; - uint8_t* data = nullptr; - int64_t size = 0; - int64_t capacity = 0; - if (buffer != nullptr) { - data = (uint8_t*)buffer->data(); - size = buffer->size(); - capacity = buffer->capacity(); - } - jobject buffer_handle = env->NewObject(record_batch_handle_buffer_class, - record_batch_handle_buffer_constructor, - CreateNativeRef(buffer), data, size, capacity); - env->SetObjectArrayElement(buffer_array, j, buffer_handle); - } - - jobject ret = env->NewObject(record_batch_handle_class, record_batch_handle_constructor, - record_batch->num_rows(), field_array, buffer_array); - return ret; + return JniGetOrThrow(arrow::dataset::jni::SerializeUnsafeFromNative(env, record_batch)); JNI_METHOD_END(nullptr) } @@ -526,6 +447,38 @@ JNIEXPORT void JNICALL Java_org_apache_arrow_dataset_jni_JniWrapper_releaseBuffe JNI_METHOD_END() } +/* + * Class: org_apache_arrow_dataset_file_JniWrapper + * Method: newJniGlobalReference + * Signature: (Ljava/lang/Object;)J + */ +JNIEXPORT jlong JNICALL +Java_org_apache_arrow_dataset_file_JniWrapper_newJniGlobalReference(JNIEnv* env, jobject, + jobject referent) { + JNI_METHOD_START + return reinterpret_cast(env->NewGlobalRef(referent)); + JNI_METHOD_END(-1L) +} + +/* + * Class: org_apache_arrow_dataset_file_JniWrapper + * Method: newJniMethodReference + * Signature: (Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;)J + */ +JNIEXPORT jlong JNICALL +Java_org_apache_arrow_dataset_file_JniWrapper_newJniMethodReference(JNIEnv* env, jobject, + jstring class_sig, + jstring method_name, + jstring method_sig) { + JNI_METHOD_START + jclass clazz = env->FindClass(JStringToCString(env, class_sig).data()); + jmethodID jmethod_id = + env->GetMethodID(clazz, JStringToCString(env, method_name).data(), + JStringToCString(env, method_sig).data()); + return reinterpret_cast(jmethod_id); + JNI_METHOD_END(-1L) +} + /* * Class: org_apache_arrow_dataset_file_JniWrapper * Method: makeFileSystemDatasetFactory @@ -544,3 +497,19 @@ Java_org_apache_arrow_dataset_file_JniWrapper_makeFileSystemDatasetFactory( return CreateNativeRef(d); JNI_METHOD_END(-1L) } + +/* + * Class: org_apache_arrow_dataset_jni_JniWrapper + * Method: reexportUnsafeSerializedBatch + * Signature: ([B[B)[B + */ +JNIEXPORT jbyteArray JNICALL +Java_org_apache_arrow_dataset_jni_JniWrapper_reexportUnsafeSerializedBatch( + JNIEnv* env, jobject, jbyteArray schema_bytes, jbyteArray batch_bytes) { + JNI_METHOD_START + auto schema = JniGetOrThrow(FromSchemaByteArray(env, schema_bytes)); + auto batch = JniGetOrThrow( + arrow::dataset::jni::DeserializeUnsafeFromJava(env, schema, batch_bytes)); + return JniGetOrThrow(arrow::dataset::jni::SerializeUnsafeFromNative(env, batch)); + JNI_METHOD_END(nullptr) +} diff --git a/java/dataset/CMakeLists.txt b/java/dataset/CMakeLists.txt index 07e2d0ae8fc..75331adee63 100644 --- a/java/dataset/CMakeLists.txt +++ b/java/dataset/CMakeLists.txt @@ -30,14 +30,14 @@ include(FindJNI) message("generating headers to ${JNI_HEADERS_DIR}") -add_jar(arrow_dataset_java - src/main/java/org/apache/arrow/dataset/jni/JniLoader.java - src/main/java/org/apache/arrow/dataset/jni/JniWrapper.java - src/main/java/org/apache/arrow/dataset/jni/NativeRecordBatchHandle.java - src/main/java/org/apache/arrow/dataset/file/JniWrapper.java - src/main/java/org/apache/arrow/dataset/jni/NativeMemoryPool.java - src/main/java/org/apache/arrow/dataset/jni/ReservationListener.java - GENERATE_NATIVE_HEADERS - arrow_dataset_java-native - DESTINATION - ${JNI_HEADERS_DIR}) +add_jar( + arrow_dataset_java + src/main/java/org/apache/arrow/dataset/jni/JniLoader.java + src/main/java/org/apache/arrow/dataset/jni/JniWrapper.java + src/main/java/org/apache/arrow/dataset/file/JniWrapper.java + src/main/java/org/apache/arrow/dataset/jni/NativeMemoryPool.java + src/main/java/org/apache/arrow/dataset/jni/ReservationListener.java + GENERATE_NATIVE_HEADERS + arrow_dataset_java-native + DESTINATION + ${JNI_HEADERS_DIR}) diff --git a/java/dataset/pom.xml b/java/dataset/pom.xml index d4fea9f0efe..6553d41ec89 100644 --- a/java/dataset/pom.xml +++ b/java/dataset/pom.xml @@ -27,7 +27,7 @@ ../../../cpp/release-build/ 2.5.0 1.11.0 - 1.8.2 + 1.9.1 @@ -38,12 +38,34 @@ compile ${arrow.vector.classifier} + + org.apache.arrow + arrow-format + ${project.version} + compile + org.apache.arrow arrow-memory-core ${project.version} compile + + com.google.flatbuffers + flatbuffers-java + ${dep.fbs.version} + compile + + + com.fasterxml.jackson.core + jackson-core + compile + + + com.fasterxml.jackson.core + jackson-databind + compile + org.apache.arrow arrow-memory-netty @@ -56,6 +78,12 @@ ${parquet.version} test + + org.apache.avro + avro + ${avro.version} + test + org.apache.parquet parquet-hadoop @@ -86,18 +114,6 @@ - - org.apache.avro - avro - ${avro.version} - test - - - com.google.guava - guava - ${dep.guava.version} - test - @@ -108,27 +124,6 @@ - - - - org.xolstice.maven.plugins - protobuf-maven-plugin - 0.5.1 - - com.google.protobuf:protoc:${protobuf.version}:exe:${os.detected.classifier} - - ../../cpp/src/jni/dataset/proto - - - - - compile - test-compile - - - - - diff --git a/java/dataset/src/main/java/org/apache/arrow/dataset/file/FileFormat.java b/java/dataset/src/main/java/org/apache/arrow/dataset/file/FileFormat.java index e341d46beac..107fc2f71d2 100644 --- a/java/dataset/src/main/java/org/apache/arrow/dataset/file/FileFormat.java +++ b/java/dataset/src/main/java/org/apache/arrow/dataset/file/FileFormat.java @@ -24,7 +24,7 @@ public enum FileFormat { PARQUET(0), NONE(-1); - private int id; + private final int id; FileFormat(int id) { this.id = id; diff --git a/java/dataset/src/main/java/org/apache/arrow/dataset/file/JniWrapper.java b/java/dataset/src/main/java/org/apache/arrow/dataset/file/JniWrapper.java index 1af307aac38..cc61f994ab0 100644 --- a/java/dataset/src/main/java/org/apache/arrow/dataset/file/JniWrapper.java +++ b/java/dataset/src/main/java/org/apache/arrow/dataset/file/JniWrapper.java @@ -34,9 +34,27 @@ private JniWrapper() { JniLoader.get().ensureLoaded(); } + /** + * Create a Jni global reference for the object. + * @param object the input object + * @return the native pointer of global reference object. + */ + public native long newJniGlobalReference(Object object); + + /** + * Create a Jni method reference. + * @param classSignature signature of the class defining the target method + * @param methodName method name + * @param methodSignature signature of the target method + * @return the native pointer of method reference object. + */ + public native long newJniMethodReference(String classSignature, String methodName, + String methodSignature); + /** * Create FileSystemDatasetFactory and return its native pointer. The pointer is pointing to a * intermediate shared_ptr of the factory instance. + * * @param uri file uri to read * @param fileFormat file format ID * @return the native pointer of the arrow::dataset::FileSystemDatasetFactory instance. diff --git a/java/dataset/src/main/java/org/apache/arrow/dataset/jni/JniWrapper.java b/java/dataset/src/main/java/org/apache/arrow/dataset/jni/JniWrapper.java index 8460841eeee..4d55c3ec22e 100644 --- a/java/dataset/src/main/java/org/apache/arrow/dataset/jni/JniWrapper.java +++ b/java/dataset/src/main/java/org/apache/arrow/dataset/jni/JniWrapper.java @@ -52,7 +52,7 @@ private JniWrapper() { * Create Dataset from a DatasetFactory and get the native pointer of the Dataset. * * @param datasetFactoryId the native pointer of the arrow::dataset::DatasetFactory instance. - * @param schema the predefined schema of the resulting Dataset. + * @param schema the predefined schema of the resulting Dataset. * @return the native pointer of the arrow::dataset::Dataset instance. */ public native long createDataset(long datasetFactoryId, byte[] schema); @@ -66,9 +66,11 @@ private JniWrapper() { /** * Create Scanner from a Dataset and get the native pointer of the Dataset. - * @param datasetId the native pointer of the arrow::dataset::Dataset instance. - * @param columns desired column names. Columns not in this list will not be emitted when performing scan operation. - * @param batchSize batch size of scanned record batches. + * + * @param datasetId the native pointer of the arrow::dataset::Dataset instance. + * @param columns desired column names. Columns not in this list will not be emitted when performing scan + * operation. + * @param batchSize batch size of scanned record batches. * @param memoryPool identifier of memory pool used in the native scanner. * @return the native pointer of the arrow::dataset::Scanner instance. */ @@ -85,21 +87,41 @@ private JniWrapper() { /** * Release the Scanner by destroying its reference held by JNI wrapper. + * * @param scannerId the native pointer of the arrow::dataset::Scanner instance. */ public native void closeScanner(long scannerId); /** * Read next record batch from the specified scanner. + * * @param scannerId the native pointer of the arrow::dataset::Scanner instance. - * @return an instance of {@link NativeRecordBatchHandle} describing the overall layout of the native record batch. + * @return a flatbuffers-serialized + * {@link org.apache.arrow.flatbuf.Message} describing + * the overall layout of the native record batch. */ - public native NativeRecordBatchHandle nextRecordBatch(long scannerId); + public native byte[] nextRecordBatch(long scannerId); /** * Release the Buffer by destroying its reference held by JNI wrapper. + * * @param bufferId the native pointer of the arrow::Buffer instance. */ public native void releaseBuffer(long bufferId); + /** + * Visible for testing: Read input record batch that is serialized unsafely using + * {@link UnsafeRecordBatchSerializer}, then re-export it into bytes using C++ + * serialization facility. + * + * Used only in round-trip testing for {@link UnsafeRecordBatchSerializer}. + * + * @param schema serialized record batch schema + * @param recordBatch serialized record batch buffer data + * + * @return output record batch data which was re-serialized from C++ code + * + * @see UnsafeRecordBatchSerializer + */ + public native byte[] reexportUnsafeSerializedBatch(byte[] schema, byte[] recordBatch); } diff --git a/java/dataset/src/main/java/org/apache/arrow/dataset/jni/NativeRecordBatchHandle.java b/java/dataset/src/main/java/org/apache/arrow/dataset/jni/NativeRecordBatchHandle.java deleted file mode 100644 index dd90fd1c1dd..00000000000 --- a/java/dataset/src/main/java/org/apache/arrow/dataset/jni/NativeRecordBatchHandle.java +++ /dev/null @@ -1,106 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.arrow.dataset.jni; - -import java.util.Arrays; -import java.util.List; - -/** - * Hold pointers to a Arrow C++ RecordBatch. - */ -public class NativeRecordBatchHandle { - - private final long numRows; - private final List fields; - private final List buffers; - - /** - * Constructor. - * - * @param numRows Total row number of the associated RecordBatch - * @param fields Metadata of fields - * @param buffers Retained Arrow buffers - */ - public NativeRecordBatchHandle(long numRows, Field[] fields, Buffer[] buffers) { - this.numRows = numRows; - this.fields = Arrays.asList(fields); - this.buffers = Arrays.asList(buffers); - } - - /** - * Returns the total row number of the associated RecordBatch. - * @return Total row number of the associated RecordBatch. - */ - public long getNumRows() { - return numRows; - } - - /** - * Returns Metadata of fields. - * @return Metadata of fields. - */ - public List getFields() { - return fields; - } - - /** - * Returns the buffers. - * @return Retained Arrow buffers. - */ - public List getBuffers() { - return buffers; - } - - /** - * Field metadata. - */ - public static class Field { - public final long length; - public final long nullCount; - - public Field(long length, long nullCount) { - this.length = length; - this.nullCount = nullCount; - } - } - - /** - * Pointers and metadata of the targeted Arrow buffer. - */ - public static class Buffer { - public final long nativeInstanceId; - public final long memoryAddress; - public final long size; - public final long capacity; - - /** - * Constructor. - * - * @param nativeInstanceId Native instance's id - * @param memoryAddress Memory address of the first byte - * @param size Size (in bytes) - * @param capacity Capacity (in bytes) - */ - public Buffer(long nativeInstanceId, long memoryAddress, long size, long capacity) { - this.nativeInstanceId = nativeInstanceId; - this.memoryAddress = memoryAddress; - this.size = size; - this.capacity = capacity; - } - } -} diff --git a/java/dataset/src/main/java/org/apache/arrow/dataset/jni/NativeScanner.java b/java/dataset/src/main/java/org/apache/arrow/dataset/jni/NativeScanner.java index 24c298067af..ea2c9edf4ec 100644 --- a/java/dataset/src/main/java/org/apache/arrow/dataset/jni/NativeScanner.java +++ b/java/dataset/src/main/java/org/apache/arrow/dataset/jni/NativeScanner.java @@ -18,23 +18,15 @@ package org.apache.arrow.dataset.jni; import java.io.IOException; -import java.util.ArrayList; import java.util.Collections; import java.util.NoSuchElementException; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReadWriteLock; import java.util.concurrent.locks.ReentrantReadWriteLock; -import java.util.stream.Collectors; import org.apache.arrow.dataset.scanner.ScanTask; import org.apache.arrow.dataset.scanner.Scanner; -import org.apache.arrow.memory.ArrowBuf; -import org.apache.arrow.memory.BufferAllocator; -import org.apache.arrow.memory.BufferLedger; -import org.apache.arrow.memory.NativeUnderlyingMemory; -import org.apache.arrow.memory.util.LargeMemoryUtil; -import org.apache.arrow.vector.ipc.message.ArrowFieldNode; import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; import org.apache.arrow.vector.types.pojo.Schema; import org.apache.arrow.vector.util.SchemaUtility; @@ -81,39 +73,21 @@ public boolean hasNext() { if (peek != null) { return true; } - final NativeRecordBatchHandle handle; + final byte[] bytes; readLock.lock(); try { if (closed) { throw new NativeInstanceReleasedException(); } - handle = JniWrapper.get().nextRecordBatch(scannerId); + bytes = JniWrapper.get().nextRecordBatch(scannerId); } finally { readLock.unlock(); } - if (handle == null) { + if (bytes == null) { return false; } - final ArrayList buffers = new ArrayList<>(); - for (NativeRecordBatchHandle.Buffer buffer : handle.getBuffers()) { - final BufferAllocator allocator = context.getAllocator(); - final int size = LargeMemoryUtil.checkedCastToInt(buffer.size); - final NativeUnderlyingMemory am = NativeUnderlyingMemory.create(allocator, - size, buffer.nativeInstanceId, buffer.memoryAddress); - BufferLedger ledger = am.associate(allocator); - ArrowBuf buf = new ArrowBuf(ledger, null, size, buffer.memoryAddress); - buffers.add(buf); - } - - try { - final int numRows = LargeMemoryUtil.checkedCastToInt(handle.getNumRows()); - peek = new ArrowRecordBatch(numRows, handle.getFields().stream() - .map(field -> new ArrowFieldNode(field.length, field.nullCount)) - .collect(Collectors.toList()), buffers); - return true; - } finally { - buffers.forEach(buffer -> buffer.getReferenceManager().release()); - } + peek = UnsafeRecordBatchSerializer.deserializeUnsafe(context.getAllocator(), bytes); + return true; } @Override diff --git a/java/dataset/src/main/java/org/apache/arrow/dataset/jni/UnsafeRecordBatchSerializer.java b/java/dataset/src/main/java/org/apache/arrow/dataset/jni/UnsafeRecordBatchSerializer.java new file mode 100644 index 00000000000..069905ac7de --- /dev/null +++ b/java/dataset/src/main/java/org/apache/arrow/dataset/jni/UnsafeRecordBatchSerializer.java @@ -0,0 +1,291 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.dataset.jni; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.channels.Channels; +import java.util.ArrayList; +import java.util.Base64; +import java.util.List; +import java.util.stream.Collectors; + +import org.apache.arrow.dataset.file.JniWrapper; +import org.apache.arrow.flatbuf.Buffer; +import org.apache.arrow.flatbuf.FieldNode; +import org.apache.arrow.flatbuf.KeyValue; +import org.apache.arrow.flatbuf.Message; +import org.apache.arrow.flatbuf.MessageHeader; +import org.apache.arrow.flatbuf.RecordBatch; +import org.apache.arrow.memory.ArrowBuf; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.BufferLedger; +import org.apache.arrow.memory.NativeUnderlyingMemory; +import org.apache.arrow.memory.util.LargeMemoryUtil; +import org.apache.arrow.util.Preconditions; +import org.apache.arrow.vector.compression.NoCompressionCodec; +import org.apache.arrow.vector.ipc.ReadChannel; +import org.apache.arrow.vector.ipc.WriteChannel; +import org.apache.arrow.vector.ipc.message.ArrowBodyCompression; +import org.apache.arrow.vector.ipc.message.ArrowFieldNode; +import org.apache.arrow.vector.ipc.message.ArrowMessage; +import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; +import org.apache.arrow.vector.ipc.message.FBSerializable; +import org.apache.arrow.vector.ipc.message.FBSerializables; +import org.apache.arrow.vector.ipc.message.IpcOption; +import org.apache.arrow.vector.ipc.message.MessageMetadataResult; +import org.apache.arrow.vector.ipc.message.MessageSerializer; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.ArrayNode; +import com.google.flatbuffers.FlatBufferBuilder; + +/** + * A set of serialization utility methods against {@link org.apache.arrow.vector.ipc.message.ArrowRecordBatch}. + * + *

The utility should be used only in JNI case since the record batch + * to serialize should keep alive during the life cycle of its deserialized + * native record batch. We made this design for achieving zero-copy of + * the buffer bodies. + */ +public class UnsafeRecordBatchSerializer { + + private static final ObjectMapper JSON = new ObjectMapper(); + + /** + * This is in response to native arrow::Buffer instance's destructor after + * Java side {@link ArrowBuf} is transferred to native via JNI. At here we + * think of a C++ transferred buffer holding one Java side buffer's reference + * count so memory management system from Java side can be able to + * know when native code finishes using the buffer. This way the corresponding + * allocated memory space can be correctly collected. + * + * @see UnsafeRecordBatchSerializer#serializeUnsafe(ArrowRecordBatch) + */ + private static class TransferredReferenceCleaner implements Runnable { + private static final long NATIVE_METHOD_REF = + JniWrapper.get().newJniMethodReference("Ljava/lang/Runnable;", "run", "()V"); + private final ArrowBuf buf; + + private TransferredReferenceCleaner(ArrowBuf buf) { + this.buf = buf; + } + + @Override + public void run() { + buf.getReferenceManager().release(); + } + } + + /** + * Deserialize from native serialized bytes to {@link ArrowRecordBatch} using flatbuffers. + * The input byte array should be written from native code and of type + * {@link Message} + * in which a native buffer ID is required in custom metadata. + * + * @param allocator Allocator that the deserialized buffer should be associated with + * @param bytes flatbuffers byte array + * @return the deserialized record batch + * @see NativeUnderlyingMemory + */ + public static ArrowRecordBatch deserializeUnsafe( + BufferAllocator allocator, + byte[] bytes) { + final ReadChannel metaIn = new ReadChannel( + Channels.newChannel(new ByteArrayInputStream(bytes))); + + final Message metaMessage; + try { + final MessageMetadataResult result = MessageSerializer.readMessage(metaIn); + Preconditions.checkNotNull(result); + metaMessage = result.getMessage(); + } catch (IOException e) { + throw new RuntimeException("Failed to deserialize record batch metadata", e); + } + final RecordBatch batchMeta = (RecordBatch) metaMessage.header(new RecordBatch()); + Preconditions.checkNotNull(batchMeta); + + if (metaMessage.customMetadataLength() != 1) { + throw new IllegalArgumentException("RecordBatch metadata not found"); + } + final String nativeBufferRefsJson = metaMessage.customMetadata(0).value(); + final ArrayNode nativeBufferRefs; + try { + // JSON array containing native buffer refs + nativeBufferRefs = (ArrayNode) JSON.readTree(nativeBufferRefsJson); + } catch (JsonProcessingException e) { + throw new RuntimeException("Malformed JSON array: " + nativeBufferRefsJson, e); + } + + if (batchMeta.buffersLength() != nativeBufferRefs.size()) { + throw new IllegalArgumentException("Buffer count mismatch between metadata and native managed refs"); + } + + final ArrayList buffers = new ArrayList<>(); + for (int i = 0; i < batchMeta.buffersLength(); i++) { + final Buffer bufferMeta = batchMeta.buffers(i); + final JsonNode jsonNode = nativeBufferRefs.get(i); + if (!jsonNode.isLong()) { + throw new RuntimeException("Not a JSON long value: " + jsonNode); + } + final long nativeBufferRef = jsonNode.asLong(); + final int size = LargeMemoryUtil.checkedCastToInt(bufferMeta.length()); + final NativeUnderlyingMemory am = NativeUnderlyingMemory.create(allocator, + size, nativeBufferRef, bufferMeta.offset()); + BufferLedger ledger = am.associate(allocator); + ArrowBuf buf = new ArrowBuf(ledger, null, size, bufferMeta.offset()); + buffers.add(buf); + } + + try { + final int numRows = LargeMemoryUtil.checkedCastToInt(batchMeta.length()); + final List nodes = new ArrayList<>(batchMeta.nodesLength()); + for (int i = 0; i < batchMeta.nodesLength(); i++) { + final FieldNode node = batchMeta.nodes(i); + nodes.add(new ArrowFieldNode(node.length(), node.nullCount())); + } + return new ArrowRecordBatch(numRows, nodes, buffers); + } finally { + buffers.forEach(buffer -> buffer.getReferenceManager().release()); + } + } + + /** + * Serialize from {@link ArrowRecordBatch} to flatbuffers bytes for native use. A cleaner callback + * {@link TransferredReferenceCleaner} will be created for each individual serialized + * buffer. The callback should be invoked once the buffer is collected from native code. + * We use the callback to decrease reference count of Java side {@link ArrowBuf} here. + * + * @param batch input record batch + * @return serialized bytes + * @see TransferredReferenceCleaner + */ + public static byte[] serializeUnsafe(ArrowRecordBatch batch) { + final ArrowBodyCompression bodyCompression = batch.getBodyCompression(); + if (bodyCompression.getCodec() != NoCompressionCodec.COMPRESSION_TYPE) { + throw new UnsupportedOperationException("Could not serialize compressed buffers"); + } + + final FlatBufferBuilder builder = new FlatBufferBuilder(); + List buffers = batch.getBuffers(); + // here we have 2 types of refs: cleaner object ref (CO_REF) and cleaner method ref (CM_REF) + final List coRefs = new ArrayList<>(); + final List cmRefs = new ArrayList<>(); + for (ArrowBuf buffer : buffers) { + final TransferredReferenceCleaner cleaner = new TransferredReferenceCleaner(buffer); + // cleaner object ref + long objectRefValue = JniWrapper.get().newJniGlobalReference(cleaner); + coRefs.add(objectRefValue); + // cleaner method ref + long methodRefValue = TransferredReferenceCleaner.NATIVE_METHOD_REF; + cmRefs.add(methodRefValue); + } + int[] metadataOffsets = new int[2]; + try { + metadataOffsets[0] = KeyValue.createKeyValue(builder, builder.createString("JAVA_BUFFER_CO_REFS"), + builder.createString(JSON.writeValueAsString(coRefs))); + metadataOffsets[1] = KeyValue.createKeyValue(builder, builder.createString("JAVA_BUFFER_CM_REFS"), + builder.createString(JSON.writeValueAsString(cmRefs))); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + + final ArrowMessage unsafeRecordMessage = new UnsafeRecordBatchMetadataMessage(batch); + final int batchOffset = unsafeRecordMessage.writeTo(builder); + final int customMetadataOffset = Message.createCustomMetadataVector(builder, metadataOffsets); + Message.startMessage(builder); + Message.addHeaderType(builder, unsafeRecordMessage.getMessageType()); + Message.addHeader(builder, batchOffset); + Message.addVersion(builder, IpcOption.DEFAULT.metadataVersion.toFlatbufID()); + Message.addBodyLength(builder, unsafeRecordMessage.computeBodyLength()); + Message.addCustomMetadata(builder, customMetadataOffset); + builder.finish(Message.endMessage(builder)); + final ByteBuffer metaBuffer = builder.dataBuffer(); + final ByteArrayOutputStream out = new ByteArrayOutputStream(); + try { + MessageSerializer.writeMessageBuffer(new WriteChannel(Channels.newChannel(out)), metaBuffer.remaining(), + metaBuffer, IpcOption.DEFAULT); + } catch (IOException e) { + throw new RuntimeException("Failed to serialize Java record batch", e); + } + return out.toByteArray(); + } + + /** + * IPC message for record batches that are based on unsafe shared virtual memory. + */ + public static class UnsafeRecordBatchMetadataMessage implements ArrowMessage { + private ArrowRecordBatch delegated; + + public UnsafeRecordBatchMetadataMessage(ArrowRecordBatch delegated) { + this.delegated = delegated; + } + + @Override + public long computeBodyLength() { + return 0L; + } + + @Override + public T accepts(ArrowMessageVisitor visitor) { + throw new UnsupportedOperationException(); + } + + @Override + public byte getMessageType() { + return MessageHeader.RecordBatch; + } + + @Override + public int writeTo(FlatBufferBuilder builder) { + final List nodes = delegated.getNodes(); + final List buffers = delegated.getBuffers(); + final ArrowBodyCompression bodyCompression = delegated.getBodyCompression(); + final int length = delegated.getLength(); + RecordBatch.startNodesVector(builder, nodes.size()); + int nodesOffset = FBSerializables.writeAllStructsToVector(builder, nodes); + RecordBatch.startBuffersVector(builder, buffers.size()); + int buffersOffset = FBSerializables.writeAllStructsToVector(builder, buffers.stream() + .map(buf -> (FBSerializable) b -> Buffer.createBuffer(b, buf.memoryAddress(), + buf.getReferenceManager().getSize())) + .collect(Collectors.toList())); + int compressOffset = 0; + if (bodyCompression.getCodec() != NoCompressionCodec.COMPRESSION_TYPE) { + compressOffset = bodyCompression.writeTo(builder); + } + RecordBatch.startRecordBatch(builder); + RecordBatch.addLength(builder, length); + RecordBatch.addNodes(builder, nodesOffset); + RecordBatch.addBuffers(builder, buffersOffset); + if (bodyCompression.getCodec() != NoCompressionCodec.COMPRESSION_TYPE) { + RecordBatch.addCompression(builder, compressOffset); + } + return RecordBatch.endRecordBatch(builder); + } + + @Override + public void close() throws Exception { + delegated.close(); + } + } +} diff --git a/java/dataset/src/main/java/org/apache/arrow/memory/NativeUnderlyingMemory.java b/java/dataset/src/main/java/org/apache/arrow/memory/NativeUnderlyingMemory.java index 963fb617040..5f3095910cc 100644 --- a/java/dataset/src/main/java/org/apache/arrow/memory/NativeUnderlyingMemory.java +++ b/java/dataset/src/main/java/org/apache/arrow/memory/NativeUnderlyingMemory.java @@ -25,7 +25,7 @@ public class NativeUnderlyingMemory extends AllocationManager { private final int size; - private final long nativeInstanceId; + private final long nativeBufferId; private final long address; /** @@ -33,12 +33,12 @@ public class NativeUnderlyingMemory extends AllocationManager { * * @param accountingAllocator The accounting allocator instance * @param size Size of underlying memory (in bytes) - * @param nativeInstanceId ID of the native instance + * @param nativeBufferId ID of the native instance */ - NativeUnderlyingMemory(BufferAllocator accountingAllocator, int size, long nativeInstanceId, long address) { + NativeUnderlyingMemory(BufferAllocator accountingAllocator, int size, long nativeBufferId, long address) { super(accountingAllocator); this.size = size; - this.nativeInstanceId = nativeInstanceId; + this.nativeBufferId = nativeBufferId; this.address = address; // pre-allocate bytes on accounting allocator final AllocationListener listener = accountingAllocator.getListener(); @@ -55,9 +55,9 @@ public class NativeUnderlyingMemory extends AllocationManager { /** * Alias to constructor. */ - public static NativeUnderlyingMemory create(BufferAllocator bufferAllocator, int size, long nativeInstanceId, + public static NativeUnderlyingMemory create(BufferAllocator bufferAllocator, int size, long nativeBufferId, long address) { - return new NativeUnderlyingMemory(bufferAllocator, size, nativeInstanceId, address); + return new NativeUnderlyingMemory(bufferAllocator, size, nativeBufferId, address); } public BufferLedger associate(BufferAllocator allocator) { @@ -66,7 +66,7 @@ public BufferLedger associate(BufferAllocator allocator) { @Override protected void release0() { - JniWrapper.get().releaseBuffer(nativeInstanceId); + JniWrapper.get().releaseBuffer(nativeBufferId); } @Override diff --git a/java/dataset/src/test/java/org/apache/arrow/dataset/ParquetWriteSupport.java b/java/dataset/src/test/java/org/apache/arrow/dataset/ParquetWriteSupport.java index c6299d135a0..da48da6c0db 100644 --- a/java/dataset/src/test/java/org/apache/arrow/dataset/ParquetWriteSupport.java +++ b/java/dataset/src/test/java/org/apache/arrow/dataset/ParquetWriteSupport.java @@ -47,7 +47,7 @@ public class ParquetWriteSupport implements AutoCloseable { public ParquetWriteSupport(String schemaName, File outputFolder) throws Exception { avroSchema = readSchemaFromFile(schemaName); path = outputFolder.getPath() + File.separator + "generated.parquet"; - uri = "file://" + path; + uri = new File(path).toURI().toString(); writer = AvroParquetWriter.builder(new org.apache.hadoop.fs.Path(path)) .withSchema(avroSchema) .build(); diff --git a/java/dataset/src/test/java/org/apache/arrow/dataset/file/TestFileSystemDataset.java b/java/dataset/src/test/java/org/apache/arrow/dataset/file/TestFileSystemDataset.java index 063f0955925..07d5a754f9b 100644 --- a/java/dataset/src/test/java/org/apache/arrow/dataset/file/TestFileSystemDataset.java +++ b/java/dataset/src/test/java/org/apache/arrow/dataset/file/TestFileSystemDataset.java @@ -81,6 +81,7 @@ public void testParquetRead() throws Exception { checkParquetReadResult(schema, writeSupport.getWrittenRecords(), datum); AutoCloseables.close(datum); + AutoCloseables.close(factory); } @Test @@ -107,6 +108,7 @@ public void testParquetProjector() throws Exception { .build()), datum); AutoCloseables.close(datum); + AutoCloseables.close(factory); } @Test @@ -126,6 +128,7 @@ public void testParquetBatchSize() throws Exception { checkParquetReadResult(schema, writeSupport.getWrittenRecords(), datum); AutoCloseables.close(datum); + AutoCloseables.close(factory); } @Test @@ -140,6 +143,8 @@ public void testCloseAgain() throws Exception { dataset.close(); dataset.close(); }); + + AutoCloseables.close(factory); } @Test @@ -197,6 +202,7 @@ public void testScanAfterClose1() throws Exception { NativeScanner scanner = dataset.newScan(options); scanner.close(); assertThrows(NativeInstanceReleasedException.class, scanner::scan); + AutoCloseables.close(factory); } @Test @@ -212,6 +218,7 @@ public void testScanAfterClose2() throws Exception { NativeScanTask task = tasks.get(0); task.close(); assertThrows(NativeInstanceReleasedException.class, task::execute); + AutoCloseables.close(factory); } @Test @@ -228,6 +235,7 @@ public void testScanAfterClose3() throws Exception { ScanTask.BatchIterator iterator = task.execute(); task.close(); assertThrows(NativeInstanceReleasedException.class, iterator::hasNext); + AutoCloseables.close(factory); } @Test @@ -247,6 +255,7 @@ public void testMemoryAllocation() throws Exception { long finalReservation = rootAllocator().getAllocatedMemory(); Assert.assertEquals(expected_diff, reservation - initReservation); Assert.assertEquals(-expected_diff, finalReservation - reservation); + AutoCloseables.close(factory); } private void checkParquetReadResult(Schema schema, List expected, List actual) { diff --git a/java/dataset/src/test/java/org/apache/arrow/dataset/jni/TestUnsafeRecordBatchSerializer.java b/java/dataset/src/test/java/org/apache/arrow/dataset/jni/TestUnsafeRecordBatchSerializer.java new file mode 100644 index 00000000000..2613d0ea4e1 --- /dev/null +++ b/java/dataset/src/test/java/org/apache/arrow/dataset/jni/TestUnsafeRecordBatchSerializer.java @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.dataset.jni; + +import static java.util.Arrays.asList; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import org.apache.arrow.memory.ArrowBuf; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.VectorLoader; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.ipc.message.ArrowFieldNode; +import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; +import org.apache.arrow.vector.types.Types; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.arrow.vector.util.SchemaUtility; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +public class TestUnsafeRecordBatchSerializer { + + private RootAllocator allocator = null; + + @Before + public void setUp() { + allocator = new RootAllocator(Long.MAX_VALUE); + } + + @After + public void tearDown() { + allocator.close(); + } + + private ArrowBuf buf(byte[] bytes) { + ArrowBuf buffer = allocator.buffer(bytes.length); + buffer.writeBytes(bytes); + return buffer; + } + + private ArrowBuf intBuf(int[] ints) { + ArrowBuf buffer = allocator.buffer(ints.length * 4L); + for (int i : ints) { + buffer.writeInt(i); + } + return buffer; + } + + private static Field field(String name, boolean nullable, ArrowType type, Field... children) { + return new Field(name, new FieldType(nullable, type, null, null), asList(children)); + } + + @Test + public void testRoundTrip() throws IOException { + final byte[] aValidity = new byte[]{(byte) 255, 0}; + final byte[] bValidity = new byte[]{(byte) 255, 0}; + // second half is "undefined" + final int[] aValues = new int[]{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; + final int[] bValues = new int[]{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}; + final Integer[] aExpected = + new Integer[]{1, 2, 3, 4, 5, 6, 7, 8, null, null, null, null, null, null, null, null}; + final Integer[] bExpected = + new Integer[]{16, 15, 14, 13, 12, 11, 10, 9, null, null, null, null, null, null, null, null}; + ArrowBuf validitya = buf(aValidity); + ArrowBuf valuesa = intBuf(aValues); + ArrowBuf validityb = buf(bValidity); + ArrowBuf valuesb = intBuf(bValues); + final ArrowRecordBatch batch = new ArrowRecordBatch( + 16, Arrays.asList(new ArrowFieldNode(16, 8), new ArrowFieldNode(16, 8)), + Arrays.asList(validitya, valuesa, validityb, valuesb)); + final Schema schema = new org.apache.arrow.vector.types.pojo.Schema(asList( + field("a", true, new ArrowType.Int(32, true)), + field("b", true, new ArrowType.Int(32, true))) + ); + + byte[] reexported = JniWrapper.get().reexportUnsafeSerializedBatch(SchemaUtility.serialize(schema), + UnsafeRecordBatchSerializer.serializeUnsafe(batch)); + final ArrowRecordBatch reexportedBatch = UnsafeRecordBatchSerializer.deserializeUnsafe(allocator, reexported); + + VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator); + VectorLoader loader = new VectorLoader(root); + loader.load(reexportedBatch); + + final List fieldVectors = root.getFieldVectors(); + + // check data + Assert.assertEquals(2, fieldVectors.size()); + FieldVector aVector = fieldVectors.get(0); + FieldVector bVector = fieldVectors.get(1); + Assert.assertEquals(Types.MinorType.INT, aVector.getMinorType()); + Assert.assertEquals(Types.MinorType.INT, bVector.getMinorType()); + IntVector aIntVector = (IntVector) aVector; + IntVector bIntVector = (IntVector) bVector; + Assert.assertEquals(8, aIntVector.getNullCount()); + Assert.assertEquals(8, bIntVector.getNullCount()); + for (int i = 0; i < 16; i++) { + Assert.assertEquals(aExpected[i], aIntVector.getObject(i)); + Assert.assertEquals(bExpected[i], bIntVector.getObject(i)); + } + // check memory release + root.close(); + reexportedBatch.close(); + batch.close(); + } +} From 28742ab00d163c5c0cfb999c5fe7a8c9bfa17b68 Mon Sep 17 00:00:00 2001 From: Hongze Zhang Date: Mon, 28 Jun 2021 11:10:48 +0800 Subject: [PATCH 2/8] style --- cpp/src/jni/dataset/jni_util.cc | 3 ++- java/dataset/CMakeLists.txt | 21 +++++++++---------- .../jni/UnsafeRecordBatchSerializer.java | 2 -- .../jni/TestUnsafeRecordBatchSerializer.java | 1 - 4 files changed, 12 insertions(+), 15 deletions(-) diff --git a/cpp/src/jni/dataset/jni_util.cc b/cpp/src/jni/dataset/jni_util.cc index 2bd533810d1..7d74009ff30 100644 --- a/cpp/src/jni/dataset/jni_util.cc +++ b/cpp/src/jni/dataset/jni_util.cc @@ -262,7 +262,8 @@ Result ToSchemaByteArray(JNIEnv* env, std::shared_ptr schema return out; } -Result> FromSchemaByteArray(JNIEnv* env, jbyteArray schema_bytes) { +Result> FromSchemaByteArray(JNIEnv* env, + jbyteArray schema_bytes) { ipc::DictionaryMemo in_memo; int schemaBytes_len = env->GetArrayLength(schema_bytes); jbyte* schemaBytes_data = env->GetByteArrayElements(schema_bytes, nullptr); diff --git a/java/dataset/CMakeLists.txt b/java/dataset/CMakeLists.txt index 75331adee63..5b6e4a9ce24 100644 --- a/java/dataset/CMakeLists.txt +++ b/java/dataset/CMakeLists.txt @@ -30,14 +30,13 @@ include(FindJNI) message("generating headers to ${JNI_HEADERS_DIR}") -add_jar( - arrow_dataset_java - src/main/java/org/apache/arrow/dataset/jni/JniLoader.java - src/main/java/org/apache/arrow/dataset/jni/JniWrapper.java - src/main/java/org/apache/arrow/dataset/file/JniWrapper.java - src/main/java/org/apache/arrow/dataset/jni/NativeMemoryPool.java - src/main/java/org/apache/arrow/dataset/jni/ReservationListener.java - GENERATE_NATIVE_HEADERS - arrow_dataset_java-native - DESTINATION - ${JNI_HEADERS_DIR}) +add_jar(arrow_dataset_java + src/main/java/org/apache/arrow/dataset/jni/JniLoader.java + src/main/java/org/apache/arrow/dataset/jni/JniWrapper.java + src/main/java/org/apache/arrow/dataset/file/JniWrapper.java + src/main/java/org/apache/arrow/dataset/jni/NativeMemoryPool.java + src/main/java/org/apache/arrow/dataset/jni/ReservationListener.java + GENERATE_NATIVE_HEADERS + arrow_dataset_java-native + DESTINATION + ${JNI_HEADERS_DIR}) diff --git a/java/dataset/src/main/java/org/apache/arrow/dataset/jni/UnsafeRecordBatchSerializer.java b/java/dataset/src/main/java/org/apache/arrow/dataset/jni/UnsafeRecordBatchSerializer.java index 069905ac7de..079801b0de5 100644 --- a/java/dataset/src/main/java/org/apache/arrow/dataset/jni/UnsafeRecordBatchSerializer.java +++ b/java/dataset/src/main/java/org/apache/arrow/dataset/jni/UnsafeRecordBatchSerializer.java @@ -21,10 +21,8 @@ import java.io.ByteArrayOutputStream; import java.io.IOException; import java.nio.ByteBuffer; -import java.nio.ByteOrder; import java.nio.channels.Channels; import java.util.ArrayList; -import java.util.Base64; import java.util.List; import java.util.stream.Collectors; diff --git a/java/dataset/src/test/java/org/apache/arrow/dataset/jni/TestUnsafeRecordBatchSerializer.java b/java/dataset/src/test/java/org/apache/arrow/dataset/jni/TestUnsafeRecordBatchSerializer.java index 2613d0ea4e1..8b1cda684bf 100644 --- a/java/dataset/src/test/java/org/apache/arrow/dataset/jni/TestUnsafeRecordBatchSerializer.java +++ b/java/dataset/src/test/java/org/apache/arrow/dataset/jni/TestUnsafeRecordBatchSerializer.java @@ -21,7 +21,6 @@ import java.io.IOException; import java.util.Arrays; -import java.util.Collections; import java.util.List; import org.apache.arrow.memory.ArrowBuf; From 9ac17b33a02ca12d046bc42a9976edc125ff5f8b Mon Sep 17 00:00:00 2001 From: Hongze Zhang Date: Wed, 7 Apr 2021 18:15:51 +0800 Subject: [PATCH 3/8] ARROW-11776: [Java][Dataset] Support writing to files within dataset scanner via JNI --- cpp/src/jni/dataset/jni_wrapper.cc | 171 ++++++++++++++++++ java/dataset/CMakeLists.txt | 2 + .../arrow/dataset/file/DatasetFileWriter.java | 74 ++++++++ .../apache/arrow/dataset/file/JniWrapper.java | 16 ++ .../dataset/file/NativeScannerAdaptor.java | 35 ++++ .../file/NativeScannerAdaptorImpl.java | 121 +++++++++++++ .../NativeSerializedRecordBatchIterator.java | 35 ++++ .../dataset/file/TestDatasetFileWriter.java | 125 +++++++++++++ 8 files changed, 579 insertions(+) create mode 100644 java/dataset/src/main/java/org/apache/arrow/dataset/file/DatasetFileWriter.java create mode 100644 java/dataset/src/main/java/org/apache/arrow/dataset/file/NativeScannerAdaptor.java create mode 100644 java/dataset/src/main/java/org/apache/arrow/dataset/file/NativeScannerAdaptorImpl.java create mode 100644 java/dataset/src/main/java/org/apache/arrow/dataset/jni/NativeSerializedRecordBatchIterator.java create mode 100644 java/dataset/src/test/java/org/apache/arrow/dataset/file/TestDatasetFileWriter.java diff --git a/cpp/src/jni/dataset/jni_wrapper.cc b/cpp/src/jni/dataset/jni_wrapper.cc index 20af823b8e2..9525439a804 100644 --- a/cpp/src/jni/dataset/jni_wrapper.cc +++ b/cpp/src/jni/dataset/jni_wrapper.cc @@ -19,6 +19,7 @@ #include "arrow/array.h" #include "arrow/dataset/api.h" +#include "arrow/dataset/dataset_internal.h" #include "arrow/dataset/file_base.h" #include "arrow/filesystem/localfs.h" #include "arrow/ipc/api.h" @@ -36,8 +37,11 @@ jclass illegal_access_exception_class; jclass illegal_argument_exception_class; jclass runtime_exception_class; +jclass serialized_record_batch_iterator_class; jclass java_reservation_listener_class; +jmethodID serialized_record_batch_iterator_hasNext; +jmethodID serialized_record_batch_iterator_next; jmethodID reserve_memory_method; jmethodID unreserve_memory_method; @@ -152,6 +156,126 @@ class DisposableScannerAdaptor { } }; +/// \brief Simple scan task implementation that is constructed directly +/// from a record batch iterator (and its corresponding fragment). +class SimpleIteratorTask : public arrow::dataset::ScanTask { + public: + SimpleIteratorTask(std::shared_ptr options, + std::shared_ptr fragment, + arrow::RecordBatchIterator itr) + : ScanTask(options, fragment) { + this->itr_ = std::move(itr); + } + + static arrow::Result> Make( + arrow::RecordBatchIterator itr, + std::shared_ptr options, + std::shared_ptr fragment) { + return std::make_shared(options, fragment, std::move(itr)); + } + + arrow::Result Execute() override { + if (used_) { + return arrow::Status::Invalid( + "SimpleIteratorFragment is disposable and" + "already scanned"); + } + used_ = true; + return std::move(itr_); + } + + private: + arrow::RecordBatchIterator itr_; + bool used_ = false; +}; + +/// \brief Simple fragment implementation that is constructed directly +/// from a record batch iterator. +class SimpleIteratorFragment : public arrow::dataset::Fragment { + public: + explicit SimpleIteratorFragment(arrow::RecordBatchIterator itr) + : arrow::dataset::Fragment() { + itr_ = std::move(itr); + } + + static arrow::Result> Make( + arrow::RecordBatchIterator itr) { + return std::make_shared(std::move(itr)); + } + + arrow::Result ScanBatchesAsync( + const std::shared_ptr& options) override { + return arrow::Status::NotImplemented("Aysnc scan not supported"); + } + + arrow::Result Scan( + std::shared_ptr options) override { + if (used_) { + return arrow::Status::Invalid( + "SimpleIteratorFragment is disposable and" + "already scanned"); + } + used_ = true; + ARROW_ASSIGN_OR_RAISE( + auto task, SimpleIteratorTask::Make(std::move(itr_), options, shared_from_this())) + return arrow::MakeVectorIterator>({task}); + } + + std::string type_name() const override { return "simple_iterator"; } + + arrow::Result> ReadPhysicalSchemaImpl() override { + return arrow::Status::NotImplemented("No physical schema is readable"); + } + + private: + arrow::RecordBatchIterator itr_; + bool used_ = false; +}; + +arrow::Result> FromBytes( + JNIEnv* env, std::shared_ptr schema, jbyteArray bytes) { + ARROW_ASSIGN_OR_RAISE( + std::shared_ptr batch, + arrow::dataset::jni::DeserializeUnsafeFromJava(env, schema, bytes)) + return batch; +} + +/// \brief Create scanner that scans over Java dataset API's components. +/// +/// Currently, we use a NativeSerializedRecordBatchIterator as the underlying +/// Java object to do scanning. Which means, only one single task will +/// be produced from C++ code. +arrow::Result> MakeJavaDatasetScanner( + JavaVM* vm, jobject java_serialized_record_batch_iterator, + std::shared_ptr schema) { + arrow::RecordBatchIterator itr = arrow::MakeFunctionIterator( + [vm, java_serialized_record_batch_iterator, + schema]() -> arrow::Result> { + JNIEnv* env; + if (vm->GetEnv(reinterpret_cast(&env), JNI_VERSION) != JNI_OK) { + return arrow::Status::Invalid("JNIEnv was not attached to current thread"); + } + if (!env->CallBooleanMethod(java_serialized_record_batch_iterator, + serialized_record_batch_iterator_hasNext)) { + return nullptr; // stream ended + } + auto bytes = (jbyteArray)env->CallObjectMethod( + java_serialized_record_batch_iterator, serialized_record_batch_iterator_next); + RETURN_NOT_OK(arrow::dataset::jni::CheckException(env)); + ARROW_ASSIGN_OR_RAISE(auto batch, FromBytes(env, schema, bytes)); + return batch; + }); + + ARROW_ASSIGN_OR_RAISE(auto fragment, SimpleIteratorFragment::Make(std::move(itr))) + + arrow::dataset::ScannerBuilder scanner_builder( + std::move(schema), fragment, std::make_shared()); + // Use default memory pool is enough as native allocation is ideally + // not being called during scanning Java-based fragments. + RETURN_NOT_OK(scanner_builder.Pool(arrow::default_memory_pool())); + return scanner_builder.Finish(); +} + } // namespace using arrow::dataset::jni::CreateGlobalClassReference; @@ -191,10 +315,19 @@ jint JNI_OnLoad(JavaVM* vm, void* reserved) { runtime_exception_class = CreateGlobalClassReference(env, "Ljava/lang/RuntimeException;"); + serialized_record_batch_iterator_class = + CreateGlobalClassReference(env, + "Lorg/apache/arrow/" + "dataset/jni/NativeSerializedRecordBatchIterator;"); java_reservation_listener_class = CreateGlobalClassReference(env, "Lorg/apache/arrow/" "dataset/jni/ReservationListener;"); + + serialized_record_batch_iterator_hasNext = JniGetOrThrow( + GetMethodID(env, serialized_record_batch_iterator_class, "hasNext", "()Z")); + serialized_record_batch_iterator_next = JniGetOrThrow( + GetMethodID(env, serialized_record_batch_iterator_class, "next", "()[B")); reserve_memory_method = JniGetOrThrow(GetMethodID(env, java_reservation_listener_class, "reserve", "(J)V")); unreserve_memory_method = JniGetOrThrow( @@ -212,6 +345,7 @@ void JNI_OnUnload(JavaVM* vm, void* reserved) { env->DeleteGlobalRef(illegal_access_exception_class); env->DeleteGlobalRef(illegal_argument_exception_class); env->DeleteGlobalRef(runtime_exception_class); + env->DeleteGlobalRef(serialized_record_batch_iterator_class); env->DeleteGlobalRef(java_reservation_listener_class); default_memory_pool_id = -1L; @@ -513,3 +647,40 @@ Java_org_apache_arrow_dataset_jni_JniWrapper_reexportUnsafeSerializedBatch( return JniGetOrThrow(arrow::dataset::jni::SerializeUnsafeFromNative(env, batch)); JNI_METHOD_END(nullptr) } + +/* + * Class: org_apache_arrow_dataset_file_JniWrapper + * Method: writeFromScannerToFile + * Signature: + * (Lorg/apache/arrow/dataset/jni/NativeSerializedRecordBatchIterator;[BILjava/lang/String;[Ljava/lang/String;ILjava/lang/String;)V + */ +JNIEXPORT void JNICALL +Java_org_apache_arrow_dataset_file_JniWrapper_writeFromScannerToFile( + JNIEnv* env, jobject, jobject itr, jbyteArray schema_bytes, jint file_format_id, + jstring uri, jobjectArray partition_columns, jint max_partitions, + jstring base_name_template) { + JNI_METHOD_START + JavaVM* vm; + if (env->GetJavaVM(&vm) != JNI_OK) { + JniThrow("Unable to get JavaVM instance"); + } + auto schema = JniGetOrThrow(FromSchemaByteArray(env, schema_bytes)); + auto scanner = JniGetOrThrow(MakeJavaDatasetScanner(vm, itr, schema)); + std::shared_ptr file_format = + JniGetOrThrow(GetFileFormat(file_format_id)); + arrow::dataset::FileSystemDatasetWriteOptions options; + std::string output_path; + auto filesystem = JniGetOrThrow( + arrow::fs::FileSystemFromUri(JStringToCString(env, uri), &output_path)); + std::vector partition_column_vector = + ToStringVector(env, partition_columns); + options.file_write_options = file_format->DefaultWriteOptions(); + options.filesystem = filesystem; + options.base_dir = output_path; + options.basename_template = JStringToCString(env, base_name_template); + options.partitioning = std::make_shared( + arrow::dataset::SchemaFromColumnNames(schema, partition_column_vector)); + options.max_partitions = max_partitions; + JniAssertOkOrThrow(arrow::dataset::FileSystemDataset::Write(options, scanner)); + JNI_METHOD_END() +} diff --git a/java/dataset/CMakeLists.txt b/java/dataset/CMakeLists.txt index 5b6e4a9ce24..b6d1ac7a604 100644 --- a/java/dataset/CMakeLists.txt +++ b/java/dataset/CMakeLists.txt @@ -36,6 +36,8 @@ add_jar(arrow_dataset_java src/main/java/org/apache/arrow/dataset/file/JniWrapper.java src/main/java/org/apache/arrow/dataset/jni/NativeMemoryPool.java src/main/java/org/apache/arrow/dataset/jni/ReservationListener.java + src/main/java/org/apache/arrow/dataset/file/NativeScannerAdaptor.java + src/main/java/org/apache/arrow/dataset/jni/NativeSerializedRecordBatchIterator.java GENERATE_NATIVE_HEADERS arrow_dataset_java-native DESTINATION diff --git a/java/dataset/src/main/java/org/apache/arrow/dataset/file/DatasetFileWriter.java b/java/dataset/src/main/java/org/apache/arrow/dataset/file/DatasetFileWriter.java new file mode 100644 index 00000000000..ea85b5d0efb --- /dev/null +++ b/java/dataset/src/main/java/org/apache/arrow/dataset/file/DatasetFileWriter.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.dataset.file; + +import org.apache.arrow.dataset.jni.NativeSerializedRecordBatchIterator; +import org.apache.arrow.dataset.scanner.Scanner; +import org.apache.arrow.util.AutoCloseables; +import org.apache.arrow.vector.util.SchemaUtility; + +/** + * JNI-based utility to write datasets into files. It internally depends on C++ static method + * FileSystemDataset::Write. + */ +public class DatasetFileWriter { + + /** + * Scan over an input {@link Scanner} then write all record batches to file. + * + * @param scanner the source scanner for writing + * @param format target file format + * @param uri target file uri + * @param maxPartitions maximum partitions to be included in written files + * @param partitionColumns columns used to partition output files. Empty to disable partitioning + * @param baseNameTemplate file name template used to make partitions. E.g. "dat_{i}", i is current partition + * ID around all written files. + */ + public static void write(Scanner scanner, FileFormat format, String uri, + String[] partitionColumns, int maxPartitions, String baseNameTemplate) { + final NativeScannerAdaptorImpl adaptor = new NativeScannerAdaptorImpl(scanner); + final NativeSerializedRecordBatchIterator itr = adaptor.scan(); + RuntimeException throwableWrapper = null; + try { + JniWrapper.get().writeFromScannerToFile(itr, SchemaUtility.serialize(scanner.schema()), + format.id(), uri, partitionColumns, maxPartitions, baseNameTemplate); + } catch (Throwable t) { + throwableWrapper = new RuntimeException(t); + throw throwableWrapper; + } finally { + try { + AutoCloseables.close(itr); + } catch (Exception e) { + if (throwableWrapper != null) { + throwableWrapper.addSuppressed(e); + } + } + } + } + + /** + * Scan over an input {@link Scanner} then write all record batches to file, with default partitioning settings. + * + * @param scanner the source scanner for writing + * @param format target file format + * @param uri target file uri + */ + public static void write(Scanner scanner, FileFormat format, String uri) { + write(scanner, format, uri, new String[0], 1024, "dat_{i}"); + } +} diff --git a/java/dataset/src/main/java/org/apache/arrow/dataset/file/JniWrapper.java b/java/dataset/src/main/java/org/apache/arrow/dataset/file/JniWrapper.java index cc61f994ab0..955b8344f1e 100644 --- a/java/dataset/src/main/java/org/apache/arrow/dataset/file/JniWrapper.java +++ b/java/dataset/src/main/java/org/apache/arrow/dataset/file/JniWrapper.java @@ -62,4 +62,20 @@ public native long newJniMethodReference(String classSignature, String methodNam */ public native long makeFileSystemDatasetFactory(String uri, int fileFormat); + /** + * Write all record batches in a {@link NativeSerializedRecordBatchIterator} into files. This internally + * depends on C++ write API: FileSystemDataset::Write. + * + * @param itr iterator to be used for writing + * @param schema serialized schema of output files + * @param fileFormat target file format (ID) + * @param uri target file uri + * @param partitionColumns columns used to partition output files + * @param maxPartitions maximum partitions to be included in written files + * @param baseNameTemplate file name template used to make partitions. E.g. "dat_{i}", i is current partition + * ID around all written files. + */ + public native void writeFromScannerToFile(NativeSerializedRecordBatchIterator itr, byte[] schema, + int fileFormat, String uri, String[] partitionColumns, int maxPartitions, String baseNameTemplate); + } diff --git a/java/dataset/src/main/java/org/apache/arrow/dataset/file/NativeScannerAdaptor.java b/java/dataset/src/main/java/org/apache/arrow/dataset/file/NativeScannerAdaptor.java new file mode 100644 index 00000000000..b30ab2cafbf --- /dev/null +++ b/java/dataset/src/main/java/org/apache/arrow/dataset/file/NativeScannerAdaptor.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.dataset.file; + +import org.apache.arrow.dataset.jni.NativeSerializedRecordBatchIterator; + +/** + * A short path comparing to {@link org.apache.arrow.dataset.scanner.Scanner} for being called from C++ side + * via JNI, to minimize JNI call overhead. + */ +public interface NativeScannerAdaptor { + + /** + * Scan with the delegated scanner. + * + * @return a iterator outputting JNI-friendly flatbuffers-serialized + * {@link org.apache.arrow.vector.ipc.message.ArrowRecordBatch} instances. + */ + NativeSerializedRecordBatchIterator scan(); +} diff --git a/java/dataset/src/main/java/org/apache/arrow/dataset/file/NativeScannerAdaptorImpl.java b/java/dataset/src/main/java/org/apache/arrow/dataset/file/NativeScannerAdaptorImpl.java new file mode 100644 index 00000000000..8fe9421f7e7 --- /dev/null +++ b/java/dataset/src/main/java/org/apache/arrow/dataset/file/NativeScannerAdaptorImpl.java @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.dataset.file; + +import java.util.Iterator; +import java.util.NoSuchElementException; + +import org.apache.arrow.dataset.jni.NativeSerializedRecordBatchIterator; +import org.apache.arrow.dataset.jni.UnsafeRecordBatchSerializer; +import org.apache.arrow.dataset.scanner.ScanTask; +import org.apache.arrow.dataset.scanner.Scanner; +import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; + +/** + * Default implementation of {@link NativeScannerAdaptor}. + */ +public class NativeScannerAdaptorImpl implements NativeScannerAdaptor, AutoCloseable { + + private final Scanner scanner; + + /** + * Constructor. + * + * @param scanner the delegated scanner. + */ + public NativeScannerAdaptorImpl(Scanner scanner) { + this.scanner = scanner; + } + + @Override + public NativeSerializedRecordBatchIterator scan() { + final Iterable tasks = scanner.scan(); + return new IteratorImpl(tasks); + } + + @Override + public void close() throws Exception { + scanner.close(); + } + + private static class IteratorImpl implements NativeSerializedRecordBatchIterator { + + private final Iterator taskIterator; + + private ScanTask currentTask = null; + private ScanTask.BatchIterator currentBatchIterator = null; + + public IteratorImpl(Iterable tasks) { + this.taskIterator = tasks.iterator(); + } + + @Override + public void close() throws Exception { + closeCurrent(); + } + + private void closeCurrent() throws Exception { + if (currentTask == null) { + return; + } + currentTask.close(); + currentBatchIterator.close(); + } + + private boolean advance() { + if (!taskIterator.hasNext()) { + return false; + } + try { + closeCurrent(); + } catch (Exception e) { + throw new RuntimeException(e); + } + currentTask = taskIterator.next(); + currentBatchIterator = currentTask.execute(); + return true; + } + + @Override + public boolean hasNext() { + if (currentTask == null) { + if (!advance()) { + return false; + } + } + if (!currentBatchIterator.hasNext()) { + if (!advance()) { + return false; + } + } + return currentBatchIterator.hasNext(); + } + + @Override + public byte[] next() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + return serialize(currentBatchIterator.next()); + } + + private byte[] serialize(ArrowRecordBatch batch) { + return UnsafeRecordBatchSerializer.serializeUnsafe(batch); + } + } +} diff --git a/java/dataset/src/main/java/org/apache/arrow/dataset/jni/NativeSerializedRecordBatchIterator.java b/java/dataset/src/main/java/org/apache/arrow/dataset/jni/NativeSerializedRecordBatchIterator.java new file mode 100644 index 00000000000..5495f85cf1c --- /dev/null +++ b/java/dataset/src/main/java/org/apache/arrow/dataset/jni/NativeSerializedRecordBatchIterator.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.dataset.jni; + +import java.util.Iterator; + +/** + * Iterate on flatbuffers-serialized {@link org.apache.arrow.vector.ipc.message.ArrowRecordBatch}. + *

+ * {@link #next()} should be called from C++ scanner to read Java-generated Arrow data. + */ +public interface NativeSerializedRecordBatchIterator extends Iterator, AutoCloseable { + + /** + * Return next serialized {@link org.apache.arrow.vector.ipc.message.ArrowRecordBatch} Java + * byte array. + */ + @Override + byte[] next(); +} diff --git a/java/dataset/src/test/java/org/apache/arrow/dataset/file/TestDatasetFileWriter.java b/java/dataset/src/test/java/org/apache/arrow/dataset/file/TestDatasetFileWriter.java new file mode 100644 index 00000000000..7b84218dfa0 --- /dev/null +++ b/java/dataset/src/test/java/org/apache/arrow/dataset/file/TestDatasetFileWriter.java @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.dataset.file; + +import java.io.ByteArrayOutputStream; +import java.io.File; +import java.io.IOException; +import java.nio.channels.Channels; +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; +import java.util.Objects; +import java.util.Set; +import java.util.stream.Collectors; + +import org.apache.arrow.dataset.ParquetWriteSupport; +import org.apache.arrow.dataset.TestDataset; +import org.apache.arrow.dataset.jni.NativeMemoryPool; +import org.apache.arrow.dataset.scanner.ScanOptions; +import org.apache.arrow.dataset.scanner.Scanner; +import org.apache.arrow.dataset.source.Dataset; +import org.apache.arrow.util.AutoCloseables; +import org.apache.arrow.vector.ipc.WriteChannel; +import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; +import org.apache.arrow.vector.ipc.message.MessageSerializer; +import org.apache.commons.io.FileUtils; +import org.junit.Assert; +import org.junit.ClassRule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +public class TestDatasetFileWriter extends TestDataset { + + @ClassRule + public static final TemporaryFolder TMP = new TemporaryFolder(); + + public static final String AVRO_SCHEMA_USER = "user.avsc"; + + @Test + public void testParquetWriteSimple() throws Exception { + ParquetWriteSupport writeSupport = ParquetWriteSupport.writeTempFile(AVRO_SCHEMA_USER, TMP.newFolder(), + 1, "a", 2, "b", 3, "c", 2, "d"); + String sampleParquet = writeSupport.getOutputURI(); + FileSystemDatasetFactory factory = new FileSystemDatasetFactory(rootAllocator(), NativeMemoryPool.getDefault(), + FileFormat.PARQUET, sampleParquet); + ScanOptions options = new ScanOptions(new String[0], 100); + final Dataset dataset = factory.finish(); + final Scanner scanner = dataset.newScan(options); + final File writtenFolder = TMP.newFolder(); + final String writtenParquet = writtenFolder.toURI().toString(); + try { + DatasetFileWriter.write(scanner, FileFormat.PARQUET, writtenParquet); + assertParquetFileEquals(sampleParquet, Objects.requireNonNull(writtenFolder.listFiles())[0].toURI().toString()); + } finally { + AutoCloseables.close(factory, scanner, dataset); + } + } + + @Test + public void testParquetWriteWithPartitions() throws Exception { + ParquetWriteSupport writeSupport = ParquetWriteSupport.writeTempFile(AVRO_SCHEMA_USER, TMP.newFolder(), + 1, "a", 2, "b", 3, "c", 2, "d"); + String sampleParquet = writeSupport.getOutputURI(); + FileSystemDatasetFactory factory = new FileSystemDatasetFactory(rootAllocator(), NativeMemoryPool.getDefault(), + FileFormat.PARQUET, sampleParquet); + ScanOptions options = new ScanOptions(new String[0], 100); + final Dataset dataset = factory.finish(); + final Scanner scanner = dataset.newScan(options); + final File writtenFolder = TMP.newFolder(); + final String writtenParquet = writtenFolder.toURI().toString(); + try { + DatasetFileWriter.write(scanner, FileFormat.PARQUET, writtenParquet, new String[]{"id", "name"}, 100, "dat_{i}"); + final Set expectedOutputFiles = new HashSet<>( + Arrays.asList("id=1/name=a/dat_0", "id=2/name=b/dat_1", "id=3/name=c/dat_2", "id=2/name=d/dat_3")); + final Set outputFiles = FileUtils.listFiles(writtenFolder, null, true) + .stream() + .map(file -> { + return writtenFolder.toURI().relativize(file.toURI()).toString(); + }) + .collect(Collectors.toSet()); + Assert.assertEquals(expectedOutputFiles, outputFiles); + } finally { + AutoCloseables.close(factory, scanner, dataset); + } + } + + private void assertParquetFileEquals(String expectedURI, String actualURI) throws Exception { + final FileSystemDatasetFactory expectedFactory = new FileSystemDatasetFactory( + rootAllocator(), NativeMemoryPool.getDefault(), FileFormat.PARQUET, expectedURI); + List expectedBatches = collectResultFromFactory(expectedFactory, + new ScanOptions(new String[0], 100)); + final FileSystemDatasetFactory actualFactory = new FileSystemDatasetFactory( + rootAllocator(), NativeMemoryPool.getDefault(), FileFormat.PARQUET, actualURI); + List actualBatches = collectResultFromFactory(actualFactory, + new ScanOptions(new String[0], 100)); + // fast-fail by comparing metadata + Assert.assertEquals(expectedBatches.toString(), actualBatches.toString()); + // compare buffers + Assert.assertEquals(serialize(expectedBatches), serialize(actualBatches)); + AutoCloseables.close(expectedBatches, actualBatches); + } + + private String serialize(List batches) throws IOException { + ByteArrayOutputStream out = new ByteArrayOutputStream(); + for (ArrowRecordBatch batch : batches) { + MessageSerializer.serialize(new WriteChannel(Channels.newChannel(out)), batch); + } + return Arrays.toString(out.toByteArray()); + } +} From 7bce23a67257a17a9f803bdebd0519ff140aadf6 Mon Sep 17 00:00:00 2001 From: Hongze Zhang Date: Sat, 22 May 2021 14:03:04 +0800 Subject: [PATCH 4/8] Use ScannerBuilder::FromRecordBatchReader --- cpp/src/arrow/record_batch.cc | 5 ++ cpp/src/arrow/record_batch.h | 7 +++ cpp/src/jni/dataset/jni_wrapper.cc | 88 ++---------------------------- 3 files changed, 18 insertions(+), 82 deletions(-) diff --git a/cpp/src/arrow/record_batch.cc b/cpp/src/arrow/record_batch.cc index 66f9e932b58..1b227cbc9c1 100644 --- a/cpp/src/arrow/record_batch.cc +++ b/cpp/src/arrow/record_batch.cc @@ -364,4 +364,9 @@ Result> RecordBatchReader::Make( return std::make_shared(std::move(batches), schema); } +Result> RecordBatchReader::Make( + Iterator> itr, std::shared_ptr schema) { + return std::make_shared(std::move(itr), schema); +} + } // namespace arrow diff --git a/cpp/src/arrow/record_batch.h b/cpp/src/arrow/record_batch.h index 3dc1f54a083..dc8bedee11f 100644 --- a/cpp/src/arrow/record_batch.h +++ b/cpp/src/arrow/record_batch.h @@ -233,6 +233,13 @@ class ARROW_EXPORT RecordBatchReader { /// element if not provided. static Result> Make( RecordBatchVector batches, std::shared_ptr schema = NULLPTR); + + /// \brief Create a RecordBatchReader from a RecordBatchIterator. + /// + /// \param[in] itr the iterator of RecordBatch to read from + /// \param[in] schema schema to conform to. + static Result> Make(RecordBatchIterator itr, + std::shared_ptr schema); }; } // namespace arrow diff --git a/cpp/src/jni/dataset/jni_wrapper.cc b/cpp/src/jni/dataset/jni_wrapper.cc index 9525439a804..b96753ed966 100644 --- a/cpp/src/jni/dataset/jni_wrapper.cc +++ b/cpp/src/jni/dataset/jni_wrapper.cc @@ -156,82 +156,6 @@ class DisposableScannerAdaptor { } }; -/// \brief Simple scan task implementation that is constructed directly -/// from a record batch iterator (and its corresponding fragment). -class SimpleIteratorTask : public arrow::dataset::ScanTask { - public: - SimpleIteratorTask(std::shared_ptr options, - std::shared_ptr fragment, - arrow::RecordBatchIterator itr) - : ScanTask(options, fragment) { - this->itr_ = std::move(itr); - } - - static arrow::Result> Make( - arrow::RecordBatchIterator itr, - std::shared_ptr options, - std::shared_ptr fragment) { - return std::make_shared(options, fragment, std::move(itr)); - } - - arrow::Result Execute() override { - if (used_) { - return arrow::Status::Invalid( - "SimpleIteratorFragment is disposable and" - "already scanned"); - } - used_ = true; - return std::move(itr_); - } - - private: - arrow::RecordBatchIterator itr_; - bool used_ = false; -}; - -/// \brief Simple fragment implementation that is constructed directly -/// from a record batch iterator. -class SimpleIteratorFragment : public arrow::dataset::Fragment { - public: - explicit SimpleIteratorFragment(arrow::RecordBatchIterator itr) - : arrow::dataset::Fragment() { - itr_ = std::move(itr); - } - - static arrow::Result> Make( - arrow::RecordBatchIterator itr) { - return std::make_shared(std::move(itr)); - } - - arrow::Result ScanBatchesAsync( - const std::shared_ptr& options) override { - return arrow::Status::NotImplemented("Aysnc scan not supported"); - } - - arrow::Result Scan( - std::shared_ptr options) override { - if (used_) { - return arrow::Status::Invalid( - "SimpleIteratorFragment is disposable and" - "already scanned"); - } - used_ = true; - ARROW_ASSIGN_OR_RAISE( - auto task, SimpleIteratorTask::Make(std::move(itr_), options, shared_from_this())) - return arrow::MakeVectorIterator>({task}); - } - - std::string type_name() const override { return "simple_iterator"; } - - arrow::Result> ReadPhysicalSchemaImpl() override { - return arrow::Status::NotImplemented("No physical schema is readable"); - } - - private: - arrow::RecordBatchIterator itr_; - bool used_ = false; -}; - arrow::Result> FromBytes( JNIEnv* env, std::shared_ptr schema, jbyteArray bytes) { ARROW_ASSIGN_OR_RAISE( @@ -266,14 +190,14 @@ arrow::Result> MakeJavaDatasetScanner( return batch; }); - ARROW_ASSIGN_OR_RAISE(auto fragment, SimpleIteratorFragment::Make(std::move(itr))) - - arrow::dataset::ScannerBuilder scanner_builder( - std::move(schema), fragment, std::make_shared()); + ARROW_ASSIGN_OR_RAISE(std::shared_ptr reader, + arrow::RecordBatchReader::Make(std::move(itr), schema)) + std::shared_ptr scanner_builder = + arrow::dataset::ScannerBuilder::FromRecordBatchReader(reader); // Use default memory pool is enough as native allocation is ideally // not being called during scanning Java-based fragments. - RETURN_NOT_OK(scanner_builder.Pool(arrow::default_memory_pool())); - return scanner_builder.Finish(); + RETURN_NOT_OK(scanner_builder->Pool(arrow::default_memory_pool())); + return scanner_builder->Finish(); } } // namespace From 17c18ed0a7c858324e3981e3e6a2c55001b8a972 Mon Sep 17 00:00:00 2001 From: Hongze Zhang Date: Sat, 22 May 2021 16:36:03 +0800 Subject: [PATCH 5/8] JNI error handling --- cpp/src/jni/dataset/jni_util.cc | 25 +++++++++---- cpp/src/jni/dataset/jni_util.h | 11 ++++-- cpp/src/jni/dataset/jni_wrapper.cc | 26 ++++++++++++-- .../file/NativeScannerAdaptorImpl.java | 3 ++ .../dataset/file/TestDatasetFileWriter.java | 35 +++++++++++++++++++ 5 files changed, 88 insertions(+), 12 deletions(-) diff --git a/cpp/src/jni/dataset/jni_util.cc b/cpp/src/jni/dataset/jni_util.cc index 7d74009ff30..afad086fac3 100644 --- a/cpp/src/jni/dataset/jni_util.cc +++ b/cpp/src/jni/dataset/jni_util.cc @@ -194,13 +194,24 @@ std::shared_ptr ReservationListenableMemoryPool::get_listen ReservationListenableMemoryPool::~ReservationListenableMemoryPool() {} -Status CheckException(JNIEnv* env) { - if (env->ExceptionCheck()) { - env->ExceptionDescribe(); - env->ExceptionClear(); - return Status::Invalid("Error during calling Java code from native code"); - } - return Status::OK(); +class JNIErrorDetail : public StatusDetail { + public: + explicit JNIErrorDetail(jthrowable t, std::string message) + : t_(t), message_(std::move(message)) {} + + const char* type_id() const override { return "arrow::dataset::jni::JNIErrorDetail"; } + + std::string ToString() const override { return message_; } + jthrowable cause() const { return t_; } + + protected: + jthrowable t_; + std::string message_; +}; + +std::shared_ptr MakeJNIErrorDetail(jthrowable t, + const std::string& message) { + return std::make_shared(t, message); } jclass CreateGlobalClassReference(JNIEnv* env, const char* class_name) { diff --git a/cpp/src/jni/dataset/jni_util.h b/cpp/src/jni/dataset/jni_util.h index a56d4bc152b..26267a5ec70 100644 --- a/cpp/src/jni/dataset/jni_util.h +++ b/cpp/src/jni/dataset/jni_util.h @@ -30,8 +30,6 @@ namespace arrow { namespace dataset { namespace jni { -Status CheckException(JNIEnv* env); - jclass CreateGlobalClassReference(JNIEnv* env, const char* class_name); Result GetMethodID(JNIEnv* env, jclass this_class, const char* name, @@ -48,6 +46,15 @@ Result ToSchemaByteArray(JNIEnv* env, std::shared_ptr schema Result> FromSchemaByteArray(JNIEnv* env, jbyteArray schema_bytes); +std::shared_ptr MakeJNIErrorDetail(jthrowable t, + const std::string& message); + +template +Status JNIError(jthrowable t, const std::string& describe) { + return Status::FromDetailAndArgs(StatusCode::Invalid, MakeJNIErrorDetail(t, describe), + describe); +} + /// \brief Serialize arrow::RecordBatch to jbyteArray (Java byte array byte[]). For /// letting Java code manage lifecycles of buffers in the input batch, shared pointer IDs /// pointing to the buffers are serialized into buffer metadata. diff --git a/cpp/src/jni/dataset/jni_wrapper.cc b/cpp/src/jni/dataset/jni_wrapper.cc index b96753ed966..ef6a3c22aef 100644 --- a/cpp/src/jni/dataset/jni_wrapper.cc +++ b/cpp/src/jni/dataset/jni_wrapper.cc @@ -36,6 +36,7 @@ namespace { jclass illegal_access_exception_class; jclass illegal_argument_exception_class; jclass runtime_exception_class; +jclass throwable_class; jclass serialized_record_batch_iterator_class; jclass java_reservation_listener_class; @@ -44,6 +45,8 @@ jmethodID serialized_record_batch_iterator_hasNext; jmethodID serialized_record_batch_iterator_next; jmethodID reserve_memory_method; jmethodID unreserve_memory_method; +jmethodID throwable_getMessage; +jmethodID throwable_toString; jlong default_memory_pool_id = -1L; @@ -74,6 +77,17 @@ void JniAssertOkOrThrow(arrow::Status status) { void JniThrow(std::string message) { ThrowPendingException(message); } +arrow::Status CheckException(JNIEnv* env) { + if (env->ExceptionCheck()) { + jthrowable t = env->ExceptionOccurred(); + env->ExceptionClear(); + auto jdescribe = (jstring)env->CallObjectMethod(t, throwable_toString); + std::string describe = arrow::dataset::jni::JStringToCString(env, jdescribe); + return arrow::dataset::jni::JNIError(t, describe); + } + return arrow::Status::OK(); +} + arrow::Result> GetFileFormat( jint file_format_id) { switch (file_format_id) { @@ -97,7 +111,7 @@ class ReserveFromJava : public arrow::dataset::jni::ReservationListener { return arrow::Status::Invalid("JNIEnv was not attached to current thread"); } env->CallObjectMethod(java_reservation_listener_, reserve_memory_method, size); - RETURN_NOT_OK(arrow::dataset::jni::CheckException(env)); + RETURN_NOT_OK(CheckException(env)); return arrow::Status::OK(); } @@ -107,7 +121,7 @@ class ReserveFromJava : public arrow::dataset::jni::ReservationListener { return arrow::Status::Invalid("JNIEnv was not attached to current thread"); } env->CallObjectMethod(java_reservation_listener_, unreserve_memory_method, size); - RETURN_NOT_OK(arrow::dataset::jni::CheckException(env)); + RETURN_NOT_OK(CheckException(env)); return arrow::Status::OK(); } @@ -185,7 +199,7 @@ arrow::Result> MakeJavaDatasetScanner( } auto bytes = (jbyteArray)env->CallObjectMethod( java_serialized_record_batch_iterator, serialized_record_batch_iterator_next); - RETURN_NOT_OK(arrow::dataset::jni::CheckException(env)); + RETURN_NOT_OK(CheckException(env)); ARROW_ASSIGN_OR_RAISE(auto batch, FromBytes(env, schema, bytes)); return batch; }); @@ -238,6 +252,7 @@ jint JNI_OnLoad(JavaVM* vm, void* reserved) { CreateGlobalClassReference(env, "Ljava/lang/IllegalArgumentException;"); runtime_exception_class = CreateGlobalClassReference(env, "Ljava/lang/RuntimeException;"); + throwable_class = CreateGlobalClassReference(env, "Ljava/lang/Throwable;"); serialized_record_batch_iterator_class = CreateGlobalClassReference(env, @@ -256,6 +271,10 @@ jint JNI_OnLoad(JavaVM* vm, void* reserved) { JniGetOrThrow(GetMethodID(env, java_reservation_listener_class, "reserve", "(J)V")); unreserve_memory_method = JniGetOrThrow( GetMethodID(env, java_reservation_listener_class, "unreserve", "(J)V")); + throwable_getMessage = JniGetOrThrow( + GetMethodID(env, throwable_class, "getMessage", "()Ljava/lang/String;")); + throwable_toString = JniGetOrThrow( + GetMethodID(env, throwable_class, "toString", "()Ljava/lang/String;")); default_memory_pool_id = reinterpret_cast(arrow::default_memory_pool()); @@ -269,6 +288,7 @@ void JNI_OnUnload(JavaVM* vm, void* reserved) { env->DeleteGlobalRef(illegal_access_exception_class); env->DeleteGlobalRef(illegal_argument_exception_class); env->DeleteGlobalRef(runtime_exception_class); + env->DeleteGlobalRef(throwable_class); env->DeleteGlobalRef(serialized_record_batch_iterator_class); env->DeleteGlobalRef(java_reservation_listener_class); diff --git a/java/dataset/src/main/java/org/apache/arrow/dataset/file/NativeScannerAdaptorImpl.java b/java/dataset/src/main/java/org/apache/arrow/dataset/file/NativeScannerAdaptorImpl.java index 8fe9421f7e7..9526515080c 100644 --- a/java/dataset/src/main/java/org/apache/arrow/dataset/file/NativeScannerAdaptorImpl.java +++ b/java/dataset/src/main/java/org/apache/arrow/dataset/file/NativeScannerAdaptorImpl.java @@ -74,6 +74,9 @@ private void closeCurrent() throws Exception { return; } currentTask.close(); + if (currentBatchIterator == null) { + return; + } currentBatchIterator.close(); } diff --git a/java/dataset/src/test/java/org/apache/arrow/dataset/file/TestDatasetFileWriter.java b/java/dataset/src/test/java/org/apache/arrow/dataset/file/TestDatasetFileWriter.java index 7b84218dfa0..036a0c28354 100644 --- a/java/dataset/src/test/java/org/apache/arrow/dataset/file/TestDatasetFileWriter.java +++ b/java/dataset/src/test/java/org/apache/arrow/dataset/file/TestDatasetFileWriter.java @@ -22,6 +22,7 @@ import java.io.IOException; import java.nio.channels.Channels; import java.util.Arrays; +import java.util.Collections; import java.util.HashSet; import java.util.List; import java.util.Objects; @@ -32,12 +33,14 @@ import org.apache.arrow.dataset.TestDataset; import org.apache.arrow.dataset.jni.NativeMemoryPool; import org.apache.arrow.dataset.scanner.ScanOptions; +import org.apache.arrow.dataset.scanner.ScanTask; import org.apache.arrow.dataset.scanner.Scanner; import org.apache.arrow.dataset.source.Dataset; import org.apache.arrow.util.AutoCloseables; import org.apache.arrow.vector.ipc.WriteChannel; import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; import org.apache.arrow.vector.ipc.message.MessageSerializer; +import org.apache.arrow.vector.types.pojo.Schema; import org.apache.commons.io.FileUtils; import org.junit.Assert; import org.junit.ClassRule; @@ -99,6 +102,38 @@ public void testParquetWriteWithPartitions() throws Exception { } } + @Test(expected = java.lang.RuntimeException.class) + public void testScanErrorHandling() throws Exception { + DatasetFileWriter.write(new Scanner() { + @Override + public Iterable scan() { + return Collections.singletonList(new ScanTask() { + @Override + public BatchIterator execute() { + // this error is supposed to be firstly investigated in native code, then thrown back to Java. + throw new RuntimeException("ERROR"); + } + + @Override + public void close() throws Exception { + // do nothing + } + }); + } + + @Override + public Schema schema() { + return new Schema(Collections.emptyList()); + } + + @Override + public void close() throws Exception { + // do nothing + } + + }, FileFormat.PARQUET, "file:/DUMMY/"); + } + private void assertParquetFileEquals(String expectedURI, String actualURI) throws Exception { final FileSystemDatasetFactory expectedFactory = new FileSystemDatasetFactory( rootAllocator(), NativeMemoryPool.getDefault(), FileFormat.PARQUET, expectedURI); From 445ffdf94c513c7d928b58b715415e40f1a44ef7 Mon Sep 17 00:00:00 2001 From: Hongze Zhang Date: Thu, 5 Aug 2021 13:30:43 +0800 Subject: [PATCH 6/8] style --- cpp/src/jni/dataset/jni_wrapper.cc | 44 +++++++++++++++--------------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/cpp/src/jni/dataset/jni_wrapper.cc b/cpp/src/jni/dataset/jni_wrapper.cc index ef6a3c22aef..b90d95636c8 100644 --- a/cpp/src/jni/dataset/jni_wrapper.cc +++ b/cpp/src/jni/dataset/jni_wrapper.cc @@ -604,27 +604,27 @@ Java_org_apache_arrow_dataset_file_JniWrapper_writeFromScannerToFile( jstring uri, jobjectArray partition_columns, jint max_partitions, jstring base_name_template) { JNI_METHOD_START - JavaVM* vm; - if (env->GetJavaVM(&vm) != JNI_OK) { - JniThrow("Unable to get JavaVM instance"); - } - auto schema = JniGetOrThrow(FromSchemaByteArray(env, schema_bytes)); - auto scanner = JniGetOrThrow(MakeJavaDatasetScanner(vm, itr, schema)); - std::shared_ptr file_format = - JniGetOrThrow(GetFileFormat(file_format_id)); - arrow::dataset::FileSystemDatasetWriteOptions options; - std::string output_path; - auto filesystem = JniGetOrThrow( - arrow::fs::FileSystemFromUri(JStringToCString(env, uri), &output_path)); - std::vector partition_column_vector = - ToStringVector(env, partition_columns); - options.file_write_options = file_format->DefaultWriteOptions(); - options.filesystem = filesystem; - options.base_dir = output_path; - options.basename_template = JStringToCString(env, base_name_template); - options.partitioning = std::make_shared( - arrow::dataset::SchemaFromColumnNames(schema, partition_column_vector)); - options.max_partitions = max_partitions; - JniAssertOkOrThrow(arrow::dataset::FileSystemDataset::Write(options, scanner)); + JavaVM* vm; + if (env->GetJavaVM(&vm) != JNI_OK) { + JniThrow("Unable to get JavaVM instance"); + } + auto schema = JniGetOrThrow(FromSchemaByteArray(env, schema_bytes)); + auto scanner = JniGetOrThrow(MakeJavaDatasetScanner(vm, itr, schema)); + std::shared_ptr file_format = + JniGetOrThrow(GetFileFormat(file_format_id)); + arrow::dataset::FileSystemDatasetWriteOptions options; + std::string output_path; + auto filesystem = JniGetOrThrow( + arrow::fs::FileSystemFromUri(JStringToCString(env, uri), &output_path)); + std::vector partition_column_vector = + ToStringVector(env, partition_columns); + options.file_write_options = file_format->DefaultWriteOptions(); + options.filesystem = filesystem; + options.base_dir = output_path; + options.basename_template = JStringToCString(env, base_name_template); + options.partitioning = std::make_shared( + arrow::dataset::SchemaFromColumnNames(schema, partition_column_vector)); + options.max_partitions = max_partitions; + JniAssertOkOrThrow(arrow::dataset::FileSystemDataset::Write(options, scanner)); JNI_METHOD_END() } From cc5ea4ffc33f4148ff51c21c398e3ea60f69f357 Mon Sep 17 00:00:00 2001 From: Hongze Zhang Date: Thu, 5 Aug 2021 13:37:15 +0800 Subject: [PATCH 7/8] fixup --- .../src/main/java/org/apache/arrow/dataset/file/JniWrapper.java | 1 + 1 file changed, 1 insertion(+) diff --git a/java/dataset/src/main/java/org/apache/arrow/dataset/file/JniWrapper.java b/java/dataset/src/main/java/org/apache/arrow/dataset/file/JniWrapper.java index 955b8344f1e..a503ca31f7e 100644 --- a/java/dataset/src/main/java/org/apache/arrow/dataset/file/JniWrapper.java +++ b/java/dataset/src/main/java/org/apache/arrow/dataset/file/JniWrapper.java @@ -18,6 +18,7 @@ package org.apache.arrow.dataset.file; import org.apache.arrow.dataset.jni.JniLoader; +import org.apache.arrow.dataset.jni.NativeSerializedRecordBatchIterator; /** * JniWrapper for filesystem based {@link org.apache.arrow.dataset.source.Dataset} implementations. From 39606967846df877c207c866b7be5badec7a66b2 Mon Sep 17 00:00:00 2001 From: Hongze Zhang Date: Thu, 5 Aug 2021 17:42:08 +0800 Subject: [PATCH 8/8] dependency --- java/dataset/pom.xml | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/java/dataset/pom.xml b/java/dataset/pom.xml index 6553d41ec89..b67c9c48e02 100644 --- a/java/dataset/pom.xml +++ b/java/dataset/pom.xml @@ -28,6 +28,7 @@ 2.5.0 1.11.0 1.9.1 + 2.4 @@ -56,6 +57,12 @@ ${dep.fbs.version} compile + + commons-io + commons-io + ${commons.io.version} + compile + com.fasterxml.jackson.core jackson-core